From 79cf6d7ba451e7c84540b57850ec3f99c2fce9ef Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Wed, 7 Jan 2026 04:16:03 +0100 Subject: [PATCH] Support LTX 2.0 audio VAE encoder --- scripts/convert_ltx2_to_diffusers.py | 6 +- .../autoencoders/autoencoder_kl_ltx2_audio.py | 269 ++++++++++++++++-- .../test_models_autoencoder_kl_ltx2_audio.py | 88 ++++++ 3 files changed, 340 insertions(+), 23 deletions(-) create mode 100644 tests/models/autoencoders/test_models_autoencoder_kl_ltx2_audio.py diff --git a/scripts/convert_ltx2_to_diffusers.py b/scripts/convert_ltx2_to_diffusers.py index 4ec654d9d7..eb0b010075 100644 --- a/scripts/convert_ltx2_to_diffusers.py +++ b/scripts/convert_ltx2_to_diffusers.py @@ -148,10 +148,7 @@ LTX_2_0_VAE_SPECIAL_KEYS_REMAP = { "per_channel_statistics.mean-of-stds": remove_keys_inplace, } -LTX_2_0_AUDIO_VAE_SPECIAL_KEYS_REMAP = { - "encoder": remove_keys_inplace, - "per_channel_statistics": convert_ltx2_audio_vae_per_channel_statistics, -} +LTX_2_0_AUDIO_VAE_SPECIAL_KEYS_REMAP = {} LTX_2_0_VOCODER_SPECIAL_KEYS_REMAP = {} @@ -499,6 +496,7 @@ def get_ltx2_audio_vae_config(version: str) -> Tuple[Dict[str, Any], Dict[str, A "mel_hop_length": 160, "is_causal": True, "mel_bins": 64, + "double_z": True, }, } rename_dict = LTX_2_0_AUDIO_VAE_RENAME_DICT diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_ltx2_audio.py b/src/diffusers/models/autoencoders/autoencoder_kl_ltx2_audio.py index 091d55645a..dc09f44d82 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_ltx2_audio.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_ltx2_audio.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Set, Tuple, Union +from typing import Optional, Tuple, Union import torch import torch.nn as nn @@ -21,8 +21,9 @@ import torch.nn.functional as F from ...configuration_utils import ConfigMixin, register_to_config from ...utils.accelerate_utils import apply_forward_hook +from ..modeling_outputs import AutoencoderKLOutput from ..modeling_utils import ModelMixin -from .vae import AutoencoderMixin, DecoderOutput +from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution LATENT_DOWNSAMPLE_FACTOR = 4 @@ -219,6 +220,40 @@ class LTX2AudioResnetBlock(nn.Module): return x + h +class LTX2AudioDownsample(nn.Module): + def __init__(self, in_channels: int, with_conv: bool, causality_axis: Optional[str] = "height") -> None: + super().__init__() + self.with_conv = with_conv + self.causality_axis = causality_axis + + if self.with_conv: + self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.with_conv: + # Padding tuple is in the order: (left, right, top, bottom). + if self.causality_axis == "none": + pad = (0, 1, 0, 1) + elif self.causality_axis == "width": + pad = (2, 0, 0, 1) + elif self.causality_axis == "height": + pad = (0, 1, 2, 0) + elif self.causality_axis == "width-compatibility": + pad = (1, 0, 0, 1) + else: + raise ValueError( + f"Invalid `causality_axis` {self.causality_axis}; supported values are `none`, `width`, `height`," + f" and `width-compatibility`." + ) + + x = F.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + else: + # with_conv=False implies that causality_axis is "none" + x = F.avg_pool2d(x, kernel_size=2, stride=2) + return x + + class LTX2AudioUpsample(nn.Module): def __init__(self, in_channels: int, with_conv: bool, causality_axis: Optional[str] = "height") -> None: super().__init__() @@ -282,6 +317,156 @@ class LTX2AudioAudioPatchifier: return self._patch_size +class LTX2AudioEncoder(nn.Module): + def __init__( + self, + base_channels: int = 128, + output_channels: int = 1, + num_res_blocks: int = 2, + attn_resolutions: Optional[Tuple[int, ...]] = None, + in_channels: int = 2, + resolution: int = 256, + latent_channels: int = 8, + ch_mult: Tuple[int, ...] = (1, 2, 4), + norm_type: str = "group", + causality_axis: Optional[str] = "width", + dropout: float = 0.0, + mid_block_add_attention: bool = False, + sample_rate: int = 16000, + mel_hop_length: int = 160, + is_causal: bool = True, + mel_bins: Optional[int] = 64, + double_z: bool = True, + ): + super().__init__() + + self.sample_rate = sample_rate + self.mel_hop_length = mel_hop_length + self.is_causal = is_causal + self.mel_bins = mel_bins + + self.base_channels = base_channels + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.out_ch = output_channels + self.give_pre_end = False + self.tanh_out = False + self.norm_type = norm_type + self.latent_channels = latent_channels + self.channel_multipliers = ch_mult + self.attn_resolutions = attn_resolutions + self.causality_axis = causality_axis + + base_block_channels = base_channels + base_resolution = resolution + self.z_shape = (1, latent_channels, base_resolution, base_resolution) + + if self.causality_axis is not None: + self.conv_in = LTX2AudioCausalConv2d( + in_channels, base_block_channels, kernel_size=3, stride=1, causality_axis=self.causality_axis + ) + else: + self.conv_in = nn.Conv2d(in_channels, base_block_channels, kernel_size=3, stride=1, padding=1) + + self.down = nn.ModuleList() + block_in = base_block_channels + curr_res = self.resolution + + for level in range(self.num_resolutions): + stage = nn.Module() + stage.block = nn.ModuleList() + stage.attn = nn.ModuleList() + block_out = self.base_channels * self.channel_multipliers[level] + + for _ in range(self.num_res_blocks): + stage.block.append( + LTX2AudioResnetBlock( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + norm_type=self.norm_type, + causality_axis=self.causality_axis, + ) + ) + block_in = block_out + if self.attn_resolutions: + if curr_res in self.attn_resolutions: + stage.attn.append(LTX2AudioAttnBlock(block_in, norm_type=self.norm_type)) + + if level != self.num_resolutions - 1: + stage.downsample = LTX2AudioDownsample(block_in, True, causality_axis=self.causality_axis) + curr_res = curr_res // 2 + + self.down.append(stage) + + self.mid = nn.Module() + self.mid.block_1 = LTX2AudioResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + norm_type=self.norm_type, + causality_axis=self.causality_axis, + ) + if mid_block_add_attention: + self.mid.attn_1 = LTX2AudioAttnBlock(block_in, norm_type=self.norm_type) + else: + self.mid.attn_1 = nn.Identity() + self.mid.block_2 = LTX2AudioResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + norm_type=self.norm_type, + causality_axis=self.causality_axis, + ) + + final_block_channels = block_in + z_channels = 2 * latent_channels if double_z else latent_channels + if self.norm_type == "group": + self.norm_out = nn.GroupNorm(num_groups=32, num_channels=final_block_channels, eps=1e-6, affine=True) + elif self.norm_type == "pixel": + self.norm_out = LTX2AudioPixelNorm(dim=1, eps=1e-6) + else: + raise ValueError(f"Invalid normalization type: {self.norm_type}") + self.non_linearity = nn.SiLU() + + if self.causality_axis is not None: + self.conv_out = LTX2AudioCausalConv2d( + final_block_channels, z_channels, kernel_size=3, stride=1, causality_axis=self.causality_axis + ) + else: + self.conv_out = nn.Conv2d(final_block_channels, z_channels, kernel_size=3, stride=1, padding=1) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + # hidden_states expected shape: (batch_size, channels, time, num_mel_bins) + hidden_states = self.conv_in(hidden_states) + + for level in range(self.num_resolutions): + stage = self.down[level] + for block_idx, block in enumerate(stage.block): + hidden_states = block(hidden_states, temb=None) + if stage.attn: + hidden_states = stage.attn[block_idx](hidden_states) + + if level != self.num_resolutions - 1 and hasattr(stage, "downsample"): + hidden_states = stage.downsample(hidden_states) + + hidden_states = self.mid.block_1(hidden_states, temb=None) + hidden_states = self.mid.attn_1(hidden_states) + hidden_states = self.mid.block_2(hidden_states, temb=None) + + hidden_states = self.norm_out(hidden_states) + hidden_states = self.non_linearity(hidden_states) + hidden_states = self.conv_out(hidden_states) + + return hidden_states + + class LTX2AudioDecoder(nn.Module): """ Symmetric decoder that reconstructs audio spectrograms from latent features. @@ -292,22 +477,22 @@ class LTX2AudioDecoder(nn.Module): def __init__( self, - base_channels: int, - output_channels: int, - num_res_blocks: int, - attn_resolutions: Set[int], - in_channels: int, - resolution: int, - latent_channels: int, - ch_mult: Tuple[int, ...] = (1, 2, 4, 8), + base_channels: int = 128, + output_channels: int = 1, + num_res_blocks: int = 2, + attn_resolutions: Optional[Tuple[int, ...]] = None, + in_channels: int = 2, + resolution: int = 256, + latent_channels: int = 8, + ch_mult: Tuple[int, ...] = (1, 2, 4), norm_type: str = "group", causality_axis: Optional[str] = "width", dropout: float = 0.0, - mid_block_add_attention: bool = True, + mid_block_add_attention: bool = False, sample_rate: int = 16000, mel_hop_length: int = 160, is_causal: bool = True, - mel_bins: Optional[int] = None, + mel_bins: Optional[int] = 64, ) -> None: super().__init__() @@ -493,9 +678,9 @@ class AutoencoderKLLTX2Audio(ModelMixin, AutoencoderMixin, ConfigMixin): self, base_channels: int = 128, output_channels: int = 2, - ch_mult: Tuple[int] = (1, 2, 4), + ch_mult: Tuple[int, ...] = (1, 2, 4), num_res_blocks: int = 2, - attn_resolutions: Optional[Tuple[int]] = None, + attn_resolutions: Optional[Tuple[int, ...]] = None, in_channels: int = 2, resolution: int = 256, latent_channels: int = 8, @@ -507,6 +692,7 @@ class AutoencoderKLLTX2Audio(ModelMixin, AutoencoderMixin, ConfigMixin): mel_hop_length: int = 160, is_causal: bool = True, mel_bins: Optional[int] = 64, + double_z: bool = True, ) -> None: super().__init__() @@ -516,6 +702,26 @@ class AutoencoderKLLTX2Audio(ModelMixin, AutoencoderMixin, ConfigMixin): attn_resolution_set = set(attn_resolutions) if attn_resolutions else attn_resolutions + self.encoder = LTX2AudioEncoder( + base_channels=base_channels, + output_channels=output_channels, + ch_mult=ch_mult, + num_res_blocks=num_res_blocks, + attn_resolutions=attn_resolution_set, + in_channels=in_channels, + resolution=resolution, + latent_channels=latent_channels, + norm_type=norm_type, + causality_axis=causality_axis, + dropout=dropout, + mid_block_add_attention=mid_block_add_attention, + sample_rate=sample_rate, + mel_hop_length=mel_hop_length, + is_causal=is_causal, + mel_bins=mel_bins, + double_z=double_z, + ) + self.decoder = LTX2AudioDecoder( base_channels=base_channels, output_channels=output_channels, @@ -548,9 +754,21 @@ class AutoencoderKLLTX2Audio(ModelMixin, AutoencoderMixin, ConfigMixin): self.mel_compression_ratio = LATENT_DOWNSAMPLE_FACTOR self.use_slicing = False + def _encode(self, x: torch.Tensor) -> torch.Tensor: + return self.encoder(x) + @apply_forward_hook def encode(self, x: torch.Tensor, return_dict: bool = True): - raise NotImplementedError("AutoencoderKLLTX2Audio does not implement encoding.") + if self.use_slicing and x.shape[0] > 1: + encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)] + h = torch.cat(encoded_slices) + else: + h = self._encode(x) + posterior = DiagonalGaussianDistribution(h) + + if not return_dict: + return (posterior,) + return AutoencoderKLOutput(latent_dist=posterior) def _decode(self, z: torch.Tensor) -> torch.Tensor: return self.decoder(z) @@ -568,7 +786,20 @@ class AutoencoderKLLTX2Audio(ModelMixin, AutoencoderMixin, ConfigMixin): return DecoderOutput(sample=decoded) - def forward(self, *args, **kwargs): - raise NotImplementedError( - "This model doesn't have an encoder yet so we don't implement its `forward()`. Please use `decode()`." - ) + def forward( + self, + sample: torch.Tensor, + sample_posterior: bool = False, + return_dict: bool = True, + generator: Optional[torch.Generator] = None, + ) -> Union[DecoderOutput, torch.Tensor]: + posterior = self.encode(sample).latent_dist + if sample_posterior: + z = posterior.sample(generator=generator) + else: + z = posterior.mode() + print(f"z shape: {z.shape}") + dec = self.decode(z) + if not return_dict: + return (dec.sample,) + return dec diff --git a/tests/models/autoencoders/test_models_autoencoder_kl_ltx2_audio.py b/tests/models/autoencoders/test_models_autoencoder_kl_ltx2_audio.py new file mode 100644 index 0000000000..3c10330e20 --- /dev/null +++ b/tests/models/autoencoders/test_models_autoencoder_kl_ltx2_audio.py @@ -0,0 +1,88 @@ +# coding=utf-8 +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from diffusers import AutoencoderKLLTX2Audio + +from ...testing_utils import ( + floats_tensor, + torch_device, +) +from ..test_modeling_common import ModelTesterMixin +from .testing_utils import AutoencoderTesterMixin + + +class AutoencoderKLLTX2AudioTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase): + model_class = AutoencoderKLLTX2Audio + main_input_name = "sample" + base_precision = 1e-2 + + def get_autoencoder_kl_ltx_video_config(self): + return { + "in_channels": 2, # stereo, + "output_channels": 2, + "latent_channels": 4, + "base_channels": 16, + "ch_mult": (1, 2, 4), + "resolution": 16, + "attn_resolutions": None, + "num_res_blocks": 2, + "norm_type": "pixel", + "causality_axis": "height", + "mid_block_add_attention": False, + "sample_rate": 16000, + "mel_hop_length": 160, + "mel_bins": 16, + "is_causal": True, + "double_z": True, + } + + @property + def dummy_input(self): + batch_size = 2 + num_channels = 2 + num_frames = 8 + num_mel_bins = 16 + + spectrogram = floats_tensor((batch_size, num_channels, num_frames, num_mel_bins)).to(torch_device) + + input_dict = {"sample": spectrogram} + return input_dict + + @property + def input_shape(self): + return (2, 5, 16) + + @property + def output_shape(self): + return (2, 5, 16) + + def prepare_init_args_and_inputs_for_common(self): + init_dict = self.get_autoencoder_kl_ltx_video_config() + inputs_dict = self.dummy_input + return init_dict, inputs_dict + + # Overriding as output shape is not the same as input shape for LTX 2.0 audio VAE + def test_output(self): + super().test_output(expected_output_shape=(2, 2, 5, 16)) + + @unittest.skip("Unsupported test.") + def test_outputs_equivalence(self): + pass + + @unittest.skip("AutoencoderKLLTXAudio does not support `norm_num_groups` because it does not use GroupNorm.") + def test_forward_with_norm_groups(self): + pass