From b34ddb1736377a9b2e01dea5408b99a8cc147f28 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 22 Dec 2025 12:23:31 +0530 Subject: [PATCH 01/11] start audio decoder. --- .../autoencoders/autoencoder_kl_ltx2_audio.py | 655 ++++++++++++++++++ 1 file changed, 655 insertions(+) create mode 100644 src/diffusers/models/autoencoders/autoencoder_kl_ltx2_audio.py diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_ltx2_audio.py b/src/diffusers/models/autoencoders/autoencoder_kl_ltx2_audio.py new file mode 100644 index 0000000000..98d8a53e23 --- /dev/null +++ b/src/diffusers/models/autoencoders/autoencoder_kl_ltx2_audio.py @@ -0,0 +1,655 @@ +# Copyright 2025 The Lightricks team and The HuggingFace Team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import namedtuple +from typing import Optional, Set, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ...configuration_utils import ConfigMixin, register_to_config +from ...utils.accelerate_utils import apply_forward_hook +from ..modeling_utils import ModelMixin +from .vae import AutoencoderMixin, DecoderOutput + + +LATENT_DOWNSAMPLE_FACTOR = 4 +SUPPORTED_CAUSAL_AXES = {"none", "width", "height", "width-compatibility"} + + +AudioLatentShape = namedtuple( + "AudioLatentShape", + [ + "batch", + "channels", + "frames", + "mel_bins", + ], +) + + +def _resolve_causality_axis(causality_axis: Optional[str] = None) -> Optional[str]: + normalized = "none" if causality_axis is None else str(causality_axis).lower() + if normalized not in SUPPORTED_CAUSAL_AXES: + raise NotImplementedError( + f"Unsupported causality_axis '{causality_axis}'. Supported: {sorted(SUPPORTED_CAUSAL_AXES)}" + ) + return None if normalized == "none" else normalized + + +def make_conv2d( + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int, int]], + stride: int = 1, + padding: Optional[Tuple[int, int, int, int]] = None, + dilation: int = 1, + groups: int = 1, + bias: bool = True, + causality_axis: Optional[str] = None, +) -> nn.Module: + if causality_axis is not None: + return LTX2AudioCausalConv2d( + in_channels, out_channels, kernel_size, stride, dilation, groups, bias, causality_axis + ) + if padding is None: + padding = kernel_size // 2 if isinstance(kernel_size, int) else tuple(k // 2 for k in kernel_size) + + return nn.Conv2d( + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + groups, + bias, + ) + + +class LTX2AudioCausalConv2d(nn.Module): + """ + A causal 2D convolution that pads asymmetrically along the causal axis. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int, int]], + stride: int = 1, + dilation: Union[int, Tuple[int, int]] = 1, + groups: int = 1, + bias: bool = True, + causality_axis: str = "height", + ) -> None: + super().__init__() + + self.causality_axis = causality_axis + kernel_size = nn.modules.utils._pair(kernel_size) + dilation = nn.modules.utils._pair(dilation) + + pad_h = (kernel_size[0] - 1) * dilation[0] + pad_w = (kernel_size[1] - 1) * dilation[1] + + if self.causality_axis == "none": + padding = (pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2) + elif self.causality_axis in {"width", "width-compatibility"}: + padding = (pad_w, 0, pad_h // 2, pad_h - pad_h // 2) + elif self.causality_axis == "height": + padding = (pad_w // 2, pad_w - pad_w // 2, pad_h, 0) + else: + raise ValueError(f"Invalid causality_axis: {causality_axis}") + + self.padding = padding + self.conv = nn.Conv2d( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=0, + dilation=dilation, + groups=groups, + bias=bias, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = F.pad(x, self.padding) + return self.conv(x) + + +class LTX2AudioPixelNorm(nn.Module): + """ + Per-pixel (per-location) RMS normalization layer. + """ + + def __init__(self, dim: int = 1, eps: float = 1e-8) -> None: + super().__init__() + self.dim = dim + self.eps = eps + + def forward(self, x: torch.Tensor) -> torch.Tensor: + mean_sq = torch.mean(x**2, dim=self.dim, keepdim=True) + rms = torch.sqrt(mean_sq + self.eps) + return x / rms + + +def build_normalization_layer(in_channels: int, *, num_groups: int = 32, normtype: str = "group") -> nn.Module: + if normtype == "group": + return nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) + if normtype == "pixel": + return LTX2AudioPixelNorm(dim=1, eps=1e-6) + raise ValueError(f"Invalid normalization type: {normtype}") + + +class LTX2AudioAttnBlock(nn.Module): + def __init__( + self, + in_channels: int, + norm_type: str = "group", + ) -> None: + super().__init__() + self.in_channels = in_channels + + self.norm = build_normalization_layer(in_channels, normtype=norm_type) + self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + h_ = self.norm(x) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + batch, channels, height, width = q.shape + q = q.reshape(batch, channels, height * width).permute(0, 2, 1).contiguous() + k = k.reshape(batch, channels, height * width).contiguous() + attn = torch.bmm(q, k) * (int(channels) ** (-0.5)) + attn = torch.nn.functional.softmax(attn, dim=2) + + v = v.reshape(batch, channels, height * width) + attn = attn.permute(0, 2, 1).contiguous() + h_ = torch.bmm(v, attn).reshape(batch, channels, height, width) + + h_ = self.proj_out(h_) + return x + h_ + + +class LTX2AudioResnetBlock(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: Optional[int] = None, + conv_shortcut: bool = False, + dropout: float = 0.0, + temb_channels: int = 512, + norm_type: str = "group", + causality_axis: str = "height", + ) -> None: + super().__init__() + self.causality_axis = causality_axis + + if self.causality_axis is not None and self.causality_axis != "none" and norm_type == "group": + raise ValueError("Causal ResnetBlock with GroupNorm is not supported.") + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = build_normalization_layer(in_channels, normtype=norm_type) + self.non_linearity = nn.SiLU() + self.conv1 = make_conv2d(in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis) + if temb_channels > 0: + self.temb_proj = nn.Linear(temb_channels, out_channels) + self.norm2 = build_normalization_layer(out_channels, normtype=norm_type) + self.dropout = nn.Dropout(dropout) + self.conv2 = make_conv2d(out_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = make_conv2d( + in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis + ) + else: + self.nin_shortcut = make_conv2d( + in_channels, out_channels, kernel_size=1, stride=1, causality_axis=causality_axis + ) + + def forward( + self, + x: torch.Tensor, + temb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + h = self.norm1(x) + h = self.non_linearity(h) + h = self.conv1(h) + + if temb is not None: + h = h + self.temb_proj(self.non_linearity(temb))[:, :, None, None] + + h = self.norm2(h) + h = self.non_linearity(h) + h = self.dropout(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + x = self.conv_shortcut(x) if self.use_conv_shortcut else self.nin_shortcut(x) + + return x + h + + +class LTX2AudioUpsample(nn.Module): + def __init__( + self, + in_channels: int, + with_conv: bool, + causality_axis: Optional[str] = "height", + ) -> None: + super().__init__() + self.with_conv = with_conv + self.causality_axis = causality_axis + if self.with_conv: + self.conv = make_conv2d(in_channels, in_channels, kernel_size=3, stride=1, causality_axis=causality_axis) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + if self.with_conv: + x = self.conv(x) + if self.causality_axis is None or self.causality_axis == "none": + pass + elif self.causality_axis == "height": + x = x[:, :, 1:, :] + elif self.causality_axis == "width": + x = x[:, :, :, 1:] + elif self.causality_axis == "width-compatibility": + pass + else: + raise ValueError(f"Invalid causality_axis: {self.causality_axis}") + + return x + + +class LTX2AudioPerChannelStatistics(nn.Module): + """ + Per-channel statistics for normalizing and denormalizing the latent representation. This statics is computed over + the entire dataset and stored in model's checkpoint under AudioVAE state_dict + """ + + def __init__(self, latent_channels: int = 128) -> None: + super().__init__() + self.register_buffer("std-of-means", torch.empty(latent_channels)) + self.register_buffer("mean-of-means", torch.empty(latent_channels)) + + def un_normalize(self, x: torch.Tensor) -> torch.Tensor: + return (x * self.get_buffer("std-of-means").to(x)) + self.get_buffer("mean-of-means").to(x) + + def normalize(self, x: torch.Tensor) -> torch.Tensor: + return (x - self.get_buffer("mean-of-means").to(x)) / self.get_buffer("std-of-means").to(x) + + +class LTX2AudioAudioPatchifier: + """ + Patchifier for spectrogram/audio latents. + """ + + def __init__( + self, + patch_size: int, + sample_rate: int = 16000, + hop_length: int = 160, + audio_latent_downsample_factor: int = 4, + is_causal: bool = True, + ): + self.hop_length = hop_length + self.sample_rate = sample_rate + self.audio_latent_downsample_factor = audio_latent_downsample_factor + self.is_causal = is_causal + self._patch_size = (1, patch_size, patch_size) + + def patchify(self, audio_latents: torch.Tensor) -> torch.Tensor: + batch, channels, time, freq = audio_latents.shape + return audio_latents.permute(0, 2, 1, 3).reshape(batch, time, channels * freq) + + def unpatchify( + self, + audio_latents: torch.Tensor, + output_shape: AudioLatentShape, + ) -> torch.Tensor: + batch, time, _ = audio_latents.shape + channels = output_shape.channels + freq = output_shape.mel_bins + return audio_latents.view(batch, time, channels, freq).permute(0, 2, 1, 3) + + @property + def patch_size(self) -> Tuple[int, int, int]: + return self._patch_size + + +class LTX2AudioDecoder(nn.Module): + """ + Symmetric decoder that reconstructs audio spectrograms from latent features. + + The decoder mirrors the encoder structure with configurable channel multipliers, attention resolutions, and causal + convolutions. + """ + + def __init__( + self, + base_channels: int, + output_channels: int, + num_res_blocks: int, + attn_resolutions: Set[int], + in_channels: int, + resolution: int, + latent_channels: int, + ch_mult: Tuple[int, ...] = (1, 2, 4, 8), + norm_type: str = "group", + causality_axis: Optional[str] = "width", + dropout: float = 0.0, + mid_block_add_attention: bool = True, + sample_rate: int = 16000, + mel_hop_length: int = 160, + is_causal: bool = True, + mel_bins: Optional[int] = None, + ) -> None: + super().__init__() + + resolved_causality_axis = _resolve_causality_axis(causality_axis) + + self.per_channel_statistics = LTX2AudioPerChannelStatistics(latent_channels=base_channels) + self.sample_rate = sample_rate + self.mel_hop_length = mel_hop_length + self.is_causal = is_causal + self.mel_bins = mel_bins + self.patchifier = LTX2AudioAudioPatchifier( + patch_size=1, + audio_latent_downsample_factor=LATENT_DOWNSAMPLE_FACTOR, + sample_rate=sample_rate, + hop_length=mel_hop_length, + is_causal=is_causal, + ) + + self.base_channels = base_channels + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.out_ch = output_channels + self.give_pre_end = False + self.tanh_out = False + self.norm_type = norm_type + self.latent_channels = latent_channels + self.channel_multipliers = ch_mult + self.attn_resolutions = attn_resolutions + self.causality_axis = resolved_causality_axis + + base_block_channels = base_channels * self.channel_multipliers[-1] + base_resolution = resolution // (2 ** (self.num_resolutions - 1)) + self.z_shape = (1, latent_channels, base_resolution, base_resolution) + + self.conv_in = make_conv2d( + latent_channels, base_block_channels, kernel_size=3, stride=1, causality_axis=self.causality_axis + ) + self.non_linearity = nn.SiLU() + self.mid = self._build_mid_layers(base_block_channels, dropout, mid_block_add_attention) + self.up, final_block_channels = self._build_up_path( + initial_block_channels=base_block_channels, + dropout=dropout, + resamp_with_conv=True, + ) + + self.norm_out = build_normalization_layer(final_block_channels, normtype=self.norm_type) + self.conv_out = make_conv2d( + final_block_channels, output_channels, kernel_size=3, stride=1, causality_axis=self.causality_axis + ) + + def _adjust_output_shape( + self, + decoded_output: torch.Tensor, + target_shape: AudioLatentShape, + ) -> torch.Tensor: + _, _, current_time, current_freq = decoded_output.shape + target_channels = target_shape.channels + target_time = target_shape.frames + target_freq = target_shape.mel_bins + + decoded_output = decoded_output[ + :, :target_channels, : min(current_time, target_time), : min(current_freq, target_freq) + ] + + time_padding_needed = target_time - decoded_output.shape[2] + freq_padding_needed = target_freq - decoded_output.shape[3] + + if time_padding_needed > 0 or freq_padding_needed > 0: + padding = ( + 0, + max(freq_padding_needed, 0), + 0, + max(time_padding_needed, 0), + ) + decoded_output = F.pad(decoded_output, padding) + + decoded_output = decoded_output[:, :target_channels, :target_time, :target_freq] + + return decoded_output + + def forward( + self, + sample: torch.Tensor, + ) -> torch.Tensor: + latent_shape = AudioLatentShape( + batch=sample.shape[0], + channels=sample.shape[1], + frames=sample.shape[2], + mel_bins=sample.shape[3], + ) + + sample_patched = self.patchifier.patchify(sample) + sample_denormalized = self.per_channel_statistics.un_normalize(sample_patched) + sample = self.patchifier.unpatchify(sample_denormalized, latent_shape) + + target_frames = latent_shape.frames * LATENT_DOWNSAMPLE_FACTOR + + if self.causality_axis is not None: + target_frames = max(target_frames - (LATENT_DOWNSAMPLE_FACTOR - 1), 1) + + target_shape = AudioLatentShape( + batch=latent_shape.batch, + channels=self.out_ch, + frames=target_frames, + mel_bins=self.mel_bins if self.mel_bins is not None else latent_shape.mel_bins, + ) + + hidden_features = self.conv_in(sample) + hidden_features = self._run_mid_layers(hidden_features) + hidden_features = self._run_upsampling_path(hidden_features) + decoded_output = self._finalize_output(hidden_features) + + decoded_output = self._adjust_output_shape(decoded_output, target_shape) + + return decoded_output + + def _build_mid_layers(self, channels: int, dropout: float, add_attention: bool) -> nn.Module: + mid = nn.Module() + mid.block_1 = LTX2AudioResnetBlock( + in_channels=channels, + out_channels=channels, + temb_channels=self.temb_ch, + dropout=dropout, + norm_type=self.norm_type, + causality_axis=self.causality_axis, + ) + mid.attn_1 = LTX2AudioAttnBlock(channels, norm_type=self.norm_type) if add_attention else nn.Identity() + mid.block_2 = LTX2AudioResnetBlock( + in_channels=channels, + out_channels=channels, + temb_channels=self.temb_ch, + dropout=dropout, + norm_type=self.norm_type, + causality_axis=self.causality_axis, + ) + return mid + + def _build_up_path( + self, + initial_block_channels: int, + dropout: float, + resamp_with_conv: bool, + ) -> tuple[nn.ModuleList, int]: + up_modules = nn.ModuleList() + block_in = initial_block_channels + curr_res = self.resolution // (2 ** (self.num_resolutions - 1)) + + for level in reversed(range(self.num_resolutions)): + stage = nn.Module() + stage.block = nn.ModuleList() + stage.attn = nn.ModuleList() + block_out = self.base_channels * self.channel_multipliers[level] + + for _ in range(self.num_res_blocks + 1): + stage.block.append( + LTX2AudioResnetBlock( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + norm_type=self.norm_type, + causality_axis=self.causality_axis, + ) + ) + block_in = block_out + if curr_res in self.attn_resolutions: + stage.attn.append(LTX2AudioAttnBlock(block_in, norm_type=self.norm_type)) + + if level != 0: + stage.upsample = LTX2AudioUpsample(block_in, resamp_with_conv, causality_axis=self.causality_axis) + curr_res *= 2 + + up_modules.insert(0, stage) + + return up_modules, block_in + + def _run_mid_layers(self, features: torch.Tensor) -> torch.Tensor: + features = self.mid.block_1(features, temb=None) + features = self.mid.attn_1(features) + return self.mid.block_2(features, temb=None) + + def _run_upsampling_path(self, features: torch.Tensor) -> torch.Tensor: + for level in reversed(range(self.num_resolutions)): + stage = self.up[level] + for block_idx, block in enumerate(stage.block): + features = block(features, temb=None) + if stage.attn: + features = stage.attn[block_idx](features) + + if level != 0 and hasattr(stage, "upsample"): + features = stage.upsample(features) + + return features + + def _finalize_output(self, features: torch.Tensor) -> torch.Tensor: + if self.give_pre_end: + return features + + hidden = self.norm_out(features) + hidden = self.non_linearity(hidden) + decoded = self.conv_out(hidden) + return torch.tanh(decoded) if self.tanh_out else decoded + + +class AutoencoderKLLTX2Audio(ModelMixin, AutoencoderMixin, ConfigMixin): + r""" + LTX2 audio VAE. Currently, only implements the decoder. + """ + + _supports_gradient_checkpointing = False + + @register_to_config + def __init__( + self, + base_channels: int = 128, + output_channels: int = 2, + ch_mult: Tuple[int] = (1, 2, 4), + num_res_blocks: int = 2, + attn_resolutions: Tuple[int] = (8, 16, 32), + in_channels: int = 2, + resolution: int = 256, + latent_channels: int = 8, + norm_type: str = "pixel", + causality_axis: Optional[str] = "height", + dropout: float = 0.0, + mid_block_add_attention: bool = True, + sample_rate: int = 16000, + mel_hop_length: int = 160, + is_causal: bool = True, + mel_bins: Optional[int] = None, + ) -> None: + super().__init__() + + resolved_causality_axis = _resolve_causality_axis(causality_axis) + attn_resolution_set = set(attn_resolutions) + + self.decoder = LTX2AudioDecoder( + base_channels=base_channels, + output_channels=output_channels, + ch_mult=ch_mult, + num_res_blocks=num_res_blocks, + attn_resolutions=attn_resolution_set, + in_channels=in_channels, + resolution=resolution, + latent_channels=latent_channels, + norm_type=norm_type, + causality_axis=resolved_causality_axis, + dropout=dropout, + mid_block_add_attention=mid_block_add_attention, + sample_rate=sample_rate, + mel_hop_length=mel_hop_length, + is_causal=is_causal, + mel_bins=mel_bins, + ) + + self.use_slicing = False + + @apply_forward_hook + def encode( + self, + x: torch.Tensor, + return_dict: bool = True, + ): + raise NotImplementedError("AutoencoderKLLTX2Audio does not implement encoding.") + + def _decode(self, z: torch.Tensor) -> torch.Tensor: + return self.decoder(z) + + @apply_forward_hook + def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: + if self.use_slicing and z.shape[0] > 1: + decoded_slices = [self._decode(z_slice) for z_slice in z.split(1)] + decoded = torch.cat(decoded_slices) + else: + decoded = self._decode(z) + + if not return_dict: + return (decoded,) + + return DecoderOutput(sample=decoded) + + def forward(self, *args, **kwargs): + raise NotImplementedError( + "This model doesn't have an encoder yet so we don't implement its `forward()`. Please use `decode()`." + ) From f4c2435d61f03e6e97bcbafec1ece6b5bcf50357 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 22 Dec 2025 12:25:36 +0530 Subject: [PATCH 02/11] init registration. --- src/diffusers/models/autoencoders/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/models/autoencoders/__init__.py b/src/diffusers/models/autoencoders/__init__.py index 56df27f93c..032bbe4123 100644 --- a/src/diffusers/models/autoencoders/__init__.py +++ b/src/diffusers/models/autoencoders/__init__.py @@ -10,6 +10,7 @@ from .autoencoder_kl_hunyuanimage import AutoencoderKLHunyuanImage from .autoencoder_kl_hunyuanimage_refiner import AutoencoderKLHunyuanImageRefiner from .autoencoder_kl_hunyuanvideo15 import AutoencoderKLHunyuanVideo15 from .autoencoder_kl_ltx import AutoencoderKLLTXVideo +from .autoencoder_kl_ltx2_audio import AutoencoderKLLTX2Audio from .autoencoder_kl_magvit import AutoencoderKLMagvit from .autoencoder_kl_mochi import AutoencoderKLMochi from .autoencoder_kl_qwenimage import AutoencoderKLQwenImage From e54cd6bb1d40f806a7b227500da7514a091e07d2 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 22 Dec 2025 13:03:40 +0530 Subject: [PATCH 03/11] up --- scripts/test_ltx2_audio_conversion.py | 106 ++++++++++++++++++++++++++ 1 file changed, 106 insertions(+) create mode 100644 scripts/test_ltx2_audio_conversion.py diff --git a/scripts/test_ltx2_audio_conversion.py b/scripts/test_ltx2_audio_conversion.py new file mode 100644 index 0000000000..251d0b64e9 --- /dev/null +++ b/scripts/test_ltx2_audio_conversion.py @@ -0,0 +1,106 @@ +#!/usr/bin/env python +""" +Quick check that an LTX2 audio decoder checkpoint converts cleanly to the diffusers +`AutoencoderKLLTX2Audio` layout and produces matching outputs on dummy data. +""" + +import argparse +import sys +from pathlib import Path + +import torch + + +def convert_state_dict(state_dict: dict) -> dict: + converted = {} + for key, value in state_dict.items(): + if not isinstance(value, torch.Tensor): + continue + new_key = key + if new_key.startswith("decoder."): + new_key = new_key[len("decoder.") :] + converted[f"decoder.{new_key}"] = value + return converted + + +def load_original_decoder(original_repo: Path, device: torch.device, dtype: torch.dtype, checkpoint_path: Path | None): + ltx_core_src = original_repo / "ltx-core" / "src" + if not ltx_core_src.exists(): + raise FileNotFoundError(f"ltx-core sources not found under {ltx_core_src}") + sys.path.insert(0, str(ltx_core_src)) + + from ltx_core.model.audio_vae.model_configurator import VAEDecoderConfigurator + + decoder = VAEDecoderConfigurator.from_config({}).to(device=device, dtype=dtype) + + if checkpoint_path is not None: + raw_state = torch.load(checkpoint_path, map_location=device) + state_dict = raw_state.get("state_dict", raw_state) + decoder_state: dict[str, torch.Tensor] = {} + for key, value in state_dict.items(): + if not isinstance(value, torch.Tensor): + continue + trimmed = key + if trimmed.startswith("audio_vae.decoder."): + trimmed = trimmed[len("audio_vae.decoder.") :] + elif trimmed.startswith("decoder."): + trimmed = trimmed[len("decoder.") :] + decoder_state[trimmed] = value + decoder.load_state_dict(decoder_state, strict=False) + + decoder.eval() + return decoder + + +def build_diffusers_decoder(device: torch.device, dtype: torch.dtype): + from diffusers.models.autoencoders.autoencoder_kl_ltx2_audio import AutoencoderKLLTX2Audio + + model = AutoencoderKLLTX2Audio().to(device=device, dtype=dtype) + model.eval() + return model + + +def main() -> None: + parser = argparse.ArgumentParser(description="Validate LTX2 audio decoder conversion.") + parser.add_argument( + "--original-repo", + type=Path, + default=Path("/Users/sayakpaul/Downloads/ltx-2"), + help="Path to the original ltx-2 repository (needed to import ltx-core).", + ) + parser.add_argument( + "--checkpoint", + type=Path, + default=None, + help="Optional path to an original checkpoint containing decoder weights.", + ) + parser.add_argument("--device", type=str, default="cpu") + parser.add_argument("--dtype", type=str, default="float32", choices=["float32", "bfloat16", "float16"]) + parser.add_argument("--batch", type=int, default=2) + args = parser.parse_args() + + device = torch.device(args.device) + dtype_map = {"float32": torch.float32, "bfloat16": torch.bfloat16, "float16": torch.float16} + dtype = dtype_map[args.dtype] + + original_decoder = load_original_decoder(args.original_repo, device, dtype, args.checkpoint) + diffusers_model = build_diffusers_decoder(device, dtype) + + converted_state = convert_state_dict(original_decoder.state_dict()) + diffusers_model.load_state_dict(converted_state, strict=False) + + levels = len(diffusers_model.decoder.channel_multipliers) + latent_size = diffusers_model.decoder.resolution // (2 ** (levels - 1)) + dummy = torch.randn(args.batch, diffusers_model.decoder.latent_channels, latent_size, latent_size, device=device, dtype=dtype) + + with torch.no_grad(): + original_out = original_decoder(dummy) + diffusers_out = diffusers_model.decode(dummy).sample + + torch.testing.assert_close(diffusers_out, original_out, rtol=1e-4, atol=1e-4) + max_diff = (diffusers_out - original_out).abs().max().item() + print(f"Conversion successful. Max diff: {max_diff:.6f}") + + +if __name__ == "__main__": + main() From 907896d533ae7089c30cd98790975c4ad5dd6b48 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 22 Dec 2025 13:41:41 +0530 Subject: [PATCH 04/11] simplify and clean up --- scripts/test_ltx2_audio_conversion.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/scripts/test_ltx2_audio_conversion.py b/scripts/test_ltx2_audio_conversion.py index 251d0b64e9..649b6d06d6 100644 --- a/scripts/test_ltx2_audio_conversion.py +++ b/scripts/test_ltx2_audio_conversion.py @@ -91,7 +91,9 @@ def main() -> None: levels = len(diffusers_model.decoder.channel_multipliers) latent_size = diffusers_model.decoder.resolution // (2 ** (levels - 1)) - dummy = torch.randn(args.batch, diffusers_model.decoder.latent_channels, latent_size, latent_size, device=device, dtype=dtype) + dummy = torch.randn( + args.batch, diffusers_model.decoder.latent_channels, latent_size, latent_size, device=device, dtype=dtype + ) with torch.no_grad(): original_out = original_decoder(dummy) From 4904fd6fa520894d586ec740bc2a10177e306883 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 22 Dec 2025 13:46:58 +0530 Subject: [PATCH 05/11] up --- scripts/test_ltx2_audio_conversion.py | 86 ++++++++++++--------------- 1 file changed, 38 insertions(+), 48 deletions(-) diff --git a/scripts/test_ltx2_audio_conversion.py b/scripts/test_ltx2_audio_conversion.py index 649b6d06d6..f9554782c9 100644 --- a/scripts/test_ltx2_audio_conversion.py +++ b/scripts/test_ltx2_audio_conversion.py @@ -1,14 +1,19 @@ -#!/usr/bin/env python -""" -Quick check that an LTX2 audio decoder checkpoint converts cleanly to the diffusers -`AutoencoderKLLTX2Audio` layout and produces matching outputs on dummy data. -""" - import argparse -import sys from pathlib import Path +import safetensors.torch import torch +from huggingface_hub import hf_hub_download + + +def download_checkpoint( + repo_id="diffusers-internal-dev/new-ltx-model", + filename="ltx-av-step-1932500-interleaved-new-vae.safetensors", + device="cuda", +): + ckpt_path = hf_hub_download(repo_id=repo_id, filename=filename) + ckpt = safetensors.torch.load_file(ckpt_path, device=device)["audio_vae"] + return ckpt def convert_state_dict(state_dict: dict) -> dict: @@ -23,71 +28,57 @@ def convert_state_dict(state_dict: dict) -> dict: return converted -def load_original_decoder(original_repo: Path, device: torch.device, dtype: torch.dtype, checkpoint_path: Path | None): - ltx_core_src = original_repo / "ltx-core" / "src" - if not ltx_core_src.exists(): - raise FileNotFoundError(f"ltx-core sources not found under {ltx_core_src}") - sys.path.insert(0, str(ltx_core_src)) - +def load_original_decoder(device: torch.device, dtype: torch.dtype): from ltx_core.model.audio_vae.model_configurator import VAEDecoderConfigurator - decoder = VAEDecoderConfigurator.from_config({}).to(device=device, dtype=dtype) + with torch.device("meta"): + decoder = VAEDecoderConfigurator.from_config({}).to(device=device, dtype=dtype) + original_state_dict = download_checkpoint(device) - if checkpoint_path is not None: - raw_state = torch.load(checkpoint_path, map_location=device) - state_dict = raw_state.get("state_dict", raw_state) - decoder_state: dict[str, torch.Tensor] = {} - for key, value in state_dict.items(): - if not isinstance(value, torch.Tensor): - continue - trimmed = key - if trimmed.startswith("audio_vae.decoder."): - trimmed = trimmed[len("audio_vae.decoder.") :] - elif trimmed.startswith("decoder."): - trimmed = trimmed[len("decoder.") :] - decoder_state[trimmed] = value - decoder.load_state_dict(decoder_state, strict=False) + decoder_state_dict = {} + for key, value in original_state_dict.items(): + if not isinstance(value, torch.Tensor): + continue + trimmed = key + if trimmed.startswith("audio_vae.decoder."): + trimmed = trimmed[len("audio_vae.decoder.") :] + elif trimmed.startswith("decoder."): + trimmed = trimmed[len("decoder.") :] + decoder_state_dict[trimmed] = value + decoder.load_state_dict(decoder_state_dict, strict=True, assign=True) decoder.eval() return decoder def build_diffusers_decoder(device: torch.device, dtype: torch.dtype): - from diffusers.models.autoencoders.autoencoder_kl_ltx2_audio import AutoencoderKLLTX2Audio + from diffusers.models.autoencoders import AutoencoderKLLTX2Audio + + with torch.device("meta"): + model = AutoencoderKLLTX2Audio().to(device=device, dtype=dtype) - model = AutoencoderKLLTX2Audio().to(device=device, dtype=dtype) model.eval() return model +@torch.no_grad() def main() -> None: parser = argparse.ArgumentParser(description="Validate LTX2 audio decoder conversion.") - parser.add_argument( - "--original-repo", - type=Path, - default=Path("/Users/sayakpaul/Downloads/ltx-2"), - help="Path to the original ltx-2 repository (needed to import ltx-core).", - ) - parser.add_argument( - "--checkpoint", - type=Path, - default=None, - help="Optional path to an original checkpoint containing decoder weights.", - ) parser.add_argument("--device", type=str, default="cpu") - parser.add_argument("--dtype", type=str, default="float32", choices=["float32", "bfloat16", "float16"]) + parser.add_argument("--dtype", type=str, default="bfloat16", choices=["float32", "bfloat16", "float16"]) parser.add_argument("--batch", type=int, default=2) + parser.add_argument("--output-path", type=Path, required=True) args = parser.parse_args() device = torch.device(args.device) dtype_map = {"float32": torch.float32, "bfloat16": torch.bfloat16, "float16": torch.float16} dtype = dtype_map[args.dtype] - original_decoder = load_original_decoder(args.original_repo, device, dtype, args.checkpoint) + original_decoder = load_original_decoder(device, dtype) diffusers_model = build_diffusers_decoder(device, dtype) converted_state = convert_state_dict(original_decoder.state_dict()) - diffusers_model.load_state_dict(converted_state, strict=False) + diffusers_model.load_state_dict(converted_state, assign=True, strict=True) levels = len(diffusers_model.decoder.channel_multipliers) latent_size = diffusers_model.decoder.resolution // (2 ** (levels - 1)) @@ -95,9 +86,8 @@ def main() -> None: args.batch, diffusers_model.decoder.latent_channels, latent_size, latent_size, device=device, dtype=dtype ) - with torch.no_grad(): - original_out = original_decoder(dummy) - diffusers_out = diffusers_model.decode(dummy).sample + original_out = original_decoder(dummy) + diffusers_out = diffusers_model.decode(dummy).sample torch.testing.assert_close(diffusers_out, original_out, rtol=1e-4, atol=1e-4) max_diff = (diffusers_out - original_out).abs().max().item() From 5f0f2a03f72fc59a606b1d7e03960b5c9a086102 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Mon, 22 Dec 2025 10:06:39 +0000 Subject: [PATCH 06/11] up --- scripts/log.txt | 32 +++++++++++ scripts/test_ltx2_audio_conversion.py | 57 +++++++++---------- .../autoencoders/autoencoder_kl_ltx2_audio.py | 20 +++++-- 3 files changed, 74 insertions(+), 35 deletions(-) create mode 100644 scripts/log.txt diff --git a/scripts/log.txt b/scripts/log.txt new file mode 100644 index 0000000000..aa3046d42a --- /dev/null +++ b/scripts/log.txt @@ -0,0 +1,32 @@ +ddconfig={'double_z': True, 'mel_bins': 64, 'z_channels': 8, 'resolution': 256, 'downsample_time': False, 'in_channels': 2, 'out_ch': 2, 'ch': 128, 'ch_mult': [1, 2, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0, 'mid_block_add_attention': False, 'norm_type': 'pixel', 'causality_axis': 'height'}, sample_rate=16000, mel_hop_length=160, is_causal=True, mel_bins=64 +mid_block_add_attention=False, attn_resolutions=[] +k='mid.block_1.conv1.conv.weight' +k='mid.block_1.conv1.conv.bias' +k='mid.block_1.conv2.conv.weight' +k='mid.block_1.conv2.conv.bias' +k='mid.block_2.conv1.conv.weight' +k='mid.block_2.conv1.conv.bias' +k='mid.block_2.conv2.conv.weight' +k='mid.block_2.conv2.conv.bias' +Traceback (most recent call last): + File "/fsx/sayak/diffusers-new-model-addition-ltx2/scripts/test_ltx2_audio_conversion.py", line 97, in + main() + File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context + return func(*args, **kwargs) + ^^^^^^^^^^^^^^^^^^^^^ + File "/fsx/sayak/diffusers-new-model-addition-ltx2/scripts/test_ltx2_audio_conversion.py", line 85, in main + original_out = original_decoder(dummy) + ^^^^^^^^^^^^^^^^^^^^^^^ + File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl + return self._call_impl(*args, **kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl + return forward_call(*args, **kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/fsx/sayak/ltx-2/ltx-core/src/ltx_core/model/audio_vae/audio_vae.py", line 206, in forward + sample_denormalized = self.per_channel_statistics.un_normalize(sample_patched) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/fsx/sayak/ltx-2/ltx-core/src/ltx_core/model/audio_vae/ops.py", line 27, in un_normalize + return (x * self.get_buffer("std-of-means").to(x)) + self.get_buffer("mean-of-means").to(x) + ~~^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +RuntimeError: The size of tensor a (512) must match the size of tensor b (128) at non-singleton dimension 2 diff --git a/scripts/test_ltx2_audio_conversion.py b/scripts/test_ltx2_audio_conversion.py index f9554782c9..6a124f74df 100644 --- a/scripts/test_ltx2_audio_conversion.py +++ b/scripts/test_ltx2_audio_conversion.py @@ -1,7 +1,6 @@ import argparse from pathlib import Path -import safetensors.torch import torch from huggingface_hub import hf_hub_download @@ -9,11 +8,9 @@ from huggingface_hub import hf_hub_download def download_checkpoint( repo_id="diffusers-internal-dev/new-ltx-model", filename="ltx-av-step-1932500-interleaved-new-vae.safetensors", - device="cuda", ): ckpt_path = hf_hub_download(repo_id=repo_id, filename=filename) - ckpt = safetensors.torch.load_file(ckpt_path, device=device)["audio_vae"] - return ckpt + return ckpt_path def convert_state_dict(state_dict: dict) -> dict: @@ -28,34 +25,33 @@ def convert_state_dict(state_dict: dict) -> dict: return converted -def load_original_decoder(device: torch.device, dtype: torch.dtype): - from ltx_core.model.audio_vae.model_configurator import VAEDecoderConfigurator - - with torch.device("meta"): - decoder = VAEDecoderConfigurator.from_config({}).to(device=device, dtype=dtype) - original_state_dict = download_checkpoint(device) - - decoder_state_dict = {} - for key, value in original_state_dict.items(): - if not isinstance(value, torch.Tensor): - continue - trimmed = key - if trimmed.startswith("audio_vae.decoder."): - trimmed = trimmed[len("audio_vae.decoder.") :] - elif trimmed.startswith("decoder."): - trimmed = trimmed[len("decoder.") :] - decoder_state_dict[trimmed] = value - decoder.load_state_dict(decoder_state_dict, strict=True, assign=True) +def load_original_decoder(device: torch.device): + from ltx_core.loader.single_gpu_model_builder import SingleGPUModelBuilder as Builder + from ltx_core.model.audio_vae.model_configurator import VAEDecoderConfigurator as AudioDecoderConfigurator + from ltx_core.model.audio_vae.model_configurator import AUDIO_VAE_DECODER_COMFY_KEYS_FILTER + + checkpoint_path = download_checkpoint() + + # The code below comes from `ltx-pipelines/src/ltx_pipelines/txt2vid.py` + decoder = Builder( + model_path=checkpoint_path, + model_class_configurator=AudioDecoderConfigurator, + model_sd_key_ops=AUDIO_VAE_DECODER_COMFY_KEYS_FILTER, + ).build(device=device) + state_dict = decoder.state_dict() + for k, v in state_dict.items(): + if "mid" in k: + print(f"{k=}") decoder.eval() return decoder -def build_diffusers_decoder(device: torch.device, dtype: torch.dtype): +def build_diffusers_decoder(): from diffusers.models.autoencoders import AutoencoderKLLTX2Audio with torch.device("meta"): - model = AutoencoderKLLTX2Audio().to(device=device, dtype=dtype) + model = AutoencoderKLLTX2Audio() model.eval() return model @@ -74,16 +70,16 @@ def main() -> None: dtype_map = {"float32": torch.float32, "bfloat16": torch.bfloat16, "float16": torch.float16} dtype = dtype_map[args.dtype] - original_decoder = load_original_decoder(device, dtype) - diffusers_model = build_diffusers_decoder(device, dtype) + original_decoder = load_original_decoder(device) + diffusers_model = build_diffusers_decoder() - converted_state = convert_state_dict(original_decoder.state_dict()) - diffusers_model.load_state_dict(converted_state, assign=True, strict=True) + converted_state_dict = convert_state_dict(original_decoder.state_dict()) + diffusers_model.load_state_dict(converted_state_dict, assign=True, strict=True) levels = len(diffusers_model.decoder.channel_multipliers) latent_size = diffusers_model.decoder.resolution // (2 ** (levels - 1)) dummy = torch.randn( - args.batch, diffusers_model.decoder.latent_channels, latent_size, latent_size, device=device, dtype=dtype + args.batch, diffusers_model.decoder.latent_channels, latent_size, latent_size, device=device ) original_out = original_decoder(dummy) @@ -93,6 +89,9 @@ def main() -> None: max_diff = (diffusers_out - original_out).abs().max().item() print(f"Conversion successful. Max diff: {max_diff:.6f}") + diffusers_model.to(dtype).save_pretrained(args.output_path) + print(f"Serialized model to {args.output_path}") + if __name__ == "__main__": main() diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_ltx2_audio.py b/src/diffusers/models/autoencoders/autoencoder_kl_ltx2_audio.py index 98d8a53e23..457cbf5bce 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_ltx2_audio.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_ltx2_audio.py @@ -533,8 +533,9 @@ class LTX2AudioDecoder(nn.Module): ) ) block_in = block_out - if curr_res in self.attn_resolutions: - stage.attn.append(LTX2AudioAttnBlock(block_in, norm_type=self.norm_type)) + if self.attn_resolutions: + if curr_res in self.attn_resolutions: + stage.attn.append(LTX2AudioAttnBlock(block_in, norm_type=self.norm_type)) if level != 0: stage.upsample = LTX2AudioUpsample(block_in, resamp_with_conv, causality_axis=self.causality_axis) @@ -579,6 +580,13 @@ class AutoencoderKLLTX2Audio(ModelMixin, AutoencoderMixin, ConfigMixin): _supports_gradient_checkpointing = False + # { + # 'double_z': True, 'mel_bins': 64, 'z_channels': 8, 'resolution': 256, 'downsample_time': False, + # 'in_channels': 2, 'out_ch': 2, 'ch': 128, 'ch_mult': [1, 2, 4], 'num_res_blocks': 2, + # 'attn_resolutions': [], 'dropout': 0.0, 'mid_block_add_attention': False, + # 'norm_type': 'pixel', 'causality_axis': 'height' + # } + # sample_rate=16000, mel_hop_length=160, is_causal=True, mel_bins=64 @register_to_config def __init__( self, @@ -586,23 +594,23 @@ class AutoencoderKLLTX2Audio(ModelMixin, AutoencoderMixin, ConfigMixin): output_channels: int = 2, ch_mult: Tuple[int] = (1, 2, 4), num_res_blocks: int = 2, - attn_resolutions: Tuple[int] = (8, 16, 32), + attn_resolutions: Optional[Tuple[int]] = None, in_channels: int = 2, resolution: int = 256, latent_channels: int = 8, norm_type: str = "pixel", causality_axis: Optional[str] = "height", dropout: float = 0.0, - mid_block_add_attention: bool = True, + mid_block_add_attention: bool = False, sample_rate: int = 16000, mel_hop_length: int = 160, is_causal: bool = True, - mel_bins: Optional[int] = None, + mel_bins: Optional[int] = 64, ) -> None: super().__init__() resolved_causality_axis = _resolve_causality_axis(causality_axis) - attn_resolution_set = set(attn_resolutions) + attn_resolution_set = set(attn_resolutions) if attn_resolutions else attn_resolutions self.decoder = LTX2AudioDecoder( base_channels=base_channels, From 58257eb0e0f1a8ac07ff4854009f35c1b2bad444 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 22 Dec 2025 15:45:56 +0530 Subject: [PATCH 07/11] up --- scripts/test_ltx2_audio_conversion.py | 31 ++++++++++++------- .../autoencoders/autoencoder_kl_ltx2_audio.py | 7 ----- 2 files changed, 19 insertions(+), 19 deletions(-) diff --git a/scripts/test_ltx2_audio_conversion.py b/scripts/test_ltx2_audio_conversion.py index 6a124f74df..8d07a6f9b1 100644 --- a/scripts/test_ltx2_audio_conversion.py +++ b/scripts/test_ltx2_audio_conversion.py @@ -25,13 +25,13 @@ def convert_state_dict(state_dict: dict) -> dict: return converted -def load_original_decoder(device: torch.device): +def load_original_decoder(device: torch.device, dtype: torch.dtype): from ltx_core.loader.single_gpu_model_builder import SingleGPUModelBuilder as Builder - from ltx_core.model.audio_vae.model_configurator import VAEDecoderConfigurator as AudioDecoderConfigurator from ltx_core.model.audio_vae.model_configurator import AUDIO_VAE_DECODER_COMFY_KEYS_FILTER - + from ltx_core.model.audio_vae.model_configurator import VAEDecoderConfigurator as AudioDecoderConfigurator + checkpoint_path = download_checkpoint() - + # The code below comes from `ltx-pipelines/src/ltx_pipelines/txt2vid.py` decoder = Builder( model_path=checkpoint_path, @@ -39,10 +39,6 @@ def load_original_decoder(device: torch.device): model_sd_key_ops=AUDIO_VAE_DECODER_COMFY_KEYS_FILTER, ).build(device=device) - state_dict = decoder.state_dict() - for k, v in state_dict.items(): - if "mid" in k: - print(f"{k=}") decoder.eval() return decoder @@ -70,16 +66,27 @@ def main() -> None: dtype_map = {"float32": torch.float32, "bfloat16": torch.bfloat16, "float16": torch.float16} dtype = dtype_map[args.dtype] - original_decoder = load_original_decoder(device) + original_decoder = load_original_decoder(device, dtype) diffusers_model = build_diffusers_decoder() converted_state_dict = convert_state_dict(original_decoder.state_dict()) - diffusers_model.load_state_dict(converted_state_dict, assign=True, strict=True) + diffusers_model.load_state_dict(converted_state_dict, assign=True, strict=False) + + per_channel_len = original_decoder.per_channel_statistics.get_buffer("std-of-means").numel() + latent_channels = diffusers_model.decoder.latent_channels + mel_bins_for_match = per_channel_len // latent_channels if per_channel_len % latent_channels == 0 else None levels = len(diffusers_model.decoder.channel_multipliers) - latent_size = diffusers_model.decoder.resolution // (2 ** (levels - 1)) + latent_height = diffusers_model.decoder.resolution // (2 ** (levels - 1)) + latent_width = mel_bins_for_match or latent_height + dummy = torch.randn( - args.batch, diffusers_model.decoder.latent_channels, latent_size, latent_size, device=device + args.batch, + diffusers_model.decoder.latent_channels, + latent_height, + latent_width, + device=device, + dtype=dtype, ) original_out = original_decoder(dummy) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_ltx2_audio.py b/src/diffusers/models/autoencoders/autoencoder_kl_ltx2_audio.py index 457cbf5bce..e7960c3e14 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_ltx2_audio.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_ltx2_audio.py @@ -580,13 +580,6 @@ class AutoencoderKLLTX2Audio(ModelMixin, AutoencoderMixin, ConfigMixin): _supports_gradient_checkpointing = False - # { - # 'double_z': True, 'mel_bins': 64, 'z_channels': 8, 'resolution': 256, 'downsample_time': False, - # 'in_channels': 2, 'out_ch': 2, 'ch': 128, 'ch_mult': [1, 2, 4], 'num_res_blocks': 2, - # 'attn_resolutions': [], 'dropout': 0.0, 'mid_block_add_attention': False, - # 'norm_type': 'pixel', 'causality_axis': 'height' - # } - # sample_rate=16000, mel_hop_length=160, is_causal=True, mel_bins=64 @register_to_config def __init__( self, From 059999a3f7ad3fe3077f61812e3b3de91136f4bb Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Mon, 22 Dec 2025 10:24:55 +0000 Subject: [PATCH 08/11] up --- scripts/log.txt | 32 ------------------- .../autoencoders/autoencoder_kl_ltx2_audio.py | 22 +++++++------ 2 files changed, 12 insertions(+), 42 deletions(-) delete mode 100644 scripts/log.txt diff --git a/scripts/log.txt b/scripts/log.txt deleted file mode 100644 index aa3046d42a..0000000000 --- a/scripts/log.txt +++ /dev/null @@ -1,32 +0,0 @@ -ddconfig={'double_z': True, 'mel_bins': 64, 'z_channels': 8, 'resolution': 256, 'downsample_time': False, 'in_channels': 2, 'out_ch': 2, 'ch': 128, 'ch_mult': [1, 2, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0, 'mid_block_add_attention': False, 'norm_type': 'pixel', 'causality_axis': 'height'}, sample_rate=16000, mel_hop_length=160, is_causal=True, mel_bins=64 -mid_block_add_attention=False, attn_resolutions=[] -k='mid.block_1.conv1.conv.weight' -k='mid.block_1.conv1.conv.bias' -k='mid.block_1.conv2.conv.weight' -k='mid.block_1.conv2.conv.bias' -k='mid.block_2.conv1.conv.weight' -k='mid.block_2.conv1.conv.bias' -k='mid.block_2.conv2.conv.weight' -k='mid.block_2.conv2.conv.bias' -Traceback (most recent call last): - File "/fsx/sayak/diffusers-new-model-addition-ltx2/scripts/test_ltx2_audio_conversion.py", line 97, in - main() - File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context - return func(*args, **kwargs) - ^^^^^^^^^^^^^^^^^^^^^ - File "/fsx/sayak/diffusers-new-model-addition-ltx2/scripts/test_ltx2_audio_conversion.py", line 85, in main - original_out = original_decoder(dummy) - ^^^^^^^^^^^^^^^^^^^^^^^ - File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl - return self._call_impl(*args, **kwargs) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl - return forward_call(*args, **kwargs) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/fsx/sayak/ltx-2/ltx-core/src/ltx_core/model/audio_vae/audio_vae.py", line 206, in forward - sample_denormalized = self.per_channel_statistics.un_normalize(sample_patched) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/fsx/sayak/ltx-2/ltx-core/src/ltx_core/model/audio_vae/ops.py", line 27, in un_normalize - return (x * self.get_buffer("std-of-means").to(x)) + self.get_buffer("mean-of-means").to(x) - ~~^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -RuntimeError: The size of tensor a (512) must match the size of tensor b (128) at non-singleton dimension 2 diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_ltx2_audio.py b/src/diffusers/models/autoencoders/autoencoder_kl_ltx2_audio.py index e7960c3e14..1385b414b9 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_ltx2_audio.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_ltx2_audio.py @@ -99,8 +99,9 @@ class LTX2AudioCausalConv2d(nn.Module): super().__init__() self.causality_axis = causality_axis - kernel_size = nn.modules.utils._pair(kernel_size) - dilation = nn.modules.utils._pair(dilation) + kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size + dilation = (dilation, dilation) if isinstance(dilation, int) else dilation + pad_h = (kernel_size[0] - 1) * dilation[0] pad_w = (kernel_size[1] - 1) * dilation[1] @@ -232,7 +233,7 @@ class LTX2AudioResnetBlock(nn.Module): def forward( self, x: torch.Tensor, - temb: Optional[torch.Tensor] = None, + temb: Optional[torch.Tensor] = None ) -> torch.Tensor: h = self.norm1(x) h = self.non_linearity(h) @@ -257,7 +258,7 @@ class LTX2AudioUpsample(nn.Module): self, in_channels: int, with_conv: bool, - causality_axis: Optional[str] = "height", + causality_axis: Optional[str] = "height" ) -> None: super().__init__() self.with_conv = with_conv @@ -291,10 +292,11 @@ class LTX2AudioPerChannelStatistics(nn.Module): def __init__(self, latent_channels: int = 128) -> None: super().__init__() + # Sayak notes: `empty` always causes problems in CI. Should we consider using `torch.ones`? self.register_buffer("std-of-means", torch.empty(latent_channels)) self.register_buffer("mean-of-means", torch.empty(latent_channels)) - def un_normalize(self, x: torch.Tensor) -> torch.Tensor: + def denormalize(self, x: torch.Tensor) -> torch.Tensor: return (x * self.get_buffer("std-of-means").to(x)) + self.get_buffer("mean-of-means").to(x) def normalize(self, x: torch.Tensor) -> torch.Tensor: @@ -327,7 +329,7 @@ class LTX2AudioAudioPatchifier: def unpatchify( self, audio_latents: torch.Tensor, - output_shape: AudioLatentShape, + output_shape: AudioLatentShape ) -> torch.Tensor: batch, time, _ = audio_latents.shape channels = output_shape.channels @@ -421,7 +423,7 @@ class LTX2AudioDecoder(nn.Module): def _adjust_output_shape( self, decoded_output: torch.Tensor, - target_shape: AudioLatentShape, + target_shape: AudioLatentShape ) -> torch.Tensor: _, _, current_time, current_freq = decoded_output.shape target_channels = target_shape.channels @@ -460,7 +462,7 @@ class LTX2AudioDecoder(nn.Module): ) sample_patched = self.patchifier.patchify(sample) - sample_denormalized = self.per_channel_statistics.un_normalize(sample_patched) + sample_denormalized = self.per_channel_statistics.denormalize(sample_patched) sample = self.patchifier.unpatchify(sample_denormalized, latent_shape) target_frames = latent_shape.frames * LATENT_DOWNSAMPLE_FACTOR @@ -509,7 +511,7 @@ class LTX2AudioDecoder(nn.Module): self, initial_block_channels: int, dropout: float, - resamp_with_conv: bool, + resamp_with_conv: bool ) -> tuple[nn.ModuleList, int]: up_modules = nn.ModuleList() block_in = initial_block_channels @@ -630,7 +632,7 @@ class AutoencoderKLLTX2Audio(ModelMixin, AutoencoderMixin, ConfigMixin): def encode( self, x: torch.Tensor, - return_dict: bool = True, + return_dict: bool = True ): raise NotImplementedError("AutoencoderKLLTX2Audio does not implement encoding.") From 8134da6a56d2fe3fde82af00f079b0615d9768e8 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 22 Dec 2025 15:55:29 +0530 Subject: [PATCH 09/11] up --- .../autoencoders/autoencoder_kl_ltx2_audio.py | 37 +++---------------- 1 file changed, 6 insertions(+), 31 deletions(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_ltx2_audio.py b/src/diffusers/models/autoencoders/autoencoder_kl_ltx2_audio.py index 1385b414b9..e3c0ef2c3d 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_ltx2_audio.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_ltx2_audio.py @@ -102,7 +102,6 @@ class LTX2AudioCausalConv2d(nn.Module): kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size dilation = (dilation, dilation) if isinstance(dilation, int) else dilation - pad_h = (kernel_size[0] - 1) * dilation[0] pad_w = (kernel_size[1] - 1) * dilation[1] @@ -230,11 +229,7 @@ class LTX2AudioResnetBlock(nn.Module): in_channels, out_channels, kernel_size=1, stride=1, causality_axis=causality_axis ) - def forward( - self, - x: torch.Tensor, - temb: Optional[torch.Tensor] = None - ) -> torch.Tensor: + def forward(self, x: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor: h = self.norm1(x) h = self.non_linearity(h) h = self.conv1(h) @@ -254,12 +249,7 @@ class LTX2AudioResnetBlock(nn.Module): class LTX2AudioUpsample(nn.Module): - def __init__( - self, - in_channels: int, - with_conv: bool, - causality_axis: Optional[str] = "height" - ) -> None: + def __init__(self, in_channels: int, with_conv: bool, causality_axis: Optional[str] = "height") -> None: super().__init__() self.with_conv = with_conv self.causality_axis = causality_axis @@ -326,11 +316,7 @@ class LTX2AudioAudioPatchifier: batch, channels, time, freq = audio_latents.shape return audio_latents.permute(0, 2, 1, 3).reshape(batch, time, channels * freq) - def unpatchify( - self, - audio_latents: torch.Tensor, - output_shape: AudioLatentShape - ) -> torch.Tensor: + def unpatchify(self, audio_latents: torch.Tensor, output_shape: AudioLatentShape) -> torch.Tensor: batch, time, _ = audio_latents.shape channels = output_shape.channels freq = output_shape.mel_bins @@ -420,11 +406,7 @@ class LTX2AudioDecoder(nn.Module): final_block_channels, output_channels, kernel_size=3, stride=1, causality_axis=self.causality_axis ) - def _adjust_output_shape( - self, - decoded_output: torch.Tensor, - target_shape: AudioLatentShape - ) -> torch.Tensor: + def _adjust_output_shape(self, decoded_output: torch.Tensor, target_shape: AudioLatentShape) -> torch.Tensor: _, _, current_time, current_freq = decoded_output.shape target_channels = target_shape.channels target_time = target_shape.frames @@ -508,10 +490,7 @@ class LTX2AudioDecoder(nn.Module): return mid def _build_up_path( - self, - initial_block_channels: int, - dropout: float, - resamp_with_conv: bool + self, initial_block_channels: int, dropout: float, resamp_with_conv: bool ) -> tuple[nn.ModuleList, int]: up_modules = nn.ModuleList() block_in = initial_block_channels @@ -629,11 +608,7 @@ class AutoencoderKLLTX2Audio(ModelMixin, AutoencoderMixin, ConfigMixin): self.use_slicing = False @apply_forward_hook - def encode( - self, - x: torch.Tensor, - return_dict: bool = True - ): + def encode(self, x: torch.Tensor, return_dict: bool = True): raise NotImplementedError("AutoencoderKLLTX2Audio does not implement encoding.") def _decode(self, z: torch.Tensor) -> torch.Tensor: From 5f7e43d17fe6edf60fe4dcd8b0d8320e84a259ac Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Tue, 23 Dec 2025 02:08:51 +0100 Subject: [PATCH 10/11] Add imports for LTX 2.0 Audio VAE --- src/diffusers/__init__.py | 2 ++ src/diffusers/models/__init__.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 71cad3425f..8c6761a07e 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -194,6 +194,7 @@ else: "AutoencoderKLHunyuanVideo", "AutoencoderKLHunyuanVideo15", "AutoencoderKLLTXVideo", + "AutoencoderKLLTX2Audio", "AutoencoderKLLTX2Video", "AutoencoderKLMagvit", "AutoencoderKLMochi", @@ -929,6 +930,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: AutoencoderKLHunyuanVideo, AutoencoderKLHunyuanVideo15, AutoencoderKLLTXVideo, + AutoencoderKLLTX2Audio, AutoencoderKLLTX2Video, AutoencoderKLMagvit, AutoencoderKLMochi, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 3f4e49015b..d3bcb3bcee 100755 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -42,6 +42,7 @@ if is_torch_available(): _import_structure["autoencoders.autoencoder_kl_hunyuanvideo15"] = ["AutoencoderKLHunyuanVideo15"] _import_structure["autoencoders.autoencoder_kl_ltx"] = ["AutoencoderKLLTXVideo"] _import_structure["autoencoders.autoencoder_kl_ltx2"] = ["AutoencoderKLLTX2Video"] + _import_structure["autoencoders.autoencoder_kl_ltx2_audio"] = ["AutoencoderKLLTX2Audio"] _import_structure["autoencoders.autoencoder_kl_magvit"] = ["AutoencoderKLMagvit"] _import_structure["autoencoders.autoencoder_kl_mochi"] = ["AutoencoderKLMochi"] _import_structure["autoencoders.autoencoder_kl_qwenimage"] = ["AutoencoderKLQwenImage"] @@ -154,6 +155,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: AutoencoderKLHunyuanVideo, AutoencoderKLHunyuanVideo15, AutoencoderKLLTXVideo, + AutoencoderKLLTX2Audio, AutoencoderKLLTX2Video, AutoencoderKLMagvit, AutoencoderKLMochi, From d303e2a6ff841919531facf302fd0e724ae57d33 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Tue, 23 Dec 2025 02:48:08 +0100 Subject: [PATCH 11/11] Conversion script for LTX 2.0 Audio VAE Decoder --- scripts/convert_ltx2_to_diffusers.py | 80 +++++++++++++++++++++++++++- 1 file changed, 78 insertions(+), 2 deletions(-) diff --git a/scripts/convert_ltx2_to_diffusers.py b/scripts/convert_ltx2_to_diffusers.py index f2e879c065..eb130a3549 100644 --- a/scripts/convert_ltx2_to_diffusers.py +++ b/scripts/convert_ltx2_to_diffusers.py @@ -8,7 +8,7 @@ import torch from accelerate import init_empty_weights from huggingface_hub import hf_hub_download -from diffusers import AutoencoderKLLTX2Video, LTX2VideoTransformer3DModel +from diffusers import AutoencoderKLLTX2Audio, AutoencoderKLLTX2Video, LTX2VideoTransformer3DModel from diffusers.utils.import_utils import is_accelerate_available from diffusers.pipelines.ltx2.vocoder import LTX2Vocoder @@ -62,6 +62,8 @@ LTX_2_0_VIDEO_VAE_RENAME_DICT = { "per_channel_statistics.std-of-means": "latents_std", } +LTX_2_0_AUDIO_VAE_RENAME_DICT = {} + LTX_2_0_VOCODER_RENAME_DICT = { "ups": "upsamplers", "resblocks": "resnets", @@ -96,6 +98,15 @@ def convert_ltx2_transformer_adaln_single(key: str, state_dict: Dict[str, Any]) return +def convert_ltx2_audio_vae_per_channel_statistics(key: str, state_dict: Dict[str, Any]) -> None: + if key.startswith("per_channel_statistics"): + new_key = ".".join(["decoder", key]) + param = state_dict.pop(key) + state_dict[new_key] = param + + return + + LTX_2_0_TRANSFORMER_SPECIAL_KEYS_REMAP = { "video_embeddings_connector": remove_keys_inplace, "audio_embeddings_connector": remove_keys_inplace, @@ -107,6 +118,11 @@ LTX_2_0_VAE_SPECIAL_KEYS_REMAP = { "per_channel_statistics.mean-of-stds": remove_keys_inplace, } +LTX_2_0_AUDIO_VAE_SPECIAL_KEYS_REMAP = { + "encoder": remove_keys_inplace, + "per_channel_statistics": convert_ltx2_audio_vae_per_channel_statistics, +} + LTX_2_0_VOCODER_SPECIAL_KEYS_REMAP = {} @@ -325,6 +341,60 @@ def convert_ltx2_video_vae(original_state_dict: Dict[str, Any], version: str) -> return vae +def get_ltx2_audio_vae_config(version: str) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]: + if version == "2.0": + config = { + "model_id": "diffusers-internal-dev/new-ltx-model", + "diffusers_config": { + "base_channels": 128, + "output_channels": 2, + "ch_mult": (1, 2, 4), + "num_res_blocks": 2, + "attn_resolutions": None, + "in_channels": 2, + "resolution": 256, + "latent_channels": 8, + "norm_type": "pixel", + "causality_axis": "height", + "dropout": 0.0, + "mid_block_add_attention": False, + "sample_rate": 16000, + "mel_hop_length": 160, + "is_causal": True, + "mel_bins": 64, + }, + } + rename_dict = LTX_2_0_AUDIO_VAE_RENAME_DICT + special_keys_remap = LTX_2_0_AUDIO_VAE_SPECIAL_KEYS_REMAP + return config, rename_dict, special_keys_remap + + +def convert_ltx2_audio_vae(original_state_dict: Dict[str, Any], version: str) -> Dict[str, Any]: + config, rename_dict, special_keys_remap = get_ltx2_audio_vae_config(version) + diffusers_config = config["diffusers_config"] + + with init_empty_weights(): + vae = AutoencoderKLLTX2Audio.from_config(diffusers_config) + + # Handle official code --> diffusers key remapping via the remap dict + for key in list(original_state_dict.keys()): + new_key = key[:] + for replace_key, rename_key in rename_dict.items(): + new_key = new_key.replace(replace_key, rename_key) + update_state_dict_inplace(original_state_dict, key, new_key) + + # Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in + # special_keys_remap + for key in list(original_state_dict.keys()): + for special_key, handler_fn_inplace in special_keys_remap.items(): + if special_key not in key: + continue + handler_fn_inplace(key, original_state_dict) + + vae.load_state_dict(original_state_dict, strict=True, assign=True) + return vae + + def get_ltx2_vocoder_config(version: str) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]: if version == "2.0": config = { @@ -513,7 +583,13 @@ def main(args): vae.to(vae_dtype).save_pretrained(os.path.join(args.output_path, "vae")) if args.audio_vae or args.full_pipeline: - pass + if args.audio_vae_filename is not None: + original_audio_vae_ckpt = load_hub_or_local_checkpoint(filename=args.audio_vae_filename) + elif combined_ckpt is not None: + original_audio_vae_ckpt = get_model_state_dict_from_combined_ckpt(combined_ckpt, args.audio_vae_prefix) + audio_vae = convert_ltx2_audio_vae(original_audio_vae_ckpt, version=args.version) + if not args.full_pipeline: + audio_vae.to(audio_vae_dtype).save_pretrained(os.path.join(args.output_path, "audio_vae")) if args.dit or args.full_pipeline: if args.dit_filename is not None: