mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Tests for T2V and I2V (#6)
* add ltx2 pipeline tests. * up * up * up * up * remove content * style * Denormalize audio latents in I2V pipeline (analogous to T2V change) * Initial refactor to put video and audio text encoder connectors in transformer * Get LTX 2 transformer tests working after connector refactor * up * up * i2v tests. * up * Address review comments * Calculate RoPE double precisions freqs using torch instead of np * Further simplify LTX 2 RoPE freq calc * revert unneded changes. * up * up * update to split style rope. * up --------- Co-authored-by: Daniel Gu <dgu8957@gmail.com>
This commit is contained in:
@@ -1,5 +1,4 @@
|
||||
import argparse
|
||||
import math
|
||||
import os
|
||||
from contextlib import nullcontext
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
@@ -8,9 +7,15 @@ import safetensors.torch
|
||||
import torch
|
||||
from accelerate import init_empty_weights
|
||||
from huggingface_hub import hf_hub_download
|
||||
from transformers import AutoModel, AutoTokenizer, Gemma3ForConditionalGeneration
|
||||
from transformers import AutoTokenizer, Gemma3ForConditionalGeneration
|
||||
|
||||
from diffusers import AutoencoderKLLTX2Audio, AutoencoderKLLTX2Video, FlowMatchEulerDiscreteScheduler, LTX2Pipeline, LTX2VideoTransformer3DModel
|
||||
from diffusers import (
|
||||
AutoencoderKLLTX2Audio,
|
||||
AutoencoderKLLTX2Video,
|
||||
FlowMatchEulerDiscreteScheduler,
|
||||
LTX2Pipeline,
|
||||
LTX2VideoTransformer3DModel,
|
||||
)
|
||||
from diffusers.pipelines.ltx2 import LTX2TextConnectors, LTX2Vocoder
|
||||
from diffusers.utils.import_utils import is_accelerate_available
|
||||
|
||||
@@ -186,7 +191,7 @@ def get_ltx2_transformer_config(version: str) -> Tuple[Dict[str, Any], Dict[str,
|
||||
"num_attention_heads": 2,
|
||||
"attention_head_dim": 8,
|
||||
"cross_attention_dim": 16,
|
||||
"vae_scale_factors": (8, 32 ,32),
|
||||
"vae_scale_factors": (8, 32, 32),
|
||||
"pos_embed_max_pos": 20,
|
||||
"base_height": 2048,
|
||||
"base_width": 2048,
|
||||
@@ -229,7 +234,7 @@ def get_ltx2_transformer_config(version: str) -> Tuple[Dict[str, Any], Dict[str,
|
||||
"num_attention_heads": 32,
|
||||
"attention_head_dim": 128,
|
||||
"cross_attention_dim": 4096,
|
||||
"vae_scale_factors": (8, 32 ,32),
|
||||
"vae_scale_factors": (8, 32, 32),
|
||||
"pos_embed_max_pos": 20,
|
||||
"base_height": 2048,
|
||||
"base_width": 2048,
|
||||
@@ -257,7 +262,7 @@ def get_ltx2_transformer_config(version: str) -> Tuple[Dict[str, Any], Dict[str,
|
||||
"causal_offset": 1,
|
||||
"timestep_scale_multiplier": 1000,
|
||||
"cross_attn_timestep_scale_multiplier": 1000,
|
||||
"rope_type": "split"
|
||||
"rope_type": "split",
|
||||
},
|
||||
}
|
||||
rename_dict = LTX_2_0_TRANSFORMER_KEYS_RENAME_DICT
|
||||
@@ -307,7 +312,7 @@ def get_ltx2_connectors_config(version: str) -> Tuple[Dict[str, Any], Dict[str,
|
||||
"rope_type": "split",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
rename_dict = LTX_2_0_CONNECTORS_KEYS_RENAME_DICT
|
||||
special_keys_remap = {}
|
||||
|
||||
@@ -541,7 +546,7 @@ def get_ltx2_vocoder_config(version: str) -> Tuple[Dict[str, Any], Dict[str, Any
|
||||
"resnet_dilations": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
||||
"leaky_relu_negative_slope": 0.1,
|
||||
"output_sampling_rate": 24000,
|
||||
}
|
||||
},
|
||||
}
|
||||
rename_dict = LTX_2_0_VOCODER_RENAME_DICT
|
||||
special_keys_remap = LTX_2_0_VOCODER_SPECIAL_KEYS_REMAP
|
||||
@@ -574,7 +579,6 @@ def convert_ltx2_vocoder(original_state_dict: Dict[str, Any], version: str) -> D
|
||||
return vocoder
|
||||
|
||||
|
||||
|
||||
def load_original_checkpoint(args, filename: Optional[str]) -> Dict[str, Any]:
|
||||
if args.original_state_dict_repo_id is not None:
|
||||
ckpt_path = hf_hub_download(repo_id=args.original_state_dict_repo_id, filename=filename)
|
||||
@@ -757,7 +761,7 @@ def main(args):
|
||||
transformer = convert_ltx2_transformer(original_dit_ckpt, version=args.version)
|
||||
if not args.full_pipeline:
|
||||
transformer.to(dit_dtype).save_pretrained(os.path.join(args.output_path, "transformer"))
|
||||
|
||||
|
||||
if args.connectors or args.full_pipeline:
|
||||
if args.dit_filename is not None:
|
||||
original_connectors_ckpt = load_hub_or_local_checkpoint(filename=args.dit_filename)
|
||||
@@ -810,6 +814,6 @@ def main(args):
|
||||
pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
args = get_args()
|
||||
main(args)
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import argparse
|
||||
import math
|
||||
import os
|
||||
from fractions import Fraction
|
||||
from typing import Optional
|
||||
@@ -211,6 +210,6 @@ def main(args):
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
main(args)
|
||||
|
||||
@@ -1,12 +1,11 @@
|
||||
|
||||
import argparse
|
||||
import os
|
||||
from fractions import Fraction
|
||||
from typing import Optional
|
||||
from PIL import Image
|
||||
|
||||
import av # Needs to be installed separately (`pip install av`)
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from diffusers.pipelines.ltx2 import LTX2ImageToVideoPipeline
|
||||
|
||||
@@ -131,7 +130,7 @@ def parse_args():
|
||||
parser.add_argument(
|
||||
"--negative_prompt",
|
||||
type=str,
|
||||
default="shaky, glitchy, low quality, worst quality, deformed, distorted, disfigured, motion smear, motion artifacts, fused fingers, bad anatomy, weird hand, ugly, transition, static."
|
||||
default="shaky, glitchy, low quality, worst quality, deformed, distorted, disfigured, motion smear, motion artifacts, fused fingers, bad anatomy, weird hand, ugly, transition, static.",
|
||||
)
|
||||
|
||||
parser.add_argument("--num_inference_steps", type=int, default=40)
|
||||
@@ -166,7 +165,9 @@ def parse_args():
|
||||
|
||||
def main(args):
|
||||
pipeline = LTX2ImageToVideoPipeline.from_pretrained(
|
||||
args.model_id, revision=args.revision, torch_dtype=args.dtype,
|
||||
args.model_id,
|
||||
revision=args.revision,
|
||||
torch_dtype=args.dtype,
|
||||
)
|
||||
if args.cpu_offload:
|
||||
pipeline.enable_model_cpu_offload()
|
||||
@@ -201,6 +202,6 @@ def main(args):
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
main(args)
|
||||
|
||||
@@ -90,7 +90,7 @@ def main() -> None:
|
||||
latent_width,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
generator=torch.Generator(device).manual_seed(42)
|
||||
generator=torch.Generator(device).manual_seed(42),
|
||||
)
|
||||
|
||||
original_out = original_decoder(dummy)
|
||||
|
||||
@@ -193,9 +193,9 @@ else:
|
||||
"AutoencoderKLHunyuanImageRefiner",
|
||||
"AutoencoderKLHunyuanVideo",
|
||||
"AutoencoderKLHunyuanVideo15",
|
||||
"AutoencoderKLLTXVideo",
|
||||
"AutoencoderKLLTX2Audio",
|
||||
"AutoencoderKLLTX2Video",
|
||||
"AutoencoderKLLTXVideo",
|
||||
"AutoencoderKLMagvit",
|
||||
"AutoencoderKLMochi",
|
||||
"AutoencoderKLQwenImage",
|
||||
@@ -237,8 +237,8 @@ else:
|
||||
"Kandinsky3UNet",
|
||||
"Kandinsky5Transformer3DModel",
|
||||
"LatteTransformer3DModel",
|
||||
"LTXVideoTransformer3DModel",
|
||||
"LTX2VideoTransformer3DModel",
|
||||
"LTXVideoTransformer3DModel",
|
||||
"Lumina2Transformer2DModel",
|
||||
"LuminaNextDiT2DModel",
|
||||
"MochiTransformer3DModel",
|
||||
@@ -533,12 +533,13 @@ else:
|
||||
"LDMTextToImagePipeline",
|
||||
"LEditsPPPipelineStableDiffusion",
|
||||
"LEditsPPPipelineStableDiffusionXL",
|
||||
"LTX2ImageToVideoPipeline",
|
||||
"LTX2Pipeline",
|
||||
"LTX2Pipeline",
|
||||
"LTXConditionPipeline",
|
||||
"LTXImageToVideoPipeline",
|
||||
"LTXLatentUpsamplePipeline",
|
||||
"LTXPipeline",
|
||||
"LTX2Pipeline",
|
||||
"LTX2ImageToVideoPipeline",
|
||||
"LucyEditPipeline",
|
||||
"Lumina2Pipeline",
|
||||
"Lumina2Text2ImgPipeline",
|
||||
@@ -931,9 +932,9 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
AutoencoderKLHunyuanImageRefiner,
|
||||
AutoencoderKLHunyuanVideo,
|
||||
AutoencoderKLHunyuanVideo15,
|
||||
AutoencoderKLLTXVideo,
|
||||
AutoencoderKLLTX2Audio,
|
||||
AutoencoderKLLTX2Video,
|
||||
AutoencoderKLLTXVideo,
|
||||
AutoencoderKLMagvit,
|
||||
AutoencoderKLMochi,
|
||||
AutoencoderKLQwenImage,
|
||||
@@ -975,8 +976,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
Kandinsky3UNet,
|
||||
Kandinsky5Transformer3DModel,
|
||||
LatteTransformer3DModel,
|
||||
LTXVideoTransformer3DModel,
|
||||
LTX2VideoTransformer3DModel,
|
||||
LTXVideoTransformer3DModel,
|
||||
Lumina2Transformer2DModel,
|
||||
LuminaNextDiT2DModel,
|
||||
MochiTransformer3DModel,
|
||||
@@ -1241,12 +1242,12 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
LDMTextToImagePipeline,
|
||||
LEditsPPPipelineStableDiffusion,
|
||||
LEditsPPPipelineStableDiffusionXL,
|
||||
LTX2ImageToVideoPipeline,
|
||||
LTX2Pipeline,
|
||||
LTXConditionPipeline,
|
||||
LTXImageToVideoPipeline,
|
||||
LTXLatentUpsamplePipeline,
|
||||
LTXPipeline,
|
||||
LTX2Pipeline,
|
||||
LTX2ImageToVideoPipeline,
|
||||
LucyEditPipeline,
|
||||
Lumina2Pipeline,
|
||||
Lumina2Text2ImgPipeline,
|
||||
|
||||
@@ -154,9 +154,9 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
AutoencoderKLHunyuanImageRefiner,
|
||||
AutoencoderKLHunyuanVideo,
|
||||
AutoencoderKLHunyuanVideo15,
|
||||
AutoencoderKLLTXVideo,
|
||||
AutoencoderKLLTX2Audio,
|
||||
AutoencoderKLLTX2Video,
|
||||
AutoencoderKLLTXVideo,
|
||||
AutoencoderKLMagvit,
|
||||
AutoencoderKLMochi,
|
||||
AutoencoderKLQwenImage,
|
||||
@@ -213,8 +213,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
HunyuanVideoTransformer3DModel,
|
||||
Kandinsky5Transformer3DModel,
|
||||
LatteTransformer3DModel,
|
||||
LTXVideoTransformer3DModel,
|
||||
LTX2VideoTransformer3DModel,
|
||||
LTXVideoTransformer3DModel,
|
||||
Lumina2Transformer2DModel,
|
||||
LuminaNextDiT2DModel,
|
||||
MochiTransformer3DModel,
|
||||
|
||||
@@ -10,8 +10,8 @@ from .autoencoder_kl_hunyuanimage import AutoencoderKLHunyuanImage
|
||||
from .autoencoder_kl_hunyuanimage_refiner import AutoencoderKLHunyuanImageRefiner
|
||||
from .autoencoder_kl_hunyuanvideo15 import AutoencoderKLHunyuanVideo15
|
||||
from .autoencoder_kl_ltx import AutoencoderKLLTXVideo
|
||||
from .autoencoder_kl_ltx2_audio import AutoencoderKLLTX2Audio
|
||||
from .autoencoder_kl_ltx2 import AutoencoderKLLTX2Video
|
||||
from .autoencoder_kl_ltx2_audio import AutoencoderKLLTX2Audio
|
||||
from .autoencoder_kl_magvit import AutoencoderKLMagvit
|
||||
from .autoencoder_kl_mochi import AutoencoderKLMochi
|
||||
from .autoencoder_kl_qwenimage import AutoencoderKLQwenImage
|
||||
|
||||
@@ -25,7 +25,6 @@ from ..activations import get_activation
|
||||
from ..embeddings import PixArtAlphaCombinedTimestepSizeEmbeddings
|
||||
from ..modeling_outputs import AutoencoderKLOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import RMSNorm
|
||||
from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution
|
||||
|
||||
|
||||
@@ -33,8 +32,8 @@ class PerChannelRMSNorm(nn.Module):
|
||||
"""
|
||||
Per-pixel (per-location) RMS normalization layer.
|
||||
|
||||
For each element along the chosen dimension, this layer normalizes the tensor
|
||||
by the root-mean-square of its values across that dimension:
|
||||
For each element along the chosen dimension, this layer normalizes the tensor by the root-mean-square of its values
|
||||
across that dimension:
|
||||
|
||||
y = x / sqrt(mean(x^2, dim=dim, keepdim=True) + eps)
|
||||
"""
|
||||
@@ -174,9 +173,7 @@ class LTX2VideoResnetBlock3d(nn.Module):
|
||||
if in_channels != out_channels:
|
||||
self.norm3 = nn.LayerNorm(in_channels, eps=eps, elementwise_affine=True, bias=True)
|
||||
# LTX 2.0 uses a normal nn.Conv3d here rather than LTXVideoCausalConv3d
|
||||
self.conv_shortcut = nn.Conv3d(
|
||||
in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1
|
||||
)
|
||||
self.conv_shortcut = nn.Conv3d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1)
|
||||
|
||||
self.per_channel_scale1 = None
|
||||
self.per_channel_scale2 = None
|
||||
@@ -953,7 +950,10 @@ class LTX2VideoDecoder3d(nn.Module):
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(
|
||||
self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None, causal: Optional[bool] = None,
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
temb: Optional[torch.Tensor] = None,
|
||||
causal: Optional[bool] = None,
|
||||
) -> torch.Tensor:
|
||||
causal = causal or self.is_causal
|
||||
|
||||
@@ -1279,7 +1279,8 @@ class AutoencoderKLLTX2Video(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig
|
||||
if self.use_slicing and z.shape[0] > 1:
|
||||
if temb is not None:
|
||||
decoded_slices = [
|
||||
self._decode(z_slice, t_slice, causal=causal).sample for z_slice, t_slice in (z.split(1), temb.split(1))
|
||||
self._decode(z_slice, t_slice, causal=causal).sample
|
||||
for z_slice, t_slice in (z.split(1), temb.split(1))
|
||||
]
|
||||
else:
|
||||
decoded_slices = [self._decode(z_slice, causal=causal).sample for z_slice in z.split(1)]
|
||||
|
||||
@@ -249,6 +249,7 @@ class LTX2AudioUpsample(nn.Module):
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class LTX2AudioAudioPatchifier:
|
||||
"""
|
||||
Patchifier for spectrogram/audio latents.
|
||||
@@ -405,9 +406,7 @@ class LTX2AudioDecoder(nn.Module):
|
||||
final_block_channels = block_in
|
||||
|
||||
if self.norm_type == "group":
|
||||
self.norm_out = nn.GroupNorm(
|
||||
num_groups=32, num_channels=final_block_channels, eps=1e-6, affine=True
|
||||
)
|
||||
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=final_block_channels, eps=1e-6, affine=True)
|
||||
elif self.norm_type == "pixel":
|
||||
self.norm_out = LTX2AudioPixelNorm(dim=1, eps=1e-6)
|
||||
else:
|
||||
@@ -538,8 +537,8 @@ class AutoencoderKLLTX2Audio(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
|
||||
# Per-channel statistics for normalizing and denormalizing the latent representation. This statics is computed over
|
||||
# the entire dataset and stored in model's checkpoint under AudioVAE state_dict
|
||||
latents_std = torch.zeros((base_channels, ))
|
||||
latents_mean = torch.ones((base_channels, ))
|
||||
latents_std = torch.zeros((base_channels,))
|
||||
latents_mean = torch.ones((base_channels,))
|
||||
self.register_buffer("latents_mean", latents_mean, persistent=True)
|
||||
self.register_buffer("latents_std", latents_std, persistent=True)
|
||||
|
||||
|
||||
@@ -22,16 +22,21 @@ import torch.nn as nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import USE_PEFT_BACKEND, BaseOutput, deprecate, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ...utils import (
|
||||
USE_PEFT_BACKEND,
|
||||
BaseOutput,
|
||||
is_torch_version,
|
||||
logging,
|
||||
scale_lora_layers,
|
||||
unscale_lora_layers,
|
||||
)
|
||||
from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
|
||||
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
|
||||
from ..attention_dispatch import dispatch_attention_fn
|
||||
from ..cache_utils import CacheMixin
|
||||
from ..embeddings import PixArtAlphaTextProjection, PixArtAlphaCombinedTimestepSizeEmbeddings
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..embeddings import PixArtAlphaCombinedTimestepSizeEmbeddings, PixArtAlphaTextProjection
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import AdaLayerNormSingle, RMSNorm
|
||||
from ..normalization import RMSNorm
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
@@ -44,6 +49,7 @@ def apply_interleaved_rotary_emb(x: torch.Tensor, freqs: Tuple[torch.Tensor, tor
|
||||
out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
|
||||
return out
|
||||
|
||||
|
||||
def apply_split_rotary_emb(x: torch.Tensor, freqs: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
|
||||
cos, sin = freqs
|
||||
|
||||
@@ -65,7 +71,7 @@ def apply_split_rotary_emb(x: torch.Tensor, freqs: Tuple[torch.Tensor, torch.Ten
|
||||
|
||||
# (..., 2, r)
|
||||
split_x = x.reshape(*x.shape[:-1], 2, r).float() # Explicitly upcast to float
|
||||
first_x = split_x[..., :1, :] # (..., 1, r)
|
||||
first_x = split_x[..., :1, :] # (..., 1, r)
|
||||
second_x = split_x[..., 1:, :] # (..., 1, r)
|
||||
|
||||
cos_u = cos.unsqueeze(-2) # broadcast to (..., 1, r) against (..., 2, r)
|
||||
@@ -89,6 +95,7 @@ def apply_split_rotary_emb(x: torch.Tensor, freqs: Tuple[torch.Tensor, torch.Ten
|
||||
|
||||
ROTARY_FN_MAP = {"interleaved": apply_interleaved_rotary_emb, "split": apply_split_rotary_emb}
|
||||
|
||||
|
||||
@dataclass
|
||||
class AudioVisualModelOutput(BaseOutput):
|
||||
r"""
|
||||
@@ -192,7 +199,9 @@ class LTX2AudioVideoAttnProcessor:
|
||||
|
||||
if query_rotary_emb is not None:
|
||||
query = ROTARY_FN_MAP[attn.rope_type](query, query_rotary_emb)
|
||||
key = ROTARY_FN_MAP[attn.rope_type](key, key_rotary_emb if key_rotary_emb is not None else query_rotary_emb)
|
||||
key = ROTARY_FN_MAP[attn.rope_type](
|
||||
key, key_rotary_emb if key_rotary_emb is not None else query_rotary_emb
|
||||
)
|
||||
|
||||
query = query.unflatten(2, (attn.heads, -1))
|
||||
key = key.unflatten(2, (attn.heads, -1))
|
||||
@@ -368,7 +377,7 @@ class LTX2VideoTransformerBlock(nn.Module):
|
||||
bias=attention_bias,
|
||||
out_bias=attention_out_bias,
|
||||
qk_norm=qk_norm,
|
||||
rope_type=rope_type
|
||||
rope_type=rope_type,
|
||||
)
|
||||
|
||||
self.audio_norm2 = RMSNorm(audio_dim, eps=eps, elementwise_affine=elementwise_affine)
|
||||
@@ -381,7 +390,7 @@ class LTX2VideoTransformerBlock(nn.Module):
|
||||
bias=attention_bias,
|
||||
out_bias=attention_out_bias,
|
||||
qk_norm=qk_norm,
|
||||
rope_type=rope_type
|
||||
rope_type=rope_type,
|
||||
)
|
||||
|
||||
# 3. Audio-to-Video (a2v) and Video-to-Audio (v2a) Cross-Attention
|
||||
@@ -396,7 +405,7 @@ class LTX2VideoTransformerBlock(nn.Module):
|
||||
bias=attention_bias,
|
||||
out_bias=attention_out_bias,
|
||||
qk_norm=qk_norm,
|
||||
rope_type=rope_type
|
||||
rope_type=rope_type,
|
||||
)
|
||||
|
||||
# Video-to-Audio (v2a) Attention --> Q: Audio; K,V: Video
|
||||
@@ -481,7 +490,9 @@ class LTX2VideoTransformerBlock(nn.Module):
|
||||
audio_ada_values = self.audio_scale_shift_table[None, None].to(temb_audio.device) + temb_audio.reshape(
|
||||
batch_size, temb_audio.size(1), num_audio_ada_params, -1
|
||||
)
|
||||
audio_shift_msa, audio_scale_msa, audio_gate_msa, audio_shift_mlp, audio_scale_mlp, audio_gate_mlp = audio_ada_values.unbind(dim=2)
|
||||
audio_shift_msa, audio_scale_msa, audio_gate_msa, audio_shift_mlp, audio_scale_mlp, audio_gate_mlp = (
|
||||
audio_ada_values.unbind(dim=2)
|
||||
)
|
||||
norm_audio_hidden_states = norm_audio_hidden_states * (1 + audio_scale_msa) + audio_shift_msa
|
||||
|
||||
attn_audio_hidden_states = self.audio_attn1(
|
||||
@@ -550,8 +561,12 @@ class LTX2VideoTransformerBlock(nn.Module):
|
||||
|
||||
if use_a2v_cross_attn:
|
||||
# Audio-to-Video Cross Attention: Q: Video; K,V: Audio
|
||||
mod_norm_hidden_states = norm_hidden_states * (1 + video_a2v_ca_scale.squeeze(2)) + video_a2v_ca_shift.squeeze(2)
|
||||
mod_norm_audio_hidden_states = norm_audio_hidden_states * (1 + audio_a2v_ca_scale.squeeze(2)) + audio_a2v_ca_shift.squeeze(2)
|
||||
mod_norm_hidden_states = norm_hidden_states * (
|
||||
1 + video_a2v_ca_scale.squeeze(2)
|
||||
) + video_a2v_ca_shift.squeeze(2)
|
||||
mod_norm_audio_hidden_states = norm_audio_hidden_states * (
|
||||
1 + audio_a2v_ca_scale.squeeze(2)
|
||||
) + audio_a2v_ca_shift.squeeze(2)
|
||||
|
||||
a2v_attn_hidden_states = self.audio_to_video_attn(
|
||||
mod_norm_hidden_states,
|
||||
@@ -565,8 +580,12 @@ class LTX2VideoTransformerBlock(nn.Module):
|
||||
|
||||
if use_v2a_cross_attn:
|
||||
# Video-to-Audio Cross Attention: Q: Audio; K,V: Video
|
||||
mod_norm_hidden_states = norm_hidden_states * (1 + video_v2a_ca_scale.squeeze(2)) + video_v2a_ca_shift.squeeze(2)
|
||||
mod_norm_audio_hidden_states = norm_audio_hidden_states * (1 + audio_v2a_ca_scale.squeeze(2)) + audio_v2a_ca_shift.squeeze(2)
|
||||
mod_norm_hidden_states = norm_hidden_states * (
|
||||
1 + video_v2a_ca_scale.squeeze(2)
|
||||
) + video_v2a_ca_shift.squeeze(2)
|
||||
mod_norm_audio_hidden_states = norm_audio_hidden_states * (
|
||||
1 + audio_v2a_ca_scale.squeeze(2)
|
||||
) + audio_v2a_ca_shift.squeeze(2)
|
||||
|
||||
v2a_attn_hidden_states = self.video_to_audio_attn(
|
||||
mod_norm_audio_hidden_states,
|
||||
@@ -596,9 +615,10 @@ class LTX2AudioVideoRotaryPosEmbed(nn.Module):
|
||||
|
||||
Args:
|
||||
causal_offset (`int`, *optional*, defaults to `1`):
|
||||
Offset in the temporal axis for causal VAE modeling. This is typically 1 (for causal modeling where
|
||||
the VAE treats the very first frame differently), but could also be 0 (for non-causal modeling).
|
||||
Offset in the temporal axis for causal VAE modeling. This is typically 1 (for causal modeling where the VAE
|
||||
treats the very first frame differently), but could also be 0 (for non-causal modeling).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
@@ -658,9 +678,9 @@ class LTX2AudioVideoRotaryPosEmbed(nn.Module):
|
||||
fps: float = 25.0,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Create per-dimension bounds [inclusive start, exclusive end) for each patch with respect to the original
|
||||
pixel space video grid (num_frames, height, width). This will ultimately have shape (batch_size, 3,
|
||||
num_patches, 2) where
|
||||
Create per-dimension bounds [inclusive start, exclusive end) for each patch with respect to the original pixel
|
||||
space video grid (num_frames, height, width). This will ultimately have shape (batch_size, 3, num_patches, 2)
|
||||
where
|
||||
- axis 1 (size 3) enumerates (frame, height, width) dimensions (e.g. idx 0 corresponds to frames)
|
||||
- axis 3 (size 2) stores `[start, end)` indices within each dimension
|
||||
|
||||
@@ -727,8 +747,8 @@ class LTX2AudioVideoRotaryPosEmbed(nn.Module):
|
||||
shift: int = 0,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Create per-dimension bounds [inclusive start, exclusive end) of start and end timestamps for each latent
|
||||
frame. This will ultimately have shape (batch_size, 3, num_patches, 2) where
|
||||
Create per-dimension bounds [inclusive start, exclusive end) of start and end timestamps for each latent frame.
|
||||
This will ultimately have shape (batch_size, 3, num_patches, 2) where
|
||||
- axis 1 (size 1) represents the temporal dimension
|
||||
- axis 3 (size 2) stores `[start, end)` indices within each dimension
|
||||
|
||||
@@ -763,7 +783,7 @@ class LTX2AudioVideoRotaryPosEmbed(nn.Module):
|
||||
# Handle first frame causal offset, ensuring non-negative timestamps
|
||||
grid_start_mel = (grid_start_mel + self.causal_offset - audio_scale_factor).clip(min=0)
|
||||
# Convert mel bins back into seconds
|
||||
grid_start_s = grid_start_mel * self.hop_length / self.sampling_rate
|
||||
grid_start_s = grid_start_mel * self.hop_length / self.sampling_rate
|
||||
|
||||
# 3. Calculate start timstamps in seconds with respect to the original spectrogram grid
|
||||
grid_end_mel = (grid_f + self.patch_size_t) * audio_scale_factor
|
||||
@@ -862,7 +882,7 @@ class LTX2AudioVideoRotaryPosEmbed(nn.Module):
|
||||
sin_padding = torch.zeros_like(cos_freqs[:, :, : self.dim % num_rope_elems])
|
||||
cos_freqs = torch.cat([cos_padding, cos_freqs], dim=-1)
|
||||
sin_freqs = torch.cat([sin_padding, sin_freqs], dim=-1)
|
||||
|
||||
|
||||
elif self.rope_type == "split":
|
||||
expected_freqs = self.dim // 2
|
||||
current_freqs = freqs.shape[-1]
|
||||
@@ -1087,7 +1107,7 @@ class LTX2VideoTransformer3DModel(
|
||||
modality="audio",
|
||||
double_precision=rope_double_precision,
|
||||
rope_type=rope_type,
|
||||
num_attention_heads=audio_num_attention_heads
|
||||
num_attention_heads=audio_num_attention_heads,
|
||||
)
|
||||
|
||||
# 5. Transformer Blocks
|
||||
@@ -1154,7 +1174,7 @@ class LTX2VideoTransformer3DModel(
|
||||
encoder_hidden_states (`torch.Tensor`):
|
||||
Input text embeddings of shape TODO.
|
||||
TODO for the rest.
|
||||
|
||||
|
||||
Returns:
|
||||
`AudioVisualModelOutput` or `tuple`:
|
||||
If `return_dict` is `True`, returns a structured output of type `AudioVisualModelOutput`, otherwise a
|
||||
@@ -1204,14 +1224,18 @@ class LTX2VideoTransformer3DModel(
|
||||
audio_rotary_emb = self.audio_rope(audio_coords, device=audio_hidden_states.device)
|
||||
|
||||
video_cross_attn_rotary_emb = self.cross_attn_rope(video_coords[:, 0:1, :], device=hidden_states.device)
|
||||
audio_cross_attn_rotary_emb = self.cross_attn_audio_rope(audio_coords[:, 0:1, :], device=audio_hidden_states.device)
|
||||
audio_cross_attn_rotary_emb = self.cross_attn_audio_rope(
|
||||
audio_coords[:, 0:1, :], device=audio_hidden_states.device
|
||||
)
|
||||
|
||||
# 2. Patchify input projections
|
||||
hidden_states = self.proj_in(hidden_states)
|
||||
audio_hidden_states = self.audio_proj_in(audio_hidden_states)
|
||||
|
||||
# 3. Prepare timestep embeddings and modulation parameters
|
||||
timestep_cross_attn_gate_scale_factor = self.config.cross_attn_timestep_scale_multiplier / self.config.timestep_scale_multiplier
|
||||
timestep_cross_attn_gate_scale_factor = (
|
||||
self.config.cross_attn_timestep_scale_multiplier / self.config.timestep_scale_multiplier
|
||||
)
|
||||
|
||||
# 3.1. Prepare global modality (video and audio) timestep embedding and modulation parameters
|
||||
# temb is used in the transformer blocks (as expected), while embedded_timestep is used for the output layer
|
||||
@@ -1243,7 +1267,9 @@ class LTX2VideoTransformer3DModel(
|
||||
batch_size=batch_size,
|
||||
hidden_dtype=hidden_states.dtype,
|
||||
)
|
||||
video_cross_attn_scale_shift = video_cross_attn_scale_shift.view(batch_size, -1, video_cross_attn_scale_shift.shape[-1])
|
||||
video_cross_attn_scale_shift = video_cross_attn_scale_shift.view(
|
||||
batch_size, -1, video_cross_attn_scale_shift.shape[-1]
|
||||
)
|
||||
video_cross_attn_a2v_gate = video_cross_attn_a2v_gate.view(batch_size, -1, video_cross_attn_a2v_gate.shape[-1])
|
||||
|
||||
audio_cross_attn_scale_shift, _ = self.av_cross_attn_audio_scale_shift(
|
||||
@@ -1256,7 +1282,9 @@ class LTX2VideoTransformer3DModel(
|
||||
batch_size=batch_size,
|
||||
hidden_dtype=audio_hidden_states.dtype,
|
||||
)
|
||||
audio_cross_attn_scale_shift = audio_cross_attn_scale_shift.view(batch_size, -1, audio_cross_attn_scale_shift.shape[-1])
|
||||
audio_cross_attn_scale_shift = audio_cross_attn_scale_shift.view(
|
||||
batch_size, -1, audio_cross_attn_scale_shift.shape[-1]
|
||||
)
|
||||
audio_cross_attn_v2a_gate = audio_cross_attn_v2a_gate.view(batch_size, -1, audio_cross_attn_v2a_gate.shape[-1])
|
||||
|
||||
# 4. Prepare prompt embeddings
|
||||
|
||||
@@ -720,7 +720,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
LEditsPPPipelineStableDiffusionXL,
|
||||
)
|
||||
from .ltx import LTXConditionPipeline, LTXImageToVideoPipeline, LTXLatentUpsamplePipeline, LTXPipeline
|
||||
from .ltx2 import LTX2Pipeline, LTX2ImageToVideoPipeline
|
||||
from .ltx2 import LTX2ImageToVideoPipeline, LTX2Pipeline
|
||||
from .lucy import LucyEditPipeline
|
||||
from .lumina import LuminaPipeline, LuminaText2ImgPipeline
|
||||
from .lumina2 import Lumina2Pipeline, Lumina2Text2ImgPipeline
|
||||
|
||||
@@ -22,9 +22,9 @@ except OptionalDependencyNotAvailable:
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["connectors"] = ["LTX2TextConnectors"]
|
||||
_import_structure["pipeline_ltx2"] = ["LTX2Pipeline"]
|
||||
_import_structure["pipeline_ltx2_image2video"] = ["LTX2ImageToVideoPipeline"]
|
||||
_import_structure["connectors"] = ["LTX2TextConnectors"]
|
||||
_import_structure["vocoder"] = ["LTX2Vocoder"]
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
@@ -35,9 +35,9 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import *
|
||||
else:
|
||||
from .connectors import LTX2TextConnectors
|
||||
from .pipeline_ltx2 import LTX2Pipeline
|
||||
from .pipeline_ltx2_image2video import LTX2ImageToVideoPipeline
|
||||
from .connectors import LTX2TextConnectors
|
||||
from .vocoder import LTX2Vocoder
|
||||
|
||||
else:
|
||||
|
||||
@@ -9,6 +9,7 @@ from ...models.attention import FeedForward
|
||||
from ...models.modeling_utils import ModelMixin
|
||||
from ...models.transformers.transformer_ltx2 import LTX2Attention, LTX2AudioVideoAttnProcessor
|
||||
|
||||
|
||||
class LTX2RotaryPosEmbed1d(nn.Module):
|
||||
"""
|
||||
1D rotary positional embeddings (RoPE) for the LTX 2.0 text encoder connectors.
|
||||
@@ -21,12 +22,12 @@ class LTX2RotaryPosEmbed1d(nn.Module):
|
||||
theta: float = 10000.0,
|
||||
double_precision: bool = True,
|
||||
rope_type: str = "interleaved",
|
||||
num_attention_heads: int = 32
|
||||
num_attention_heads: int = 32,
|
||||
):
|
||||
super().__init__()
|
||||
if rope_type not in ["interleaved", "split"]:
|
||||
raise ValueError(f"{rope_type=} not supported. Choose between 'interleaved' and 'split'.")
|
||||
|
||||
|
||||
self.dim = dim
|
||||
self.base_seq_len = base_seq_len
|
||||
self.theta = theta
|
||||
@@ -69,7 +70,7 @@ class LTX2RotaryPosEmbed1d(nn.Module):
|
||||
sin_padding = torch.zeros_like(sin_freqs[:, :, : self.dim % num_rope_elems])
|
||||
cos_freqs = torch.cat([cos_padding, cos_freqs], dim=-1)
|
||||
sin_freqs = torch.cat([sin_padding, sin_freqs], dim=-1)
|
||||
|
||||
|
||||
elif self.rope_type == "split":
|
||||
expected_freqs = self.dim // 2
|
||||
current_freqs = freqs.shape[-1]
|
||||
@@ -116,7 +117,7 @@ class LTX2TransformerBlock1d(nn.Module):
|
||||
kv_heads=num_attention_heads,
|
||||
dim_head=attention_head_dim,
|
||||
processor=LTX2AudioVideoAttnProcessor(),
|
||||
rope_type=rope_type
|
||||
rope_type=rope_type,
|
||||
)
|
||||
|
||||
self.norm2 = torch.nn.RMSNorm(dim, eps=eps, elementwise_affine=False)
|
||||
@@ -159,7 +160,7 @@ class LTX2ConnectorTransformer1d(nn.Module):
|
||||
rope_double_precision: bool = True,
|
||||
eps: float = 1e-6,
|
||||
causal_temporal_positioning: bool = False,
|
||||
rope_type: str = "interleaved"
|
||||
rope_type: str = "interleaved",
|
||||
):
|
||||
super().__init__()
|
||||
self.num_attention_heads = num_attention_heads
|
||||
@@ -173,12 +174,12 @@ class LTX2ConnectorTransformer1d(nn.Module):
|
||||
self.learnable_registers = torch.nn.Parameter(init_registers)
|
||||
|
||||
self.rope = LTX2RotaryPosEmbed1d(
|
||||
self.inner_dim,
|
||||
base_seq_len=rope_base_seq_len,
|
||||
theta=rope_theta,
|
||||
self.inner_dim,
|
||||
base_seq_len=rope_base_seq_len,
|
||||
theta=rope_theta,
|
||||
double_precision=rope_double_precision,
|
||||
rope_type=rope_type,
|
||||
num_attention_heads=num_attention_heads
|
||||
num_attention_heads=num_attention_heads,
|
||||
)
|
||||
|
||||
self.transformer_blocks = torch.nn.ModuleList(
|
||||
@@ -187,7 +188,7 @@ class LTX2ConnectorTransformer1d(nn.Module):
|
||||
dim=self.inner_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
rope_type=rope_type
|
||||
rope_type=rope_type,
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
@@ -253,8 +254,8 @@ class LTX2ConnectorTransformer1d(nn.Module):
|
||||
|
||||
class LTX2TextConnectors(ModelMixin, ConfigMixin):
|
||||
"""
|
||||
Text connector stack used by LTX 2.0 to process the packed text encoder hidden states for both the video and
|
||||
audio streams.
|
||||
Text connector stack used by LTX 2.0 to process the packed text encoder hidden states for both the video and audio
|
||||
streams.
|
||||
"""
|
||||
|
||||
@register_to_config
|
||||
@@ -287,7 +288,7 @@ class LTX2TextConnectors(ModelMixin, ConfigMixin):
|
||||
rope_theta=rope_theta,
|
||||
rope_double_precision=rope_double_precision,
|
||||
causal_temporal_positioning=causal_temporal_positioning,
|
||||
rope_type=rope_type
|
||||
rope_type=rope_type,
|
||||
)
|
||||
self.audio_connector = LTX2ConnectorTransformer1d(
|
||||
num_attention_heads=audio_connector_num_attention_heads,
|
||||
@@ -298,7 +299,7 @@ class LTX2TextConnectors(ModelMixin, ConfigMixin):
|
||||
rope_theta=rope_theta,
|
||||
rope_double_precision=rope_double_precision,
|
||||
causal_temporal_positioning=causal_temporal_positioning,
|
||||
rope_type=rope_type
|
||||
rope_type=rope_type,
|
||||
)
|
||||
|
||||
def forward(
|
||||
|
||||
@@ -674,7 +674,9 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMix
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
duration_s = num_frames / frame_rate
|
||||
latents_per_second = float(sampling_rate) / float(hop_length) / float(self.audio_vae_temporal_compression_ratio)
|
||||
latents_per_second = (
|
||||
float(sampling_rate) / float(hop_length) / float(self.audio_vae_temporal_compression_ratio)
|
||||
)
|
||||
latent_length = int(duration_s * latents_per_second)
|
||||
|
||||
if latents is not None:
|
||||
@@ -995,7 +997,9 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMix
|
||||
|
||||
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
||||
latent_model_input = latent_model_input.to(prompt_embeds.dtype)
|
||||
audio_latent_model_input = torch.cat([audio_latents] * 2) if self.do_classifier_free_guidance else audio_latents
|
||||
audio_latent_model_input = (
|
||||
torch.cat([audio_latents] * 2) if self.do_classifier_free_guidance else audio_latents
|
||||
)
|
||||
audio_latent_model_input = audio_latent_model_input.to(prompt_embeds.dtype)
|
||||
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
@@ -1026,10 +1030,14 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMix
|
||||
|
||||
if self.do_classifier_free_guidance:
|
||||
noise_pred_video_uncond, noise_pred_video_text = noise_pred_video.chunk(2)
|
||||
noise_pred_video = noise_pred_video_uncond + self.guidance_scale * (noise_pred_video_text - noise_pred_video_uncond)
|
||||
noise_pred_video = noise_pred_video_uncond + self.guidance_scale * (
|
||||
noise_pred_video_text - noise_pred_video_uncond
|
||||
)
|
||||
|
||||
noise_pred_audio_uncond, noise_pred_audio_text = noise_pred_audio.chunk(2)
|
||||
noise_pred_audio = noise_pred_audio_uncond + self.guidance_scale * (noise_pred_audio_text - noise_pred_audio_uncond)
|
||||
noise_pred_audio = noise_pred_audio_uncond + self.guidance_scale * (
|
||||
noise_pred_audio_text - noise_pred_audio_uncond
|
||||
)
|
||||
|
||||
if self.guidance_rescale > 0:
|
||||
# Based on 3.4. in https://huggingface.co/papers/2305.08891
|
||||
|
||||
@@ -13,25 +13,26 @@
|
||||
# limitations under the License.
|
||||
|
||||
import copy
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
import inspect
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import Gemma3ForConditionalGeneration, GemmaTokenizer, GemmaTokenizerFast
|
||||
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
|
||||
from ...image_processor import PipelineImageInput
|
||||
from ...utils import is_torch_xla_available, logging, replace_example_docstring
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from .connectors import LTX2TextConnectors
|
||||
from .pipeline_output import LTX2PipelineOutput
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from .vocoder import LTX2Vocoder
|
||||
from ...loaders import FromSingleFileMixin, LTXVideoLoraLoaderMixin
|
||||
from ...models.autoencoders import AutoencoderKLLTX2Audio, AutoencoderKLLTX2Video
|
||||
from ...models.transformers import LTX2VideoTransformer3DModel
|
||||
from transformers import Gemma3ForConditionalGeneration, GemmaTokenizer, GemmaTokenizerFast
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from ...utils import is_torch_xla_available, logging, replace_example_docstring
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ...video_processor import VideoProcessor
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from .connectors import LTX2TextConnectors
|
||||
from .pipeline_output import LTX2PipelineOutput
|
||||
from .vocoder import LTX2Vocoder
|
||||
|
||||
|
||||
if is_torch_xla_available():
|
||||
@@ -86,6 +87,7 @@ def retrieve_latents(
|
||||
else:
|
||||
raise AttributeError("Could not access latents of provided encoder_output")
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
|
||||
def calculate_shift(
|
||||
image_seq_len,
|
||||
@@ -665,7 +667,7 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL
|
||||
|
||||
shape = (batch_size, num_channels_latents, num_frames, height, width)
|
||||
mask_shape = (batch_size, 1, num_frames, height, width)
|
||||
|
||||
|
||||
if latents is not None:
|
||||
conditioning_mask = latents.new_zeros(mask_shape)
|
||||
conditioning_mask[:, :, 0] = 1.0
|
||||
@@ -697,7 +699,7 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL
|
||||
init_latents = torch.cat(init_latents, dim=0).to(dtype)
|
||||
init_latents = self._normalize_latents(init_latents, self.vae.latents_mean, self.vae.latents_std)
|
||||
init_latents = init_latents.repeat(1, 1, num_frames, 1, 1)
|
||||
|
||||
|
||||
# First condition is image latents and those should be kept clean.
|
||||
conditioning_mask = torch.zeros(mask_shape, device=device, dtype=dtype)
|
||||
conditioning_mask[:, :, 0] = 1.0
|
||||
@@ -731,7 +733,9 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
duration_s = num_frames / frame_rate
|
||||
latents_per_second = float(sampling_rate) / float(hop_length) / float(self.audio_vae_temporal_compression_ratio)
|
||||
latents_per_second = (
|
||||
float(sampling_rate) / float(hop_length) / float(self.audio_vae_temporal_compression_ratio)
|
||||
)
|
||||
latent_length = int(duration_s * latents_per_second)
|
||||
|
||||
if latents is not None:
|
||||
@@ -982,7 +986,7 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL
|
||||
)
|
||||
if self.do_classifier_free_guidance:
|
||||
conditioning_mask = torch.cat([conditioning_mask, conditioning_mask])
|
||||
|
||||
|
||||
num_mel_bins = self.audio_vae.config.mel_bins if getattr(self, "audio_vae", None) is not None else 64
|
||||
latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio
|
||||
|
||||
@@ -1063,12 +1067,14 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL
|
||||
|
||||
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
||||
latent_model_input = latent_model_input.to(prompt_embeds.dtype)
|
||||
audio_latent_model_input = torch.cat([audio_latents] * 2) if self.do_classifier_free_guidance else audio_latents
|
||||
audio_latent_model_input = (
|
||||
torch.cat([audio_latents] * 2) if self.do_classifier_free_guidance else audio_latents
|
||||
)
|
||||
audio_latent_model_input = audio_latent_model_input.to(prompt_embeds.dtype)
|
||||
|
||||
timestep = t.expand(latent_model_input.shape[0])
|
||||
video_timestep = timestep.unsqueeze(-1) * (1 - conditioning_mask)
|
||||
|
||||
|
||||
with self.transformer.cache_context("cond_uncond"):
|
||||
noise_pred_video, noise_pred_audio = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
@@ -1095,10 +1101,14 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL
|
||||
|
||||
if self.do_classifier_free_guidance:
|
||||
noise_pred_video_uncond, noise_pred_video_text = noise_pred_video.chunk(2)
|
||||
noise_pred_video = noise_pred_video_uncond + self.guidance_scale * (noise_pred_video_text - noise_pred_video_uncond)
|
||||
noise_pred_video = noise_pred_video_uncond + self.guidance_scale * (
|
||||
noise_pred_video_text - noise_pred_video_uncond
|
||||
)
|
||||
|
||||
noise_pred_audio_uncond, noise_pred_audio_text = noise_pred_audio.chunk(2)
|
||||
noise_pred_audio = noise_pred_audio_uncond + self.guidance_scale * (noise_pred_audio_text - noise_pred_audio_uncond)
|
||||
noise_pred_audio = noise_pred_audio_uncond + self.guidance_scale * (
|
||||
noise_pred_audio_text - noise_pred_audio_uncond
|
||||
)
|
||||
|
||||
if self.guidance_rescale > 0:
|
||||
# Based on 3.4. in https://huggingface.co/papers/2305.08891
|
||||
|
||||
@@ -25,32 +25,18 @@ class ResBlock(nn.Module):
|
||||
|
||||
self.convs1 = nn.ModuleList(
|
||||
[
|
||||
nn.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
stride=stride,
|
||||
dilation=dilation,
|
||||
padding=padding_mode
|
||||
)
|
||||
nn.Conv1d(channels, channels, kernel_size, stride=stride, dilation=dilation, padding=padding_mode)
|
||||
for dilation in dilations
|
||||
]
|
||||
)
|
||||
|
||||
self.convs2 = nn.ModuleList(
|
||||
[
|
||||
nn.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
stride=stride,
|
||||
dilation=1,
|
||||
padding=padding_mode
|
||||
)
|
||||
nn.Conv1d(channels, channels, kernel_size, stride=stride, dilation=1, padding=padding_mode)
|
||||
for _ in range(len(dilations))
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
for conv1, conv2 in zip(self.convs1, self.convs2):
|
||||
xt = F.leaky_relu(x, negative_slope=self.negative_slope)
|
||||
@@ -127,7 +113,7 @@ class LTX2Vocoder(ModelMixin, ConfigMixin):
|
||||
input_channels = output_channels
|
||||
|
||||
self.conv_out = nn.Conv1d(output_channels, out_channels, 7, stride=1, padding=3)
|
||||
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, time_last: bool = False) -> torch.Tensor:
|
||||
r"""
|
||||
Forward pass of the vocoder.
|
||||
|
||||
@@ -502,6 +502,36 @@ class AutoencoderKLHunyuanVideo15(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class AutoencoderKLLTX2Audio(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class AutoencoderKLLTX2Video(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class AutoencoderKLLTXVideo(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
@@ -1132,6 +1162,21 @@ class LatteTransformer3DModel(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class LTX2VideoTransformer3DModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class LTXVideoTransformer3DModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
||||
@@ -1802,6 +1802,36 @@ class LEditsPPPipelineStableDiffusionXL(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class LTX2ImageToVideoPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class LTX2Pipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class LTXConditionPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
|
||||
@@ -15,8 +15,6 @@
|
||||
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import AutoencoderKLLTX2Video
|
||||
|
||||
from ...testing_utils import (
|
||||
|
||||
@@ -52,9 +52,9 @@ class LTX2TransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
sequence_length = 16
|
||||
|
||||
hidden_states = torch.randn((batch_size, num_frames * height * width, num_channels)).to(torch_device)
|
||||
audio_hidden_states = torch.randn(
|
||||
(batch_size, audio_num_frames, audio_num_channels * num_mel_bins)
|
||||
).to(torch_device)
|
||||
audio_hidden_states = torch.randn((batch_size, audio_num_frames, audio_num_channels * num_mel_bins)).to(
|
||||
torch_device
|
||||
)
|
||||
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
|
||||
audio_encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
|
||||
encoder_attention_mask = torch.ones((batch_size, sequence_length)).bool().to(torch_device)
|
||||
|
||||
0
tests/pipelines/ltx2/__init__.py
Normal file
0
tests/pipelines/ltx2/__init__.py
Normal file
239
tests/pipelines/ltx2/test_ltx2.py
Normal file
239
tests/pipelines/ltx2/test_ltx2.py
Normal file
@@ -0,0 +1,239 @@
|
||||
# Copyright 2025 The HuggingFace Team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from transformers import AutoTokenizer, Gemma3ForConditionalGeneration
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKLLTX2Audio,
|
||||
AutoencoderKLLTX2Video,
|
||||
FlowMatchEulerDiscreteScheduler,
|
||||
LTX2Pipeline,
|
||||
LTX2VideoTransformer3DModel,
|
||||
)
|
||||
from diffusers.pipelines.ltx2 import LTX2TextConnectors
|
||||
from diffusers.pipelines.ltx2.vocoder import LTX2Vocoder
|
||||
|
||||
from ...testing_utils import enable_full_determinism
|
||||
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
|
||||
from ..test_pipelines_common import PipelineTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class LTX2PipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
pipeline_class = LTX2Pipeline
|
||||
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
|
||||
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
|
||||
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
||||
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
||||
required_optional_params = frozenset(
|
||||
[
|
||||
"num_inference_steps",
|
||||
"generator",
|
||||
"latents",
|
||||
"audio_latents",
|
||||
"output_type",
|
||||
"return_dict",
|
||||
"callback_on_step_end",
|
||||
"callback_on_step_end_tensor_inputs",
|
||||
]
|
||||
)
|
||||
test_attention_slicing = False
|
||||
test_xformers_attention = False
|
||||
supports_dduf = False
|
||||
|
||||
base_text_encoder_ckpt_id = "hf-internal-testing/tiny-gemma3"
|
||||
|
||||
def get_dummy_components(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained(self.base_text_encoder_ckpt_id)
|
||||
text_encoder = Gemma3ForConditionalGeneration.from_pretrained(self.base_text_encoder_ckpt_id)
|
||||
|
||||
torch.manual_seed(0)
|
||||
transformer = LTX2VideoTransformer3DModel(
|
||||
in_channels=4,
|
||||
out_channels=4,
|
||||
patch_size=1,
|
||||
patch_size_t=1,
|
||||
num_attention_heads=2,
|
||||
attention_head_dim=8,
|
||||
cross_attention_dim=16,
|
||||
audio_in_channels=4,
|
||||
audio_out_channels=4,
|
||||
audio_num_attention_heads=2,
|
||||
audio_attention_head_dim=4,
|
||||
audio_cross_attention_dim=8,
|
||||
num_layers=2,
|
||||
qk_norm="rms_norm_across_heads",
|
||||
caption_channels=text_encoder.config.text_config.hidden_size,
|
||||
rope_double_precision=False,
|
||||
rope_type="split",
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
connectors = LTX2TextConnectors(
|
||||
caption_channels=text_encoder.config.text_config.hidden_size,
|
||||
text_proj_in_factor=text_encoder.config.text_config.num_hidden_layers + 1,
|
||||
video_connector_num_attention_heads=4,
|
||||
video_connector_attention_head_dim=8,
|
||||
video_connector_num_layers=1,
|
||||
video_connector_num_learnable_registers=None,
|
||||
audio_connector_num_attention_heads=4,
|
||||
audio_connector_attention_head_dim=8,
|
||||
audio_connector_num_layers=1,
|
||||
audio_connector_num_learnable_registers=None,
|
||||
connector_rope_base_seq_len=32,
|
||||
rope_theta=10000.0,
|
||||
rope_double_precision=False,
|
||||
causal_temporal_positioning=False,
|
||||
rope_type="split",
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
vae = AutoencoderKLLTX2Video(
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
latent_channels=4,
|
||||
block_out_channels=(8,),
|
||||
decoder_block_out_channels=(8,),
|
||||
layers_per_block=(1,),
|
||||
decoder_layers_per_block=(1, 1),
|
||||
spatio_temporal_scaling=(True,),
|
||||
decoder_spatio_temporal_scaling=(True,),
|
||||
decoder_inject_noise=(False, False),
|
||||
downsample_type=("spatial",),
|
||||
upsample_residual=(False,),
|
||||
upsample_factor=(1,),
|
||||
timestep_conditioning=False,
|
||||
patch_size=1,
|
||||
patch_size_t=1,
|
||||
encoder_causal=True,
|
||||
decoder_causal=False,
|
||||
)
|
||||
vae.use_framewise_encoding = False
|
||||
vae.use_framewise_decoding = False
|
||||
|
||||
torch.manual_seed(0)
|
||||
audio_vae = AutoencoderKLLTX2Audio(
|
||||
base_channels=4,
|
||||
output_channels=2,
|
||||
ch_mult=(1,),
|
||||
num_res_blocks=1,
|
||||
attn_resolutions=None,
|
||||
in_channels=2,
|
||||
resolution=32,
|
||||
latent_channels=2,
|
||||
norm_type="pixel",
|
||||
causality_axis="height",
|
||||
dropout=0.0,
|
||||
mid_block_add_attention=False,
|
||||
sample_rate=16000,
|
||||
mel_hop_length=160,
|
||||
is_causal=True,
|
||||
mel_bins=8,
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
vocoder = LTX2Vocoder(
|
||||
in_channels=audio_vae.config.output_channels * audio_vae.config.mel_bins,
|
||||
hidden_channels=32,
|
||||
out_channels=2,
|
||||
upsample_kernel_sizes=[4, 4],
|
||||
upsample_factors=[2, 2],
|
||||
resnet_kernel_sizes=[3],
|
||||
resnet_dilations=[[1, 3, 5]],
|
||||
leaky_relu_negative_slope=0.1,
|
||||
output_sampling_rate=16000,
|
||||
)
|
||||
|
||||
scheduler = FlowMatchEulerDiscreteScheduler()
|
||||
|
||||
components = {
|
||||
"transformer": transformer,
|
||||
"vae": vae,
|
||||
"audio_vae": audio_vae,
|
||||
"scheduler": scheduler,
|
||||
"text_encoder": text_encoder,
|
||||
"tokenizer": tokenizer,
|
||||
"connectors": connectors,
|
||||
"vocoder": vocoder,
|
||||
}
|
||||
|
||||
return components
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device=device).manual_seed(seed)
|
||||
|
||||
inputs = {
|
||||
"prompt": "a robot dancing",
|
||||
"negative_prompt": "",
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 1.0,
|
||||
"height": 32,
|
||||
"width": 32,
|
||||
"num_frames": 5,
|
||||
"frame_rate": 25.0,
|
||||
"max_sequence_length": 16,
|
||||
"output_type": "pt",
|
||||
}
|
||||
|
||||
return inputs
|
||||
|
||||
def test_inference(self):
|
||||
device = "cpu"
|
||||
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.to(device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
output = pipe(**inputs)
|
||||
video = output.frames
|
||||
audio = output.audio
|
||||
|
||||
self.assertEqual(video.shape, (1, 5, 3, 32, 32))
|
||||
self.assertEqual(audio.shape[0], 1)
|
||||
self.assertEqual(audio.shape[1], components["vocoder"].config.out_channels)
|
||||
|
||||
# fmt: off
|
||||
expected_video_slice = torch.tensor(
|
||||
[
|
||||
0.4331, 0.6203, 0.3245, 0.7294, 0.4822, 0.5703, 0.2999, 0.7700, 0.4961, 0.4242, 0.4581, 0.4351, 0.1137, 0.4437, 0.6304, 0.3184
|
||||
]
|
||||
)
|
||||
expected_audio_slice = torch.tensor(
|
||||
[
|
||||
0.0229, 0.0503, 0.1220, 0.1083, 0.1745, 0.1075, 0.1779, 0.0974, 0.0672, -0.0069, 0.0688, 0.0097, 0.0808, 0.1231, 0.0986, 0.0739
|
||||
]
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
video = video.flatten()
|
||||
audio = audio.flatten()
|
||||
generated_video_slice = torch.cat([video[:8], video[-8:]])
|
||||
generated_audio_slice = torch.cat([audio[:8], audio[-8:]])
|
||||
|
||||
assert torch.allclose(expected_video_slice, generated_video_slice, atol=1e-4, rtol=1e-4)
|
||||
assert torch.allclose(expected_audio_slice, generated_audio_slice, atol=1e-4, rtol=1e-4)
|
||||
|
||||
def test_inference_batch_single_identical(self):
|
||||
self._test_inference_batch_single_identical(batch_size=2, expected_max_diff=2e-2)
|
||||
241
tests/pipelines/ltx2/test_ltx2_image2video.py
Normal file
241
tests/pipelines/ltx2/test_ltx2_image2video.py
Normal file
@@ -0,0 +1,241 @@
|
||||
# Copyright 2025 The HuggingFace Team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from transformers import AutoTokenizer, Gemma3ForConditionalGeneration
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKLLTX2Audio,
|
||||
AutoencoderKLLTX2Video,
|
||||
FlowMatchEulerDiscreteScheduler,
|
||||
LTX2ImageToVideoPipeline,
|
||||
LTX2VideoTransformer3DModel,
|
||||
)
|
||||
from diffusers.pipelines.ltx2 import LTX2TextConnectors
|
||||
from diffusers.pipelines.ltx2.vocoder import LTX2Vocoder
|
||||
|
||||
from ...testing_utils import enable_full_determinism
|
||||
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
|
||||
from ..test_pipelines_common import PipelineTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class LTX2ImageToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
pipeline_class = LTX2ImageToVideoPipeline
|
||||
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
|
||||
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS.union({"image"})
|
||||
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
||||
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
||||
required_optional_params = frozenset(
|
||||
[
|
||||
"num_inference_steps",
|
||||
"generator",
|
||||
"latents",
|
||||
"audio_latents",
|
||||
"return_dict",
|
||||
"callback_on_step_end",
|
||||
"callback_on_step_end_tensor_inputs",
|
||||
]
|
||||
)
|
||||
test_attention_slicing = False
|
||||
test_xformers_attention = False
|
||||
supports_dduf = False
|
||||
|
||||
base_text_encoder_ckpt_id = "hf-internal-testing/tiny-gemma3"
|
||||
|
||||
def get_dummy_components(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained(self.base_text_encoder_ckpt_id)
|
||||
text_encoder = Gemma3ForConditionalGeneration.from_pretrained(self.base_text_encoder_ckpt_id)
|
||||
|
||||
torch.manual_seed(0)
|
||||
transformer = LTX2VideoTransformer3DModel(
|
||||
in_channels=4,
|
||||
out_channels=4,
|
||||
patch_size=1,
|
||||
patch_size_t=1,
|
||||
num_attention_heads=2,
|
||||
attention_head_dim=8,
|
||||
cross_attention_dim=16,
|
||||
audio_in_channels=4,
|
||||
audio_out_channels=4,
|
||||
audio_num_attention_heads=2,
|
||||
audio_attention_head_dim=4,
|
||||
audio_cross_attention_dim=8,
|
||||
num_layers=2,
|
||||
qk_norm="rms_norm_across_heads",
|
||||
caption_channels=text_encoder.config.text_config.hidden_size,
|
||||
rope_double_precision=False,
|
||||
rope_type="split",
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
connectors = LTX2TextConnectors(
|
||||
caption_channels=text_encoder.config.text_config.hidden_size,
|
||||
text_proj_in_factor=text_encoder.config.text_config.num_hidden_layers + 1,
|
||||
video_connector_num_attention_heads=4,
|
||||
video_connector_attention_head_dim=8,
|
||||
video_connector_num_layers=1,
|
||||
video_connector_num_learnable_registers=None,
|
||||
audio_connector_num_attention_heads=4,
|
||||
audio_connector_attention_head_dim=8,
|
||||
audio_connector_num_layers=1,
|
||||
audio_connector_num_learnable_registers=None,
|
||||
connector_rope_base_seq_len=32,
|
||||
rope_theta=10000.0,
|
||||
rope_double_precision=False,
|
||||
causal_temporal_positioning=False,
|
||||
rope_type="split",
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
vae = AutoencoderKLLTX2Video(
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
latent_channels=4,
|
||||
block_out_channels=(8,),
|
||||
decoder_block_out_channels=(8,),
|
||||
layers_per_block=(1,),
|
||||
decoder_layers_per_block=(1, 1),
|
||||
spatio_temporal_scaling=(True,),
|
||||
decoder_spatio_temporal_scaling=(True,),
|
||||
decoder_inject_noise=(False, False),
|
||||
downsample_type=("spatial",),
|
||||
upsample_residual=(False,),
|
||||
upsample_factor=(1,),
|
||||
timestep_conditioning=False,
|
||||
patch_size=1,
|
||||
patch_size_t=1,
|
||||
encoder_causal=True,
|
||||
decoder_causal=False,
|
||||
)
|
||||
vae.use_framewise_encoding = False
|
||||
vae.use_framewise_decoding = False
|
||||
|
||||
torch.manual_seed(0)
|
||||
audio_vae = AutoencoderKLLTX2Audio(
|
||||
base_channels=4,
|
||||
output_channels=2,
|
||||
ch_mult=(1,),
|
||||
num_res_blocks=1,
|
||||
attn_resolutions=None,
|
||||
in_channels=2,
|
||||
resolution=32,
|
||||
latent_channels=2,
|
||||
norm_type="pixel",
|
||||
causality_axis="height",
|
||||
dropout=0.0,
|
||||
mid_block_add_attention=False,
|
||||
sample_rate=16000,
|
||||
mel_hop_length=160,
|
||||
is_causal=True,
|
||||
mel_bins=8,
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
vocoder = LTX2Vocoder(
|
||||
in_channels=audio_vae.config.output_channels * audio_vae.config.mel_bins,
|
||||
hidden_channels=32,
|
||||
out_channels=2,
|
||||
upsample_kernel_sizes=[4, 4],
|
||||
upsample_factors=[2, 2],
|
||||
resnet_kernel_sizes=[3],
|
||||
resnet_dilations=[[1, 3, 5]],
|
||||
leaky_relu_negative_slope=0.1,
|
||||
output_sampling_rate=16000,
|
||||
)
|
||||
|
||||
scheduler = FlowMatchEulerDiscreteScheduler()
|
||||
|
||||
components = {
|
||||
"transformer": transformer,
|
||||
"vae": vae,
|
||||
"audio_vae": audio_vae,
|
||||
"scheduler": scheduler,
|
||||
"text_encoder": text_encoder,
|
||||
"tokenizer": tokenizer,
|
||||
"connectors": connectors,
|
||||
"vocoder": vocoder,
|
||||
}
|
||||
|
||||
return components
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device=device).manual_seed(seed)
|
||||
|
||||
image = torch.rand((1, 3, 32, 32), generator=generator, device=device)
|
||||
|
||||
inputs = {
|
||||
"image": image,
|
||||
"prompt": "a robot dancing",
|
||||
"negative_prompt": "",
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 1.0,
|
||||
"height": 32,
|
||||
"width": 32,
|
||||
"num_frames": 5,
|
||||
"frame_rate": 25.0,
|
||||
"max_sequence_length": 16,
|
||||
"output_type": "pt",
|
||||
}
|
||||
|
||||
return inputs
|
||||
|
||||
def test_inference(self):
|
||||
device = "cpu"
|
||||
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.to(device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
output = pipe(**inputs)
|
||||
video = output.frames
|
||||
audio = output.audio
|
||||
|
||||
self.assertEqual(video.shape, (1, 5, 3, 32, 32))
|
||||
self.assertEqual(audio.shape[0], 1)
|
||||
self.assertEqual(audio.shape[1], components["vocoder"].config.out_channels)
|
||||
|
||||
# fmt: off
|
||||
expected_video_slice = torch.tensor(
|
||||
[
|
||||
0.3573, 0.8382, 0.3581, 0.6114, 0.3682, 0.7969, 0.2552, 0.6399, 0.3113, 0.1497, 0.3249, 0.5395, 0.3498, 0.4526, 0.4536, 0.4555
|
||||
]
|
||||
)
|
||||
expected_audio_slice = torch.tensor(
|
||||
[
|
||||
0.0229, 0.0503, 0.1220, 0.1083, 0.1745, 0.1075, 0.1779, 0.0974, 0.0672, -0.0069, 0.0688, 0.0097, 0.0808, 0.1231, 0.0986, 0.0739
|
||||
]
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
video = video.flatten()
|
||||
audio = audio.flatten()
|
||||
generated_video_slice = torch.cat([video[:8], video[-8:]])
|
||||
generated_audio_slice = torch.cat([audio[:8], audio[-8:]])
|
||||
|
||||
assert torch.allclose(expected_video_slice, generated_video_slice, atol=1e-4, rtol=1e-4)
|
||||
assert torch.allclose(expected_audio_slice, generated_audio_slice, atol=1e-4, rtol=1e-4)
|
||||
|
||||
def test_inference_batch_single_identical(self):
|
||||
self._test_inference_batch_single_identical(batch_size=2, expected_max_diff=2e-2)
|
||||
Reference in New Issue
Block a user