diff --git a/scripts/test_ltx2_audio_conversion.py b/scripts/test_ltx2_audio_conversion.py new file mode 100644 index 0000000000..8d07a6f9b1 --- /dev/null +++ b/scripts/test_ltx2_audio_conversion.py @@ -0,0 +1,104 @@ +import argparse +from pathlib import Path + +import torch +from huggingface_hub import hf_hub_download + + +def download_checkpoint( + repo_id="diffusers-internal-dev/new-ltx-model", + filename="ltx-av-step-1932500-interleaved-new-vae.safetensors", +): + ckpt_path = hf_hub_download(repo_id=repo_id, filename=filename) + return ckpt_path + + +def convert_state_dict(state_dict: dict) -> dict: + converted = {} + for key, value in state_dict.items(): + if not isinstance(value, torch.Tensor): + continue + new_key = key + if new_key.startswith("decoder."): + new_key = new_key[len("decoder.") :] + converted[f"decoder.{new_key}"] = value + return converted + + +def load_original_decoder(device: torch.device, dtype: torch.dtype): + from ltx_core.loader.single_gpu_model_builder import SingleGPUModelBuilder as Builder + from ltx_core.model.audio_vae.model_configurator import AUDIO_VAE_DECODER_COMFY_KEYS_FILTER + from ltx_core.model.audio_vae.model_configurator import VAEDecoderConfigurator as AudioDecoderConfigurator + + checkpoint_path = download_checkpoint() + + # The code below comes from `ltx-pipelines/src/ltx_pipelines/txt2vid.py` + decoder = Builder( + model_path=checkpoint_path, + model_class_configurator=AudioDecoderConfigurator, + model_sd_key_ops=AUDIO_VAE_DECODER_COMFY_KEYS_FILTER, + ).build(device=device) + + decoder.eval() + return decoder + + +def build_diffusers_decoder(): + from diffusers.models.autoencoders import AutoencoderKLLTX2Audio + + with torch.device("meta"): + model = AutoencoderKLLTX2Audio() + + model.eval() + return model + + +@torch.no_grad() +def main() -> None: + parser = argparse.ArgumentParser(description="Validate LTX2 audio decoder conversion.") + parser.add_argument("--device", type=str, default="cpu") + parser.add_argument("--dtype", type=str, default="bfloat16", choices=["float32", "bfloat16", "float16"]) + parser.add_argument("--batch", type=int, default=2) + parser.add_argument("--output-path", type=Path, required=True) + args = parser.parse_args() + + device = torch.device(args.device) + dtype_map = {"float32": torch.float32, "bfloat16": torch.bfloat16, "float16": torch.float16} + dtype = dtype_map[args.dtype] + + original_decoder = load_original_decoder(device, dtype) + diffusers_model = build_diffusers_decoder() + + converted_state_dict = convert_state_dict(original_decoder.state_dict()) + diffusers_model.load_state_dict(converted_state_dict, assign=True, strict=False) + + per_channel_len = original_decoder.per_channel_statistics.get_buffer("std-of-means").numel() + latent_channels = diffusers_model.decoder.latent_channels + mel_bins_for_match = per_channel_len // latent_channels if per_channel_len % latent_channels == 0 else None + + levels = len(diffusers_model.decoder.channel_multipliers) + latent_height = diffusers_model.decoder.resolution // (2 ** (levels - 1)) + latent_width = mel_bins_for_match or latent_height + + dummy = torch.randn( + args.batch, + diffusers_model.decoder.latent_channels, + latent_height, + latent_width, + device=device, + dtype=dtype, + ) + + original_out = original_decoder(dummy) + diffusers_out = diffusers_model.decode(dummy).sample + + torch.testing.assert_close(diffusers_out, original_out, rtol=1e-4, atol=1e-4) + max_diff = (diffusers_out - original_out).abs().max().item() + print(f"Conversion successful. Max diff: {max_diff:.6f}") + + diffusers_model.to(dtype).save_pretrained(args.output_path) + print(f"Serialized model to {args.output_path}") + + +if __name__ == "__main__": + main() diff --git a/src/diffusers/models/autoencoders/__init__.py b/src/diffusers/models/autoencoders/__init__.py index ca0cac1a57..38d52f0eb5 100644 --- a/src/diffusers/models/autoencoders/__init__.py +++ b/src/diffusers/models/autoencoders/__init__.py @@ -10,6 +10,7 @@ from .autoencoder_kl_hunyuanimage import AutoencoderKLHunyuanImage from .autoencoder_kl_hunyuanimage_refiner import AutoencoderKLHunyuanImageRefiner from .autoencoder_kl_hunyuanvideo15 import AutoencoderKLHunyuanVideo15 from .autoencoder_kl_ltx import AutoencoderKLLTXVideo +from .autoencoder_kl_ltx2_audio import AutoencoderKLLTX2Audio from .autoencoder_kl_ltx2 import AutoencoderKLLTX2Video from .autoencoder_kl_magvit import AutoencoderKLMagvit from .autoencoder_kl_mochi import AutoencoderKLMochi diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_ltx2_audio.py b/src/diffusers/models/autoencoders/autoencoder_kl_ltx2_audio.py new file mode 100644 index 0000000000..e3c0ef2c3d --- /dev/null +++ b/src/diffusers/models/autoencoders/autoencoder_kl_ltx2_audio.py @@ -0,0 +1,633 @@ +# Copyright 2025 The Lightricks team and The HuggingFace Team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import namedtuple +from typing import Optional, Set, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ...configuration_utils import ConfigMixin, register_to_config +from ...utils.accelerate_utils import apply_forward_hook +from ..modeling_utils import ModelMixin +from .vae import AutoencoderMixin, DecoderOutput + + +LATENT_DOWNSAMPLE_FACTOR = 4 +SUPPORTED_CAUSAL_AXES = {"none", "width", "height", "width-compatibility"} + + +AudioLatentShape = namedtuple( + "AudioLatentShape", + [ + "batch", + "channels", + "frames", + "mel_bins", + ], +) + + +def _resolve_causality_axis(causality_axis: Optional[str] = None) -> Optional[str]: + normalized = "none" if causality_axis is None else str(causality_axis).lower() + if normalized not in SUPPORTED_CAUSAL_AXES: + raise NotImplementedError( + f"Unsupported causality_axis '{causality_axis}'. Supported: {sorted(SUPPORTED_CAUSAL_AXES)}" + ) + return None if normalized == "none" else normalized + + +def make_conv2d( + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int, int]], + stride: int = 1, + padding: Optional[Tuple[int, int, int, int]] = None, + dilation: int = 1, + groups: int = 1, + bias: bool = True, + causality_axis: Optional[str] = None, +) -> nn.Module: + if causality_axis is not None: + return LTX2AudioCausalConv2d( + in_channels, out_channels, kernel_size, stride, dilation, groups, bias, causality_axis + ) + if padding is None: + padding = kernel_size // 2 if isinstance(kernel_size, int) else tuple(k // 2 for k in kernel_size) + + return nn.Conv2d( + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + groups, + bias, + ) + + +class LTX2AudioCausalConv2d(nn.Module): + """ + A causal 2D convolution that pads asymmetrically along the causal axis. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int, int]], + stride: int = 1, + dilation: Union[int, Tuple[int, int]] = 1, + groups: int = 1, + bias: bool = True, + causality_axis: str = "height", + ) -> None: + super().__init__() + + self.causality_axis = causality_axis + kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size + dilation = (dilation, dilation) if isinstance(dilation, int) else dilation + + pad_h = (kernel_size[0] - 1) * dilation[0] + pad_w = (kernel_size[1] - 1) * dilation[1] + + if self.causality_axis == "none": + padding = (pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2) + elif self.causality_axis in {"width", "width-compatibility"}: + padding = (pad_w, 0, pad_h // 2, pad_h - pad_h // 2) + elif self.causality_axis == "height": + padding = (pad_w // 2, pad_w - pad_w // 2, pad_h, 0) + else: + raise ValueError(f"Invalid causality_axis: {causality_axis}") + + self.padding = padding + self.conv = nn.Conv2d( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=0, + dilation=dilation, + groups=groups, + bias=bias, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = F.pad(x, self.padding) + return self.conv(x) + + +class LTX2AudioPixelNorm(nn.Module): + """ + Per-pixel (per-location) RMS normalization layer. + """ + + def __init__(self, dim: int = 1, eps: float = 1e-8) -> None: + super().__init__() + self.dim = dim + self.eps = eps + + def forward(self, x: torch.Tensor) -> torch.Tensor: + mean_sq = torch.mean(x**2, dim=self.dim, keepdim=True) + rms = torch.sqrt(mean_sq + self.eps) + return x / rms + + +def build_normalization_layer(in_channels: int, *, num_groups: int = 32, normtype: str = "group") -> nn.Module: + if normtype == "group": + return nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) + if normtype == "pixel": + return LTX2AudioPixelNorm(dim=1, eps=1e-6) + raise ValueError(f"Invalid normalization type: {normtype}") + + +class LTX2AudioAttnBlock(nn.Module): + def __init__( + self, + in_channels: int, + norm_type: str = "group", + ) -> None: + super().__init__() + self.in_channels = in_channels + + self.norm = build_normalization_layer(in_channels, normtype=norm_type) + self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + h_ = self.norm(x) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + batch, channels, height, width = q.shape + q = q.reshape(batch, channels, height * width).permute(0, 2, 1).contiguous() + k = k.reshape(batch, channels, height * width).contiguous() + attn = torch.bmm(q, k) * (int(channels) ** (-0.5)) + attn = torch.nn.functional.softmax(attn, dim=2) + + v = v.reshape(batch, channels, height * width) + attn = attn.permute(0, 2, 1).contiguous() + h_ = torch.bmm(v, attn).reshape(batch, channels, height, width) + + h_ = self.proj_out(h_) + return x + h_ + + +class LTX2AudioResnetBlock(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: Optional[int] = None, + conv_shortcut: bool = False, + dropout: float = 0.0, + temb_channels: int = 512, + norm_type: str = "group", + causality_axis: str = "height", + ) -> None: + super().__init__() + self.causality_axis = causality_axis + + if self.causality_axis is not None and self.causality_axis != "none" and norm_type == "group": + raise ValueError("Causal ResnetBlock with GroupNorm is not supported.") + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = build_normalization_layer(in_channels, normtype=norm_type) + self.non_linearity = nn.SiLU() + self.conv1 = make_conv2d(in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis) + if temb_channels > 0: + self.temb_proj = nn.Linear(temb_channels, out_channels) + self.norm2 = build_normalization_layer(out_channels, normtype=norm_type) + self.dropout = nn.Dropout(dropout) + self.conv2 = make_conv2d(out_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = make_conv2d( + in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis + ) + else: + self.nin_shortcut = make_conv2d( + in_channels, out_channels, kernel_size=1, stride=1, causality_axis=causality_axis + ) + + def forward(self, x: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor: + h = self.norm1(x) + h = self.non_linearity(h) + h = self.conv1(h) + + if temb is not None: + h = h + self.temb_proj(self.non_linearity(temb))[:, :, None, None] + + h = self.norm2(h) + h = self.non_linearity(h) + h = self.dropout(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + x = self.conv_shortcut(x) if self.use_conv_shortcut else self.nin_shortcut(x) + + return x + h + + +class LTX2AudioUpsample(nn.Module): + def __init__(self, in_channels: int, with_conv: bool, causality_axis: Optional[str] = "height") -> None: + super().__init__() + self.with_conv = with_conv + self.causality_axis = causality_axis + if self.with_conv: + self.conv = make_conv2d(in_channels, in_channels, kernel_size=3, stride=1, causality_axis=causality_axis) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + if self.with_conv: + x = self.conv(x) + if self.causality_axis is None or self.causality_axis == "none": + pass + elif self.causality_axis == "height": + x = x[:, :, 1:, :] + elif self.causality_axis == "width": + x = x[:, :, :, 1:] + elif self.causality_axis == "width-compatibility": + pass + else: + raise ValueError(f"Invalid causality_axis: {self.causality_axis}") + + return x + + +class LTX2AudioPerChannelStatistics(nn.Module): + """ + Per-channel statistics for normalizing and denormalizing the latent representation. This statics is computed over + the entire dataset and stored in model's checkpoint under AudioVAE state_dict + """ + + def __init__(self, latent_channels: int = 128) -> None: + super().__init__() + # Sayak notes: `empty` always causes problems in CI. Should we consider using `torch.ones`? + self.register_buffer("std-of-means", torch.empty(latent_channels)) + self.register_buffer("mean-of-means", torch.empty(latent_channels)) + + def denormalize(self, x: torch.Tensor) -> torch.Tensor: + return (x * self.get_buffer("std-of-means").to(x)) + self.get_buffer("mean-of-means").to(x) + + def normalize(self, x: torch.Tensor) -> torch.Tensor: + return (x - self.get_buffer("mean-of-means").to(x)) / self.get_buffer("std-of-means").to(x) + + +class LTX2AudioAudioPatchifier: + """ + Patchifier for spectrogram/audio latents. + """ + + def __init__( + self, + patch_size: int, + sample_rate: int = 16000, + hop_length: int = 160, + audio_latent_downsample_factor: int = 4, + is_causal: bool = True, + ): + self.hop_length = hop_length + self.sample_rate = sample_rate + self.audio_latent_downsample_factor = audio_latent_downsample_factor + self.is_causal = is_causal + self._patch_size = (1, patch_size, patch_size) + + def patchify(self, audio_latents: torch.Tensor) -> torch.Tensor: + batch, channels, time, freq = audio_latents.shape + return audio_latents.permute(0, 2, 1, 3).reshape(batch, time, channels * freq) + + def unpatchify(self, audio_latents: torch.Tensor, output_shape: AudioLatentShape) -> torch.Tensor: + batch, time, _ = audio_latents.shape + channels = output_shape.channels + freq = output_shape.mel_bins + return audio_latents.view(batch, time, channels, freq).permute(0, 2, 1, 3) + + @property + def patch_size(self) -> Tuple[int, int, int]: + return self._patch_size + + +class LTX2AudioDecoder(nn.Module): + """ + Symmetric decoder that reconstructs audio spectrograms from latent features. + + The decoder mirrors the encoder structure with configurable channel multipliers, attention resolutions, and causal + convolutions. + """ + + def __init__( + self, + base_channels: int, + output_channels: int, + num_res_blocks: int, + attn_resolutions: Set[int], + in_channels: int, + resolution: int, + latent_channels: int, + ch_mult: Tuple[int, ...] = (1, 2, 4, 8), + norm_type: str = "group", + causality_axis: Optional[str] = "width", + dropout: float = 0.0, + mid_block_add_attention: bool = True, + sample_rate: int = 16000, + mel_hop_length: int = 160, + is_causal: bool = True, + mel_bins: Optional[int] = None, + ) -> None: + super().__init__() + + resolved_causality_axis = _resolve_causality_axis(causality_axis) + + self.per_channel_statistics = LTX2AudioPerChannelStatistics(latent_channels=base_channels) + self.sample_rate = sample_rate + self.mel_hop_length = mel_hop_length + self.is_causal = is_causal + self.mel_bins = mel_bins + self.patchifier = LTX2AudioAudioPatchifier( + patch_size=1, + audio_latent_downsample_factor=LATENT_DOWNSAMPLE_FACTOR, + sample_rate=sample_rate, + hop_length=mel_hop_length, + is_causal=is_causal, + ) + + self.base_channels = base_channels + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.out_ch = output_channels + self.give_pre_end = False + self.tanh_out = False + self.norm_type = norm_type + self.latent_channels = latent_channels + self.channel_multipliers = ch_mult + self.attn_resolutions = attn_resolutions + self.causality_axis = resolved_causality_axis + + base_block_channels = base_channels * self.channel_multipliers[-1] + base_resolution = resolution // (2 ** (self.num_resolutions - 1)) + self.z_shape = (1, latent_channels, base_resolution, base_resolution) + + self.conv_in = make_conv2d( + latent_channels, base_block_channels, kernel_size=3, stride=1, causality_axis=self.causality_axis + ) + self.non_linearity = nn.SiLU() + self.mid = self._build_mid_layers(base_block_channels, dropout, mid_block_add_attention) + self.up, final_block_channels = self._build_up_path( + initial_block_channels=base_block_channels, + dropout=dropout, + resamp_with_conv=True, + ) + + self.norm_out = build_normalization_layer(final_block_channels, normtype=self.norm_type) + self.conv_out = make_conv2d( + final_block_channels, output_channels, kernel_size=3, stride=1, causality_axis=self.causality_axis + ) + + def _adjust_output_shape(self, decoded_output: torch.Tensor, target_shape: AudioLatentShape) -> torch.Tensor: + _, _, current_time, current_freq = decoded_output.shape + target_channels = target_shape.channels + target_time = target_shape.frames + target_freq = target_shape.mel_bins + + decoded_output = decoded_output[ + :, :target_channels, : min(current_time, target_time), : min(current_freq, target_freq) + ] + + time_padding_needed = target_time - decoded_output.shape[2] + freq_padding_needed = target_freq - decoded_output.shape[3] + + if time_padding_needed > 0 or freq_padding_needed > 0: + padding = ( + 0, + max(freq_padding_needed, 0), + 0, + max(time_padding_needed, 0), + ) + decoded_output = F.pad(decoded_output, padding) + + decoded_output = decoded_output[:, :target_channels, :target_time, :target_freq] + + return decoded_output + + def forward( + self, + sample: torch.Tensor, + ) -> torch.Tensor: + latent_shape = AudioLatentShape( + batch=sample.shape[0], + channels=sample.shape[1], + frames=sample.shape[2], + mel_bins=sample.shape[3], + ) + + sample_patched = self.patchifier.patchify(sample) + sample_denormalized = self.per_channel_statistics.denormalize(sample_patched) + sample = self.patchifier.unpatchify(sample_denormalized, latent_shape) + + target_frames = latent_shape.frames * LATENT_DOWNSAMPLE_FACTOR + + if self.causality_axis is not None: + target_frames = max(target_frames - (LATENT_DOWNSAMPLE_FACTOR - 1), 1) + + target_shape = AudioLatentShape( + batch=latent_shape.batch, + channels=self.out_ch, + frames=target_frames, + mel_bins=self.mel_bins if self.mel_bins is not None else latent_shape.mel_bins, + ) + + hidden_features = self.conv_in(sample) + hidden_features = self._run_mid_layers(hidden_features) + hidden_features = self._run_upsampling_path(hidden_features) + decoded_output = self._finalize_output(hidden_features) + + decoded_output = self._adjust_output_shape(decoded_output, target_shape) + + return decoded_output + + def _build_mid_layers(self, channels: int, dropout: float, add_attention: bool) -> nn.Module: + mid = nn.Module() + mid.block_1 = LTX2AudioResnetBlock( + in_channels=channels, + out_channels=channels, + temb_channels=self.temb_ch, + dropout=dropout, + norm_type=self.norm_type, + causality_axis=self.causality_axis, + ) + mid.attn_1 = LTX2AudioAttnBlock(channels, norm_type=self.norm_type) if add_attention else nn.Identity() + mid.block_2 = LTX2AudioResnetBlock( + in_channels=channels, + out_channels=channels, + temb_channels=self.temb_ch, + dropout=dropout, + norm_type=self.norm_type, + causality_axis=self.causality_axis, + ) + return mid + + def _build_up_path( + self, initial_block_channels: int, dropout: float, resamp_with_conv: bool + ) -> tuple[nn.ModuleList, int]: + up_modules = nn.ModuleList() + block_in = initial_block_channels + curr_res = self.resolution // (2 ** (self.num_resolutions - 1)) + + for level in reversed(range(self.num_resolutions)): + stage = nn.Module() + stage.block = nn.ModuleList() + stage.attn = nn.ModuleList() + block_out = self.base_channels * self.channel_multipliers[level] + + for _ in range(self.num_res_blocks + 1): + stage.block.append( + LTX2AudioResnetBlock( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + norm_type=self.norm_type, + causality_axis=self.causality_axis, + ) + ) + block_in = block_out + if self.attn_resolutions: + if curr_res in self.attn_resolutions: + stage.attn.append(LTX2AudioAttnBlock(block_in, norm_type=self.norm_type)) + + if level != 0: + stage.upsample = LTX2AudioUpsample(block_in, resamp_with_conv, causality_axis=self.causality_axis) + curr_res *= 2 + + up_modules.insert(0, stage) + + return up_modules, block_in + + def _run_mid_layers(self, features: torch.Tensor) -> torch.Tensor: + features = self.mid.block_1(features, temb=None) + features = self.mid.attn_1(features) + return self.mid.block_2(features, temb=None) + + def _run_upsampling_path(self, features: torch.Tensor) -> torch.Tensor: + for level in reversed(range(self.num_resolutions)): + stage = self.up[level] + for block_idx, block in enumerate(stage.block): + features = block(features, temb=None) + if stage.attn: + features = stage.attn[block_idx](features) + + if level != 0 and hasattr(stage, "upsample"): + features = stage.upsample(features) + + return features + + def _finalize_output(self, features: torch.Tensor) -> torch.Tensor: + if self.give_pre_end: + return features + + hidden = self.norm_out(features) + hidden = self.non_linearity(hidden) + decoded = self.conv_out(hidden) + return torch.tanh(decoded) if self.tanh_out else decoded + + +class AutoencoderKLLTX2Audio(ModelMixin, AutoencoderMixin, ConfigMixin): + r""" + LTX2 audio VAE. Currently, only implements the decoder. + """ + + _supports_gradient_checkpointing = False + + @register_to_config + def __init__( + self, + base_channels: int = 128, + output_channels: int = 2, + ch_mult: Tuple[int] = (1, 2, 4), + num_res_blocks: int = 2, + attn_resolutions: Optional[Tuple[int]] = None, + in_channels: int = 2, + resolution: int = 256, + latent_channels: int = 8, + norm_type: str = "pixel", + causality_axis: Optional[str] = "height", + dropout: float = 0.0, + mid_block_add_attention: bool = False, + sample_rate: int = 16000, + mel_hop_length: int = 160, + is_causal: bool = True, + mel_bins: Optional[int] = 64, + ) -> None: + super().__init__() + + resolved_causality_axis = _resolve_causality_axis(causality_axis) + attn_resolution_set = set(attn_resolutions) if attn_resolutions else attn_resolutions + + self.decoder = LTX2AudioDecoder( + base_channels=base_channels, + output_channels=output_channels, + ch_mult=ch_mult, + num_res_blocks=num_res_blocks, + attn_resolutions=attn_resolution_set, + in_channels=in_channels, + resolution=resolution, + latent_channels=latent_channels, + norm_type=norm_type, + causality_axis=resolved_causality_axis, + dropout=dropout, + mid_block_add_attention=mid_block_add_attention, + sample_rate=sample_rate, + mel_hop_length=mel_hop_length, + is_causal=is_causal, + mel_bins=mel_bins, + ) + + self.use_slicing = False + + @apply_forward_hook + def encode(self, x: torch.Tensor, return_dict: bool = True): + raise NotImplementedError("AutoencoderKLLTX2Audio does not implement encoding.") + + def _decode(self, z: torch.Tensor) -> torch.Tensor: + return self.decoder(z) + + @apply_forward_hook + def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: + if self.use_slicing and z.shape[0] > 1: + decoded_slices = [self._decode(z_slice) for z_slice in z.split(1)] + decoded = torch.cat(decoded_slices) + else: + decoded = self._decode(z) + + if not return_dict: + return (decoded,) + + return DecoderOutput(sample=decoded) + + def forward(self, *args, **kwargs): + raise NotImplementedError( + "This model doesn't have an encoder yet so we don't implement its `forward()`. Please use `decode()`." + )