1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Support LTX 2.0 audio VAE encoder

This commit is contained in:
Daniel Gu
2026-01-07 04:16:03 +01:00
parent d01a242cdb
commit 79cf6d7ba4
3 changed files with 340 additions and 23 deletions

View File

@@ -148,10 +148,7 @@ LTX_2_0_VAE_SPECIAL_KEYS_REMAP = {
"per_channel_statistics.mean-of-stds": remove_keys_inplace,
}
LTX_2_0_AUDIO_VAE_SPECIAL_KEYS_REMAP = {
"encoder": remove_keys_inplace,
"per_channel_statistics": convert_ltx2_audio_vae_per_channel_statistics,
}
LTX_2_0_AUDIO_VAE_SPECIAL_KEYS_REMAP = {}
LTX_2_0_VOCODER_SPECIAL_KEYS_REMAP = {}
@@ -499,6 +496,7 @@ def get_ltx2_audio_vae_config(version: str) -> Tuple[Dict[str, Any], Dict[str, A
"mel_hop_length": 160,
"is_causal": True,
"mel_bins": 64,
"double_z": True,
},
}
rename_dict = LTX_2_0_AUDIO_VAE_RENAME_DICT

View File

@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Set, Tuple, Union
from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
@@ -21,8 +21,9 @@ import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...utils.accelerate_utils import apply_forward_hook
from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin
from .vae import AutoencoderMixin, DecoderOutput
from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution
LATENT_DOWNSAMPLE_FACTOR = 4
@@ -219,6 +220,40 @@ class LTX2AudioResnetBlock(nn.Module):
return x + h
class LTX2AudioDownsample(nn.Module):
def __init__(self, in_channels: int, with_conv: bool, causality_axis: Optional[str] = "height") -> None:
super().__init__()
self.with_conv = with_conv
self.causality_axis = causality_axis
if self.with_conv:
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.with_conv:
# Padding tuple is in the order: (left, right, top, bottom).
if self.causality_axis == "none":
pad = (0, 1, 0, 1)
elif self.causality_axis == "width":
pad = (2, 0, 0, 1)
elif self.causality_axis == "height":
pad = (0, 1, 2, 0)
elif self.causality_axis == "width-compatibility":
pad = (1, 0, 0, 1)
else:
raise ValueError(
f"Invalid `causality_axis` {self.causality_axis}; supported values are `none`, `width`, `height`,"
f" and `width-compatibility`."
)
x = F.pad(x, pad, mode="constant", value=0)
x = self.conv(x)
else:
# with_conv=False implies that causality_axis is "none"
x = F.avg_pool2d(x, kernel_size=2, stride=2)
return x
class LTX2AudioUpsample(nn.Module):
def __init__(self, in_channels: int, with_conv: bool, causality_axis: Optional[str] = "height") -> None:
super().__init__()
@@ -282,6 +317,156 @@ class LTX2AudioAudioPatchifier:
return self._patch_size
class LTX2AudioEncoder(nn.Module):
def __init__(
self,
base_channels: int = 128,
output_channels: int = 1,
num_res_blocks: int = 2,
attn_resolutions: Optional[Tuple[int, ...]] = None,
in_channels: int = 2,
resolution: int = 256,
latent_channels: int = 8,
ch_mult: Tuple[int, ...] = (1, 2, 4),
norm_type: str = "group",
causality_axis: Optional[str] = "width",
dropout: float = 0.0,
mid_block_add_attention: bool = False,
sample_rate: int = 16000,
mel_hop_length: int = 160,
is_causal: bool = True,
mel_bins: Optional[int] = 64,
double_z: bool = True,
):
super().__init__()
self.sample_rate = sample_rate
self.mel_hop_length = mel_hop_length
self.is_causal = is_causal
self.mel_bins = mel_bins
self.base_channels = base_channels
self.temb_ch = 0
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
self.out_ch = output_channels
self.give_pre_end = False
self.tanh_out = False
self.norm_type = norm_type
self.latent_channels = latent_channels
self.channel_multipliers = ch_mult
self.attn_resolutions = attn_resolutions
self.causality_axis = causality_axis
base_block_channels = base_channels
base_resolution = resolution
self.z_shape = (1, latent_channels, base_resolution, base_resolution)
if self.causality_axis is not None:
self.conv_in = LTX2AudioCausalConv2d(
in_channels, base_block_channels, kernel_size=3, stride=1, causality_axis=self.causality_axis
)
else:
self.conv_in = nn.Conv2d(in_channels, base_block_channels, kernel_size=3, stride=1, padding=1)
self.down = nn.ModuleList()
block_in = base_block_channels
curr_res = self.resolution
for level in range(self.num_resolutions):
stage = nn.Module()
stage.block = nn.ModuleList()
stage.attn = nn.ModuleList()
block_out = self.base_channels * self.channel_multipliers[level]
for _ in range(self.num_res_blocks):
stage.block.append(
LTX2AudioResnetBlock(
in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout,
norm_type=self.norm_type,
causality_axis=self.causality_axis,
)
)
block_in = block_out
if self.attn_resolutions:
if curr_res in self.attn_resolutions:
stage.attn.append(LTX2AudioAttnBlock(block_in, norm_type=self.norm_type))
if level != self.num_resolutions - 1:
stage.downsample = LTX2AudioDownsample(block_in, True, causality_axis=self.causality_axis)
curr_res = curr_res // 2
self.down.append(stage)
self.mid = nn.Module()
self.mid.block_1 = LTX2AudioResnetBlock(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout,
norm_type=self.norm_type,
causality_axis=self.causality_axis,
)
if mid_block_add_attention:
self.mid.attn_1 = LTX2AudioAttnBlock(block_in, norm_type=self.norm_type)
else:
self.mid.attn_1 = nn.Identity()
self.mid.block_2 = LTX2AudioResnetBlock(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout,
norm_type=self.norm_type,
causality_axis=self.causality_axis,
)
final_block_channels = block_in
z_channels = 2 * latent_channels if double_z else latent_channels
if self.norm_type == "group":
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=final_block_channels, eps=1e-6, affine=True)
elif self.norm_type == "pixel":
self.norm_out = LTX2AudioPixelNorm(dim=1, eps=1e-6)
else:
raise ValueError(f"Invalid normalization type: {self.norm_type}")
self.non_linearity = nn.SiLU()
if self.causality_axis is not None:
self.conv_out = LTX2AudioCausalConv2d(
final_block_channels, z_channels, kernel_size=3, stride=1, causality_axis=self.causality_axis
)
else:
self.conv_out = nn.Conv2d(final_block_channels, z_channels, kernel_size=3, stride=1, padding=1)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# hidden_states expected shape: (batch_size, channels, time, num_mel_bins)
hidden_states = self.conv_in(hidden_states)
for level in range(self.num_resolutions):
stage = self.down[level]
for block_idx, block in enumerate(stage.block):
hidden_states = block(hidden_states, temb=None)
if stage.attn:
hidden_states = stage.attn[block_idx](hidden_states)
if level != self.num_resolutions - 1 and hasattr(stage, "downsample"):
hidden_states = stage.downsample(hidden_states)
hidden_states = self.mid.block_1(hidden_states, temb=None)
hidden_states = self.mid.attn_1(hidden_states)
hidden_states = self.mid.block_2(hidden_states, temb=None)
hidden_states = self.norm_out(hidden_states)
hidden_states = self.non_linearity(hidden_states)
hidden_states = self.conv_out(hidden_states)
return hidden_states
class LTX2AudioDecoder(nn.Module):
"""
Symmetric decoder that reconstructs audio spectrograms from latent features.
@@ -292,22 +477,22 @@ class LTX2AudioDecoder(nn.Module):
def __init__(
self,
base_channels: int,
output_channels: int,
num_res_blocks: int,
attn_resolutions: Set[int],
in_channels: int,
resolution: int,
latent_channels: int,
ch_mult: Tuple[int, ...] = (1, 2, 4, 8),
base_channels: int = 128,
output_channels: int = 1,
num_res_blocks: int = 2,
attn_resolutions: Optional[Tuple[int, ...]] = None,
in_channels: int = 2,
resolution: int = 256,
latent_channels: int = 8,
ch_mult: Tuple[int, ...] = (1, 2, 4),
norm_type: str = "group",
causality_axis: Optional[str] = "width",
dropout: float = 0.0,
mid_block_add_attention: bool = True,
mid_block_add_attention: bool = False,
sample_rate: int = 16000,
mel_hop_length: int = 160,
is_causal: bool = True,
mel_bins: Optional[int] = None,
mel_bins: Optional[int] = 64,
) -> None:
super().__init__()
@@ -493,9 +678,9 @@ class AutoencoderKLLTX2Audio(ModelMixin, AutoencoderMixin, ConfigMixin):
self,
base_channels: int = 128,
output_channels: int = 2,
ch_mult: Tuple[int] = (1, 2, 4),
ch_mult: Tuple[int, ...] = (1, 2, 4),
num_res_blocks: int = 2,
attn_resolutions: Optional[Tuple[int]] = None,
attn_resolutions: Optional[Tuple[int, ...]] = None,
in_channels: int = 2,
resolution: int = 256,
latent_channels: int = 8,
@@ -507,6 +692,7 @@ class AutoencoderKLLTX2Audio(ModelMixin, AutoencoderMixin, ConfigMixin):
mel_hop_length: int = 160,
is_causal: bool = True,
mel_bins: Optional[int] = 64,
double_z: bool = True,
) -> None:
super().__init__()
@@ -516,6 +702,26 @@ class AutoencoderKLLTX2Audio(ModelMixin, AutoencoderMixin, ConfigMixin):
attn_resolution_set = set(attn_resolutions) if attn_resolutions else attn_resolutions
self.encoder = LTX2AudioEncoder(
base_channels=base_channels,
output_channels=output_channels,
ch_mult=ch_mult,
num_res_blocks=num_res_blocks,
attn_resolutions=attn_resolution_set,
in_channels=in_channels,
resolution=resolution,
latent_channels=latent_channels,
norm_type=norm_type,
causality_axis=causality_axis,
dropout=dropout,
mid_block_add_attention=mid_block_add_attention,
sample_rate=sample_rate,
mel_hop_length=mel_hop_length,
is_causal=is_causal,
mel_bins=mel_bins,
double_z=double_z,
)
self.decoder = LTX2AudioDecoder(
base_channels=base_channels,
output_channels=output_channels,
@@ -548,9 +754,21 @@ class AutoencoderKLLTX2Audio(ModelMixin, AutoencoderMixin, ConfigMixin):
self.mel_compression_ratio = LATENT_DOWNSAMPLE_FACTOR
self.use_slicing = False
def _encode(self, x: torch.Tensor) -> torch.Tensor:
return self.encoder(x)
@apply_forward_hook
def encode(self, x: torch.Tensor, return_dict: bool = True):
raise NotImplementedError("AutoencoderKLLTX2Audio does not implement encoding.")
if self.use_slicing and x.shape[0] > 1:
encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
h = torch.cat(encoded_slices)
else:
h = self._encode(x)
posterior = DiagonalGaussianDistribution(h)
if not return_dict:
return (posterior,)
return AutoencoderKLOutput(latent_dist=posterior)
def _decode(self, z: torch.Tensor) -> torch.Tensor:
return self.decoder(z)
@@ -568,7 +786,20 @@ class AutoencoderKLLTX2Audio(ModelMixin, AutoencoderMixin, ConfigMixin):
return DecoderOutput(sample=decoded)
def forward(self, *args, **kwargs):
raise NotImplementedError(
"This model doesn't have an encoder yet so we don't implement its `forward()`. Please use `decode()`."
)
def forward(
self,
sample: torch.Tensor,
sample_posterior: bool = False,
return_dict: bool = True,
generator: Optional[torch.Generator] = None,
) -> Union[DecoderOutput, torch.Tensor]:
posterior = self.encode(sample).latent_dist
if sample_posterior:
z = posterior.sample(generator=generator)
else:
z = posterior.mode()
print(f"z shape: {z.shape}")
dec = self.decode(z)
if not return_dict:
return (dec.sample,)
return dec

View File

@@ -0,0 +1,88 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from diffusers import AutoencoderKLLTX2Audio
from ...testing_utils import (
floats_tensor,
torch_device,
)
from ..test_modeling_common import ModelTesterMixin
from .testing_utils import AutoencoderTesterMixin
class AutoencoderKLLTX2AudioTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
model_class = AutoencoderKLLTX2Audio
main_input_name = "sample"
base_precision = 1e-2
def get_autoencoder_kl_ltx_video_config(self):
return {
"in_channels": 2, # stereo,
"output_channels": 2,
"latent_channels": 4,
"base_channels": 16,
"ch_mult": (1, 2, 4),
"resolution": 16,
"attn_resolutions": None,
"num_res_blocks": 2,
"norm_type": "pixel",
"causality_axis": "height",
"mid_block_add_attention": False,
"sample_rate": 16000,
"mel_hop_length": 160,
"mel_bins": 16,
"is_causal": True,
"double_z": True,
}
@property
def dummy_input(self):
batch_size = 2
num_channels = 2
num_frames = 8
num_mel_bins = 16
spectrogram = floats_tensor((batch_size, num_channels, num_frames, num_mel_bins)).to(torch_device)
input_dict = {"sample": spectrogram}
return input_dict
@property
def input_shape(self):
return (2, 5, 16)
@property
def output_shape(self):
return (2, 5, 16)
def prepare_init_args_and_inputs_for_common(self):
init_dict = self.get_autoencoder_kl_ltx_video_config()
inputs_dict = self.dummy_input
return init_dict, inputs_dict
# Overriding as output shape is not the same as input shape for LTX 2.0 audio VAE
def test_output(self):
super().test_output(expected_output_shape=(2, 2, 5, 16))
@unittest.skip("Unsupported test.")
def test_outputs_equivalence(self):
pass
@unittest.skip("AutoencoderKLLTXAudio does not support `norm_num_groups` because it does not use GroupNorm.")
def test_forward_with_norm_groups(self):
pass