mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Add LTX 2.0 Video Pipelines (#12915)
* Initial LTX 2.0 transformer implementation * Add tests for LTX 2 transformer model * Get LTX 2 transformer tests working * Rename LTX 2 compile test class to have LTX2 * Remove RoPE debug print statements * Get LTX 2 transformer compile tests passing * Fix LTX 2 transformer shape errors * Initial script to convert LTX 2 transformer to diffusers * Add more LTX 2 transformer audio arguments * Allow LTX 2 transformer to be loaded from local path for conversion * Improve dummy inputs and add test for LTX 2 transformer consistency * Fix LTX 2 transformer bugs so consistency test passes * Initial implementation of LTX 2.0 video VAE * Explicitly specify temporal and spatial VAE scale factors when converting * Add initial LTX 2.0 video VAE tests * Add initial LTX 2.0 video VAE tests (part 2) * Get diffusers implementation on par with official LTX 2.0 video VAE implementation * Initial LTX 2.0 vocoder implementation * Use RMSNorm implementation closer to original for LTX 2.0 video VAE * start audio decoder. * init registration. * up * simplify and clean up * up * Initial LTX 2.0 text encoder implementation * Rough initial LTX 2.0 pipeline implementation * up * up * up * up * Add imports for LTX 2.0 Audio VAE * Conversion script for LTX 2.0 Audio VAE Decoder * Add Audio VAE logic to T2V pipeline * Duplicate scheduler for audio latents * Support num_videos_per_prompt for prompt embeddings * LTX 2.0 scheduler and full pipeline conversion * Add script to test full LTX2Pipeline T2V inference * Fix pipeline return bugs * Add LTX 2 text encoder and vocoder to ltx2 subdirectory __init__ * Fix more bugs in LTX2Pipeline.__call__ * Improve CPU offload support * Fix pipeline audio VAE decoding dtype bug * Fix video shape error in full pipeline test script * Get LTX 2 T2V pipeline to produce reasonable outputs * Make LTX 2.0 scheduler more consistent with original code * Fix typo when applying scheduler fix in T2V inference script * Refactor Audio VAE to be simpler and remove helpers (#7) * remove resolve causality axes stuff. * remove a bunch of helpers. * remove adjust output shape helper. * remove the use of audiolatentshape. * move normalization and patchify out of pipeline. * fix * up * up * Remove unpatchify and patchify ops before audio latents denormalization (#9) --------- Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> * Add support for I2V (#8) * start i2v. * up * up * up * up * up * remove uniform strategy code. * remove unneeded code. * Denormalize audio latents in I2V pipeline (analogous to T2V change) (#11) * test i2v. * Move Video and Audio Text Encoder Connectors to Transformer (#12) * Denormalize audio latents in I2V pipeline (analogous to T2V change) * Initial refactor to put video and audio text encoder connectors in transformer * Get LTX 2 transformer tests working after connector refactor * precompute run_connectors,. * fixes * Address review comments * Calculate RoPE double precisions freqs using torch instead of np * Further simplify LTX 2 RoPE freq calc * Make connectors a separate module (#18) * remove text_encoder.py * address yiyi's comments. * up * up * up * up --------- Co-authored-by: sayakpaul <spsayakpaul@gmail.com> * up (#19) * address initial feedback from lightricks team (#16) * cross_attn_timestep_scale_multiplier to 1000 * implement split rope type. * up * propagate rope_type to rope embed classes as well. * up * When using split RoPE, make sure that the output dtype is same as input dtype * Fix apply split RoPE shape error when reshaping x to 4D * Add export_utils file for exporting LTX 2.0 videos with audio * Tests for T2V and I2V (#6) * add ltx2 pipeline tests. * up * up * up * up * remove content * style * Denormalize audio latents in I2V pipeline (analogous to T2V change) * Initial refactor to put video and audio text encoder connectors in transformer * Get LTX 2 transformer tests working after connector refactor * up * up * i2v tests. * up * Address review comments * Calculate RoPE double precisions freqs using torch instead of np * Further simplify LTX 2 RoPE freq calc * revert unneded changes. * up * up * update to split style rope. * up --------- Co-authored-by: Daniel Gu <dgu8957@gmail.com> * up * use export util funcs. * Point original checkpoint to LTX 2.0 official checkpoint * Allow the I2V pipeline to accept image URLs * make style and make quality * remove function map. * remove args. * update docs. * update doc entries. * disable ltx2_consistency test * Simplify LTX 2 RoPE forward by removing coords is None logic * make style and make quality * Support LTX 2.0 audio VAE encoder * Apply suggestions from code review Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> * Remove print statement in audio VAE * up * Fix bug when calculating audio RoPE coords * Ltx 2 latent upsample pipeline (#12922) * Initial implementation of LTX 2.0 latent upsampling pipeline * Add new LTX 2.0 spatial latent upsampler logic * Add test script for LTX 2.0 latent upsampling * Add option to enable VAE tiling in upsampling test script * Get latent upsampler working with video latents * Fix typo in BlurDownsample * Add latent upsample pipeline docstring and example * Remove deprecated pipeline VAE slicing/tiling methods * make style and make quality * When returning latents, return unpacked and denormalized latents for T2V and I2V * Add model_cpu_offload_seq for latent upsampling pipeline --------- Co-authored-by: Daniel Gu <dgu8957@gmail.com> * Fix latent upsampler filename in LTX 2 conversion script * Add latent upsample pipeline to LTX 2 docs * Add dummy objects for LTX 2 latent upsample pipeline * Set default FPS to official LTX 2 ckpt default of 24.0 * Set default CFG scale to official LTX 2 ckpt default of 4.0 * Update LTX 2 pipeline example docstrings * make style and make quality * Remove LTX 2 test scripts * Fix LTX 2 upsample pipeline example docstring * Add logic to convert and save a LTX 2 upsampling pipeline * Document LTX2VideoTransformer3DModel forward pass --------- Co-authored-by: sayakpaul <spsayakpaul@gmail.com>
This commit is contained in:
@@ -367,6 +367,8 @@
|
||||
title: LatteTransformer3DModel
|
||||
- local: api/models/longcat_image_transformer2d
|
||||
title: LongCatImageTransformer2DModel
|
||||
- local: api/models/ltx2_video_transformer3d
|
||||
title: LTX2VideoTransformer3DModel
|
||||
- local: api/models/ltx_video_transformer3d
|
||||
title: LTXVideoTransformer3DModel
|
||||
- local: api/models/lumina2_transformer2d
|
||||
@@ -443,6 +445,10 @@
|
||||
title: AutoencoderKLHunyuanVideo
|
||||
- local: api/models/autoencoder_kl_hunyuan_video15
|
||||
title: AutoencoderKLHunyuanVideo15
|
||||
- local: api/models/autoencoderkl_audio_ltx_2
|
||||
title: AutoencoderKLLTX2Audio
|
||||
- local: api/models/autoencoderkl_ltx_2
|
||||
title: AutoencoderKLLTX2Video
|
||||
- local: api/models/autoencoderkl_ltx_video
|
||||
title: AutoencoderKLLTXVideo
|
||||
- local: api/models/autoencoderkl_magvit
|
||||
@@ -678,6 +684,8 @@
|
||||
title: Kandinsky 5.0 Video
|
||||
- local: api/pipelines/latte
|
||||
title: Latte
|
||||
- local: api/pipelines/ltx2
|
||||
title: LTX-2
|
||||
- local: api/pipelines/ltx_video
|
||||
title: LTXVideo
|
||||
- local: api/pipelines/mochi
|
||||
|
||||
29
docs/source/en/api/models/autoencoderkl_audio_ltx_2.md
Normal file
29
docs/source/en/api/models/autoencoderkl_audio_ltx_2.md
Normal file
@@ -0,0 +1,29 @@
|
||||
<!-- Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License. -->
|
||||
|
||||
# AutoencoderKLLTX2Audio
|
||||
|
||||
The 3D variational autoencoder (VAE) model with KL loss used in [LTX-2](https://huggingface.co/Lightricks/LTX-2) was introduced by Lightricks. This is for encoding and decoding audio latent representations.
|
||||
|
||||
The model can be loaded with the following code snippet.
|
||||
|
||||
```python
|
||||
from diffusers import AutoencoderKLLTX2Audio
|
||||
|
||||
vae = AutoencoderKLLTX2Audio.from_pretrained("Lightricks/LTX-2", subfolder="vae", torch_dtype=torch.float32).to("cuda")
|
||||
```
|
||||
|
||||
## AutoencoderKLLTX2Audio
|
||||
|
||||
[[autodoc]] AutoencoderKLLTX2Audio
|
||||
- encode
|
||||
- decode
|
||||
- all
|
||||
29
docs/source/en/api/models/autoencoderkl_ltx_2.md
Normal file
29
docs/source/en/api/models/autoencoderkl_ltx_2.md
Normal file
@@ -0,0 +1,29 @@
|
||||
<!-- Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License. -->
|
||||
|
||||
# AutoencoderKLLTX2Video
|
||||
|
||||
The 3D variational autoencoder (VAE) model with KL loss used in [LTX-2](https://huggingface.co/Lightricks/LTX-2) was introduced by Lightricks.
|
||||
|
||||
The model can be loaded with the following code snippet.
|
||||
|
||||
```python
|
||||
from diffusers import AutoencoderKLLTX2Video
|
||||
|
||||
vae = AutoencoderKLLTX2Video.from_pretrained("Lightricks/LTX-2", subfolder="vae", torch_dtype=torch.float32).to("cuda")
|
||||
```
|
||||
|
||||
## AutoencoderKLLTX2Video
|
||||
|
||||
[[autodoc]] AutoencoderKLLTX2Video
|
||||
- decode
|
||||
- encode
|
||||
- all
|
||||
26
docs/source/en/api/models/ltx2_video_transformer3d.md
Normal file
26
docs/source/en/api/models/ltx2_video_transformer3d.md
Normal file
@@ -0,0 +1,26 @@
|
||||
<!-- Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License. -->
|
||||
|
||||
# LTX2VideoTransformer3DModel
|
||||
|
||||
A Diffusion Transformer model for 3D data from [LTX](https://huggingface.co/Lightricks/LTX-2) was introduced by Lightricks.
|
||||
|
||||
The model can be loaded with the following code snippet.
|
||||
|
||||
```python
|
||||
from diffusers import LTX2VideoTransformer3DModel
|
||||
|
||||
transformer = LTX2VideoTransformer3DModel.from_pretrained("Lightricks/LTX-2", subfolder="transformer", torch_dtype=torch.bfloat16).to("cuda")
|
||||
```
|
||||
|
||||
## LTX2VideoTransformer3DModel
|
||||
|
||||
[[autodoc]] LTX2VideoTransformer3DModel
|
||||
43
docs/source/en/api/pipelines/ltx2.md
Normal file
43
docs/source/en/api/pipelines/ltx2.md
Normal file
@@ -0,0 +1,43 @@
|
||||
<!-- Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License. -->
|
||||
|
||||
# LTX-2
|
||||
|
||||
LTX-2 is a DiT-based audio-video foundation model designed to generate synchronized video and audio within a single model. It brings together the core building blocks of modern video generation, with open weights and a focus on practical, local execution.
|
||||
|
||||
You can find all the original LTX-Video checkpoints under the [Lightricks](https://huggingface.co/Lightricks) organization.
|
||||
|
||||
The original codebase for LTX-2 can be found [here](https://github.com/Lightricks/LTX-2).
|
||||
|
||||
## LTX2Pipeline
|
||||
|
||||
[[autodoc]] LTX2Pipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## LTX2ImageToVideoPipeline
|
||||
|
||||
[[autodoc]] LTX2ImageToVideoPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## LTX2LatentUpsamplePipeline
|
||||
|
||||
[[autodoc]] LTX2LatentUpsamplePipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## LTX2PipelineOutput
|
||||
|
||||
[[autodoc]] pipelines.ltx2.pipeline_output.LTX2PipelineOutput
|
||||
886
scripts/convert_ltx2_to_diffusers.py
Normal file
886
scripts/convert_ltx2_to_diffusers.py
Normal file
@@ -0,0 +1,886 @@
|
||||
import argparse
|
||||
import os
|
||||
from contextlib import nullcontext
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
import safetensors.torch
|
||||
import torch
|
||||
from accelerate import init_empty_weights
|
||||
from huggingface_hub import hf_hub_download
|
||||
from transformers import AutoTokenizer, Gemma3ForConditionalGeneration
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKLLTX2Audio,
|
||||
AutoencoderKLLTX2Video,
|
||||
FlowMatchEulerDiscreteScheduler,
|
||||
LTX2LatentUpsamplePipeline,
|
||||
LTX2Pipeline,
|
||||
LTX2VideoTransformer3DModel,
|
||||
)
|
||||
from diffusers.pipelines.ltx2 import LTX2LatentUpsamplerModel, LTX2TextConnectors, LTX2Vocoder
|
||||
from diffusers.utils.import_utils import is_accelerate_available
|
||||
|
||||
|
||||
CTX = init_empty_weights if is_accelerate_available() else nullcontext
|
||||
|
||||
|
||||
LTX_2_0_TRANSFORMER_KEYS_RENAME_DICT = {
|
||||
# Input Patchify Projections
|
||||
"patchify_proj": "proj_in",
|
||||
"audio_patchify_proj": "audio_proj_in",
|
||||
# Modulation Parameters
|
||||
# Handle adaln_single --> time_embed, audioln_single --> audio_time_embed separately as the original keys are
|
||||
# substrings of the other modulation parameters below
|
||||
"av_ca_video_scale_shift_adaln_single": "av_cross_attn_video_scale_shift",
|
||||
"av_ca_a2v_gate_adaln_single": "av_cross_attn_video_a2v_gate",
|
||||
"av_ca_audio_scale_shift_adaln_single": "av_cross_attn_audio_scale_shift",
|
||||
"av_ca_v2a_gate_adaln_single": "av_cross_attn_audio_v2a_gate",
|
||||
# Transformer Blocks
|
||||
# Per-Block Cross Attention Modulatin Parameters
|
||||
"scale_shift_table_a2v_ca_video": "video_a2v_cross_attn_scale_shift_table",
|
||||
"scale_shift_table_a2v_ca_audio": "audio_a2v_cross_attn_scale_shift_table",
|
||||
# Attention QK Norms
|
||||
"q_norm": "norm_q",
|
||||
"k_norm": "norm_k",
|
||||
}
|
||||
|
||||
LTX_2_0_VIDEO_VAE_RENAME_DICT = {
|
||||
# Encoder
|
||||
"down_blocks.0": "down_blocks.0",
|
||||
"down_blocks.1": "down_blocks.0.downsamplers.0",
|
||||
"down_blocks.2": "down_blocks.1",
|
||||
"down_blocks.3": "down_blocks.1.downsamplers.0",
|
||||
"down_blocks.4": "down_blocks.2",
|
||||
"down_blocks.5": "down_blocks.2.downsamplers.0",
|
||||
"down_blocks.6": "down_blocks.3",
|
||||
"down_blocks.7": "down_blocks.3.downsamplers.0",
|
||||
"down_blocks.8": "mid_block",
|
||||
# Decoder
|
||||
"up_blocks.0": "mid_block",
|
||||
"up_blocks.1": "up_blocks.0.upsamplers.0",
|
||||
"up_blocks.2": "up_blocks.0",
|
||||
"up_blocks.3": "up_blocks.1.upsamplers.0",
|
||||
"up_blocks.4": "up_blocks.1",
|
||||
"up_blocks.5": "up_blocks.2.upsamplers.0",
|
||||
"up_blocks.6": "up_blocks.2",
|
||||
# Common
|
||||
# For all 3D ResNets
|
||||
"res_blocks": "resnets",
|
||||
"per_channel_statistics.mean-of-means": "latents_mean",
|
||||
"per_channel_statistics.std-of-means": "latents_std",
|
||||
}
|
||||
|
||||
LTX_2_0_AUDIO_VAE_RENAME_DICT = {
|
||||
"per_channel_statistics.mean-of-means": "latents_mean",
|
||||
"per_channel_statistics.std-of-means": "latents_std",
|
||||
}
|
||||
|
||||
LTX_2_0_VOCODER_RENAME_DICT = {
|
||||
"ups": "upsamplers",
|
||||
"resblocks": "resnets",
|
||||
"conv_pre": "conv_in",
|
||||
"conv_post": "conv_out",
|
||||
}
|
||||
|
||||
LTX_2_0_TEXT_ENCODER_RENAME_DICT = {
|
||||
"video_embeddings_connector": "video_connector",
|
||||
"audio_embeddings_connector": "audio_connector",
|
||||
"transformer_1d_blocks": "transformer_blocks",
|
||||
# Attention QK Norms
|
||||
"q_norm": "norm_q",
|
||||
"k_norm": "norm_k",
|
||||
}
|
||||
|
||||
|
||||
def update_state_dict_inplace(state_dict: Dict[str, Any], old_key: str, new_key: str) -> None:
|
||||
state_dict[new_key] = state_dict.pop(old_key)
|
||||
|
||||
|
||||
def remove_keys_inplace(key: str, state_dict: Dict[str, Any]) -> None:
|
||||
state_dict.pop(key)
|
||||
|
||||
|
||||
def convert_ltx2_transformer_adaln_single(key: str, state_dict: Dict[str, Any]) -> None:
|
||||
# Skip if not a weight, bias
|
||||
if ".weight" not in key and ".bias" not in key:
|
||||
return
|
||||
|
||||
if key.startswith("adaln_single."):
|
||||
new_key = key.replace("adaln_single.", "time_embed.")
|
||||
param = state_dict.pop(key)
|
||||
state_dict[new_key] = param
|
||||
|
||||
if key.startswith("audio_adaln_single."):
|
||||
new_key = key.replace("audio_adaln_single.", "audio_time_embed.")
|
||||
param = state_dict.pop(key)
|
||||
state_dict[new_key] = param
|
||||
|
||||
return
|
||||
|
||||
|
||||
def convert_ltx2_audio_vae_per_channel_statistics(key: str, state_dict: Dict[str, Any]) -> None:
|
||||
if key.startswith("per_channel_statistics"):
|
||||
new_key = ".".join(["decoder", key])
|
||||
param = state_dict.pop(key)
|
||||
state_dict[new_key] = param
|
||||
|
||||
return
|
||||
|
||||
|
||||
LTX_2_0_TRANSFORMER_SPECIAL_KEYS_REMAP = {
|
||||
"video_embeddings_connector": remove_keys_inplace,
|
||||
"audio_embeddings_connector": remove_keys_inplace,
|
||||
"adaln_single": convert_ltx2_transformer_adaln_single,
|
||||
}
|
||||
|
||||
LTX_2_0_CONNECTORS_KEYS_RENAME_DICT = {
|
||||
"connectors.": "",
|
||||
"video_embeddings_connector": "video_connector",
|
||||
"audio_embeddings_connector": "audio_connector",
|
||||
"transformer_1d_blocks": "transformer_blocks",
|
||||
"text_embedding_projection.aggregate_embed": "text_proj_in",
|
||||
# Attention QK Norms
|
||||
"q_norm": "norm_q",
|
||||
"k_norm": "norm_k",
|
||||
}
|
||||
|
||||
LTX_2_0_VAE_SPECIAL_KEYS_REMAP = {
|
||||
"per_channel_statistics.channel": remove_keys_inplace,
|
||||
"per_channel_statistics.mean-of-stds": remove_keys_inplace,
|
||||
}
|
||||
|
||||
LTX_2_0_AUDIO_VAE_SPECIAL_KEYS_REMAP = {}
|
||||
|
||||
LTX_2_0_VOCODER_SPECIAL_KEYS_REMAP = {}
|
||||
|
||||
|
||||
def split_transformer_and_connector_state_dict(state_dict: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
||||
connector_prefixes = (
|
||||
"video_embeddings_connector",
|
||||
"audio_embeddings_connector",
|
||||
"transformer_1d_blocks",
|
||||
"text_embedding_projection.aggregate_embed",
|
||||
"connectors.",
|
||||
"video_connector",
|
||||
"audio_connector",
|
||||
"text_proj_in",
|
||||
)
|
||||
|
||||
transformer_state_dict, connector_state_dict = {}, {}
|
||||
for key, value in state_dict.items():
|
||||
if key.startswith(connector_prefixes):
|
||||
connector_state_dict[key] = value
|
||||
else:
|
||||
transformer_state_dict[key] = value
|
||||
|
||||
return transformer_state_dict, connector_state_dict
|
||||
|
||||
|
||||
def get_ltx2_transformer_config(version: str) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
|
||||
if version == "test":
|
||||
# Produces a transformer of the same size as used in test_models_transformer_ltx2.py
|
||||
config = {
|
||||
"model_id": "diffusers-internal-dev/dummy-ltx2",
|
||||
"diffusers_config": {
|
||||
"in_channels": 4,
|
||||
"out_channels": 4,
|
||||
"patch_size": 1,
|
||||
"patch_size_t": 1,
|
||||
"num_attention_heads": 2,
|
||||
"attention_head_dim": 8,
|
||||
"cross_attention_dim": 16,
|
||||
"vae_scale_factors": (8, 32, 32),
|
||||
"pos_embed_max_pos": 20,
|
||||
"base_height": 2048,
|
||||
"base_width": 2048,
|
||||
"audio_in_channels": 4,
|
||||
"audio_out_channels": 4,
|
||||
"audio_patch_size": 1,
|
||||
"audio_patch_size_t": 1,
|
||||
"audio_num_attention_heads": 2,
|
||||
"audio_attention_head_dim": 4,
|
||||
"audio_cross_attention_dim": 8,
|
||||
"audio_scale_factor": 4,
|
||||
"audio_pos_embed_max_pos": 20,
|
||||
"audio_sampling_rate": 16000,
|
||||
"audio_hop_length": 160,
|
||||
"num_layers": 2,
|
||||
"activation_fn": "gelu-approximate",
|
||||
"qk_norm": "rms_norm_across_heads",
|
||||
"norm_elementwise_affine": False,
|
||||
"norm_eps": 1e-6,
|
||||
"caption_channels": 16,
|
||||
"attention_bias": True,
|
||||
"attention_out_bias": True,
|
||||
"rope_theta": 10000.0,
|
||||
"rope_double_precision": False,
|
||||
"causal_offset": 1,
|
||||
"timestep_scale_multiplier": 1000,
|
||||
"cross_attn_timestep_scale_multiplier": 1,
|
||||
},
|
||||
}
|
||||
rename_dict = LTX_2_0_TRANSFORMER_KEYS_RENAME_DICT
|
||||
special_keys_remap = LTX_2_0_TRANSFORMER_SPECIAL_KEYS_REMAP
|
||||
elif version == "2.0":
|
||||
config = {
|
||||
"model_id": "diffusers-internal-dev/new-ltx-model",
|
||||
"diffusers_config": {
|
||||
"in_channels": 128,
|
||||
"out_channels": 128,
|
||||
"patch_size": 1,
|
||||
"patch_size_t": 1,
|
||||
"num_attention_heads": 32,
|
||||
"attention_head_dim": 128,
|
||||
"cross_attention_dim": 4096,
|
||||
"vae_scale_factors": (8, 32, 32),
|
||||
"pos_embed_max_pos": 20,
|
||||
"base_height": 2048,
|
||||
"base_width": 2048,
|
||||
"audio_in_channels": 128,
|
||||
"audio_out_channels": 128,
|
||||
"audio_patch_size": 1,
|
||||
"audio_patch_size_t": 1,
|
||||
"audio_num_attention_heads": 32,
|
||||
"audio_attention_head_dim": 64,
|
||||
"audio_cross_attention_dim": 2048,
|
||||
"audio_scale_factor": 4,
|
||||
"audio_pos_embed_max_pos": 20,
|
||||
"audio_sampling_rate": 16000,
|
||||
"audio_hop_length": 160,
|
||||
"num_layers": 48,
|
||||
"activation_fn": "gelu-approximate",
|
||||
"qk_norm": "rms_norm_across_heads",
|
||||
"norm_elementwise_affine": False,
|
||||
"norm_eps": 1e-6,
|
||||
"caption_channels": 3840,
|
||||
"attention_bias": True,
|
||||
"attention_out_bias": True,
|
||||
"rope_theta": 10000.0,
|
||||
"rope_double_precision": True,
|
||||
"causal_offset": 1,
|
||||
"timestep_scale_multiplier": 1000,
|
||||
"cross_attn_timestep_scale_multiplier": 1000,
|
||||
"rope_type": "split",
|
||||
},
|
||||
}
|
||||
rename_dict = LTX_2_0_TRANSFORMER_KEYS_RENAME_DICT
|
||||
special_keys_remap = LTX_2_0_TRANSFORMER_SPECIAL_KEYS_REMAP
|
||||
return config, rename_dict, special_keys_remap
|
||||
|
||||
|
||||
def get_ltx2_connectors_config(version: str) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
|
||||
if version == "test":
|
||||
config = {
|
||||
"model_id": "diffusers-internal-dev/dummy-ltx2",
|
||||
"diffusers_config": {
|
||||
"caption_channels": 16,
|
||||
"text_proj_in_factor": 3,
|
||||
"video_connector_num_attention_heads": 4,
|
||||
"video_connector_attention_head_dim": 8,
|
||||
"video_connector_num_layers": 1,
|
||||
"video_connector_num_learnable_registers": None,
|
||||
"audio_connector_num_attention_heads": 4,
|
||||
"audio_connector_attention_head_dim": 8,
|
||||
"audio_connector_num_layers": 1,
|
||||
"audio_connector_num_learnable_registers": None,
|
||||
"connector_rope_base_seq_len": 32,
|
||||
"rope_theta": 10000.0,
|
||||
"rope_double_precision": False,
|
||||
"causal_temporal_positioning": False,
|
||||
},
|
||||
}
|
||||
elif version == "2.0":
|
||||
config = {
|
||||
"model_id": "diffusers-internal-dev/new-ltx-model",
|
||||
"diffusers_config": {
|
||||
"caption_channels": 3840,
|
||||
"text_proj_in_factor": 49,
|
||||
"video_connector_num_attention_heads": 30,
|
||||
"video_connector_attention_head_dim": 128,
|
||||
"video_connector_num_layers": 2,
|
||||
"video_connector_num_learnable_registers": 128,
|
||||
"audio_connector_num_attention_heads": 30,
|
||||
"audio_connector_attention_head_dim": 128,
|
||||
"audio_connector_num_layers": 2,
|
||||
"audio_connector_num_learnable_registers": 128,
|
||||
"connector_rope_base_seq_len": 4096,
|
||||
"rope_theta": 10000.0,
|
||||
"rope_double_precision": True,
|
||||
"causal_temporal_positioning": False,
|
||||
"rope_type": "split",
|
||||
},
|
||||
}
|
||||
|
||||
rename_dict = LTX_2_0_CONNECTORS_KEYS_RENAME_DICT
|
||||
special_keys_remap = {}
|
||||
|
||||
return config, rename_dict, special_keys_remap
|
||||
|
||||
|
||||
def convert_ltx2_transformer(original_state_dict: Dict[str, Any], version: str) -> Dict[str, Any]:
|
||||
config, rename_dict, special_keys_remap = get_ltx2_transformer_config(version)
|
||||
diffusers_config = config["diffusers_config"]
|
||||
|
||||
transformer_state_dict, _ = split_transformer_and_connector_state_dict(original_state_dict)
|
||||
|
||||
with init_empty_weights():
|
||||
transformer = LTX2VideoTransformer3DModel.from_config(diffusers_config)
|
||||
|
||||
# Handle official code --> diffusers key remapping via the remap dict
|
||||
for key in list(transformer_state_dict.keys()):
|
||||
new_key = key[:]
|
||||
for replace_key, rename_key in rename_dict.items():
|
||||
new_key = new_key.replace(replace_key, rename_key)
|
||||
update_state_dict_inplace(transformer_state_dict, key, new_key)
|
||||
|
||||
# Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in
|
||||
# special_keys_remap
|
||||
for key in list(transformer_state_dict.keys()):
|
||||
for special_key, handler_fn_inplace in special_keys_remap.items():
|
||||
if special_key not in key:
|
||||
continue
|
||||
handler_fn_inplace(key, transformer_state_dict)
|
||||
|
||||
transformer.load_state_dict(transformer_state_dict, strict=True, assign=True)
|
||||
return transformer
|
||||
|
||||
|
||||
def convert_ltx2_connectors(original_state_dict: Dict[str, Any], version: str) -> LTX2TextConnectors:
|
||||
config, rename_dict, special_keys_remap = get_ltx2_connectors_config(version)
|
||||
diffusers_config = config["diffusers_config"]
|
||||
|
||||
_, connector_state_dict = split_transformer_and_connector_state_dict(original_state_dict)
|
||||
if len(connector_state_dict) == 0:
|
||||
raise ValueError("No connector weights found in the provided state dict.")
|
||||
|
||||
with init_empty_weights():
|
||||
connectors = LTX2TextConnectors.from_config(diffusers_config)
|
||||
|
||||
for key in list(connector_state_dict.keys()):
|
||||
new_key = key[:]
|
||||
for replace_key, rename_key in rename_dict.items():
|
||||
new_key = new_key.replace(replace_key, rename_key)
|
||||
update_state_dict_inplace(connector_state_dict, key, new_key)
|
||||
|
||||
for key in list(connector_state_dict.keys()):
|
||||
for special_key, handler_fn_inplace in special_keys_remap.items():
|
||||
if special_key not in key:
|
||||
continue
|
||||
handler_fn_inplace(key, connector_state_dict)
|
||||
|
||||
connectors.load_state_dict(connector_state_dict, strict=True, assign=True)
|
||||
return connectors
|
||||
|
||||
|
||||
def get_ltx2_video_vae_config(version: str) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
|
||||
if version == "test":
|
||||
config = {
|
||||
"model_id": "diffusers-internal-dev/dummy-ltx2",
|
||||
"diffusers_config": {
|
||||
"in_channels": 3,
|
||||
"out_channels": 3,
|
||||
"latent_channels": 128,
|
||||
"block_out_channels": (256, 512, 1024, 2048),
|
||||
"down_block_types": (
|
||||
"LTX2VideoDownBlock3D",
|
||||
"LTX2VideoDownBlock3D",
|
||||
"LTX2VideoDownBlock3D",
|
||||
"LTX2VideoDownBlock3D",
|
||||
),
|
||||
"decoder_block_out_channels": (256, 512, 1024),
|
||||
"layers_per_block": (4, 6, 6, 2, 2),
|
||||
"decoder_layers_per_block": (5, 5, 5, 5),
|
||||
"spatio_temporal_scaling": (True, True, True, True),
|
||||
"decoder_spatio_temporal_scaling": (True, True, True),
|
||||
"decoder_inject_noise": (False, False, False, False),
|
||||
"downsample_type": ("spatial", "temporal", "spatiotemporal", "spatiotemporal"),
|
||||
"upsample_residual": (True, True, True),
|
||||
"upsample_factor": (2, 2, 2),
|
||||
"timestep_conditioning": False,
|
||||
"patch_size": 4,
|
||||
"patch_size_t": 1,
|
||||
"resnet_norm_eps": 1e-6,
|
||||
"encoder_causal": True,
|
||||
"decoder_causal": False,
|
||||
"encoder_spatial_padding_mode": "zeros",
|
||||
"decoder_spatial_padding_mode": "reflect",
|
||||
"spatial_compression_ratio": 32,
|
||||
"temporal_compression_ratio": 8,
|
||||
},
|
||||
}
|
||||
rename_dict = LTX_2_0_VIDEO_VAE_RENAME_DICT
|
||||
special_keys_remap = LTX_2_0_VAE_SPECIAL_KEYS_REMAP
|
||||
elif version == "2.0":
|
||||
config = {
|
||||
"model_id": "diffusers-internal-dev/dummy-ltx2",
|
||||
"diffusers_config": {
|
||||
"in_channels": 3,
|
||||
"out_channels": 3,
|
||||
"latent_channels": 128,
|
||||
"block_out_channels": (256, 512, 1024, 2048),
|
||||
"down_block_types": (
|
||||
"LTX2VideoDownBlock3D",
|
||||
"LTX2VideoDownBlock3D",
|
||||
"LTX2VideoDownBlock3D",
|
||||
"LTX2VideoDownBlock3D",
|
||||
),
|
||||
"decoder_block_out_channels": (256, 512, 1024),
|
||||
"layers_per_block": (4, 6, 6, 2, 2),
|
||||
"decoder_layers_per_block": (5, 5, 5, 5),
|
||||
"spatio_temporal_scaling": (True, True, True, True),
|
||||
"decoder_spatio_temporal_scaling": (True, True, True),
|
||||
"decoder_inject_noise": (False, False, False, False),
|
||||
"downsample_type": ("spatial", "temporal", "spatiotemporal", "spatiotemporal"),
|
||||
"upsample_residual": (True, True, True),
|
||||
"upsample_factor": (2, 2, 2),
|
||||
"timestep_conditioning": False,
|
||||
"patch_size": 4,
|
||||
"patch_size_t": 1,
|
||||
"resnet_norm_eps": 1e-6,
|
||||
"encoder_causal": True,
|
||||
"decoder_causal": False,
|
||||
"encoder_spatial_padding_mode": "zeros",
|
||||
"decoder_spatial_padding_mode": "reflect",
|
||||
"spatial_compression_ratio": 32,
|
||||
"temporal_compression_ratio": 8,
|
||||
},
|
||||
}
|
||||
rename_dict = LTX_2_0_VIDEO_VAE_RENAME_DICT
|
||||
special_keys_remap = LTX_2_0_VAE_SPECIAL_KEYS_REMAP
|
||||
return config, rename_dict, special_keys_remap
|
||||
|
||||
|
||||
def convert_ltx2_video_vae(original_state_dict: Dict[str, Any], version: str) -> Dict[str, Any]:
|
||||
config, rename_dict, special_keys_remap = get_ltx2_video_vae_config(version)
|
||||
diffusers_config = config["diffusers_config"]
|
||||
|
||||
with init_empty_weights():
|
||||
vae = AutoencoderKLLTX2Video.from_config(diffusers_config)
|
||||
|
||||
# Handle official code --> diffusers key remapping via the remap dict
|
||||
for key in list(original_state_dict.keys()):
|
||||
new_key = key[:]
|
||||
for replace_key, rename_key in rename_dict.items():
|
||||
new_key = new_key.replace(replace_key, rename_key)
|
||||
update_state_dict_inplace(original_state_dict, key, new_key)
|
||||
|
||||
# Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in
|
||||
# special_keys_remap
|
||||
for key in list(original_state_dict.keys()):
|
||||
for special_key, handler_fn_inplace in special_keys_remap.items():
|
||||
if special_key not in key:
|
||||
continue
|
||||
handler_fn_inplace(key, original_state_dict)
|
||||
|
||||
vae.load_state_dict(original_state_dict, strict=True, assign=True)
|
||||
return vae
|
||||
|
||||
|
||||
def get_ltx2_audio_vae_config(version: str) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
|
||||
if version == "2.0":
|
||||
config = {
|
||||
"model_id": "diffusers-internal-dev/new-ltx-model",
|
||||
"diffusers_config": {
|
||||
"base_channels": 128,
|
||||
"output_channels": 2,
|
||||
"ch_mult": (1, 2, 4),
|
||||
"num_res_blocks": 2,
|
||||
"attn_resolutions": None,
|
||||
"in_channels": 2,
|
||||
"resolution": 256,
|
||||
"latent_channels": 8,
|
||||
"norm_type": "pixel",
|
||||
"causality_axis": "height",
|
||||
"dropout": 0.0,
|
||||
"mid_block_add_attention": False,
|
||||
"sample_rate": 16000,
|
||||
"mel_hop_length": 160,
|
||||
"is_causal": True,
|
||||
"mel_bins": 64,
|
||||
"double_z": True,
|
||||
},
|
||||
}
|
||||
rename_dict = LTX_2_0_AUDIO_VAE_RENAME_DICT
|
||||
special_keys_remap = LTX_2_0_AUDIO_VAE_SPECIAL_KEYS_REMAP
|
||||
return config, rename_dict, special_keys_remap
|
||||
|
||||
|
||||
def convert_ltx2_audio_vae(original_state_dict: Dict[str, Any], version: str) -> Dict[str, Any]:
|
||||
config, rename_dict, special_keys_remap = get_ltx2_audio_vae_config(version)
|
||||
diffusers_config = config["diffusers_config"]
|
||||
|
||||
with init_empty_weights():
|
||||
vae = AutoencoderKLLTX2Audio.from_config(diffusers_config)
|
||||
|
||||
# Handle official code --> diffusers key remapping via the remap dict
|
||||
for key in list(original_state_dict.keys()):
|
||||
new_key = key[:]
|
||||
for replace_key, rename_key in rename_dict.items():
|
||||
new_key = new_key.replace(replace_key, rename_key)
|
||||
update_state_dict_inplace(original_state_dict, key, new_key)
|
||||
|
||||
# Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in
|
||||
# special_keys_remap
|
||||
for key in list(original_state_dict.keys()):
|
||||
for special_key, handler_fn_inplace in special_keys_remap.items():
|
||||
if special_key not in key:
|
||||
continue
|
||||
handler_fn_inplace(key, original_state_dict)
|
||||
|
||||
vae.load_state_dict(original_state_dict, strict=True, assign=True)
|
||||
return vae
|
||||
|
||||
|
||||
def get_ltx2_vocoder_config(version: str) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
|
||||
if version == "2.0":
|
||||
config = {
|
||||
"model_id": "diffusers-internal-dev/new-ltx-model",
|
||||
"diffusers_config": {
|
||||
"in_channels": 128,
|
||||
"hidden_channels": 1024,
|
||||
"out_channels": 2,
|
||||
"upsample_kernel_sizes": [16, 15, 8, 4, 4],
|
||||
"upsample_factors": [6, 5, 2, 2, 2],
|
||||
"resnet_kernel_sizes": [3, 7, 11],
|
||||
"resnet_dilations": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
||||
"leaky_relu_negative_slope": 0.1,
|
||||
"output_sampling_rate": 24000,
|
||||
},
|
||||
}
|
||||
rename_dict = LTX_2_0_VOCODER_RENAME_DICT
|
||||
special_keys_remap = LTX_2_0_VOCODER_SPECIAL_KEYS_REMAP
|
||||
return config, rename_dict, special_keys_remap
|
||||
|
||||
|
||||
def convert_ltx2_vocoder(original_state_dict: Dict[str, Any], version: str) -> Dict[str, Any]:
|
||||
config, rename_dict, special_keys_remap = get_ltx2_vocoder_config(version)
|
||||
diffusers_config = config["diffusers_config"]
|
||||
|
||||
with init_empty_weights():
|
||||
vocoder = LTX2Vocoder.from_config(diffusers_config)
|
||||
|
||||
# Handle official code --> diffusers key remapping via the remap dict
|
||||
for key in list(original_state_dict.keys()):
|
||||
new_key = key[:]
|
||||
for replace_key, rename_key in rename_dict.items():
|
||||
new_key = new_key.replace(replace_key, rename_key)
|
||||
update_state_dict_inplace(original_state_dict, key, new_key)
|
||||
|
||||
# Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in
|
||||
# special_keys_remap
|
||||
for key in list(original_state_dict.keys()):
|
||||
for special_key, handler_fn_inplace in special_keys_remap.items():
|
||||
if special_key not in key:
|
||||
continue
|
||||
handler_fn_inplace(key, original_state_dict)
|
||||
|
||||
vocoder.load_state_dict(original_state_dict, strict=True, assign=True)
|
||||
return vocoder
|
||||
|
||||
|
||||
def get_ltx2_spatial_latent_upsampler_config(version: str):
|
||||
if version == "2.0":
|
||||
config = {
|
||||
"in_channels": 128,
|
||||
"mid_channels": 1024,
|
||||
"num_blocks_per_stage": 4,
|
||||
"dims": 3,
|
||||
"spatial_upsample": True,
|
||||
"temporal_upsample": False,
|
||||
"rational_spatial_scale": 2.0,
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"Unsupported version: {version}")
|
||||
return config
|
||||
|
||||
|
||||
def convert_ltx2_spatial_latent_upsampler(
|
||||
original_state_dict: Dict[str, Any], config: Dict[str, Any], dtype: torch.dtype
|
||||
):
|
||||
with init_empty_weights():
|
||||
latent_upsampler = LTX2LatentUpsamplerModel(**config)
|
||||
|
||||
latent_upsampler.load_state_dict(original_state_dict, strict=True, assign=True)
|
||||
latent_upsampler.to(dtype)
|
||||
return latent_upsampler
|
||||
|
||||
|
||||
def load_original_checkpoint(args, filename: Optional[str]) -> Dict[str, Any]:
|
||||
if args.original_state_dict_repo_id is not None:
|
||||
ckpt_path = hf_hub_download(repo_id=args.original_state_dict_repo_id, filename=filename)
|
||||
elif args.checkpoint_path is not None:
|
||||
ckpt_path = args.checkpoint_path
|
||||
else:
|
||||
raise ValueError("Please provide either `original_state_dict_repo_id` or a local `checkpoint_path`")
|
||||
|
||||
original_state_dict = safetensors.torch.load_file(ckpt_path)
|
||||
return original_state_dict
|
||||
|
||||
|
||||
def load_hub_or_local_checkpoint(repo_id: Optional[str] = None, filename: Optional[str] = None) -> Dict[str, Any]:
|
||||
if repo_id is None and filename is None:
|
||||
raise ValueError("Please supply at least one of `repo_id` or `filename`")
|
||||
|
||||
if repo_id is not None:
|
||||
if filename is None:
|
||||
raise ValueError("If repo_id is specified, filename must also be specified.")
|
||||
ckpt_path = hf_hub_download(repo_id=repo_id, filename=filename)
|
||||
else:
|
||||
ckpt_path = filename
|
||||
|
||||
_, ext = os.path.splitext(ckpt_path)
|
||||
if ext in [".safetensors", ".sft"]:
|
||||
state_dict = safetensors.torch.load_file(ckpt_path)
|
||||
else:
|
||||
state_dict = torch.load(ckpt_path, map_location="cpu")
|
||||
|
||||
return state_dict
|
||||
|
||||
|
||||
def get_model_state_dict_from_combined_ckpt(combined_ckpt: Dict[str, Any], prefix: str) -> Dict[str, Any]:
|
||||
# Ensure that the key prefix ends with a dot (.)
|
||||
if not prefix.endswith("."):
|
||||
prefix = prefix + "."
|
||||
|
||||
model_state_dict = {}
|
||||
for param_name, param in combined_ckpt.items():
|
||||
if param_name.startswith(prefix):
|
||||
model_state_dict[param_name.replace(prefix, "")] = param
|
||||
|
||||
if prefix == "model.diffusion_model.":
|
||||
# Some checkpoints store the text connector projection outside the diffusion model prefix.
|
||||
connector_key = "text_embedding_projection.aggregate_embed.weight"
|
||||
if connector_key in combined_ckpt and connector_key not in model_state_dict:
|
||||
model_state_dict[connector_key] = combined_ckpt[connector_key]
|
||||
|
||||
return model_state_dict
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--original_state_dict_repo_id",
|
||||
default="Lightricks/LTX-2",
|
||||
type=str,
|
||||
help="HF Hub repo id with LTX 2.0 checkpoint",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--checkpoint_path",
|
||||
default=None,
|
||||
type=str,
|
||||
help="Local checkpoint path for LTX 2.0. Will be used if `original_state_dict_repo_id` is not specified.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--version",
|
||||
type=str,
|
||||
default="2.0",
|
||||
choices=["test", "2.0"],
|
||||
help="Version of the LTX 2.0 model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--combined_filename",
|
||||
default="ltx-2-19b-dev.safetensors",
|
||||
type=str,
|
||||
help="Filename for combined checkpoint with all LTX 2.0 models (VAE, DiT, etc.)",
|
||||
)
|
||||
parser.add_argument("--vae_prefix", default="vae.", type=str)
|
||||
parser.add_argument("--audio_vae_prefix", default="audio_vae.", type=str)
|
||||
parser.add_argument("--dit_prefix", default="model.diffusion_model.", type=str)
|
||||
parser.add_argument("--vocoder_prefix", default="vocoder.", type=str)
|
||||
|
||||
parser.add_argument("--vae_filename", default=None, type=str, help="VAE filename; overrides combined ckpt if set")
|
||||
parser.add_argument(
|
||||
"--audio_vae_filename", default=None, type=str, help="Audio VAE filename; overrides combined ckpt if set"
|
||||
)
|
||||
parser.add_argument("--dit_filename", default=None, type=str, help="DiT filename; overrides combined ckpt if set")
|
||||
parser.add_argument(
|
||||
"--vocoder_filename", default=None, type=str, help="Vocoder filename; overrides combined ckpt if set"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--text_encoder_model_id",
|
||||
default="google/gemma-3-12b-it-qat-q4_0-unquantized",
|
||||
type=str,
|
||||
help="HF Hub id for the LTX 2.0 base text encoder model",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tokenizer_id",
|
||||
default="google/gemma-3-12b-it-qat-q4_0-unquantized",
|
||||
type=str,
|
||||
help="HF Hub id for the LTX 2.0 text tokenizer",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--latent_upsampler_filename",
|
||||
default="ltx-2-spatial-upscaler-x2-1.0.safetensors",
|
||||
type=str,
|
||||
help="Latent upsampler filename",
|
||||
)
|
||||
|
||||
parser.add_argument("--vae", action="store_true", help="Whether to convert the video VAE model")
|
||||
parser.add_argument("--audio_vae", action="store_true", help="Whether to convert the audio VAE model")
|
||||
parser.add_argument("--dit", action="store_true", help="Whether to convert the DiT model")
|
||||
parser.add_argument("--connectors", action="store_true", help="Whether to convert the connector model")
|
||||
parser.add_argument("--vocoder", action="store_true", help="Whether to convert the vocoder model")
|
||||
parser.add_argument("--text_encoder", action="store_true", help="Whether to conver the text encoder")
|
||||
parser.add_argument("--latent_upsampler", action="store_true", help="Whether to convert the latent upsampler")
|
||||
parser.add_argument(
|
||||
"--full_pipeline",
|
||||
action="store_true",
|
||||
help="Whether to save the pipeline. This will attempt to convert all models (e.g. vae, dit, etc.)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--upsample_pipeline",
|
||||
action="store_true",
|
||||
help="Whether to save a latent upsampling pipeline",
|
||||
)
|
||||
|
||||
parser.add_argument("--vae_dtype", type=str, default="bf16", choices=["fp32", "fp16", "bf16"])
|
||||
parser.add_argument("--audio_vae_dtype", type=str, default="bf16", choices=["fp32", "fp16", "bf16"])
|
||||
parser.add_argument("--dit_dtype", type=str, default="bf16", choices=["fp32", "fp16", "bf16"])
|
||||
parser.add_argument("--vocoder_dtype", type=str, default="bf16", choices=["fp32", "fp16", "bf16"])
|
||||
parser.add_argument("--text_encoder_dtype", type=str, default="bf16", choices=["fp32", "fp16", "bf16"])
|
||||
|
||||
parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
DTYPE_MAPPING = {
|
||||
"fp32": torch.float32,
|
||||
"fp16": torch.float16,
|
||||
"bf16": torch.bfloat16,
|
||||
}
|
||||
|
||||
VARIANT_MAPPING = {
|
||||
"fp32": None,
|
||||
"fp16": "fp16",
|
||||
"bf16": "bf16",
|
||||
}
|
||||
|
||||
|
||||
def main(args):
|
||||
vae_dtype = DTYPE_MAPPING[args.vae_dtype]
|
||||
audio_vae_dtype = DTYPE_MAPPING[args.audio_vae_dtype]
|
||||
dit_dtype = DTYPE_MAPPING[args.dit_dtype]
|
||||
vocoder_dtype = DTYPE_MAPPING[args.vocoder_dtype]
|
||||
text_encoder_dtype = DTYPE_MAPPING[args.text_encoder_dtype]
|
||||
|
||||
combined_ckpt = None
|
||||
load_combined_models = any(
|
||||
[
|
||||
args.vae,
|
||||
args.audio_vae,
|
||||
args.dit,
|
||||
args.vocoder,
|
||||
args.text_encoder,
|
||||
args.full_pipeline,
|
||||
args.upsample_pipeline,
|
||||
]
|
||||
)
|
||||
if args.combined_filename is not None and load_combined_models:
|
||||
combined_ckpt = load_original_checkpoint(args, filename=args.combined_filename)
|
||||
|
||||
if args.vae or args.full_pipeline or args.upsample_pipeline:
|
||||
if args.vae_filename is not None:
|
||||
original_vae_ckpt = load_hub_or_local_checkpoint(filename=args.vae_filename)
|
||||
elif combined_ckpt is not None:
|
||||
original_vae_ckpt = get_model_state_dict_from_combined_ckpt(combined_ckpt, args.vae_prefix)
|
||||
vae = convert_ltx2_video_vae(original_vae_ckpt, version=args.version)
|
||||
if not args.full_pipeline and not args.upsample_pipeline:
|
||||
vae.to(vae_dtype).save_pretrained(os.path.join(args.output_path, "vae"))
|
||||
|
||||
if args.audio_vae or args.full_pipeline:
|
||||
if args.audio_vae_filename is not None:
|
||||
original_audio_vae_ckpt = load_hub_or_local_checkpoint(filename=args.audio_vae_filename)
|
||||
elif combined_ckpt is not None:
|
||||
original_audio_vae_ckpt = get_model_state_dict_from_combined_ckpt(combined_ckpt, args.audio_vae_prefix)
|
||||
audio_vae = convert_ltx2_audio_vae(original_audio_vae_ckpt, version=args.version)
|
||||
if not args.full_pipeline:
|
||||
audio_vae.to(audio_vae_dtype).save_pretrained(os.path.join(args.output_path, "audio_vae"))
|
||||
|
||||
if args.dit or args.full_pipeline:
|
||||
if args.dit_filename is not None:
|
||||
original_dit_ckpt = load_hub_or_local_checkpoint(filename=args.dit_filename)
|
||||
elif combined_ckpt is not None:
|
||||
original_dit_ckpt = get_model_state_dict_from_combined_ckpt(combined_ckpt, args.dit_prefix)
|
||||
transformer = convert_ltx2_transformer(original_dit_ckpt, version=args.version)
|
||||
if not args.full_pipeline:
|
||||
transformer.to(dit_dtype).save_pretrained(os.path.join(args.output_path, "transformer"))
|
||||
|
||||
if args.connectors or args.full_pipeline:
|
||||
if args.dit_filename is not None:
|
||||
original_connectors_ckpt = load_hub_or_local_checkpoint(filename=args.dit_filename)
|
||||
elif combined_ckpt is not None:
|
||||
original_connectors_ckpt = get_model_state_dict_from_combined_ckpt(combined_ckpt, args.dit_prefix)
|
||||
connectors = convert_ltx2_connectors(original_connectors_ckpt, version=args.version)
|
||||
if not args.full_pipeline:
|
||||
connectors.to(dit_dtype).save_pretrained(os.path.join(args.output_path, "connectors"))
|
||||
|
||||
if args.vocoder or args.full_pipeline:
|
||||
if args.vocoder_filename is not None:
|
||||
original_vocoder_ckpt = load_hub_or_local_checkpoint(filename=args.vocoder_filename)
|
||||
elif combined_ckpt is not None:
|
||||
original_vocoder_ckpt = get_model_state_dict_from_combined_ckpt(combined_ckpt, args.vocoder_prefix)
|
||||
vocoder = convert_ltx2_vocoder(original_vocoder_ckpt, version=args.version)
|
||||
if not args.full_pipeline:
|
||||
vocoder.to(vocoder_dtype).save_pretrained(os.path.join(args.output_path, "vocoder"))
|
||||
|
||||
if args.text_encoder or args.full_pipeline:
|
||||
# text_encoder = AutoModel.from_pretrained(args.text_encoder_model_id)
|
||||
text_encoder = Gemma3ForConditionalGeneration.from_pretrained(args.text_encoder_model_id)
|
||||
if not args.full_pipeline:
|
||||
text_encoder.to(text_encoder_dtype).save_pretrained(os.path.join(args.output_path, "text_encoder"))
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_id)
|
||||
if not args.full_pipeline:
|
||||
tokenizer.save_pretrained(os.path.join(args.output_path, "tokenizer"))
|
||||
|
||||
if args.latent_upsampler or args.full_pipeline or args.upsample_pipeline:
|
||||
original_latent_upsampler_ckpt = load_hub_or_local_checkpoint(
|
||||
repo_id=args.original_state_dict_repo_id, filename=args.latent_upsampler_filename
|
||||
)
|
||||
latent_upsampler_config = get_ltx2_spatial_latent_upsampler_config(args.version)
|
||||
latent_upsampler = convert_ltx2_spatial_latent_upsampler(
|
||||
original_latent_upsampler_ckpt,
|
||||
latent_upsampler_config,
|
||||
dtype=vae_dtype,
|
||||
)
|
||||
if not args.full_pipeline and not args.upsample_pipeline:
|
||||
latent_upsampler.save_pretrained(os.path.join(args.output_path, "latent_upsampler"))
|
||||
|
||||
if args.full_pipeline:
|
||||
scheduler = FlowMatchEulerDiscreteScheduler(
|
||||
use_dynamic_shifting=True,
|
||||
base_shift=0.95,
|
||||
max_shift=2.05,
|
||||
base_image_seq_len=1024,
|
||||
max_image_seq_len=4096,
|
||||
shift_terminal=0.1,
|
||||
)
|
||||
|
||||
pipe = LTX2Pipeline(
|
||||
scheduler=scheduler,
|
||||
vae=vae,
|
||||
audio_vae=audio_vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
connectors=connectors,
|
||||
transformer=transformer,
|
||||
vocoder=vocoder,
|
||||
)
|
||||
|
||||
pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
|
||||
|
||||
if args.upsample_pipeline:
|
||||
pipe = LTX2LatentUpsamplePipeline(vae=vae, latent_upsampler=latent_upsampler)
|
||||
|
||||
# Put latent upsampling pipeline in its own subdirectory so it doesn't mess with the full pipeline
|
||||
pipe.save_pretrained(
|
||||
os.path.join(args.output_path, "upsample_pipeline"), safe_serialization=True, max_shard_size="5GB"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = get_args()
|
||||
main(args)
|
||||
@@ -193,6 +193,8 @@ else:
|
||||
"AutoencoderKLHunyuanImageRefiner",
|
||||
"AutoencoderKLHunyuanVideo",
|
||||
"AutoencoderKLHunyuanVideo15",
|
||||
"AutoencoderKLLTX2Audio",
|
||||
"AutoencoderKLLTX2Video",
|
||||
"AutoencoderKLLTXVideo",
|
||||
"AutoencoderKLMagvit",
|
||||
"AutoencoderKLMochi",
|
||||
@@ -236,6 +238,7 @@ else:
|
||||
"Kandinsky5Transformer3DModel",
|
||||
"LatteTransformer3DModel",
|
||||
"LongCatImageTransformer2DModel",
|
||||
"LTX2VideoTransformer3DModel",
|
||||
"LTXVideoTransformer3DModel",
|
||||
"Lumina2Transformer2DModel",
|
||||
"LuminaNextDiT2DModel",
|
||||
@@ -538,6 +541,9 @@ else:
|
||||
"LEditsPPPipelineStableDiffusionXL",
|
||||
"LongCatImageEditPipeline",
|
||||
"LongCatImagePipeline",
|
||||
"LTX2ImageToVideoPipeline",
|
||||
"LTX2LatentUpsamplePipeline",
|
||||
"LTX2Pipeline",
|
||||
"LTXConditionPipeline",
|
||||
"LTXI2VLongMultiPromptPipeline",
|
||||
"LTXImageToVideoPipeline",
|
||||
@@ -939,6 +945,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
AutoencoderKLHunyuanImageRefiner,
|
||||
AutoencoderKLHunyuanVideo,
|
||||
AutoencoderKLHunyuanVideo15,
|
||||
AutoencoderKLLTX2Audio,
|
||||
AutoencoderKLLTX2Video,
|
||||
AutoencoderKLLTXVideo,
|
||||
AutoencoderKLMagvit,
|
||||
AutoencoderKLMochi,
|
||||
@@ -982,6 +990,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
Kandinsky5Transformer3DModel,
|
||||
LatteTransformer3DModel,
|
||||
LongCatImageTransformer2DModel,
|
||||
LTX2VideoTransformer3DModel,
|
||||
LTXVideoTransformer3DModel,
|
||||
Lumina2Transformer2DModel,
|
||||
LuminaNextDiT2DModel,
|
||||
@@ -1254,6 +1263,9 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
LEditsPPPipelineStableDiffusionXL,
|
||||
LongCatImageEditPipeline,
|
||||
LongCatImagePipeline,
|
||||
LTX2ImageToVideoPipeline,
|
||||
LTX2LatentUpsamplePipeline,
|
||||
LTX2Pipeline,
|
||||
LTXConditionPipeline,
|
||||
LTXI2VLongMultiPromptPipeline,
|
||||
LTXImageToVideoPipeline,
|
||||
|
||||
@@ -41,6 +41,8 @@ if is_torch_available():
|
||||
_import_structure["autoencoders.autoencoder_kl_hunyuanimage_refiner"] = ["AutoencoderKLHunyuanImageRefiner"]
|
||||
_import_structure["autoencoders.autoencoder_kl_hunyuanvideo15"] = ["AutoencoderKLHunyuanVideo15"]
|
||||
_import_structure["autoencoders.autoencoder_kl_ltx"] = ["AutoencoderKLLTXVideo"]
|
||||
_import_structure["autoencoders.autoencoder_kl_ltx2"] = ["AutoencoderKLLTX2Video"]
|
||||
_import_structure["autoencoders.autoencoder_kl_ltx2_audio"] = ["AutoencoderKLLTX2Audio"]
|
||||
_import_structure["autoencoders.autoencoder_kl_magvit"] = ["AutoencoderKLMagvit"]
|
||||
_import_structure["autoencoders.autoencoder_kl_mochi"] = ["AutoencoderKLMochi"]
|
||||
_import_structure["autoencoders.autoencoder_kl_qwenimage"] = ["AutoencoderKLQwenImage"]
|
||||
@@ -104,6 +106,7 @@ if is_torch_available():
|
||||
_import_structure["transformers.transformer_kandinsky"] = ["Kandinsky5Transformer3DModel"]
|
||||
_import_structure["transformers.transformer_longcat_image"] = ["LongCatImageTransformer2DModel"]
|
||||
_import_structure["transformers.transformer_ltx"] = ["LTXVideoTransformer3DModel"]
|
||||
_import_structure["transformers.transformer_ltx2"] = ["LTX2VideoTransformer3DModel"]
|
||||
_import_structure["transformers.transformer_lumina2"] = ["Lumina2Transformer2DModel"]
|
||||
_import_structure["transformers.transformer_mochi"] = ["MochiTransformer3DModel"]
|
||||
_import_structure["transformers.transformer_omnigen"] = ["OmniGenTransformer2DModel"]
|
||||
@@ -153,6 +156,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
AutoencoderKLHunyuanImageRefiner,
|
||||
AutoencoderKLHunyuanVideo,
|
||||
AutoencoderKLHunyuanVideo15,
|
||||
AutoencoderKLLTX2Audio,
|
||||
AutoencoderKLLTX2Video,
|
||||
AutoencoderKLLTXVideo,
|
||||
AutoencoderKLMagvit,
|
||||
AutoencoderKLMochi,
|
||||
@@ -212,6 +217,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
Kandinsky5Transformer3DModel,
|
||||
LatteTransformer3DModel,
|
||||
LongCatImageTransformer2DModel,
|
||||
LTX2VideoTransformer3DModel,
|
||||
LTXVideoTransformer3DModel,
|
||||
Lumina2Transformer2DModel,
|
||||
LuminaNextDiT2DModel,
|
||||
|
||||
@@ -10,6 +10,8 @@ from .autoencoder_kl_hunyuanimage import AutoencoderKLHunyuanImage
|
||||
from .autoencoder_kl_hunyuanimage_refiner import AutoencoderKLHunyuanImageRefiner
|
||||
from .autoencoder_kl_hunyuanvideo15 import AutoencoderKLHunyuanVideo15
|
||||
from .autoencoder_kl_ltx import AutoencoderKLLTXVideo
|
||||
from .autoencoder_kl_ltx2 import AutoencoderKLLTX2Video
|
||||
from .autoencoder_kl_ltx2_audio import AutoencoderKLLTX2Audio
|
||||
from .autoencoder_kl_magvit import AutoencoderKLMagvit
|
||||
from .autoencoder_kl_mochi import AutoencoderKLMochi
|
||||
from .autoencoder_kl_qwenimage import AutoencoderKLQwenImage
|
||||
|
||||
1521
src/diffusers/models/autoencoders/autoencoder_kl_ltx2.py
Normal file
1521
src/diffusers/models/autoencoders/autoencoder_kl_ltx2.py
Normal file
File diff suppressed because it is too large
Load Diff
804
src/diffusers/models/autoencoders/autoencoder_kl_ltx2_audio.py
Normal file
804
src/diffusers/models/autoencoders/autoencoder_kl_ltx2_audio.py
Normal file
@@ -0,0 +1,804 @@
|
||||
# Copyright 2025 The Lightricks team and The HuggingFace Team.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...utils.accelerate_utils import apply_forward_hook
|
||||
from ..modeling_outputs import AutoencoderKLOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution
|
||||
|
||||
|
||||
LATENT_DOWNSAMPLE_FACTOR = 4
|
||||
|
||||
|
||||
class LTX2AudioCausalConv2d(nn.Module):
|
||||
"""
|
||||
A causal 2D convolution that pads asymmetrically along the causal axis.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: Union[int, Tuple[int, int]],
|
||||
stride: int = 1,
|
||||
dilation: Union[int, Tuple[int, int]] = 1,
|
||||
groups: int = 1,
|
||||
bias: bool = True,
|
||||
causality_axis: str = "height",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.causality_axis = causality_axis
|
||||
kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size
|
||||
dilation = (dilation, dilation) if isinstance(dilation, int) else dilation
|
||||
|
||||
pad_h = (kernel_size[0] - 1) * dilation[0]
|
||||
pad_w = (kernel_size[1] - 1) * dilation[1]
|
||||
|
||||
if self.causality_axis == "none":
|
||||
padding = (pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2)
|
||||
elif self.causality_axis in {"width", "width-compatibility"}:
|
||||
padding = (pad_w, 0, pad_h // 2, pad_h - pad_h // 2)
|
||||
elif self.causality_axis == "height":
|
||||
padding = (pad_w // 2, pad_w - pad_w // 2, pad_h, 0)
|
||||
else:
|
||||
raise ValueError(f"Invalid causality_axis: {causality_axis}")
|
||||
|
||||
self.padding = padding
|
||||
self.conv = nn.Conv2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride=stride,
|
||||
padding=0,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = F.pad(x, self.padding)
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class LTX2AudioPixelNorm(nn.Module):
|
||||
"""
|
||||
Per-pixel (per-location) RMS normalization layer.
|
||||
"""
|
||||
|
||||
def __init__(self, dim: int = 1, eps: float = 1e-8) -> None:
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.eps = eps
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
mean_sq = torch.mean(x**2, dim=self.dim, keepdim=True)
|
||||
rms = torch.sqrt(mean_sq + self.eps)
|
||||
return x / rms
|
||||
|
||||
|
||||
class LTX2AudioAttnBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
norm_type: str = "group",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
|
||||
if norm_type == "group":
|
||||
self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
elif norm_type == "pixel":
|
||||
self.norm = LTX2AudioPixelNorm(dim=1, eps=1e-6)
|
||||
else:
|
||||
raise ValueError(f"Invalid normalization type: {norm_type}")
|
||||
self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
h_ = self.norm(x)
|
||||
q = self.q(h_)
|
||||
k = self.k(h_)
|
||||
v = self.v(h_)
|
||||
|
||||
batch, channels, height, width = q.shape
|
||||
q = q.reshape(batch, channels, height * width).permute(0, 2, 1).contiguous()
|
||||
k = k.reshape(batch, channels, height * width).contiguous()
|
||||
attn = torch.bmm(q, k) * (int(channels) ** (-0.5))
|
||||
attn = torch.nn.functional.softmax(attn, dim=2)
|
||||
|
||||
v = v.reshape(batch, channels, height * width)
|
||||
attn = attn.permute(0, 2, 1).contiguous()
|
||||
h_ = torch.bmm(v, attn).reshape(batch, channels, height, width)
|
||||
|
||||
h_ = self.proj_out(h_)
|
||||
return x + h_
|
||||
|
||||
|
||||
class LTX2AudioResnetBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: Optional[int] = None,
|
||||
conv_shortcut: bool = False,
|
||||
dropout: float = 0.0,
|
||||
temb_channels: int = 512,
|
||||
norm_type: str = "group",
|
||||
causality_axis: str = "height",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.causality_axis = causality_axis
|
||||
|
||||
if self.causality_axis is not None and self.causality_axis != "none" and norm_type == "group":
|
||||
raise ValueError("Causal ResnetBlock with GroupNorm is not supported.")
|
||||
self.in_channels = in_channels
|
||||
out_channels = in_channels if out_channels is None else out_channels
|
||||
self.out_channels = out_channels
|
||||
self.use_conv_shortcut = conv_shortcut
|
||||
|
||||
if norm_type == "group":
|
||||
self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
elif norm_type == "pixel":
|
||||
self.norm1 = LTX2AudioPixelNorm(dim=1, eps=1e-6)
|
||||
else:
|
||||
raise ValueError(f"Invalid normalization type: {norm_type}")
|
||||
self.non_linearity = nn.SiLU()
|
||||
if causality_axis is not None:
|
||||
self.conv1 = LTX2AudioCausalConv2d(
|
||||
in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis
|
||||
)
|
||||
else:
|
||||
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
if temb_channels > 0:
|
||||
self.temb_proj = nn.Linear(temb_channels, out_channels)
|
||||
if norm_type == "group":
|
||||
self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True)
|
||||
elif norm_type == "pixel":
|
||||
self.norm2 = LTX2AudioPixelNorm(dim=1, eps=1e-6)
|
||||
else:
|
||||
raise ValueError(f"Invalid normalization type: {norm_type}")
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
if causality_axis is not None:
|
||||
self.conv2 = LTX2AudioCausalConv2d(
|
||||
out_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis
|
||||
)
|
||||
else:
|
||||
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
if self.in_channels != self.out_channels:
|
||||
if self.use_conv_shortcut:
|
||||
if causality_axis is not None:
|
||||
self.conv_shortcut = LTX2AudioCausalConv2d(
|
||||
in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis
|
||||
)
|
||||
else:
|
||||
self.conv_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
else:
|
||||
if causality_axis is not None:
|
||||
self.nin_shortcut = LTX2AudioCausalConv2d(
|
||||
in_channels, out_channels, kernel_size=1, stride=1, causality_axis=causality_axis
|
||||
)
|
||||
else:
|
||||
self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
||||
|
||||
def forward(self, x: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
h = self.norm1(x)
|
||||
h = self.non_linearity(h)
|
||||
h = self.conv1(h)
|
||||
|
||||
if temb is not None:
|
||||
h = h + self.temb_proj(self.non_linearity(temb))[:, :, None, None]
|
||||
|
||||
h = self.norm2(h)
|
||||
h = self.non_linearity(h)
|
||||
h = self.dropout(h)
|
||||
h = self.conv2(h)
|
||||
|
||||
if self.in_channels != self.out_channels:
|
||||
x = self.conv_shortcut(x) if self.use_conv_shortcut else self.nin_shortcut(x)
|
||||
|
||||
return x + h
|
||||
|
||||
|
||||
class LTX2AudioDownsample(nn.Module):
|
||||
def __init__(self, in_channels: int, with_conv: bool, causality_axis: Optional[str] = "height") -> None:
|
||||
super().__init__()
|
||||
self.with_conv = with_conv
|
||||
self.causality_axis = causality_axis
|
||||
|
||||
if self.with_conv:
|
||||
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
if self.with_conv:
|
||||
# Padding tuple is in the order: (left, right, top, bottom).
|
||||
if self.causality_axis == "none":
|
||||
pad = (0, 1, 0, 1)
|
||||
elif self.causality_axis == "width":
|
||||
pad = (2, 0, 0, 1)
|
||||
elif self.causality_axis == "height":
|
||||
pad = (0, 1, 2, 0)
|
||||
elif self.causality_axis == "width-compatibility":
|
||||
pad = (1, 0, 0, 1)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid `causality_axis` {self.causality_axis}; supported values are `none`, `width`, `height`,"
|
||||
f" and `width-compatibility`."
|
||||
)
|
||||
|
||||
x = F.pad(x, pad, mode="constant", value=0)
|
||||
x = self.conv(x)
|
||||
else:
|
||||
# with_conv=False implies that causality_axis is "none"
|
||||
x = F.avg_pool2d(x, kernel_size=2, stride=2)
|
||||
return x
|
||||
|
||||
|
||||
class LTX2AudioUpsample(nn.Module):
|
||||
def __init__(self, in_channels: int, with_conv: bool, causality_axis: Optional[str] = "height") -> None:
|
||||
super().__init__()
|
||||
self.with_conv = with_conv
|
||||
self.causality_axis = causality_axis
|
||||
if self.with_conv:
|
||||
if causality_axis is not None:
|
||||
self.conv = LTX2AudioCausalConv2d(
|
||||
in_channels, in_channels, kernel_size=3, stride=1, causality_axis=causality_axis
|
||||
)
|
||||
else:
|
||||
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
||||
if self.with_conv:
|
||||
x = self.conv(x)
|
||||
if self.causality_axis is None or self.causality_axis == "none":
|
||||
pass
|
||||
elif self.causality_axis == "height":
|
||||
x = x[:, :, 1:, :]
|
||||
elif self.causality_axis == "width":
|
||||
x = x[:, :, :, 1:]
|
||||
elif self.causality_axis == "width-compatibility":
|
||||
pass
|
||||
else:
|
||||
raise ValueError(f"Invalid causality_axis: {self.causality_axis}")
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class LTX2AudioAudioPatchifier:
|
||||
"""
|
||||
Patchifier for spectrogram/audio latents.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
patch_size: int,
|
||||
sample_rate: int = 16000,
|
||||
hop_length: int = 160,
|
||||
audio_latent_downsample_factor: int = 4,
|
||||
is_causal: bool = True,
|
||||
):
|
||||
self.hop_length = hop_length
|
||||
self.sample_rate = sample_rate
|
||||
self.audio_latent_downsample_factor = audio_latent_downsample_factor
|
||||
self.is_causal = is_causal
|
||||
self._patch_size = (1, patch_size, patch_size)
|
||||
|
||||
def patchify(self, audio_latents: torch.Tensor) -> torch.Tensor:
|
||||
batch, channels, time, freq = audio_latents.shape
|
||||
return audio_latents.permute(0, 2, 1, 3).reshape(batch, time, channels * freq)
|
||||
|
||||
def unpatchify(self, audio_latents: torch.Tensor, channels: int, mel_bins: int) -> torch.Tensor:
|
||||
batch, time, _ = audio_latents.shape
|
||||
return audio_latents.view(batch, time, channels, mel_bins).permute(0, 2, 1, 3)
|
||||
|
||||
@property
|
||||
def patch_size(self) -> Tuple[int, int, int]:
|
||||
return self._patch_size
|
||||
|
||||
|
||||
class LTX2AudioEncoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
base_channels: int = 128,
|
||||
output_channels: int = 1,
|
||||
num_res_blocks: int = 2,
|
||||
attn_resolutions: Optional[Tuple[int, ...]] = None,
|
||||
in_channels: int = 2,
|
||||
resolution: int = 256,
|
||||
latent_channels: int = 8,
|
||||
ch_mult: Tuple[int, ...] = (1, 2, 4),
|
||||
norm_type: str = "group",
|
||||
causality_axis: Optional[str] = "width",
|
||||
dropout: float = 0.0,
|
||||
mid_block_add_attention: bool = False,
|
||||
sample_rate: int = 16000,
|
||||
mel_hop_length: int = 160,
|
||||
is_causal: bool = True,
|
||||
mel_bins: Optional[int] = 64,
|
||||
double_z: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.sample_rate = sample_rate
|
||||
self.mel_hop_length = mel_hop_length
|
||||
self.is_causal = is_causal
|
||||
self.mel_bins = mel_bins
|
||||
|
||||
self.base_channels = base_channels
|
||||
self.temb_ch = 0
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.resolution = resolution
|
||||
self.in_channels = in_channels
|
||||
self.out_ch = output_channels
|
||||
self.give_pre_end = False
|
||||
self.tanh_out = False
|
||||
self.norm_type = norm_type
|
||||
self.latent_channels = latent_channels
|
||||
self.channel_multipliers = ch_mult
|
||||
self.attn_resolutions = attn_resolutions
|
||||
self.causality_axis = causality_axis
|
||||
|
||||
base_block_channels = base_channels
|
||||
base_resolution = resolution
|
||||
self.z_shape = (1, latent_channels, base_resolution, base_resolution)
|
||||
|
||||
if self.causality_axis is not None:
|
||||
self.conv_in = LTX2AudioCausalConv2d(
|
||||
in_channels, base_block_channels, kernel_size=3, stride=1, causality_axis=self.causality_axis
|
||||
)
|
||||
else:
|
||||
self.conv_in = nn.Conv2d(in_channels, base_block_channels, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
self.down = nn.ModuleList()
|
||||
block_in = base_block_channels
|
||||
curr_res = self.resolution
|
||||
|
||||
for level in range(self.num_resolutions):
|
||||
stage = nn.Module()
|
||||
stage.block = nn.ModuleList()
|
||||
stage.attn = nn.ModuleList()
|
||||
block_out = self.base_channels * self.channel_multipliers[level]
|
||||
|
||||
for _ in range(self.num_res_blocks):
|
||||
stage.block.append(
|
||||
LTX2AudioResnetBlock(
|
||||
in_channels=block_in,
|
||||
out_channels=block_out,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
norm_type=self.norm_type,
|
||||
causality_axis=self.causality_axis,
|
||||
)
|
||||
)
|
||||
block_in = block_out
|
||||
if self.attn_resolutions:
|
||||
if curr_res in self.attn_resolutions:
|
||||
stage.attn.append(LTX2AudioAttnBlock(block_in, norm_type=self.norm_type))
|
||||
|
||||
if level != self.num_resolutions - 1:
|
||||
stage.downsample = LTX2AudioDownsample(block_in, True, causality_axis=self.causality_axis)
|
||||
curr_res = curr_res // 2
|
||||
|
||||
self.down.append(stage)
|
||||
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = LTX2AudioResnetBlock(
|
||||
in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
norm_type=self.norm_type,
|
||||
causality_axis=self.causality_axis,
|
||||
)
|
||||
if mid_block_add_attention:
|
||||
self.mid.attn_1 = LTX2AudioAttnBlock(block_in, norm_type=self.norm_type)
|
||||
else:
|
||||
self.mid.attn_1 = nn.Identity()
|
||||
self.mid.block_2 = LTX2AudioResnetBlock(
|
||||
in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
norm_type=self.norm_type,
|
||||
causality_axis=self.causality_axis,
|
||||
)
|
||||
|
||||
final_block_channels = block_in
|
||||
z_channels = 2 * latent_channels if double_z else latent_channels
|
||||
if self.norm_type == "group":
|
||||
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=final_block_channels, eps=1e-6, affine=True)
|
||||
elif self.norm_type == "pixel":
|
||||
self.norm_out = LTX2AudioPixelNorm(dim=1, eps=1e-6)
|
||||
else:
|
||||
raise ValueError(f"Invalid normalization type: {self.norm_type}")
|
||||
self.non_linearity = nn.SiLU()
|
||||
|
||||
if self.causality_axis is not None:
|
||||
self.conv_out = LTX2AudioCausalConv2d(
|
||||
final_block_channels, z_channels, kernel_size=3, stride=1, causality_axis=self.causality_axis
|
||||
)
|
||||
else:
|
||||
self.conv_out = nn.Conv2d(final_block_channels, z_channels, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
# hidden_states expected shape: (batch_size, channels, time, num_mel_bins)
|
||||
hidden_states = self.conv_in(hidden_states)
|
||||
|
||||
for level in range(self.num_resolutions):
|
||||
stage = self.down[level]
|
||||
for block_idx, block in enumerate(stage.block):
|
||||
hidden_states = block(hidden_states, temb=None)
|
||||
if stage.attn:
|
||||
hidden_states = stage.attn[block_idx](hidden_states)
|
||||
|
||||
if level != self.num_resolutions - 1 and hasattr(stage, "downsample"):
|
||||
hidden_states = stage.downsample(hidden_states)
|
||||
|
||||
hidden_states = self.mid.block_1(hidden_states, temb=None)
|
||||
hidden_states = self.mid.attn_1(hidden_states)
|
||||
hidden_states = self.mid.block_2(hidden_states, temb=None)
|
||||
|
||||
hidden_states = self.norm_out(hidden_states)
|
||||
hidden_states = self.non_linearity(hidden_states)
|
||||
hidden_states = self.conv_out(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class LTX2AudioDecoder(nn.Module):
|
||||
"""
|
||||
Symmetric decoder that reconstructs audio spectrograms from latent features.
|
||||
|
||||
The decoder mirrors the encoder structure with configurable channel multipliers, attention resolutions, and causal
|
||||
convolutions.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_channels: int = 128,
|
||||
output_channels: int = 1,
|
||||
num_res_blocks: int = 2,
|
||||
attn_resolutions: Optional[Tuple[int, ...]] = None,
|
||||
in_channels: int = 2,
|
||||
resolution: int = 256,
|
||||
latent_channels: int = 8,
|
||||
ch_mult: Tuple[int, ...] = (1, 2, 4),
|
||||
norm_type: str = "group",
|
||||
causality_axis: Optional[str] = "width",
|
||||
dropout: float = 0.0,
|
||||
mid_block_add_attention: bool = False,
|
||||
sample_rate: int = 16000,
|
||||
mel_hop_length: int = 160,
|
||||
is_causal: bool = True,
|
||||
mel_bins: Optional[int] = 64,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.sample_rate = sample_rate
|
||||
self.mel_hop_length = mel_hop_length
|
||||
self.is_causal = is_causal
|
||||
self.mel_bins = mel_bins
|
||||
self.patchifier = LTX2AudioAudioPatchifier(
|
||||
patch_size=1,
|
||||
audio_latent_downsample_factor=LATENT_DOWNSAMPLE_FACTOR,
|
||||
sample_rate=sample_rate,
|
||||
hop_length=mel_hop_length,
|
||||
is_causal=is_causal,
|
||||
)
|
||||
|
||||
self.base_channels = base_channels
|
||||
self.temb_ch = 0
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.resolution = resolution
|
||||
self.in_channels = in_channels
|
||||
self.out_ch = output_channels
|
||||
self.give_pre_end = False
|
||||
self.tanh_out = False
|
||||
self.norm_type = norm_type
|
||||
self.latent_channels = latent_channels
|
||||
self.channel_multipliers = ch_mult
|
||||
self.attn_resolutions = attn_resolutions
|
||||
self.causality_axis = causality_axis
|
||||
|
||||
base_block_channels = base_channels * self.channel_multipliers[-1]
|
||||
base_resolution = resolution // (2 ** (self.num_resolutions - 1))
|
||||
self.z_shape = (1, latent_channels, base_resolution, base_resolution)
|
||||
|
||||
if self.causality_axis is not None:
|
||||
self.conv_in = LTX2AudioCausalConv2d(
|
||||
latent_channels, base_block_channels, kernel_size=3, stride=1, causality_axis=self.causality_axis
|
||||
)
|
||||
else:
|
||||
self.conv_in = nn.Conv2d(latent_channels, base_block_channels, kernel_size=3, stride=1, padding=1)
|
||||
self.non_linearity = nn.SiLU()
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = LTX2AudioResnetBlock(
|
||||
in_channels=base_block_channels,
|
||||
out_channels=base_block_channels,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
norm_type=self.norm_type,
|
||||
causality_axis=self.causality_axis,
|
||||
)
|
||||
if mid_block_add_attention:
|
||||
self.mid.attn_1 = LTX2AudioAttnBlock(base_block_channels, norm_type=self.norm_type)
|
||||
else:
|
||||
self.mid.attn_1 = nn.Identity()
|
||||
self.mid.block_2 = LTX2AudioResnetBlock(
|
||||
in_channels=base_block_channels,
|
||||
out_channels=base_block_channels,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
norm_type=self.norm_type,
|
||||
causality_axis=self.causality_axis,
|
||||
)
|
||||
|
||||
self.up = nn.ModuleList()
|
||||
block_in = base_block_channels
|
||||
curr_res = self.resolution // (2 ** (self.num_resolutions - 1))
|
||||
|
||||
for level in reversed(range(self.num_resolutions)):
|
||||
stage = nn.Module()
|
||||
stage.block = nn.ModuleList()
|
||||
stage.attn = nn.ModuleList()
|
||||
block_out = self.base_channels * self.channel_multipliers[level]
|
||||
|
||||
for _ in range(self.num_res_blocks + 1):
|
||||
stage.block.append(
|
||||
LTX2AudioResnetBlock(
|
||||
in_channels=block_in,
|
||||
out_channels=block_out,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
norm_type=self.norm_type,
|
||||
causality_axis=self.causality_axis,
|
||||
)
|
||||
)
|
||||
block_in = block_out
|
||||
if self.attn_resolutions:
|
||||
if curr_res in self.attn_resolutions:
|
||||
stage.attn.append(LTX2AudioAttnBlock(block_in, norm_type=self.norm_type))
|
||||
|
||||
if level != 0:
|
||||
stage.upsample = LTX2AudioUpsample(block_in, True, causality_axis=self.causality_axis)
|
||||
curr_res *= 2
|
||||
|
||||
self.up.insert(0, stage)
|
||||
|
||||
final_block_channels = block_in
|
||||
|
||||
if self.norm_type == "group":
|
||||
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=final_block_channels, eps=1e-6, affine=True)
|
||||
elif self.norm_type == "pixel":
|
||||
self.norm_out = LTX2AudioPixelNorm(dim=1, eps=1e-6)
|
||||
else:
|
||||
raise ValueError(f"Invalid normalization type: {self.norm_type}")
|
||||
|
||||
if self.causality_axis is not None:
|
||||
self.conv_out = LTX2AudioCausalConv2d(
|
||||
final_block_channels, output_channels, kernel_size=3, stride=1, causality_axis=self.causality_axis
|
||||
)
|
||||
else:
|
||||
self.conv_out = nn.Conv2d(final_block_channels, output_channels, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
_, _, frames, mel_bins = sample.shape
|
||||
|
||||
target_frames = frames * LATENT_DOWNSAMPLE_FACTOR
|
||||
|
||||
if self.causality_axis is not None:
|
||||
target_frames = max(target_frames - (LATENT_DOWNSAMPLE_FACTOR - 1), 1)
|
||||
|
||||
target_channels = self.out_ch
|
||||
target_mel_bins = self.mel_bins if self.mel_bins is not None else mel_bins
|
||||
|
||||
hidden_features = self.conv_in(sample)
|
||||
hidden_features = self.mid.block_1(hidden_features, temb=None)
|
||||
hidden_features = self.mid.attn_1(hidden_features)
|
||||
hidden_features = self.mid.block_2(hidden_features, temb=None)
|
||||
|
||||
for level in reversed(range(self.num_resolutions)):
|
||||
stage = self.up[level]
|
||||
for block_idx, block in enumerate(stage.block):
|
||||
hidden_features = block(hidden_features, temb=None)
|
||||
if stage.attn:
|
||||
hidden_features = stage.attn[block_idx](hidden_features)
|
||||
|
||||
if level != 0 and hasattr(stage, "upsample"):
|
||||
hidden_features = stage.upsample(hidden_features)
|
||||
|
||||
if self.give_pre_end:
|
||||
return hidden_features
|
||||
|
||||
hidden = self.norm_out(hidden_features)
|
||||
hidden = self.non_linearity(hidden)
|
||||
decoded_output = self.conv_out(hidden)
|
||||
decoded_output = torch.tanh(decoded_output) if self.tanh_out else decoded_output
|
||||
|
||||
_, _, current_time, current_freq = decoded_output.shape
|
||||
target_time = target_frames
|
||||
target_freq = target_mel_bins
|
||||
|
||||
decoded_output = decoded_output[
|
||||
:, :target_channels, : min(current_time, target_time), : min(current_freq, target_freq)
|
||||
]
|
||||
|
||||
time_padding_needed = target_time - decoded_output.shape[2]
|
||||
freq_padding_needed = target_freq - decoded_output.shape[3]
|
||||
|
||||
if time_padding_needed > 0 or freq_padding_needed > 0:
|
||||
padding = (
|
||||
0,
|
||||
max(freq_padding_needed, 0),
|
||||
0,
|
||||
max(time_padding_needed, 0),
|
||||
)
|
||||
decoded_output = F.pad(decoded_output, padding)
|
||||
|
||||
decoded_output = decoded_output[:, :target_channels, :target_time, :target_freq]
|
||||
|
||||
return decoded_output
|
||||
|
||||
|
||||
class AutoencoderKLLTX2Audio(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
r"""
|
||||
LTX2 audio VAE for encoding and decoding audio latent representations.
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = False
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
base_channels: int = 128,
|
||||
output_channels: int = 2,
|
||||
ch_mult: Tuple[int, ...] = (1, 2, 4),
|
||||
num_res_blocks: int = 2,
|
||||
attn_resolutions: Optional[Tuple[int, ...]] = None,
|
||||
in_channels: int = 2,
|
||||
resolution: int = 256,
|
||||
latent_channels: int = 8,
|
||||
norm_type: str = "pixel",
|
||||
causality_axis: Optional[str] = "height",
|
||||
dropout: float = 0.0,
|
||||
mid_block_add_attention: bool = False,
|
||||
sample_rate: int = 16000,
|
||||
mel_hop_length: int = 160,
|
||||
is_causal: bool = True,
|
||||
mel_bins: Optional[int] = 64,
|
||||
double_z: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
supported_causality_axes = {"none", "width", "height", "width-compatibility"}
|
||||
if causality_axis not in supported_causality_axes:
|
||||
raise ValueError(f"{causality_axis=} is not valid. Supported values: {supported_causality_axes}")
|
||||
|
||||
attn_resolution_set = set(attn_resolutions) if attn_resolutions else attn_resolutions
|
||||
|
||||
self.encoder = LTX2AudioEncoder(
|
||||
base_channels=base_channels,
|
||||
output_channels=output_channels,
|
||||
ch_mult=ch_mult,
|
||||
num_res_blocks=num_res_blocks,
|
||||
attn_resolutions=attn_resolution_set,
|
||||
in_channels=in_channels,
|
||||
resolution=resolution,
|
||||
latent_channels=latent_channels,
|
||||
norm_type=norm_type,
|
||||
causality_axis=causality_axis,
|
||||
dropout=dropout,
|
||||
mid_block_add_attention=mid_block_add_attention,
|
||||
sample_rate=sample_rate,
|
||||
mel_hop_length=mel_hop_length,
|
||||
is_causal=is_causal,
|
||||
mel_bins=mel_bins,
|
||||
double_z=double_z,
|
||||
)
|
||||
|
||||
self.decoder = LTX2AudioDecoder(
|
||||
base_channels=base_channels,
|
||||
output_channels=output_channels,
|
||||
ch_mult=ch_mult,
|
||||
num_res_blocks=num_res_blocks,
|
||||
attn_resolutions=attn_resolution_set,
|
||||
in_channels=in_channels,
|
||||
resolution=resolution,
|
||||
latent_channels=latent_channels,
|
||||
norm_type=norm_type,
|
||||
causality_axis=causality_axis,
|
||||
dropout=dropout,
|
||||
mid_block_add_attention=mid_block_add_attention,
|
||||
sample_rate=sample_rate,
|
||||
mel_hop_length=mel_hop_length,
|
||||
is_causal=is_causal,
|
||||
mel_bins=mel_bins,
|
||||
)
|
||||
|
||||
# Per-channel statistics for normalizing and denormalizing the latent representation. This statics is computed over
|
||||
# the entire dataset and stored in model's checkpoint under AudioVAE state_dict
|
||||
latents_std = torch.zeros((base_channels,))
|
||||
latents_mean = torch.ones((base_channels,))
|
||||
self.register_buffer("latents_mean", latents_mean, persistent=True)
|
||||
self.register_buffer("latents_std", latents_std, persistent=True)
|
||||
|
||||
# TODO: calculate programmatically instead of hardcoding
|
||||
self.temporal_compression_ratio = LATENT_DOWNSAMPLE_FACTOR # 4
|
||||
# TODO: confirm whether the mel compression ratio below is correct
|
||||
self.mel_compression_ratio = LATENT_DOWNSAMPLE_FACTOR
|
||||
self.use_slicing = False
|
||||
|
||||
def _encode(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.encoder(x)
|
||||
|
||||
@apply_forward_hook
|
||||
def encode(self, x: torch.Tensor, return_dict: bool = True):
|
||||
if self.use_slicing and x.shape[0] > 1:
|
||||
encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
|
||||
h = torch.cat(encoded_slices)
|
||||
else:
|
||||
h = self._encode(x)
|
||||
posterior = DiagonalGaussianDistribution(h)
|
||||
|
||||
if not return_dict:
|
||||
return (posterior,)
|
||||
return AutoencoderKLOutput(latent_dist=posterior)
|
||||
|
||||
def _decode(self, z: torch.Tensor) -> torch.Tensor:
|
||||
return self.decoder(z)
|
||||
|
||||
@apply_forward_hook
|
||||
def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
||||
if self.use_slicing and z.shape[0] > 1:
|
||||
decoded_slices = [self._decode(z_slice) for z_slice in z.split(1)]
|
||||
decoded = torch.cat(decoded_slices)
|
||||
else:
|
||||
decoded = self._decode(z)
|
||||
|
||||
if not return_dict:
|
||||
return (decoded,)
|
||||
|
||||
return DecoderOutput(sample=decoded)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.Tensor,
|
||||
sample_posterior: bool = False,
|
||||
return_dict: bool = True,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
) -> Union[DecoderOutput, torch.Tensor]:
|
||||
posterior = self.encode(sample).latent_dist
|
||||
if sample_posterior:
|
||||
z = posterior.sample(generator=generator)
|
||||
else:
|
||||
z = posterior.mode()
|
||||
dec = self.decode(z)
|
||||
if not return_dict:
|
||||
return (dec.sample,)
|
||||
return dec
|
||||
@@ -35,6 +35,7 @@ if is_torch_available():
|
||||
from .transformer_kandinsky import Kandinsky5Transformer3DModel
|
||||
from .transformer_longcat_image import LongCatImageTransformer2DModel
|
||||
from .transformer_ltx import LTXVideoTransformer3DModel
|
||||
from .transformer_ltx2 import LTX2VideoTransformer3DModel
|
||||
from .transformer_lumina2 import Lumina2Transformer2DModel
|
||||
from .transformer_mochi import MochiTransformer3DModel
|
||||
from .transformer_omnigen import OmniGenTransformer2DModel
|
||||
|
||||
1350
src/diffusers/models/transformers/transformer_ltx2.py
Normal file
1350
src/diffusers/models/transformers/transformer_ltx2.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -290,6 +290,7 @@ else:
|
||||
"LTXLatentUpsamplePipeline",
|
||||
"LTXI2VLongMultiPromptPipeline",
|
||||
]
|
||||
_import_structure["ltx2"] = ["LTX2Pipeline", "LTX2ImageToVideoPipeline", "LTX2LatentUpsamplePipeline"]
|
||||
_import_structure["lumina"] = ["LuminaPipeline", "LuminaText2ImgPipeline"]
|
||||
_import_structure["lumina2"] = ["Lumina2Pipeline", "Lumina2Text2ImgPipeline"]
|
||||
_import_structure["lucy"] = ["LucyEditPipeline"]
|
||||
@@ -737,6 +738,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
LTXLatentUpsamplePipeline,
|
||||
LTXPipeline,
|
||||
)
|
||||
from .ltx2 import LTX2ImageToVideoPipeline, LTX2LatentUpsamplePipeline, LTX2Pipeline
|
||||
from .lucy import LucyEditPipeline
|
||||
from .lumina import LuminaPipeline, LuminaText2ImgPipeline
|
||||
from .lumina2 import Lumina2Pipeline, Lumina2Text2ImgPipeline
|
||||
|
||||
58
src/diffusers/pipelines/ltx2/__init__.py
Normal file
58
src/diffusers/pipelines/ltx2/__init__.py
Normal file
@@ -0,0 +1,58 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import (
|
||||
DIFFUSERS_SLOW_IMPORT,
|
||||
OptionalDependencyNotAvailable,
|
||||
_LazyModule,
|
||||
get_objects_from_module,
|
||||
is_torch_available,
|
||||
is_transformers_available,
|
||||
)
|
||||
|
||||
|
||||
_dummy_objects = {}
|
||||
_import_structure = {}
|
||||
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils import dummy_torch_and_transformers_objects # noqa F403
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["connectors"] = ["LTX2TextConnectors"]
|
||||
_import_structure["latent_upsampler"] = ["LTX2LatentUpsamplerModel"]
|
||||
_import_structure["pipeline_ltx2"] = ["LTX2Pipeline"]
|
||||
_import_structure["pipeline_ltx2_image2video"] = ["LTX2ImageToVideoPipeline"]
|
||||
_import_structure["pipeline_ltx2_latent_upsample"] = ["LTX2LatentUpsamplePipeline"]
|
||||
_import_structure["vocoder"] = ["LTX2Vocoder"]
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import *
|
||||
else:
|
||||
from .connectors import LTX2TextConnectors
|
||||
from .latent_upsampler import LTX2LatentUpsamplerModel
|
||||
from .pipeline_ltx2 import LTX2Pipeline
|
||||
from .pipeline_ltx2_image2video import LTX2ImageToVideoPipeline
|
||||
from .pipeline_ltx2_latent_upsample import LTX2LatentUpsamplePipeline
|
||||
from .vocoder import LTX2Vocoder
|
||||
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(
|
||||
__name__,
|
||||
globals()["__file__"],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
)
|
||||
|
||||
for name, value in _dummy_objects.items():
|
||||
setattr(sys.modules[__name__], name, value)
|
||||
325
src/diffusers/pipelines/ltx2/connectors.py
Normal file
325
src/diffusers/pipelines/ltx2/connectors.py
Normal file
@@ -0,0 +1,325 @@
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...models.attention import FeedForward
|
||||
from ...models.modeling_utils import ModelMixin
|
||||
from ...models.transformers.transformer_ltx2 import LTX2Attention, LTX2AudioVideoAttnProcessor
|
||||
|
||||
|
||||
class LTX2RotaryPosEmbed1d(nn.Module):
|
||||
"""
|
||||
1D rotary positional embeddings (RoPE) for the LTX 2.0 text encoder connectors.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
base_seq_len: int = 4096,
|
||||
theta: float = 10000.0,
|
||||
double_precision: bool = True,
|
||||
rope_type: str = "interleaved",
|
||||
num_attention_heads: int = 32,
|
||||
):
|
||||
super().__init__()
|
||||
if rope_type not in ["interleaved", "split"]:
|
||||
raise ValueError(f"{rope_type=} not supported. Choose between 'interleaved' and 'split'.")
|
||||
|
||||
self.dim = dim
|
||||
self.base_seq_len = base_seq_len
|
||||
self.theta = theta
|
||||
self.double_precision = double_precision
|
||||
self.rope_type = rope_type
|
||||
self.num_attention_heads = num_attention_heads
|
||||
|
||||
def forward(
|
||||
self,
|
||||
batch_size: int,
|
||||
pos: int,
|
||||
device: Union[str, torch.device],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# 1. Get 1D position ids
|
||||
grid_1d = torch.arange(pos, dtype=torch.float32, device=device)
|
||||
# Get fractional indices relative to self.base_seq_len
|
||||
grid_1d = grid_1d / self.base_seq_len
|
||||
grid = grid_1d.unsqueeze(0).repeat(batch_size, 1) # [batch_size, seq_len]
|
||||
|
||||
# 2. Calculate 1D RoPE frequencies
|
||||
num_rope_elems = 2 # 1 (because 1D) * 2 (for cos, sin) = 2
|
||||
freqs_dtype = torch.float64 if self.double_precision else torch.float32
|
||||
pow_indices = torch.pow(
|
||||
self.theta,
|
||||
torch.linspace(start=0.0, end=1.0, steps=self.dim // num_rope_elems, dtype=freqs_dtype, device=device),
|
||||
)
|
||||
freqs = (pow_indices * torch.pi / 2.0).to(dtype=torch.float32)
|
||||
|
||||
# 3. Matrix-vector outer product between pos ids of shape (batch_size, seq_len) and freqs vector of shape
|
||||
# (self.dim // 2,).
|
||||
freqs = (grid.unsqueeze(-1) * 2 - 1) * freqs # [B, seq_len, self.dim // 2]
|
||||
|
||||
# 4. Get real, interleaved (cos, sin) frequencies, padded to self.dim
|
||||
if self.rope_type == "interleaved":
|
||||
cos_freqs = freqs.cos().repeat_interleave(2, dim=-1)
|
||||
sin_freqs = freqs.sin().repeat_interleave(2, dim=-1)
|
||||
|
||||
if self.dim % num_rope_elems != 0:
|
||||
cos_padding = torch.ones_like(cos_freqs[:, :, : self.dim % num_rope_elems])
|
||||
sin_padding = torch.zeros_like(sin_freqs[:, :, : self.dim % num_rope_elems])
|
||||
cos_freqs = torch.cat([cos_padding, cos_freqs], dim=-1)
|
||||
sin_freqs = torch.cat([sin_padding, sin_freqs], dim=-1)
|
||||
|
||||
elif self.rope_type == "split":
|
||||
expected_freqs = self.dim // 2
|
||||
current_freqs = freqs.shape[-1]
|
||||
pad_size = expected_freqs - current_freqs
|
||||
cos_freq = freqs.cos()
|
||||
sin_freq = freqs.sin()
|
||||
|
||||
if pad_size != 0:
|
||||
cos_padding = torch.ones_like(cos_freq[:, :, :pad_size])
|
||||
sin_padding = torch.zeros_like(sin_freq[:, :, :pad_size])
|
||||
|
||||
cos_freq = torch.concatenate([cos_padding, cos_freq], axis=-1)
|
||||
sin_freq = torch.concatenate([sin_padding, sin_freq], axis=-1)
|
||||
|
||||
# Reshape freqs to be compatible with multi-head attention
|
||||
b = cos_freq.shape[0]
|
||||
t = cos_freq.shape[1]
|
||||
|
||||
cos_freq = cos_freq.reshape(b, t, self.num_attention_heads, -1)
|
||||
sin_freq = sin_freq.reshape(b, t, self.num_attention_heads, -1)
|
||||
|
||||
cos_freqs = torch.swapaxes(cos_freq, 1, 2) # (B,H,T,D//2)
|
||||
sin_freqs = torch.swapaxes(sin_freq, 1, 2) # (B,H,T,D//2)
|
||||
|
||||
return cos_freqs, sin_freqs
|
||||
|
||||
|
||||
class LTX2TransformerBlock1d(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_attention_heads: int,
|
||||
attention_head_dim: int,
|
||||
activation_fn: str = "gelu-approximate",
|
||||
eps: float = 1e-6,
|
||||
rope_type: str = "interleaved",
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.norm1 = torch.nn.RMSNorm(dim, eps=eps, elementwise_affine=False)
|
||||
self.attn1 = LTX2Attention(
|
||||
query_dim=dim,
|
||||
heads=num_attention_heads,
|
||||
kv_heads=num_attention_heads,
|
||||
dim_head=attention_head_dim,
|
||||
processor=LTX2AudioVideoAttnProcessor(),
|
||||
rope_type=rope_type,
|
||||
)
|
||||
|
||||
self.norm2 = torch.nn.RMSNorm(dim, eps=eps, elementwise_affine=False)
|
||||
self.ff = FeedForward(dim, activation_fn=activation_fn)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
rotary_emb: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
norm_hidden_states = self.norm1(hidden_states)
|
||||
attn_hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, query_rotary_emb=rotary_emb)
|
||||
hidden_states = hidden_states + attn_hidden_states
|
||||
|
||||
norm_hidden_states = self.norm2(hidden_states)
|
||||
ff_hidden_states = self.ff(norm_hidden_states)
|
||||
hidden_states = hidden_states + ff_hidden_states
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class LTX2ConnectorTransformer1d(nn.Module):
|
||||
"""
|
||||
A 1D sequence transformer for modalities such as text.
|
||||
|
||||
In LTX 2.0, this is used to process the text encoder hidden states for each of the video and audio streams.
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_attention_heads: int = 30,
|
||||
attention_head_dim: int = 128,
|
||||
num_layers: int = 2,
|
||||
num_learnable_registers: int | None = 128,
|
||||
rope_base_seq_len: int = 4096,
|
||||
rope_theta: float = 10000.0,
|
||||
rope_double_precision: bool = True,
|
||||
eps: float = 1e-6,
|
||||
causal_temporal_positioning: bool = False,
|
||||
rope_type: str = "interleaved",
|
||||
):
|
||||
super().__init__()
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.inner_dim = num_attention_heads * attention_head_dim
|
||||
self.causal_temporal_positioning = causal_temporal_positioning
|
||||
|
||||
self.num_learnable_registers = num_learnable_registers
|
||||
self.learnable_registers = None
|
||||
if num_learnable_registers is not None:
|
||||
init_registers = torch.rand(num_learnable_registers, self.inner_dim) * 2.0 - 1.0
|
||||
self.learnable_registers = torch.nn.Parameter(init_registers)
|
||||
|
||||
self.rope = LTX2RotaryPosEmbed1d(
|
||||
self.inner_dim,
|
||||
base_seq_len=rope_base_seq_len,
|
||||
theta=rope_theta,
|
||||
double_precision=rope_double_precision,
|
||||
rope_type=rope_type,
|
||||
num_attention_heads=num_attention_heads,
|
||||
)
|
||||
|
||||
self.transformer_blocks = torch.nn.ModuleList(
|
||||
[
|
||||
LTX2TransformerBlock1d(
|
||||
dim=self.inner_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
rope_type=rope_type,
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
self.norm_out = torch.nn.RMSNorm(self.inner_dim, eps=eps, elementwise_affine=False)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
attn_mask_binarize_threshold: float = -9000.0,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# hidden_states shape: [batch_size, seq_len, hidden_dim]
|
||||
# attention_mask shape: [batch_size, seq_len] or [batch_size, 1, 1, seq_len]
|
||||
batch_size, seq_len, _ = hidden_states.shape
|
||||
|
||||
# 1. Replace padding with learned registers, if using
|
||||
if self.learnable_registers is not None:
|
||||
if seq_len % self.num_learnable_registers != 0:
|
||||
raise ValueError(
|
||||
f"The `hidden_states` sequence length {hidden_states.shape[1]} should be divisible by the number"
|
||||
f" of learnable registers {self.num_learnable_registers}"
|
||||
)
|
||||
|
||||
num_register_repeats = seq_len // self.num_learnable_registers
|
||||
registers = torch.tile(self.learnable_registers, (num_register_repeats, 1)) # [seq_len, inner_dim]
|
||||
|
||||
binary_attn_mask = (attention_mask >= attn_mask_binarize_threshold).int()
|
||||
if binary_attn_mask.ndim == 4:
|
||||
binary_attn_mask = binary_attn_mask.squeeze(1).squeeze(1) # [B, 1, 1, L] --> [B, L]
|
||||
|
||||
hidden_states_non_padded = [hidden_states[i, binary_attn_mask[i].bool(), :] for i in range(batch_size)]
|
||||
valid_seq_lens = [x.shape[0] for x in hidden_states_non_padded]
|
||||
pad_lengths = [seq_len - valid_seq_len for valid_seq_len in valid_seq_lens]
|
||||
padded_hidden_states = [
|
||||
F.pad(x, pad=(0, 0, 0, p), value=0) for x, p in zip(hidden_states_non_padded, pad_lengths)
|
||||
]
|
||||
padded_hidden_states = torch.cat([x.unsqueeze(0) for x in padded_hidden_states], dim=0) # [B, L, D]
|
||||
|
||||
flipped_mask = torch.flip(binary_attn_mask, dims=[1]).unsqueeze(-1) # [B, L, 1]
|
||||
hidden_states = flipped_mask * padded_hidden_states + (1 - flipped_mask) * registers
|
||||
|
||||
# Overwrite attention_mask with an all-zeros mask if using registers.
|
||||
attention_mask = torch.zeros_like(attention_mask)
|
||||
|
||||
# 2. Calculate 1D RoPE positional embeddings
|
||||
rotary_emb = self.rope(batch_size, seq_len, device=hidden_states.device)
|
||||
|
||||
# 3. Run 1D transformer blocks
|
||||
for block in self.transformer_blocks:
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
hidden_states = self._gradient_checkpointing_func(block, hidden_states, attention_mask, rotary_emb)
|
||||
else:
|
||||
hidden_states = block(hidden_states, attention_mask=attention_mask, rotary_emb=rotary_emb)
|
||||
|
||||
hidden_states = self.norm_out(hidden_states)
|
||||
|
||||
return hidden_states, attention_mask
|
||||
|
||||
|
||||
class LTX2TextConnectors(ModelMixin, ConfigMixin):
|
||||
"""
|
||||
Text connector stack used by LTX 2.0 to process the packed text encoder hidden states for both the video and audio
|
||||
streams.
|
||||
"""
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
caption_channels: int,
|
||||
text_proj_in_factor: int,
|
||||
video_connector_num_attention_heads: int,
|
||||
video_connector_attention_head_dim: int,
|
||||
video_connector_num_layers: int,
|
||||
video_connector_num_learnable_registers: int | None,
|
||||
audio_connector_num_attention_heads: int,
|
||||
audio_connector_attention_head_dim: int,
|
||||
audio_connector_num_layers: int,
|
||||
audio_connector_num_learnable_registers: int | None,
|
||||
connector_rope_base_seq_len: int,
|
||||
rope_theta: float,
|
||||
rope_double_precision: bool,
|
||||
causal_temporal_positioning: bool,
|
||||
rope_type: str = "interleaved",
|
||||
):
|
||||
super().__init__()
|
||||
self.text_proj_in = nn.Linear(caption_channels * text_proj_in_factor, caption_channels, bias=False)
|
||||
self.video_connector = LTX2ConnectorTransformer1d(
|
||||
num_attention_heads=video_connector_num_attention_heads,
|
||||
attention_head_dim=video_connector_attention_head_dim,
|
||||
num_layers=video_connector_num_layers,
|
||||
num_learnable_registers=video_connector_num_learnable_registers,
|
||||
rope_base_seq_len=connector_rope_base_seq_len,
|
||||
rope_theta=rope_theta,
|
||||
rope_double_precision=rope_double_precision,
|
||||
causal_temporal_positioning=causal_temporal_positioning,
|
||||
rope_type=rope_type,
|
||||
)
|
||||
self.audio_connector = LTX2ConnectorTransformer1d(
|
||||
num_attention_heads=audio_connector_num_attention_heads,
|
||||
attention_head_dim=audio_connector_attention_head_dim,
|
||||
num_layers=audio_connector_num_layers,
|
||||
num_learnable_registers=audio_connector_num_learnable_registers,
|
||||
rope_base_seq_len=connector_rope_base_seq_len,
|
||||
rope_theta=rope_theta,
|
||||
rope_double_precision=rope_double_precision,
|
||||
causal_temporal_positioning=causal_temporal_positioning,
|
||||
rope_type=rope_type,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, text_encoder_hidden_states: torch.Tensor, attention_mask: torch.Tensor, additive_mask: bool = False
|
||||
):
|
||||
# Convert to additive attention mask, if necessary
|
||||
if not additive_mask:
|
||||
text_dtype = text_encoder_hidden_states.dtype
|
||||
attention_mask = (attention_mask - 1).reshape(attention_mask.shape[0], 1, -1, attention_mask.shape[-1])
|
||||
attention_mask = attention_mask.to(text_dtype) * torch.finfo(text_dtype).max
|
||||
|
||||
text_encoder_hidden_states = self.text_proj_in(text_encoder_hidden_states)
|
||||
|
||||
video_text_embedding, new_attn_mask = self.video_connector(text_encoder_hidden_states, attention_mask)
|
||||
|
||||
attn_mask = (new_attn_mask < 1e-6).to(torch.int64)
|
||||
attn_mask = attn_mask.reshape(video_text_embedding.shape[0], video_text_embedding.shape[1], 1)
|
||||
video_text_embedding = video_text_embedding * attn_mask
|
||||
new_attn_mask = attn_mask.squeeze(-1)
|
||||
|
||||
audio_text_embedding, _ = self.audio_connector(text_encoder_hidden_states, attention_mask)
|
||||
|
||||
return video_text_embedding, audio_text_embedding, new_attn_mask
|
||||
134
src/diffusers/pipelines/ltx2/export_utils.py
Normal file
134
src/diffusers/pipelines/ltx2/export_utils.py
Normal file
@@ -0,0 +1,134 @@
|
||||
# Copyright 2025 The Lightricks team and The HuggingFace Team.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from fractions import Fraction
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from ...utils import is_av_available
|
||||
|
||||
|
||||
_CAN_USE_AV = is_av_available()
|
||||
if _CAN_USE_AV:
|
||||
import av
|
||||
else:
|
||||
raise ImportError(
|
||||
"PyAV is required to use LTX 2.0 video export utilities. You can install it with `pip install av`"
|
||||
)
|
||||
|
||||
|
||||
def _prepare_audio_stream(container: av.container.Container, audio_sample_rate: int) -> av.audio.AudioStream:
|
||||
"""
|
||||
Prepare the audio stream for writing.
|
||||
"""
|
||||
audio_stream = container.add_stream("aac", rate=audio_sample_rate)
|
||||
audio_stream.codec_context.sample_rate = audio_sample_rate
|
||||
audio_stream.codec_context.layout = "stereo"
|
||||
audio_stream.codec_context.time_base = Fraction(1, audio_sample_rate)
|
||||
return audio_stream
|
||||
|
||||
|
||||
def _resample_audio(
|
||||
container: av.container.Container, audio_stream: av.audio.AudioStream, frame_in: av.AudioFrame
|
||||
) -> None:
|
||||
cc = audio_stream.codec_context
|
||||
|
||||
# Use the encoder's format/layout/rate as the *target*
|
||||
target_format = cc.format or "fltp" # AAC → usually fltp
|
||||
target_layout = cc.layout or "stereo"
|
||||
target_rate = cc.sample_rate or frame_in.sample_rate
|
||||
|
||||
audio_resampler = av.audio.resampler.AudioResampler(
|
||||
format=target_format,
|
||||
layout=target_layout,
|
||||
rate=target_rate,
|
||||
)
|
||||
|
||||
audio_next_pts = 0
|
||||
for rframe in audio_resampler.resample(frame_in):
|
||||
if rframe.pts is None:
|
||||
rframe.pts = audio_next_pts
|
||||
audio_next_pts += rframe.samples
|
||||
rframe.sample_rate = frame_in.sample_rate
|
||||
container.mux(audio_stream.encode(rframe))
|
||||
|
||||
# flush audio encoder
|
||||
for packet in audio_stream.encode():
|
||||
container.mux(packet)
|
||||
|
||||
|
||||
def _write_audio(
|
||||
container: av.container.Container,
|
||||
audio_stream: av.audio.AudioStream,
|
||||
samples: torch.Tensor,
|
||||
audio_sample_rate: int,
|
||||
) -> None:
|
||||
if samples.ndim == 1:
|
||||
samples = samples[:, None]
|
||||
|
||||
if samples.shape[1] != 2 and samples.shape[0] == 2:
|
||||
samples = samples.T
|
||||
|
||||
if samples.shape[1] != 2:
|
||||
raise ValueError(f"Expected samples with 2 channels; got shape {samples.shape}.")
|
||||
|
||||
# Convert to int16 packed for ingestion; resampler converts to encoder fmt.
|
||||
if samples.dtype != torch.int16:
|
||||
samples = torch.clip(samples, -1.0, 1.0)
|
||||
samples = (samples * 32767.0).to(torch.int16)
|
||||
|
||||
frame_in = av.AudioFrame.from_ndarray(
|
||||
samples.contiguous().reshape(1, -1).cpu().numpy(),
|
||||
format="s16",
|
||||
layout="stereo",
|
||||
)
|
||||
frame_in.sample_rate = audio_sample_rate
|
||||
|
||||
_resample_audio(container, audio_stream, frame_in)
|
||||
|
||||
|
||||
def encode_video(
|
||||
video: torch.Tensor, fps: int, audio: Optional[torch.Tensor], audio_sample_rate: Optional[int], output_path: str
|
||||
) -> None:
|
||||
video_np = video.cpu().numpy()
|
||||
|
||||
_, height, width, _ = video_np.shape
|
||||
|
||||
container = av.open(output_path, mode="w")
|
||||
stream = container.add_stream("libx264", rate=int(fps))
|
||||
stream.width = width
|
||||
stream.height = height
|
||||
stream.pix_fmt = "yuv420p"
|
||||
|
||||
if audio is not None:
|
||||
if audio_sample_rate is None:
|
||||
raise ValueError("audio_sample_rate is required when audio is provided")
|
||||
|
||||
audio_stream = _prepare_audio_stream(container, audio_sample_rate)
|
||||
|
||||
for frame_array in video_np:
|
||||
frame = av.VideoFrame.from_ndarray(frame_array, format="rgb24")
|
||||
for packet in stream.encode(frame):
|
||||
container.mux(packet)
|
||||
|
||||
# Flush encoder
|
||||
for packet in stream.encode():
|
||||
container.mux(packet)
|
||||
|
||||
if audio is not None:
|
||||
_write_audio(container, audio_stream, audio, audio_sample_rate)
|
||||
|
||||
container.close()
|
||||
285
src/diffusers/pipelines/ltx2/latent_upsampler.py
Normal file
285
src/diffusers/pipelines/ltx2/latent_upsampler.py
Normal file
@@ -0,0 +1,285 @@
|
||||
# Copyright 2025 Lightricks and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...models.modeling_utils import ModelMixin
|
||||
|
||||
|
||||
RATIONAL_RESAMPLER_SCALE_MAPPING = {
|
||||
0.75: (3, 4),
|
||||
1.5: (3, 2),
|
||||
2.0: (2, 1),
|
||||
4.0: (4, 1),
|
||||
}
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.ltx.modeling_latent_upsampler.ResBlock
|
||||
class ResBlock(torch.nn.Module):
|
||||
def __init__(self, channels: int, mid_channels: Optional[int] = None, dims: int = 3):
|
||||
super().__init__()
|
||||
if mid_channels is None:
|
||||
mid_channels = channels
|
||||
|
||||
Conv = torch.nn.Conv2d if dims == 2 else torch.nn.Conv3d
|
||||
|
||||
self.conv1 = Conv(channels, mid_channels, kernel_size=3, padding=1)
|
||||
self.norm1 = torch.nn.GroupNorm(32, mid_channels)
|
||||
self.conv2 = Conv(mid_channels, channels, kernel_size=3, padding=1)
|
||||
self.norm2 = torch.nn.GroupNorm(32, channels)
|
||||
self.activation = torch.nn.SiLU()
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
residual = hidden_states
|
||||
hidden_states = self.conv1(hidden_states)
|
||||
hidden_states = self.norm1(hidden_states)
|
||||
hidden_states = self.activation(hidden_states)
|
||||
hidden_states = self.conv2(hidden_states)
|
||||
hidden_states = self.norm2(hidden_states)
|
||||
hidden_states = self.activation(hidden_states + residual)
|
||||
return hidden_states
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.ltx.modeling_latent_upsampler.PixelShuffleND
|
||||
class PixelShuffleND(torch.nn.Module):
|
||||
def __init__(self, dims, upscale_factors=(2, 2, 2)):
|
||||
super().__init__()
|
||||
|
||||
self.dims = dims
|
||||
self.upscale_factors = upscale_factors
|
||||
|
||||
if dims not in [1, 2, 3]:
|
||||
raise ValueError("dims must be 1, 2, or 3")
|
||||
|
||||
def forward(self, x):
|
||||
if self.dims == 3:
|
||||
# spatiotemporal: b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)
|
||||
return (
|
||||
x.unflatten(1, (-1, *self.upscale_factors[:3]))
|
||||
.permute(0, 1, 5, 2, 6, 3, 7, 4)
|
||||
.flatten(6, 7)
|
||||
.flatten(4, 5)
|
||||
.flatten(2, 3)
|
||||
)
|
||||
elif self.dims == 2:
|
||||
# spatial: b (c p1 p2) h w -> b c (h p1) (w p2)
|
||||
return (
|
||||
x.unflatten(1, (-1, *self.upscale_factors[:2])).permute(0, 1, 4, 2, 5, 3).flatten(4, 5).flatten(2, 3)
|
||||
)
|
||||
elif self.dims == 1:
|
||||
# temporal: b (c p1) f h w -> b c (f p1) h w
|
||||
return x.unflatten(1, (-1, *self.upscale_factors[:1])).permute(0, 1, 3, 2, 4, 5).flatten(2, 3)
|
||||
|
||||
|
||||
class BlurDownsample(torch.nn.Module):
|
||||
"""
|
||||
Anti-aliased spatial downsampling by integer stride using a fixed separable binomial kernel. Applies only on H,W.
|
||||
Works for dims=2 or dims=3 (per-frame).
|
||||
"""
|
||||
|
||||
def __init__(self, dims: int, stride: int, kernel_size: int = 5) -> None:
|
||||
super().__init__()
|
||||
|
||||
if dims not in (2, 3):
|
||||
raise ValueError(f"`dims` must be either 2 or 3 but is {dims}")
|
||||
if kernel_size < 3 or kernel_size % 2 != 1:
|
||||
raise ValueError(f"`kernel_size` must be an odd number >= 3 but is {kernel_size}")
|
||||
|
||||
self.dims = dims
|
||||
self.stride = stride
|
||||
self.kernel_size = kernel_size
|
||||
|
||||
# 5x5 separable binomial kernel using binomial coefficients [1, 4, 6, 4, 1] from
|
||||
# the 4th row of Pascal's triangle. This kernel is used for anti-aliasing and
|
||||
# provides a smooth approximation of a Gaussian filter (often called a "binomial filter").
|
||||
# The 2D kernel is constructed as the outer product and normalized.
|
||||
k = torch.tensor([math.comb(kernel_size - 1, k) for k in range(kernel_size)])
|
||||
k2d = k[:, None] @ k[None, :]
|
||||
k2d = (k2d / k2d.sum()).float() # shape (kernel_size, kernel_size)
|
||||
self.register_buffer("kernel", k2d[None, None, :, :]) # (1, 1, kernel_size, kernel_size)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
if self.stride == 1:
|
||||
return x
|
||||
|
||||
if self.dims == 2:
|
||||
c = x.shape[1]
|
||||
weight = self.kernel.expand(c, 1, self.kernel_size, self.kernel_size) # depthwise
|
||||
x = F.conv2d(x, weight=weight, bias=None, stride=self.stride, padding=self.kernel_size // 2, groups=c)
|
||||
else:
|
||||
# dims == 3: apply per-frame on H,W
|
||||
b, c, f, _, _ = x.shape
|
||||
x = x.transpose(1, 2).flatten(0, 1) # [B, C, F, H, W] --> [B * F, C, H, W]
|
||||
|
||||
weight = self.kernel.expand(c, 1, self.kernel_size, self.kernel_size) # depthwise
|
||||
x = F.conv2d(x, weight=weight, bias=None, stride=self.stride, padding=self.kernel_size // 2, groups=c)
|
||||
|
||||
h2, w2 = x.shape[-2:]
|
||||
x = x.unflatten(0, (b, f)).reshape(b, -1, f, h2, w2) # [B * F, C, H, W] --> [B, C, F, H, W]
|
||||
return x
|
||||
|
||||
|
||||
class SpatialRationalResampler(torch.nn.Module):
|
||||
"""
|
||||
Scales by the spatial size of the input by a rational number `scale`. For example, `scale = 0.75` will downsample
|
||||
by a factor of 3 / 4, while `scale = 1.5` will upsample by a factor of 3 / 2. This works by first upsampling the
|
||||
input by the (integer) numerator of `scale`, and then performing a blur + stride anti-aliased downsample by the
|
||||
(integer) denominator.
|
||||
"""
|
||||
|
||||
def __init__(self, mid_channels: int = 1024, scale: float = 2.0):
|
||||
super().__init__()
|
||||
self.scale = float(scale)
|
||||
num_denom = RATIONAL_RESAMPLER_SCALE_MAPPING.get(scale, None)
|
||||
if num_denom is None:
|
||||
raise ValueError(
|
||||
f"The supplied `scale` {scale} is not supported; supported scales are {list(RATIONAL_RESAMPLER_SCALE_MAPPING.keys())}"
|
||||
)
|
||||
self.num, self.den = num_denom
|
||||
|
||||
self.conv = torch.nn.Conv2d(mid_channels, (self.num**2) * mid_channels, kernel_size=3, padding=1)
|
||||
self.pixel_shuffle = PixelShuffleND(2, upscale_factors=(self.num, self.num))
|
||||
self.blur_down = BlurDownsample(dims=2, stride=self.den)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
# Expected x shape: [B * F, C, H, W]
|
||||
# b, _, f, h, w = x.shape
|
||||
# x = x.transpose(1, 2).flatten(0, 1) # [B, C, F, H, W] --> [B * F, C, H, W]
|
||||
x = self.conv(x)
|
||||
x = self.pixel_shuffle(x)
|
||||
x = self.blur_down(x)
|
||||
# x = x.unflatten(0, (b, f)).reshape(b, -1, f, h, w) # [B * F, C, H, W] --> [B, C, F, H, W]
|
||||
return x
|
||||
|
||||
|
||||
class LTX2LatentUpsamplerModel(ModelMixin, ConfigMixin):
|
||||
"""
|
||||
Model to spatially upsample VAE latents.
|
||||
|
||||
Args:
|
||||
in_channels (`int`, defaults to `128`):
|
||||
Number of channels in the input latent
|
||||
mid_channels (`int`, defaults to `512`):
|
||||
Number of channels in the middle layers
|
||||
num_blocks_per_stage (`int`, defaults to `4`):
|
||||
Number of ResBlocks to use in each stage (pre/post upsampling)
|
||||
dims (`int`, defaults to `3`):
|
||||
Number of dimensions for convolutions (2 or 3)
|
||||
spatial_upsample (`bool`, defaults to `True`):
|
||||
Whether to spatially upsample the latent
|
||||
temporal_upsample (`bool`, defaults to `False`):
|
||||
Whether to temporally upsample the latent
|
||||
"""
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 128,
|
||||
mid_channels: int = 1024,
|
||||
num_blocks_per_stage: int = 4,
|
||||
dims: int = 3,
|
||||
spatial_upsample: bool = True,
|
||||
temporal_upsample: bool = False,
|
||||
rational_spatial_scale: Optional[float] = 2.0,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.mid_channels = mid_channels
|
||||
self.num_blocks_per_stage = num_blocks_per_stage
|
||||
self.dims = dims
|
||||
self.spatial_upsample = spatial_upsample
|
||||
self.temporal_upsample = temporal_upsample
|
||||
|
||||
ConvNd = torch.nn.Conv2d if dims == 2 else torch.nn.Conv3d
|
||||
|
||||
self.initial_conv = ConvNd(in_channels, mid_channels, kernel_size=3, padding=1)
|
||||
self.initial_norm = torch.nn.GroupNorm(32, mid_channels)
|
||||
self.initial_activation = torch.nn.SiLU()
|
||||
|
||||
self.res_blocks = torch.nn.ModuleList([ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)])
|
||||
|
||||
if spatial_upsample and temporal_upsample:
|
||||
self.upsampler = torch.nn.Sequential(
|
||||
torch.nn.Conv3d(mid_channels, 8 * mid_channels, kernel_size=3, padding=1),
|
||||
PixelShuffleND(3),
|
||||
)
|
||||
elif spatial_upsample:
|
||||
if rational_spatial_scale is not None:
|
||||
self.upsampler = SpatialRationalResampler(mid_channels=mid_channels, scale=rational_spatial_scale)
|
||||
else:
|
||||
self.upsampler = torch.nn.Sequential(
|
||||
torch.nn.Conv2d(mid_channels, 4 * mid_channels, kernel_size=3, padding=1),
|
||||
PixelShuffleND(2),
|
||||
)
|
||||
elif temporal_upsample:
|
||||
self.upsampler = torch.nn.Sequential(
|
||||
torch.nn.Conv3d(mid_channels, 2 * mid_channels, kernel_size=3, padding=1),
|
||||
PixelShuffleND(1),
|
||||
)
|
||||
else:
|
||||
raise ValueError("Either spatial_upsample or temporal_upsample must be True")
|
||||
|
||||
self.post_upsample_res_blocks = torch.nn.ModuleList(
|
||||
[ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)]
|
||||
)
|
||||
|
||||
self.final_conv = ConvNd(mid_channels, in_channels, kernel_size=3, padding=1)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
batch_size, num_channels, num_frames, height, width = hidden_states.shape
|
||||
|
||||
if self.dims == 2:
|
||||
hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1)
|
||||
hidden_states = self.initial_conv(hidden_states)
|
||||
hidden_states = self.initial_norm(hidden_states)
|
||||
hidden_states = self.initial_activation(hidden_states)
|
||||
|
||||
for block in self.res_blocks:
|
||||
hidden_states = block(hidden_states)
|
||||
|
||||
hidden_states = self.upsampler(hidden_states)
|
||||
|
||||
for block in self.post_upsample_res_blocks:
|
||||
hidden_states = block(hidden_states)
|
||||
|
||||
hidden_states = self.final_conv(hidden_states)
|
||||
hidden_states = hidden_states.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4)
|
||||
else:
|
||||
hidden_states = self.initial_conv(hidden_states)
|
||||
hidden_states = self.initial_norm(hidden_states)
|
||||
hidden_states = self.initial_activation(hidden_states)
|
||||
|
||||
for block in self.res_blocks:
|
||||
hidden_states = block(hidden_states)
|
||||
|
||||
if self.temporal_upsample:
|
||||
hidden_states = self.upsampler(hidden_states)
|
||||
hidden_states = hidden_states[:, :, 1:, :, :]
|
||||
else:
|
||||
hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1)
|
||||
hidden_states = self.upsampler(hidden_states)
|
||||
hidden_states = hidden_states.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4)
|
||||
|
||||
for block in self.post_upsample_res_blocks:
|
||||
hidden_states = block(hidden_states)
|
||||
|
||||
hidden_states = self.final_conv(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
1141
src/diffusers/pipelines/ltx2/pipeline_ltx2.py
Normal file
1141
src/diffusers/pipelines/ltx2/pipeline_ltx2.py
Normal file
File diff suppressed because it is too large
Load Diff
1238
src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py
Normal file
1238
src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py
Normal file
File diff suppressed because it is too large
Load Diff
442
src/diffusers/pipelines/ltx2/pipeline_ltx2_latent_upsample.py
Normal file
442
src/diffusers/pipelines/ltx2/pipeline_ltx2_latent_upsample.py
Normal file
@@ -0,0 +1,442 @@
|
||||
# Copyright 2025 Lightricks and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from ...image_processor import PipelineImageInput
|
||||
from ...models import AutoencoderKLLTX2Video
|
||||
from ...utils import get_logger, replace_example_docstring
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ...video_processor import VideoProcessor
|
||||
from ..ltx.pipeline_output import LTXPipelineOutput
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from .latent_upsampler import LTX2LatentUpsamplerModel
|
||||
|
||||
|
||||
logger = get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
>>> import torch
|
||||
>>> from diffusers import LTX2ImageToVideoPipeline, LTX2LatentUpsamplePipeline
|
||||
>>> from diffusers.pipelines.ltx2.export_utils import encode_video
|
||||
>>> from diffusers.pipelines.ltx2.latent_upsampler import LTX2LatentUpsamplerModel
|
||||
>>> from diffusers.utils import load_image
|
||||
|
||||
>>> pipe = LTX2ImageToVideoPipeline.from_pretrained("Lightricks/LTX-2", torch_dtype=torch.bfloat16)
|
||||
>>> pipe.enable_model_cpu_offload()
|
||||
|
||||
>>> image = load_image(
|
||||
... "https://huggingface.co/datasets/a-r-r-o-w/tiny-meme-dataset-captioned/resolve/main/images/8.png"
|
||||
... )
|
||||
>>> prompt = "A young girl stands calmly in the foreground, looking directly at the camera, as a house fire rages in the background."
|
||||
>>> negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted"
|
||||
|
||||
>>> frame_rate = 24.0
|
||||
>>> video, audio = pipe(
|
||||
... image=image,
|
||||
... prompt=prompt,
|
||||
... negative_prompt=negative_prompt,
|
||||
... width=768,
|
||||
... height=512,
|
||||
... num_frames=121,
|
||||
... frame_rate=frame_rate,
|
||||
... num_inference_steps=40,
|
||||
... guidance_scale=4.0,
|
||||
... output_type="pil",
|
||||
... return_dict=False,
|
||||
... )
|
||||
|
||||
>>> latent_upsampler = LTX2LatentUpsamplerModel.from_pretrained(
|
||||
... "Lightricks/LTX-2", subfolder="latent_upsampler", torch_dtype=torch.bfloat16
|
||||
... )
|
||||
>>> upsample_pipe = LTX2LatentUpsamplePipeline(vae=pipe.vae, latent_upsampler=latent_upsampler)
|
||||
>>> upsample_pipe.vae.enable_tiling()
|
||||
>>> upsample_pipe.to(device="cuda", dtype=torch.bfloat16)
|
||||
|
||||
>>> video = upsample_pipe(
|
||||
... video=video,
|
||||
... width=768,
|
||||
... height=512,
|
||||
... output_type="np",
|
||||
... return_dict=False,
|
||||
... )[0]
|
||||
>>> video = (video * 255).round().astype("uint8")
|
||||
>>> video = torch.from_numpy(video)
|
||||
|
||||
>>> encode_video(
|
||||
... video[0],
|
||||
... fps=frame_rate,
|
||||
... audio=audio[0].float().cpu(),
|
||||
... audio_sample_rate=pipe.vocoder.config.output_sampling_rate, # should be 24000
|
||||
... output_path="video.mp4",
|
||||
... )
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
||||
def retrieve_latents(
|
||||
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
||||
):
|
||||
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
||||
return encoder_output.latent_dist.sample(generator)
|
||||
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
||||
return encoder_output.latent_dist.mode()
|
||||
elif hasattr(encoder_output, "latents"):
|
||||
return encoder_output.latents
|
||||
else:
|
||||
raise AttributeError("Could not access latents of provided encoder_output")
|
||||
|
||||
|
||||
class LTX2LatentUpsamplePipeline(DiffusionPipeline):
|
||||
model_cpu_offload_seq = "vae->latent_upsampler"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vae: AutoencoderKLLTX2Video,
|
||||
latent_upsampler: LTX2LatentUpsamplerModel,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(vae=vae, latent_upsampler=latent_upsampler)
|
||||
|
||||
self.vae_spatial_compression_ratio = (
|
||||
self.vae.spatial_compression_ratio if getattr(self, "vae", None) is not None else 32
|
||||
)
|
||||
self.vae_temporal_compression_ratio = (
|
||||
self.vae.temporal_compression_ratio if getattr(self, "vae", None) is not None else 8
|
||||
)
|
||||
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_spatial_compression_ratio)
|
||||
|
||||
def prepare_latents(
|
||||
self,
|
||||
video: Optional[torch.Tensor] = None,
|
||||
batch_size: int = 1,
|
||||
num_frames: int = 121,
|
||||
height: int = 512,
|
||||
width: int = 768,
|
||||
spatial_patch_size: int = 1,
|
||||
temporal_patch_size: int = 1,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if latents is not None:
|
||||
if latents.ndim == 3:
|
||||
# Convert token seq [B, S, D] to latent video [B, C, F, H, W]
|
||||
latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1
|
||||
latent_height = height // self.vae_spatial_compression_ratio
|
||||
latent_width = width // self.vae_spatial_compression_ratio
|
||||
latents = self._unpack_latents(
|
||||
latents, latent_num_frames, latent_height, latent_width, spatial_patch_size, temporal_patch_size
|
||||
)
|
||||
return latents.to(device=device, dtype=dtype)
|
||||
|
||||
video = video.to(device=device, dtype=self.vae.dtype)
|
||||
if isinstance(generator, list):
|
||||
if len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
init_latents = [
|
||||
retrieve_latents(self.vae.encode(video[i].unsqueeze(0)), generator[i]) for i in range(batch_size)
|
||||
]
|
||||
else:
|
||||
init_latents = [retrieve_latents(self.vae.encode(vid.unsqueeze(0)), generator) for vid in video]
|
||||
|
||||
init_latents = torch.cat(init_latents, dim=0).to(dtype)
|
||||
# NOTE: latent upsampler operates on the unnormalized latents, so don't normalize here
|
||||
# init_latents = self._normalize_latents(init_latents, self.vae.latents_mean, self.vae.latents_std)
|
||||
return init_latents
|
||||
|
||||
def adain_filter_latent(self, latents: torch.Tensor, reference_latents: torch.Tensor, factor: float = 1.0):
|
||||
"""
|
||||
Applies Adaptive Instance Normalization (AdaIN) to a latent tensor based on statistics from a reference latent
|
||||
tensor.
|
||||
|
||||
Args:
|
||||
latent (`torch.Tensor`):
|
||||
Input latents to normalize
|
||||
reference_latents (`torch.Tensor`):
|
||||
The reference latents providing style statistics.
|
||||
factor (`float`):
|
||||
Blending factor between original and transformed latent. Range: -10.0 to 10.0, Default: 1.0
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The transformed latent tensor
|
||||
"""
|
||||
result = latents.clone()
|
||||
|
||||
for i in range(latents.size(0)):
|
||||
for c in range(latents.size(1)):
|
||||
r_sd, r_mean = torch.std_mean(reference_latents[i, c], dim=None) # index by original dim order
|
||||
i_sd, i_mean = torch.std_mean(result[i, c], dim=None)
|
||||
|
||||
result[i, c] = ((result[i, c] - i_mean) / i_sd) * r_sd + r_mean
|
||||
|
||||
result = torch.lerp(latents, result, factor)
|
||||
return result
|
||||
|
||||
def tone_map_latents(self, latents: torch.Tensor, compression: float) -> torch.Tensor:
|
||||
"""
|
||||
Applies a non-linear tone-mapping function to latent values to reduce their dynamic range in a perceptually
|
||||
smooth way using a sigmoid-based compression.
|
||||
|
||||
This is useful for regularizing high-variance latents or for conditioning outputs during generation, especially
|
||||
when controlling dynamic behavior with a `compression` factor.
|
||||
|
||||
Args:
|
||||
latents : torch.Tensor
|
||||
Input latent tensor with arbitrary shape. Expected to be roughly in [-1, 1] or [0, 1] range.
|
||||
compression : float
|
||||
Compression strength in the range [0, 1].
|
||||
- 0.0: No tone-mapping (identity transform)
|
||||
- 1.0: Full compression effect
|
||||
|
||||
Returns:
|
||||
torch.Tensor
|
||||
The tone-mapped latent tensor of the same shape as input.
|
||||
"""
|
||||
# Remap [0-1] to [0-0.75] and apply sigmoid compression in one shot
|
||||
scale_factor = compression * 0.75
|
||||
abs_latents = torch.abs(latents)
|
||||
|
||||
# Sigmoid compression: sigmoid shifts large values toward 0.2, small values stay ~1.0
|
||||
# When scale_factor=0, sigmoid term vanishes, when scale_factor=0.75, full effect
|
||||
sigmoid_term = torch.sigmoid(4.0 * scale_factor * (abs_latents - 1.0))
|
||||
scales = 1.0 - 0.8 * scale_factor * sigmoid_term
|
||||
|
||||
filtered = latents * scales
|
||||
return filtered
|
||||
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.ltx2.pipeline_ltx2_image2video.LTX2ImageToVideoPipeline._normalize_latents
|
||||
def _normalize_latents(
|
||||
latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0
|
||||
) -> torch.Tensor:
|
||||
# Normalize latents across the channel dimension [B, C, F, H, W]
|
||||
latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
|
||||
latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
|
||||
latents = (latents - latents_mean) * scaling_factor / latents_std
|
||||
return latents
|
||||
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._denormalize_latents
|
||||
def _denormalize_latents(
|
||||
latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0
|
||||
) -> torch.Tensor:
|
||||
# Denormalize latents across the channel dimension [B, C, F, H, W]
|
||||
latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
|
||||
latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
|
||||
latents = latents * latents_std / scaling_factor + latents_mean
|
||||
return latents
|
||||
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._unpack_latents
|
||||
def _unpack_latents(
|
||||
latents: torch.Tensor, num_frames: int, height: int, width: int, patch_size: int = 1, patch_size_t: int = 1
|
||||
) -> torch.Tensor:
|
||||
# Packed latents of shape [B, S, D] (S is the effective video sequence length, D is the effective feature dimensions)
|
||||
# are unpacked and reshaped into a video tensor of shape [B, C, F, H, W]. This is the inverse operation of
|
||||
# what happens in the `_pack_latents` method.
|
||||
batch_size = latents.size(0)
|
||||
latents = latents.reshape(batch_size, num_frames, height, width, -1, patch_size_t, patch_size, patch_size)
|
||||
latents = latents.permute(0, 4, 1, 5, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(2, 3)
|
||||
return latents
|
||||
|
||||
def check_inputs(self, video, height, width, latents, tone_map_compression_ratio):
|
||||
if height % self.vae_spatial_compression_ratio != 0 or width % self.vae_spatial_compression_ratio != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.")
|
||||
|
||||
if video is not None and latents is not None:
|
||||
raise ValueError("Only one of `video` or `latents` can be provided.")
|
||||
if video is None and latents is None:
|
||||
raise ValueError("One of `video` or `latents` has to be provided.")
|
||||
|
||||
if not (0 <= tone_map_compression_ratio <= 1):
|
||||
raise ValueError("`tone_map_compression_ratio` must be in the range [0, 1]")
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
video: Optional[List[PipelineImageInput]] = None,
|
||||
height: int = 512,
|
||||
width: int = 768,
|
||||
num_frames: int = 121,
|
||||
spatial_patch_size: int = 1,
|
||||
temporal_patch_size: int = 1,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
latents_normalized: bool = False,
|
||||
decode_timestep: Union[float, List[float]] = 0.0,
|
||||
decode_noise_scale: Optional[Union[float, List[float]]] = None,
|
||||
adain_factor: float = 0.0,
|
||||
tone_map_compression_ratio: float = 0.0,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
video (`List[PipelineImageInput]`, *optional*)
|
||||
The video to be upsampled (such as a LTX 2.0 first stage output). If not supplied, `latents` should be
|
||||
supplied.
|
||||
height (`int`, *optional*, defaults to `512`):
|
||||
The height in pixels of the input video (not the generated video, which will have a larger resolution).
|
||||
width (`int`, *optional*, defaults to `768`):
|
||||
The width in pixels of the input video (not the generated video, which will have a larger resolution).
|
||||
num_frames (`int`, *optional*, defaults to `121`):
|
||||
The number of frames in the input video.
|
||||
spatial_patch_size (`int`, *optional*, defaults to `1`):
|
||||
The spatial patch size of the video latents. Used when `latents` is supplied if unpacking is necessary.
|
||||
temporal_patch_size (`int`, *optional*, defaults to `1`):
|
||||
The temporal patch size of the video latents. Used when `latents` is supplied if unpacking is
|
||||
necessary.
|
||||
latents (`torch.Tensor`, *optional*):
|
||||
Pre-generated video latents. This can be supplied in place of the `video` argument. Can either be a
|
||||
patch sequence of shape `(batch_size, seq_len, hidden_dim)` or a video latent of shape `(batch_size,
|
||||
latent_channels, latent_frames, latent_height, latent_width)`.
|
||||
latents_normalized (`bool`, *optional*, defaults to `False`)
|
||||
If `latents` are supplied, whether the `latents` are normalized using the VAE latent mean and std. If
|
||||
`True`, the `latents` will be denormalized before being supplied to the latent upsampler.
|
||||
decode_timestep (`float`, defaults to `0.0`):
|
||||
The timestep at which generated video is decoded.
|
||||
decode_noise_scale (`float`, defaults to `None`):
|
||||
The interpolation factor between random noise and denoised latents at the decode timestep.
|
||||
adain_factor (`float`, *optional*, defaults to `0.0`):
|
||||
Adaptive Instance Normalization (AdaIN) blending factor between the upsampled and original latents.
|
||||
Should be in [-10.0, 10.0]; supplying 0.0 (the default) means that AdaIN is not performed.
|
||||
tone_map_compression_ratio (`float`, *optional*, defaults to `0.0`):
|
||||
The compression strength for tone mapping, which will reduce the dynamic range of the latent values.
|
||||
This is useful for regularizing high-variance latents or for conditioning outputs during generation.
|
||||
Should be in [0, 1], where 0.0 (the default) means tone mapping is not applied and 1.0 corresponds to
|
||||
the full compression effect.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.ltx.LTXPipelineOutput`] instead of a plain tuple.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~pipelines.ltx.LTXPipelineOutput`] or `tuple`:
|
||||
If `return_dict` is `True`, [`~pipelines.ltx.LTXPipelineOutput`] is returned, otherwise a `tuple` is
|
||||
returned where the first element is the upsampled video.
|
||||
"""
|
||||
|
||||
self.check_inputs(
|
||||
video=video,
|
||||
height=height,
|
||||
width=width,
|
||||
latents=latents,
|
||||
tone_map_compression_ratio=tone_map_compression_ratio,
|
||||
)
|
||||
|
||||
if video is not None:
|
||||
# Batched video input is not yet tested/supported. TODO: take a look later
|
||||
batch_size = 1
|
||||
else:
|
||||
batch_size = latents.shape[0]
|
||||
device = self._execution_device
|
||||
|
||||
if video is not None:
|
||||
num_frames = len(video)
|
||||
if num_frames % self.vae_temporal_compression_ratio != 1:
|
||||
num_frames = (
|
||||
num_frames // self.vae_temporal_compression_ratio * self.vae_temporal_compression_ratio + 1
|
||||
)
|
||||
video = video[:num_frames]
|
||||
logger.warning(
|
||||
f"Video length expected to be of the form `k * {self.vae_temporal_compression_ratio} + 1` but is {len(video)}. Truncating to {num_frames} frames."
|
||||
)
|
||||
video = self.video_processor.preprocess_video(video, height=height, width=width)
|
||||
video = video.to(device=device, dtype=torch.float32)
|
||||
|
||||
latents_supplied = latents is not None
|
||||
latents = self.prepare_latents(
|
||||
video=video,
|
||||
batch_size=batch_size,
|
||||
num_frames=num_frames,
|
||||
height=height,
|
||||
width=width,
|
||||
spatial_patch_size=spatial_patch_size,
|
||||
temporal_patch_size=temporal_patch_size,
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
generator=generator,
|
||||
latents=latents,
|
||||
)
|
||||
|
||||
if latents_supplied and latents_normalized:
|
||||
latents = self._denormalize_latents(
|
||||
latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor
|
||||
)
|
||||
latents = latents.to(self.latent_upsampler.dtype)
|
||||
latents_upsampled = self.latent_upsampler(latents)
|
||||
|
||||
if adain_factor > 0.0:
|
||||
latents = self.adain_filter_latent(latents_upsampled, latents, adain_factor)
|
||||
else:
|
||||
latents = latents_upsampled
|
||||
|
||||
if tone_map_compression_ratio > 0.0:
|
||||
latents = self.tone_map_latents(latents, tone_map_compression_ratio)
|
||||
|
||||
if output_type == "latent":
|
||||
latents = self._normalize_latents(
|
||||
latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor
|
||||
)
|
||||
video = latents
|
||||
else:
|
||||
if not self.vae.config.timestep_conditioning:
|
||||
timestep = None
|
||||
else:
|
||||
noise = randn_tensor(latents.shape, generator=generator, device=device, dtype=latents.dtype)
|
||||
if not isinstance(decode_timestep, list):
|
||||
decode_timestep = [decode_timestep] * batch_size
|
||||
if decode_noise_scale is None:
|
||||
decode_noise_scale = decode_timestep
|
||||
elif not isinstance(decode_noise_scale, list):
|
||||
decode_noise_scale = [decode_noise_scale] * batch_size
|
||||
|
||||
timestep = torch.tensor(decode_timestep, device=device, dtype=latents.dtype)
|
||||
decode_noise_scale = torch.tensor(decode_noise_scale, device=device, dtype=latents.dtype)[
|
||||
:, None, None, None, None
|
||||
]
|
||||
latents = (1 - decode_noise_scale) * latents + decode_noise_scale * noise
|
||||
|
||||
video = self.vae.decode(latents, timestep, return_dict=False)[0]
|
||||
video = self.video_processor.postprocess_video(video, output_type=output_type)
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return (video,)
|
||||
|
||||
return LTXPipelineOutput(frames=video)
|
||||
23
src/diffusers/pipelines/ltx2/pipeline_output.py
Normal file
23
src/diffusers/pipelines/ltx2/pipeline_output.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers.utils import BaseOutput
|
||||
|
||||
|
||||
@dataclass
|
||||
class LTX2PipelineOutput(BaseOutput):
|
||||
r"""
|
||||
Output class for LTX pipelines.
|
||||
|
||||
Args:
|
||||
frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
|
||||
List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
|
||||
denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
|
||||
`(batch_size, num_frames, channels, height, width)`.
|
||||
audio (`torch.Tensor`, `np.ndarray`):
|
||||
TODO
|
||||
"""
|
||||
|
||||
frames: torch.Tensor
|
||||
audio: torch.Tensor
|
||||
159
src/diffusers/pipelines/ltx2/vocoder.py
Normal file
159
src/diffusers/pipelines/ltx2/vocoder.py
Normal file
@@ -0,0 +1,159 @@
|
||||
import math
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...models.modeling_utils import ModelMixin
|
||||
|
||||
|
||||
class ResBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
channels: int,
|
||||
kernel_size: int = 3,
|
||||
stride: int = 1,
|
||||
dilations: Tuple[int, ...] = (1, 3, 5),
|
||||
leaky_relu_negative_slope: float = 0.1,
|
||||
padding_mode: str = "same",
|
||||
):
|
||||
super().__init__()
|
||||
self.dilations = dilations
|
||||
self.negative_slope = leaky_relu_negative_slope
|
||||
|
||||
self.convs1 = nn.ModuleList(
|
||||
[
|
||||
nn.Conv1d(channels, channels, kernel_size, stride=stride, dilation=dilation, padding=padding_mode)
|
||||
for dilation in dilations
|
||||
]
|
||||
)
|
||||
|
||||
self.convs2 = nn.ModuleList(
|
||||
[
|
||||
nn.Conv1d(channels, channels, kernel_size, stride=stride, dilation=1, padding=padding_mode)
|
||||
for _ in range(len(dilations))
|
||||
]
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
for conv1, conv2 in zip(self.convs1, self.convs2):
|
||||
xt = F.leaky_relu(x, negative_slope=self.negative_slope)
|
||||
xt = conv1(xt)
|
||||
xt = F.leaky_relu(xt, negative_slope=self.negative_slope)
|
||||
xt = conv2(xt)
|
||||
x = x + xt
|
||||
return x
|
||||
|
||||
|
||||
class LTX2Vocoder(ModelMixin, ConfigMixin):
|
||||
r"""
|
||||
LTX 2.0 vocoder for converting generated mel spectrograms back to audio waveforms.
|
||||
"""
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 128,
|
||||
hidden_channels: int = 1024,
|
||||
out_channels: int = 2,
|
||||
upsample_kernel_sizes: List[int] = [16, 15, 8, 4, 4],
|
||||
upsample_factors: List[int] = [6, 5, 2, 2, 2],
|
||||
resnet_kernel_sizes: List[int] = [3, 7, 11],
|
||||
resnet_dilations: List[List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
||||
leaky_relu_negative_slope: float = 0.1,
|
||||
output_sampling_rate: int = 24000,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_upsample_layers = len(upsample_kernel_sizes)
|
||||
self.resnets_per_upsample = len(resnet_kernel_sizes)
|
||||
self.out_channels = out_channels
|
||||
self.total_upsample_factor = math.prod(upsample_factors)
|
||||
self.negative_slope = leaky_relu_negative_slope
|
||||
|
||||
if self.num_upsample_layers != len(upsample_factors):
|
||||
raise ValueError(
|
||||
f"`upsample_kernel_sizes` and `upsample_factors` should be lists of the same length but are length"
|
||||
f" {self.num_upsample_layers} and {len(upsample_factors)}, respectively."
|
||||
)
|
||||
|
||||
if self.resnets_per_upsample != len(resnet_dilations):
|
||||
raise ValueError(
|
||||
f"`resnet_kernel_sizes` and `resnet_dilations` should be lists of the same length but are length"
|
||||
f" {len(self.resnets_per_upsample)} and {len(resnet_dilations)}, respectively."
|
||||
)
|
||||
|
||||
self.conv_in = nn.Conv1d(in_channels, hidden_channels, kernel_size=7, stride=1, padding=3)
|
||||
|
||||
self.upsamplers = nn.ModuleList()
|
||||
self.resnets = nn.ModuleList()
|
||||
input_channels = hidden_channels
|
||||
for i, (stride, kernel_size) in enumerate(zip(upsample_factors, upsample_kernel_sizes)):
|
||||
output_channels = input_channels // 2
|
||||
self.upsamplers.append(
|
||||
nn.ConvTranspose1d(
|
||||
input_channels, # hidden_channels // (2 ** i)
|
||||
output_channels, # hidden_channels // (2 ** (i + 1))
|
||||
kernel_size,
|
||||
stride=stride,
|
||||
padding=(kernel_size - stride) // 2,
|
||||
)
|
||||
)
|
||||
|
||||
for kernel_size, dilations in zip(resnet_kernel_sizes, resnet_dilations):
|
||||
self.resnets.append(
|
||||
ResBlock(
|
||||
output_channels,
|
||||
kernel_size,
|
||||
dilations=dilations,
|
||||
leaky_relu_negative_slope=leaky_relu_negative_slope,
|
||||
)
|
||||
)
|
||||
input_channels = output_channels
|
||||
|
||||
self.conv_out = nn.Conv1d(output_channels, out_channels, 7, stride=1, padding=3)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, time_last: bool = False) -> torch.Tensor:
|
||||
r"""
|
||||
Forward pass of the vocoder.
|
||||
|
||||
Args:
|
||||
hidden_states (`torch.Tensor`):
|
||||
Input Mel spectrogram tensor of shape `(batch_size, num_channels, time, num_mel_bins)` if `time_last`
|
||||
is `False` (the default) or shape `(batch_size, num_channels, num_mel_bins, time)` if `time_last` is
|
||||
`True`.
|
||||
time_last (`bool`, *optional*, defaults to `False`):
|
||||
Whether the last dimension of the input is the time/frame dimension or the Mel bins dimension.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`:
|
||||
Audio waveform tensor of shape (batch_size, out_channels, audio_length)
|
||||
"""
|
||||
|
||||
# Ensure that the time/frame dimension is last
|
||||
if not time_last:
|
||||
hidden_states = hidden_states.transpose(2, 3)
|
||||
# Combine channels and frequency (mel bins) dimensions
|
||||
hidden_states = hidden_states.flatten(1, 2)
|
||||
|
||||
hidden_states = self.conv_in(hidden_states)
|
||||
|
||||
for i in range(self.num_upsample_layers):
|
||||
hidden_states = F.leaky_relu(hidden_states, negative_slope=self.negative_slope)
|
||||
hidden_states = self.upsamplers[i](hidden_states)
|
||||
|
||||
# Run all resnets in parallel on hidden_states
|
||||
start = i * self.resnets_per_upsample
|
||||
end = (i + 1) * self.resnets_per_upsample
|
||||
resnet_outputs = torch.stack([self.resnets[j](hidden_states) for j in range(start, end)], dim=0)
|
||||
|
||||
hidden_states = torch.mean(resnet_outputs, dim=0)
|
||||
|
||||
# NOTE: unlike the first leaky ReLU, this leaky ReLU is set to use the default F.leaky_relu negative slope of
|
||||
# 0.01 (whereas the others usually use a slope of 0.1). Not sure if this is intended
|
||||
hidden_states = F.leaky_relu(hidden_states, negative_slope=0.01)
|
||||
hidden_states = self.conv_out(hidden_states)
|
||||
hidden_states = torch.tanh(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
@@ -66,6 +66,7 @@ from .import_utils import (
|
||||
is_accelerate_version,
|
||||
is_aiter_available,
|
||||
is_aiter_version,
|
||||
is_av_available,
|
||||
is_better_profanity_available,
|
||||
is_bitsandbytes_available,
|
||||
is_bitsandbytes_version,
|
||||
|
||||
@@ -502,6 +502,36 @@ class AutoencoderKLHunyuanVideo15(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class AutoencoderKLLTX2Audio(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class AutoencoderKLLTX2Video(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class AutoencoderKLLTXVideo(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
@@ -1147,6 +1177,21 @@ class LongCatImageTransformer2DModel(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class LTX2VideoTransformer3DModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class LTXVideoTransformer3DModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
||||
@@ -1877,6 +1877,51 @@ class LongCatImagePipeline(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class LTX2ImageToVideoPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class LTX2LatentUpsamplePipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class LTX2Pipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class LTXConditionPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
|
||||
@@ -230,6 +230,7 @@ _flash_attn_3_available, _flash_attn_3_version = _is_package_available("flash_at
|
||||
_aiter_available, _aiter_version = _is_package_available("aiter")
|
||||
_kornia_available, _kornia_version = _is_package_available("kornia")
|
||||
_nvidia_modelopt_available, _nvidia_modelopt_version = _is_package_available("modelopt", get_dist_name=True)
|
||||
_av_available, _av_version = _is_package_available("av")
|
||||
|
||||
|
||||
def is_torch_available():
|
||||
@@ -420,6 +421,10 @@ def is_kornia_available():
|
||||
return _kornia_available
|
||||
|
||||
|
||||
def is_av_available():
|
||||
return _av_available
|
||||
|
||||
|
||||
# docstyle-ignore
|
||||
FLAX_IMPORT_ERROR = """
|
||||
{0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the
|
||||
|
||||
@@ -0,0 +1,88 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
from diffusers import AutoencoderKLLTX2Audio
|
||||
|
||||
from ...testing_utils import (
|
||||
floats_tensor,
|
||||
torch_device,
|
||||
)
|
||||
from ..test_modeling_common import ModelTesterMixin
|
||||
from .testing_utils import AutoencoderTesterMixin
|
||||
|
||||
|
||||
class AutoencoderKLLTX2AudioTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
|
||||
model_class = AutoencoderKLLTX2Audio
|
||||
main_input_name = "sample"
|
||||
base_precision = 1e-2
|
||||
|
||||
def get_autoencoder_kl_ltx_video_config(self):
|
||||
return {
|
||||
"in_channels": 2, # stereo,
|
||||
"output_channels": 2,
|
||||
"latent_channels": 4,
|
||||
"base_channels": 16,
|
||||
"ch_mult": (1, 2, 4),
|
||||
"resolution": 16,
|
||||
"attn_resolutions": None,
|
||||
"num_res_blocks": 2,
|
||||
"norm_type": "pixel",
|
||||
"causality_axis": "height",
|
||||
"mid_block_add_attention": False,
|
||||
"sample_rate": 16000,
|
||||
"mel_hop_length": 160,
|
||||
"mel_bins": 16,
|
||||
"is_causal": True,
|
||||
"double_z": True,
|
||||
}
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
batch_size = 2
|
||||
num_channels = 2
|
||||
num_frames = 8
|
||||
num_mel_bins = 16
|
||||
|
||||
spectrogram = floats_tensor((batch_size, num_channels, num_frames, num_mel_bins)).to(torch_device)
|
||||
|
||||
input_dict = {"sample": spectrogram}
|
||||
return input_dict
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (2, 5, 16)
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (2, 5, 16)
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = self.get_autoencoder_kl_ltx_video_config()
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
# Overriding as output shape is not the same as input shape for LTX 2.0 audio VAE
|
||||
def test_output(self):
|
||||
super().test_output(expected_output_shape=(2, 2, 5, 16))
|
||||
|
||||
@unittest.skip("Unsupported test.")
|
||||
def test_outputs_equivalence(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("AutoencoderKLLTX2Audio does not support `norm_num_groups` because it does not use GroupNorm.")
|
||||
def test_forward_with_norm_groups(self):
|
||||
pass
|
||||
103
tests/models/autoencoders/test_models_autoencoder_ltx2_video.py
Normal file
103
tests/models/autoencoders/test_models_autoencoder_ltx2_video.py
Normal file
@@ -0,0 +1,103 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
from diffusers import AutoencoderKLLTX2Video
|
||||
|
||||
from ...testing_utils import (
|
||||
enable_full_determinism,
|
||||
floats_tensor,
|
||||
torch_device,
|
||||
)
|
||||
from ..test_modeling_common import ModelTesterMixin
|
||||
from .testing_utils import AutoencoderTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class AutoencoderKLLTX2VideoTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
|
||||
model_class = AutoencoderKLLTX2Video
|
||||
main_input_name = "sample"
|
||||
base_precision = 1e-2
|
||||
|
||||
def get_autoencoder_kl_ltx_video_config(self):
|
||||
return {
|
||||
"in_channels": 3,
|
||||
"out_channels": 3,
|
||||
"latent_channels": 8,
|
||||
"block_out_channels": (8, 8, 8, 8),
|
||||
"decoder_block_out_channels": (16, 32, 64),
|
||||
"layers_per_block": (1, 1, 1, 1, 1),
|
||||
"decoder_layers_per_block": (1, 1, 1, 1),
|
||||
"spatio_temporal_scaling": (True, True, True, True),
|
||||
"decoder_spatio_temporal_scaling": (True, True, True),
|
||||
"decoder_inject_noise": (False, False, False, False),
|
||||
"downsample_type": ("spatial", "temporal", "spatiotemporal", "spatiotemporal"),
|
||||
"upsample_residual": (True, True, True),
|
||||
"upsample_factor": (2, 2, 2),
|
||||
"timestep_conditioning": False,
|
||||
"patch_size": 1,
|
||||
"patch_size_t": 1,
|
||||
"encoder_causal": True,
|
||||
"decoder_causal": False,
|
||||
"encoder_spatial_padding_mode": "zeros",
|
||||
# Full model uses `reflect` but this does not have deterministic backward implementation, so use `zeros`
|
||||
"decoder_spatial_padding_mode": "zeros",
|
||||
}
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
batch_size = 2
|
||||
num_frames = 9
|
||||
num_channels = 3
|
||||
sizes = (16, 16)
|
||||
|
||||
image = floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device)
|
||||
|
||||
input_dict = {"sample": image}
|
||||
return input_dict
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (3, 9, 16, 16)
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (3, 9, 16, 16)
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = self.get_autoencoder_kl_ltx_video_config()
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {
|
||||
"LTX2VideoEncoder3d",
|
||||
"LTX2VideoDecoder3d",
|
||||
"LTX2VideoDownBlock3D",
|
||||
"LTX2VideoMidBlock3d",
|
||||
"LTX2VideoUpBlock3d",
|
||||
}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
@unittest.skip("Unsupported test.")
|
||||
def test_outputs_equivalence(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("AutoencoderKLLTXVideo does not support `norm_num_groups` because it does not use GroupNorm.")
|
||||
def test_forward_with_norm_groups(self):
|
||||
pass
|
||||
222
tests/models/transformers/test_models_transformer_ltx2.py
Normal file
222
tests/models/transformers/test_models_transformer_ltx2.py
Normal file
@@ -0,0 +1,222 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import LTX2VideoTransformer3DModel
|
||||
|
||||
from ...testing_utils import enable_full_determinism, torch_device
|
||||
from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class LTX2TransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
model_class = LTX2VideoTransformer3DModel
|
||||
main_input_name = "hidden_states"
|
||||
uses_custom_attn_processor = True
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
# Common
|
||||
batch_size = 2
|
||||
|
||||
# Video
|
||||
num_frames = 2
|
||||
num_channels = 4
|
||||
height = 16
|
||||
width = 16
|
||||
|
||||
# Audio
|
||||
audio_num_frames = 9
|
||||
audio_num_channels = 2
|
||||
num_mel_bins = 2
|
||||
|
||||
# Text
|
||||
embedding_dim = 16
|
||||
sequence_length = 16
|
||||
|
||||
hidden_states = torch.randn((batch_size, num_frames * height * width, num_channels)).to(torch_device)
|
||||
audio_hidden_states = torch.randn((batch_size, audio_num_frames, audio_num_channels * num_mel_bins)).to(
|
||||
torch_device
|
||||
)
|
||||
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
|
||||
audio_encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
|
||||
encoder_attention_mask = torch.ones((batch_size, sequence_length)).bool().to(torch_device)
|
||||
timestep = torch.rand((batch_size,)).to(torch_device) * 1000
|
||||
|
||||
return {
|
||||
"hidden_states": hidden_states,
|
||||
"audio_hidden_states": audio_hidden_states,
|
||||
"encoder_hidden_states": encoder_hidden_states,
|
||||
"audio_encoder_hidden_states": audio_encoder_hidden_states,
|
||||
"timestep": timestep,
|
||||
"encoder_attention_mask": encoder_attention_mask,
|
||||
"num_frames": num_frames,
|
||||
"height": height,
|
||||
"width": width,
|
||||
"audio_num_frames": audio_num_frames,
|
||||
"fps": 25.0,
|
||||
}
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (512, 4)
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (512, 4)
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = {
|
||||
"in_channels": 4,
|
||||
"out_channels": 4,
|
||||
"patch_size": 1,
|
||||
"patch_size_t": 1,
|
||||
"num_attention_heads": 2,
|
||||
"attention_head_dim": 8,
|
||||
"cross_attention_dim": 16,
|
||||
"audio_in_channels": 4,
|
||||
"audio_out_channels": 4,
|
||||
"audio_num_attention_heads": 2,
|
||||
"audio_attention_head_dim": 4,
|
||||
"audio_cross_attention_dim": 8,
|
||||
"num_layers": 2,
|
||||
"qk_norm": "rms_norm_across_heads",
|
||||
"caption_channels": 16,
|
||||
"rope_double_precision": False,
|
||||
}
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {"LTX2VideoTransformer3DModel"}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
# def test_ltx2_consistency(self, seed=0, dtype=torch.float32):
|
||||
# torch.manual_seed(seed)
|
||||
# init_dict, _ = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
# # Calculate dummy inputs in a custom manner to ensure compatibility with original code
|
||||
# batch_size = 2
|
||||
# num_frames = 9
|
||||
# latent_frames = 2
|
||||
# text_embedding_dim = 16
|
||||
# text_seq_len = 16
|
||||
# fps = 25.0
|
||||
# sampling_rate = 16000.0
|
||||
# hop_length = 160.0
|
||||
|
||||
# sigma = torch.rand((1,), generator=torch.manual_seed(seed), dtype=dtype, device="cpu") * 1000
|
||||
# timestep = (sigma * torch.ones((batch_size,), dtype=dtype, device="cpu")).to(device=torch_device)
|
||||
|
||||
# num_channels = 4
|
||||
# latent_height = 4
|
||||
# latent_width = 4
|
||||
# hidden_states = torch.randn(
|
||||
# (batch_size, num_channels, latent_frames, latent_height, latent_width),
|
||||
# generator=torch.manual_seed(seed),
|
||||
# dtype=dtype,
|
||||
# device="cpu",
|
||||
# )
|
||||
# # Patchify video latents (with patch_size (1, 1, 1))
|
||||
# hidden_states = hidden_states.reshape(batch_size, -1, latent_frames, 1, latent_height, 1, latent_width, 1)
|
||||
# hidden_states = hidden_states.permute(0, 2, 4, 6, 1, 3, 5, 7).flatten(4, 7).flatten(1, 3)
|
||||
# encoder_hidden_states = torch.randn(
|
||||
# (batch_size, text_seq_len, text_embedding_dim),
|
||||
# generator=torch.manual_seed(seed),
|
||||
# dtype=dtype,
|
||||
# device="cpu",
|
||||
# )
|
||||
|
||||
# audio_num_channels = 2
|
||||
# num_mel_bins = 2
|
||||
# latent_length = int((sampling_rate / hop_length / 4) * (num_frames / fps))
|
||||
# audio_hidden_states = torch.randn(
|
||||
# (batch_size, audio_num_channels, latent_length, num_mel_bins),
|
||||
# generator=torch.manual_seed(seed),
|
||||
# dtype=dtype,
|
||||
# device="cpu",
|
||||
# )
|
||||
# # Patchify audio latents
|
||||
# audio_hidden_states = audio_hidden_states.transpose(1, 2).flatten(2, 3)
|
||||
# audio_encoder_hidden_states = torch.randn(
|
||||
# (batch_size, text_seq_len, text_embedding_dim),
|
||||
# generator=torch.manual_seed(seed),
|
||||
# dtype=dtype,
|
||||
# device="cpu",
|
||||
# )
|
||||
|
||||
# inputs_dict = {
|
||||
# "hidden_states": hidden_states.to(device=torch_device),
|
||||
# "audio_hidden_states": audio_hidden_states.to(device=torch_device),
|
||||
# "encoder_hidden_states": encoder_hidden_states.to(device=torch_device),
|
||||
# "audio_encoder_hidden_states": audio_encoder_hidden_states.to(device=torch_device),
|
||||
# "timestep": timestep,
|
||||
# "num_frames": latent_frames,
|
||||
# "height": latent_height,
|
||||
# "width": latent_width,
|
||||
# "audio_num_frames": num_frames,
|
||||
# "fps": 25.0,
|
||||
# }
|
||||
|
||||
# model = self.model_class.from_pretrained(
|
||||
# "diffusers-internal-dev/dummy-ltx2",
|
||||
# subfolder="transformer",
|
||||
# device_map="cpu",
|
||||
# )
|
||||
# # torch.manual_seed(seed)
|
||||
# # model = self.model_class(**init_dict)
|
||||
# model.to(torch_device)
|
||||
# model.eval()
|
||||
|
||||
# with attention_backend("native"):
|
||||
# with torch.no_grad():
|
||||
# output = model(**inputs_dict)
|
||||
|
||||
# video_output, audio_output = output.to_tuple()
|
||||
|
||||
# self.assertIsNotNone(video_output)
|
||||
# self.assertIsNotNone(audio_output)
|
||||
|
||||
# # input & output have to have the same shape
|
||||
# video_expected_shape = (batch_size, latent_frames * latent_height * latent_width, num_channels)
|
||||
# self.assertEqual(video_output.shape, video_expected_shape, "Video input and output shapes do not match")
|
||||
# audio_expected_shape = (batch_size, latent_length, audio_num_channels * num_mel_bins)
|
||||
# self.assertEqual(audio_output.shape, audio_expected_shape, "Audio input and output shapes do not match")
|
||||
|
||||
# # Check against expected slice
|
||||
# # fmt: off
|
||||
# video_expected_slice = torch.tensor([0.4783, 1.6954, -1.2092, 0.1762, 0.7801, 1.2025, -1.4525, -0.2721, 0.3354, 1.9144, -1.5546, 0.0831, 0.4391, 1.7012, -1.7373, -0.2676])
|
||||
# audio_expected_slice = torch.tensor([-0.4236, 0.4750, 0.3901, -0.4339, -0.2782, 0.4357, 0.4526, -0.3927, -0.0980, 0.4870, 0.3964, -0.3169, -0.3974, 0.4408, 0.3809, -0.4692])
|
||||
# # fmt: on
|
||||
|
||||
# video_output_flat = video_output.cpu().flatten().float()
|
||||
# video_generated_slice = torch.cat([video_output_flat[:8], video_output_flat[-8:]])
|
||||
# self.assertTrue(torch.allclose(video_generated_slice, video_expected_slice, atol=1e-4))
|
||||
|
||||
# audio_output_flat = audio_output.cpu().flatten().float()
|
||||
# audio_generated_slice = torch.cat([audio_output_flat[:8], audio_output_flat[-8:]])
|
||||
# self.assertTrue(torch.allclose(audio_generated_slice, audio_expected_slice, atol=1e-4))
|
||||
|
||||
|
||||
class LTX2TransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
|
||||
model_class = LTX2VideoTransformer3DModel
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
return LTX2TransformerTests().prepare_init_args_and_inputs_for_common()
|
||||
0
tests/pipelines/ltx2/__init__.py
Normal file
0
tests/pipelines/ltx2/__init__.py
Normal file
239
tests/pipelines/ltx2/test_ltx2.py
Normal file
239
tests/pipelines/ltx2/test_ltx2.py
Normal file
@@ -0,0 +1,239 @@
|
||||
# Copyright 2025 The HuggingFace Team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from transformers import AutoTokenizer, Gemma3ForConditionalGeneration
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKLLTX2Audio,
|
||||
AutoencoderKLLTX2Video,
|
||||
FlowMatchEulerDiscreteScheduler,
|
||||
LTX2Pipeline,
|
||||
LTX2VideoTransformer3DModel,
|
||||
)
|
||||
from diffusers.pipelines.ltx2 import LTX2TextConnectors
|
||||
from diffusers.pipelines.ltx2.vocoder import LTX2Vocoder
|
||||
|
||||
from ...testing_utils import enable_full_determinism
|
||||
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
|
||||
from ..test_pipelines_common import PipelineTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class LTX2PipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
pipeline_class = LTX2Pipeline
|
||||
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
|
||||
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
|
||||
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
||||
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
||||
required_optional_params = frozenset(
|
||||
[
|
||||
"num_inference_steps",
|
||||
"generator",
|
||||
"latents",
|
||||
"audio_latents",
|
||||
"output_type",
|
||||
"return_dict",
|
||||
"callback_on_step_end",
|
||||
"callback_on_step_end_tensor_inputs",
|
||||
]
|
||||
)
|
||||
test_attention_slicing = False
|
||||
test_xformers_attention = False
|
||||
supports_dduf = False
|
||||
|
||||
base_text_encoder_ckpt_id = "hf-internal-testing/tiny-gemma3"
|
||||
|
||||
def get_dummy_components(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained(self.base_text_encoder_ckpt_id)
|
||||
text_encoder = Gemma3ForConditionalGeneration.from_pretrained(self.base_text_encoder_ckpt_id)
|
||||
|
||||
torch.manual_seed(0)
|
||||
transformer = LTX2VideoTransformer3DModel(
|
||||
in_channels=4,
|
||||
out_channels=4,
|
||||
patch_size=1,
|
||||
patch_size_t=1,
|
||||
num_attention_heads=2,
|
||||
attention_head_dim=8,
|
||||
cross_attention_dim=16,
|
||||
audio_in_channels=4,
|
||||
audio_out_channels=4,
|
||||
audio_num_attention_heads=2,
|
||||
audio_attention_head_dim=4,
|
||||
audio_cross_attention_dim=8,
|
||||
num_layers=2,
|
||||
qk_norm="rms_norm_across_heads",
|
||||
caption_channels=text_encoder.config.text_config.hidden_size,
|
||||
rope_double_precision=False,
|
||||
rope_type="split",
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
connectors = LTX2TextConnectors(
|
||||
caption_channels=text_encoder.config.text_config.hidden_size,
|
||||
text_proj_in_factor=text_encoder.config.text_config.num_hidden_layers + 1,
|
||||
video_connector_num_attention_heads=4,
|
||||
video_connector_attention_head_dim=8,
|
||||
video_connector_num_layers=1,
|
||||
video_connector_num_learnable_registers=None,
|
||||
audio_connector_num_attention_heads=4,
|
||||
audio_connector_attention_head_dim=8,
|
||||
audio_connector_num_layers=1,
|
||||
audio_connector_num_learnable_registers=None,
|
||||
connector_rope_base_seq_len=32,
|
||||
rope_theta=10000.0,
|
||||
rope_double_precision=False,
|
||||
causal_temporal_positioning=False,
|
||||
rope_type="split",
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
vae = AutoencoderKLLTX2Video(
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
latent_channels=4,
|
||||
block_out_channels=(8,),
|
||||
decoder_block_out_channels=(8,),
|
||||
layers_per_block=(1,),
|
||||
decoder_layers_per_block=(1, 1),
|
||||
spatio_temporal_scaling=(True,),
|
||||
decoder_spatio_temporal_scaling=(True,),
|
||||
decoder_inject_noise=(False, False),
|
||||
downsample_type=("spatial",),
|
||||
upsample_residual=(False,),
|
||||
upsample_factor=(1,),
|
||||
timestep_conditioning=False,
|
||||
patch_size=1,
|
||||
patch_size_t=1,
|
||||
encoder_causal=True,
|
||||
decoder_causal=False,
|
||||
)
|
||||
vae.use_framewise_encoding = False
|
||||
vae.use_framewise_decoding = False
|
||||
|
||||
torch.manual_seed(0)
|
||||
audio_vae = AutoencoderKLLTX2Audio(
|
||||
base_channels=4,
|
||||
output_channels=2,
|
||||
ch_mult=(1,),
|
||||
num_res_blocks=1,
|
||||
attn_resolutions=None,
|
||||
in_channels=2,
|
||||
resolution=32,
|
||||
latent_channels=2,
|
||||
norm_type="pixel",
|
||||
causality_axis="height",
|
||||
dropout=0.0,
|
||||
mid_block_add_attention=False,
|
||||
sample_rate=16000,
|
||||
mel_hop_length=160,
|
||||
is_causal=True,
|
||||
mel_bins=8,
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
vocoder = LTX2Vocoder(
|
||||
in_channels=audio_vae.config.output_channels * audio_vae.config.mel_bins,
|
||||
hidden_channels=32,
|
||||
out_channels=2,
|
||||
upsample_kernel_sizes=[4, 4],
|
||||
upsample_factors=[2, 2],
|
||||
resnet_kernel_sizes=[3],
|
||||
resnet_dilations=[[1, 3, 5]],
|
||||
leaky_relu_negative_slope=0.1,
|
||||
output_sampling_rate=16000,
|
||||
)
|
||||
|
||||
scheduler = FlowMatchEulerDiscreteScheduler()
|
||||
|
||||
components = {
|
||||
"transformer": transformer,
|
||||
"vae": vae,
|
||||
"audio_vae": audio_vae,
|
||||
"scheduler": scheduler,
|
||||
"text_encoder": text_encoder,
|
||||
"tokenizer": tokenizer,
|
||||
"connectors": connectors,
|
||||
"vocoder": vocoder,
|
||||
}
|
||||
|
||||
return components
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device=device).manual_seed(seed)
|
||||
|
||||
inputs = {
|
||||
"prompt": "a robot dancing",
|
||||
"negative_prompt": "",
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 1.0,
|
||||
"height": 32,
|
||||
"width": 32,
|
||||
"num_frames": 5,
|
||||
"frame_rate": 25.0,
|
||||
"max_sequence_length": 16,
|
||||
"output_type": "pt",
|
||||
}
|
||||
|
||||
return inputs
|
||||
|
||||
def test_inference(self):
|
||||
device = "cpu"
|
||||
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.to(device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
output = pipe(**inputs)
|
||||
video = output.frames
|
||||
audio = output.audio
|
||||
|
||||
self.assertEqual(video.shape, (1, 5, 3, 32, 32))
|
||||
self.assertEqual(audio.shape[0], 1)
|
||||
self.assertEqual(audio.shape[1], components["vocoder"].config.out_channels)
|
||||
|
||||
# fmt: off
|
||||
expected_video_slice = torch.tensor(
|
||||
[
|
||||
0.4331, 0.6203, 0.3245, 0.7294, 0.4822, 0.5703, 0.2999, 0.7700, 0.4961, 0.4242, 0.4581, 0.4351, 0.1137, 0.4437, 0.6304, 0.3184
|
||||
]
|
||||
)
|
||||
expected_audio_slice = torch.tensor(
|
||||
[
|
||||
0.0236, 0.0499, 0.1230, 0.1094, 0.1713, 0.1044, 0.1729, 0.1009, 0.0672, -0.0069, 0.0688, 0.0097, 0.0808, 0.1231, 0.0986, 0.0739
|
||||
]
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
video = video.flatten()
|
||||
audio = audio.flatten()
|
||||
generated_video_slice = torch.cat([video[:8], video[-8:]])
|
||||
generated_audio_slice = torch.cat([audio[:8], audio[-8:]])
|
||||
|
||||
assert torch.allclose(expected_video_slice, generated_video_slice, atol=1e-4, rtol=1e-4)
|
||||
assert torch.allclose(expected_audio_slice, generated_audio_slice, atol=1e-4, rtol=1e-4)
|
||||
|
||||
def test_inference_batch_single_identical(self):
|
||||
self._test_inference_batch_single_identical(batch_size=2, expected_max_diff=2e-2)
|
||||
241
tests/pipelines/ltx2/test_ltx2_image2video.py
Normal file
241
tests/pipelines/ltx2/test_ltx2_image2video.py
Normal file
@@ -0,0 +1,241 @@
|
||||
# Copyright 2025 The HuggingFace Team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from transformers import AutoTokenizer, Gemma3ForConditionalGeneration
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKLLTX2Audio,
|
||||
AutoencoderKLLTX2Video,
|
||||
FlowMatchEulerDiscreteScheduler,
|
||||
LTX2ImageToVideoPipeline,
|
||||
LTX2VideoTransformer3DModel,
|
||||
)
|
||||
from diffusers.pipelines.ltx2 import LTX2TextConnectors
|
||||
from diffusers.pipelines.ltx2.vocoder import LTX2Vocoder
|
||||
|
||||
from ...testing_utils import enable_full_determinism
|
||||
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
|
||||
from ..test_pipelines_common import PipelineTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class LTX2ImageToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
pipeline_class = LTX2ImageToVideoPipeline
|
||||
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
|
||||
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS.union({"image"})
|
||||
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
||||
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
||||
required_optional_params = frozenset(
|
||||
[
|
||||
"num_inference_steps",
|
||||
"generator",
|
||||
"latents",
|
||||
"audio_latents",
|
||||
"return_dict",
|
||||
"callback_on_step_end",
|
||||
"callback_on_step_end_tensor_inputs",
|
||||
]
|
||||
)
|
||||
test_attention_slicing = False
|
||||
test_xformers_attention = False
|
||||
supports_dduf = False
|
||||
|
||||
base_text_encoder_ckpt_id = "hf-internal-testing/tiny-gemma3"
|
||||
|
||||
def get_dummy_components(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained(self.base_text_encoder_ckpt_id)
|
||||
text_encoder = Gemma3ForConditionalGeneration.from_pretrained(self.base_text_encoder_ckpt_id)
|
||||
|
||||
torch.manual_seed(0)
|
||||
transformer = LTX2VideoTransformer3DModel(
|
||||
in_channels=4,
|
||||
out_channels=4,
|
||||
patch_size=1,
|
||||
patch_size_t=1,
|
||||
num_attention_heads=2,
|
||||
attention_head_dim=8,
|
||||
cross_attention_dim=16,
|
||||
audio_in_channels=4,
|
||||
audio_out_channels=4,
|
||||
audio_num_attention_heads=2,
|
||||
audio_attention_head_dim=4,
|
||||
audio_cross_attention_dim=8,
|
||||
num_layers=2,
|
||||
qk_norm="rms_norm_across_heads",
|
||||
caption_channels=text_encoder.config.text_config.hidden_size,
|
||||
rope_double_precision=False,
|
||||
rope_type="split",
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
connectors = LTX2TextConnectors(
|
||||
caption_channels=text_encoder.config.text_config.hidden_size,
|
||||
text_proj_in_factor=text_encoder.config.text_config.num_hidden_layers + 1,
|
||||
video_connector_num_attention_heads=4,
|
||||
video_connector_attention_head_dim=8,
|
||||
video_connector_num_layers=1,
|
||||
video_connector_num_learnable_registers=None,
|
||||
audio_connector_num_attention_heads=4,
|
||||
audio_connector_attention_head_dim=8,
|
||||
audio_connector_num_layers=1,
|
||||
audio_connector_num_learnable_registers=None,
|
||||
connector_rope_base_seq_len=32,
|
||||
rope_theta=10000.0,
|
||||
rope_double_precision=False,
|
||||
causal_temporal_positioning=False,
|
||||
rope_type="split",
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
vae = AutoencoderKLLTX2Video(
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
latent_channels=4,
|
||||
block_out_channels=(8,),
|
||||
decoder_block_out_channels=(8,),
|
||||
layers_per_block=(1,),
|
||||
decoder_layers_per_block=(1, 1),
|
||||
spatio_temporal_scaling=(True,),
|
||||
decoder_spatio_temporal_scaling=(True,),
|
||||
decoder_inject_noise=(False, False),
|
||||
downsample_type=("spatial",),
|
||||
upsample_residual=(False,),
|
||||
upsample_factor=(1,),
|
||||
timestep_conditioning=False,
|
||||
patch_size=1,
|
||||
patch_size_t=1,
|
||||
encoder_causal=True,
|
||||
decoder_causal=False,
|
||||
)
|
||||
vae.use_framewise_encoding = False
|
||||
vae.use_framewise_decoding = False
|
||||
|
||||
torch.manual_seed(0)
|
||||
audio_vae = AutoencoderKLLTX2Audio(
|
||||
base_channels=4,
|
||||
output_channels=2,
|
||||
ch_mult=(1,),
|
||||
num_res_blocks=1,
|
||||
attn_resolutions=None,
|
||||
in_channels=2,
|
||||
resolution=32,
|
||||
latent_channels=2,
|
||||
norm_type="pixel",
|
||||
causality_axis="height",
|
||||
dropout=0.0,
|
||||
mid_block_add_attention=False,
|
||||
sample_rate=16000,
|
||||
mel_hop_length=160,
|
||||
is_causal=True,
|
||||
mel_bins=8,
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
vocoder = LTX2Vocoder(
|
||||
in_channels=audio_vae.config.output_channels * audio_vae.config.mel_bins,
|
||||
hidden_channels=32,
|
||||
out_channels=2,
|
||||
upsample_kernel_sizes=[4, 4],
|
||||
upsample_factors=[2, 2],
|
||||
resnet_kernel_sizes=[3],
|
||||
resnet_dilations=[[1, 3, 5]],
|
||||
leaky_relu_negative_slope=0.1,
|
||||
output_sampling_rate=16000,
|
||||
)
|
||||
|
||||
scheduler = FlowMatchEulerDiscreteScheduler()
|
||||
|
||||
components = {
|
||||
"transformer": transformer,
|
||||
"vae": vae,
|
||||
"audio_vae": audio_vae,
|
||||
"scheduler": scheduler,
|
||||
"text_encoder": text_encoder,
|
||||
"tokenizer": tokenizer,
|
||||
"connectors": connectors,
|
||||
"vocoder": vocoder,
|
||||
}
|
||||
|
||||
return components
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device=device).manual_seed(seed)
|
||||
|
||||
image = torch.rand((1, 3, 32, 32), generator=generator, device=device)
|
||||
|
||||
inputs = {
|
||||
"image": image,
|
||||
"prompt": "a robot dancing",
|
||||
"negative_prompt": "",
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 1.0,
|
||||
"height": 32,
|
||||
"width": 32,
|
||||
"num_frames": 5,
|
||||
"frame_rate": 25.0,
|
||||
"max_sequence_length": 16,
|
||||
"output_type": "pt",
|
||||
}
|
||||
|
||||
return inputs
|
||||
|
||||
def test_inference(self):
|
||||
device = "cpu"
|
||||
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.to(device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
output = pipe(**inputs)
|
||||
video = output.frames
|
||||
audio = output.audio
|
||||
|
||||
self.assertEqual(video.shape, (1, 5, 3, 32, 32))
|
||||
self.assertEqual(audio.shape[0], 1)
|
||||
self.assertEqual(audio.shape[1], components["vocoder"].config.out_channels)
|
||||
|
||||
# fmt: off
|
||||
expected_video_slice = torch.tensor(
|
||||
[
|
||||
0.3573, 0.8382, 0.3581, 0.6114, 0.3682, 0.7969, 0.2552, 0.6399, 0.3113, 0.1497, 0.3249, 0.5395, 0.3498, 0.4526, 0.4536, 0.4555
|
||||
]
|
||||
)
|
||||
expected_audio_slice = torch.tensor(
|
||||
[
|
||||
0.0236, 0.0499, 0.1230, 0.1094, 0.1713, 0.1044, 0.1729, 0.1009, 0.0672, -0.0069, 0.0688, 0.0097, 0.0808, 0.1231, 0.0986, 0.0739
|
||||
]
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
video = video.flatten()
|
||||
audio = audio.flatten()
|
||||
generated_video_slice = torch.cat([video[:8], video[-8:]])
|
||||
generated_audio_slice = torch.cat([audio[:8], audio[-8:]])
|
||||
|
||||
assert torch.allclose(expected_video_slice, generated_video_slice, atol=1e-4, rtol=1e-4)
|
||||
assert torch.allclose(expected_audio_slice, generated_audio_slice, atol=1e-4, rtol=1e-4)
|
||||
|
||||
def test_inference_batch_single_identical(self):
|
||||
self._test_inference_batch_single_identical(batch_size=2, expected_max_diff=2e-2)
|
||||
Reference in New Issue
Block a user