mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
104
scripts/test_ltx2_audio_conversion.py
Normal file
104
scripts/test_ltx2_audio_conversion.py
Normal file
@@ -0,0 +1,104 @@
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
|
||||
def download_checkpoint(
|
||||
repo_id="diffusers-internal-dev/new-ltx-model",
|
||||
filename="ltx-av-step-1932500-interleaved-new-vae.safetensors",
|
||||
):
|
||||
ckpt_path = hf_hub_download(repo_id=repo_id, filename=filename)
|
||||
return ckpt_path
|
||||
|
||||
|
||||
def convert_state_dict(state_dict: dict) -> dict:
|
||||
converted = {}
|
||||
for key, value in state_dict.items():
|
||||
if not isinstance(value, torch.Tensor):
|
||||
continue
|
||||
new_key = key
|
||||
if new_key.startswith("decoder."):
|
||||
new_key = new_key[len("decoder.") :]
|
||||
converted[f"decoder.{new_key}"] = value
|
||||
return converted
|
||||
|
||||
|
||||
def load_original_decoder(device: torch.device, dtype: torch.dtype):
|
||||
from ltx_core.loader.single_gpu_model_builder import SingleGPUModelBuilder as Builder
|
||||
from ltx_core.model.audio_vae.model_configurator import AUDIO_VAE_DECODER_COMFY_KEYS_FILTER
|
||||
from ltx_core.model.audio_vae.model_configurator import VAEDecoderConfigurator as AudioDecoderConfigurator
|
||||
|
||||
checkpoint_path = download_checkpoint()
|
||||
|
||||
# The code below comes from `ltx-pipelines/src/ltx_pipelines/txt2vid.py`
|
||||
decoder = Builder(
|
||||
model_path=checkpoint_path,
|
||||
model_class_configurator=AudioDecoderConfigurator,
|
||||
model_sd_key_ops=AUDIO_VAE_DECODER_COMFY_KEYS_FILTER,
|
||||
).build(device=device)
|
||||
|
||||
decoder.eval()
|
||||
return decoder
|
||||
|
||||
|
||||
def build_diffusers_decoder():
|
||||
from diffusers.models.autoencoders import AutoencoderKLLTX2Audio
|
||||
|
||||
with torch.device("meta"):
|
||||
model = AutoencoderKLLTX2Audio()
|
||||
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(description="Validate LTX2 audio decoder conversion.")
|
||||
parser.add_argument("--device", type=str, default="cpu")
|
||||
parser.add_argument("--dtype", type=str, default="bfloat16", choices=["float32", "bfloat16", "float16"])
|
||||
parser.add_argument("--batch", type=int, default=2)
|
||||
parser.add_argument("--output-path", type=Path, required=True)
|
||||
args = parser.parse_args()
|
||||
|
||||
device = torch.device(args.device)
|
||||
dtype_map = {"float32": torch.float32, "bfloat16": torch.bfloat16, "float16": torch.float16}
|
||||
dtype = dtype_map[args.dtype]
|
||||
|
||||
original_decoder = load_original_decoder(device, dtype)
|
||||
diffusers_model = build_diffusers_decoder()
|
||||
|
||||
converted_state_dict = convert_state_dict(original_decoder.state_dict())
|
||||
diffusers_model.load_state_dict(converted_state_dict, assign=True, strict=False)
|
||||
|
||||
per_channel_len = original_decoder.per_channel_statistics.get_buffer("std-of-means").numel()
|
||||
latent_channels = diffusers_model.decoder.latent_channels
|
||||
mel_bins_for_match = per_channel_len // latent_channels if per_channel_len % latent_channels == 0 else None
|
||||
|
||||
levels = len(diffusers_model.decoder.channel_multipliers)
|
||||
latent_height = diffusers_model.decoder.resolution // (2 ** (levels - 1))
|
||||
latent_width = mel_bins_for_match or latent_height
|
||||
|
||||
dummy = torch.randn(
|
||||
args.batch,
|
||||
diffusers_model.decoder.latent_channels,
|
||||
latent_height,
|
||||
latent_width,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
original_out = original_decoder(dummy)
|
||||
diffusers_out = diffusers_model.decode(dummy).sample
|
||||
|
||||
torch.testing.assert_close(diffusers_out, original_out, rtol=1e-4, atol=1e-4)
|
||||
max_diff = (diffusers_out - original_out).abs().max().item()
|
||||
print(f"Conversion successful. Max diff: {max_diff:.6f}")
|
||||
|
||||
diffusers_model.to(dtype).save_pretrained(args.output_path)
|
||||
print(f"Serialized model to {args.output_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -10,6 +10,7 @@ from .autoencoder_kl_hunyuanimage import AutoencoderKLHunyuanImage
|
||||
from .autoencoder_kl_hunyuanimage_refiner import AutoencoderKLHunyuanImageRefiner
|
||||
from .autoencoder_kl_hunyuanvideo15 import AutoencoderKLHunyuanVideo15
|
||||
from .autoencoder_kl_ltx import AutoencoderKLLTXVideo
|
||||
from .autoencoder_kl_ltx2_audio import AutoencoderKLLTX2Audio
|
||||
from .autoencoder_kl_ltx2 import AutoencoderKLLTX2Video
|
||||
from .autoencoder_kl_magvit import AutoencoderKLMagvit
|
||||
from .autoencoder_kl_mochi import AutoencoderKLMochi
|
||||
|
||||
633
src/diffusers/models/autoencoders/autoencoder_kl_ltx2_audio.py
Normal file
633
src/diffusers/models/autoencoders/autoencoder_kl_ltx2_audio.py
Normal file
@@ -0,0 +1,633 @@
|
||||
# Copyright 2025 The Lightricks team and The HuggingFace Team.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from collections import namedtuple
|
||||
from typing import Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...utils.accelerate_utils import apply_forward_hook
|
||||
from ..modeling_utils import ModelMixin
|
||||
from .vae import AutoencoderMixin, DecoderOutput
|
||||
|
||||
|
||||
LATENT_DOWNSAMPLE_FACTOR = 4
|
||||
SUPPORTED_CAUSAL_AXES = {"none", "width", "height", "width-compatibility"}
|
||||
|
||||
|
||||
AudioLatentShape = namedtuple(
|
||||
"AudioLatentShape",
|
||||
[
|
||||
"batch",
|
||||
"channels",
|
||||
"frames",
|
||||
"mel_bins",
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def _resolve_causality_axis(causality_axis: Optional[str] = None) -> Optional[str]:
|
||||
normalized = "none" if causality_axis is None else str(causality_axis).lower()
|
||||
if normalized not in SUPPORTED_CAUSAL_AXES:
|
||||
raise NotImplementedError(
|
||||
f"Unsupported causality_axis '{causality_axis}'. Supported: {sorted(SUPPORTED_CAUSAL_AXES)}"
|
||||
)
|
||||
return None if normalized == "none" else normalized
|
||||
|
||||
|
||||
def make_conv2d(
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: Union[int, Tuple[int, int]],
|
||||
stride: int = 1,
|
||||
padding: Optional[Tuple[int, int, int, int]] = None,
|
||||
dilation: int = 1,
|
||||
groups: int = 1,
|
||||
bias: bool = True,
|
||||
causality_axis: Optional[str] = None,
|
||||
) -> nn.Module:
|
||||
if causality_axis is not None:
|
||||
return LTX2AudioCausalConv2d(
|
||||
in_channels, out_channels, kernel_size, stride, dilation, groups, bias, causality_axis
|
||||
)
|
||||
if padding is None:
|
||||
padding = kernel_size // 2 if isinstance(kernel_size, int) else tuple(k // 2 for k in kernel_size)
|
||||
|
||||
return nn.Conv2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
groups,
|
||||
bias,
|
||||
)
|
||||
|
||||
|
||||
class LTX2AudioCausalConv2d(nn.Module):
|
||||
"""
|
||||
A causal 2D convolution that pads asymmetrically along the causal axis.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: Union[int, Tuple[int, int]],
|
||||
stride: int = 1,
|
||||
dilation: Union[int, Tuple[int, int]] = 1,
|
||||
groups: int = 1,
|
||||
bias: bool = True,
|
||||
causality_axis: str = "height",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.causality_axis = causality_axis
|
||||
kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size
|
||||
dilation = (dilation, dilation) if isinstance(dilation, int) else dilation
|
||||
|
||||
pad_h = (kernel_size[0] - 1) * dilation[0]
|
||||
pad_w = (kernel_size[1] - 1) * dilation[1]
|
||||
|
||||
if self.causality_axis == "none":
|
||||
padding = (pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2)
|
||||
elif self.causality_axis in {"width", "width-compatibility"}:
|
||||
padding = (pad_w, 0, pad_h // 2, pad_h - pad_h // 2)
|
||||
elif self.causality_axis == "height":
|
||||
padding = (pad_w // 2, pad_w - pad_w // 2, pad_h, 0)
|
||||
else:
|
||||
raise ValueError(f"Invalid causality_axis: {causality_axis}")
|
||||
|
||||
self.padding = padding
|
||||
self.conv = nn.Conv2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride=stride,
|
||||
padding=0,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = F.pad(x, self.padding)
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class LTX2AudioPixelNorm(nn.Module):
|
||||
"""
|
||||
Per-pixel (per-location) RMS normalization layer.
|
||||
"""
|
||||
|
||||
def __init__(self, dim: int = 1, eps: float = 1e-8) -> None:
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.eps = eps
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
mean_sq = torch.mean(x**2, dim=self.dim, keepdim=True)
|
||||
rms = torch.sqrt(mean_sq + self.eps)
|
||||
return x / rms
|
||||
|
||||
|
||||
def build_normalization_layer(in_channels: int, *, num_groups: int = 32, normtype: str = "group") -> nn.Module:
|
||||
if normtype == "group":
|
||||
return nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
if normtype == "pixel":
|
||||
return LTX2AudioPixelNorm(dim=1, eps=1e-6)
|
||||
raise ValueError(f"Invalid normalization type: {normtype}")
|
||||
|
||||
|
||||
class LTX2AudioAttnBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
norm_type: str = "group",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.norm = build_normalization_layer(in_channels, normtype=norm_type)
|
||||
self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
h_ = self.norm(x)
|
||||
q = self.q(h_)
|
||||
k = self.k(h_)
|
||||
v = self.v(h_)
|
||||
|
||||
batch, channels, height, width = q.shape
|
||||
q = q.reshape(batch, channels, height * width).permute(0, 2, 1).contiguous()
|
||||
k = k.reshape(batch, channels, height * width).contiguous()
|
||||
attn = torch.bmm(q, k) * (int(channels) ** (-0.5))
|
||||
attn = torch.nn.functional.softmax(attn, dim=2)
|
||||
|
||||
v = v.reshape(batch, channels, height * width)
|
||||
attn = attn.permute(0, 2, 1).contiguous()
|
||||
h_ = torch.bmm(v, attn).reshape(batch, channels, height, width)
|
||||
|
||||
h_ = self.proj_out(h_)
|
||||
return x + h_
|
||||
|
||||
|
||||
class LTX2AudioResnetBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: Optional[int] = None,
|
||||
conv_shortcut: bool = False,
|
||||
dropout: float = 0.0,
|
||||
temb_channels: int = 512,
|
||||
norm_type: str = "group",
|
||||
causality_axis: str = "height",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.causality_axis = causality_axis
|
||||
|
||||
if self.causality_axis is not None and self.causality_axis != "none" and norm_type == "group":
|
||||
raise ValueError("Causal ResnetBlock with GroupNorm is not supported.")
|
||||
self.in_channels = in_channels
|
||||
out_channels = in_channels if out_channels is None else out_channels
|
||||
self.out_channels = out_channels
|
||||
self.use_conv_shortcut = conv_shortcut
|
||||
|
||||
self.norm1 = build_normalization_layer(in_channels, normtype=norm_type)
|
||||
self.non_linearity = nn.SiLU()
|
||||
self.conv1 = make_conv2d(in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis)
|
||||
if temb_channels > 0:
|
||||
self.temb_proj = nn.Linear(temb_channels, out_channels)
|
||||
self.norm2 = build_normalization_layer(out_channels, normtype=norm_type)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.conv2 = make_conv2d(out_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis)
|
||||
if self.in_channels != self.out_channels:
|
||||
if self.use_conv_shortcut:
|
||||
self.conv_shortcut = make_conv2d(
|
||||
in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis
|
||||
)
|
||||
else:
|
||||
self.nin_shortcut = make_conv2d(
|
||||
in_channels, out_channels, kernel_size=1, stride=1, causality_axis=causality_axis
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
h = self.norm1(x)
|
||||
h = self.non_linearity(h)
|
||||
h = self.conv1(h)
|
||||
|
||||
if temb is not None:
|
||||
h = h + self.temb_proj(self.non_linearity(temb))[:, :, None, None]
|
||||
|
||||
h = self.norm2(h)
|
||||
h = self.non_linearity(h)
|
||||
h = self.dropout(h)
|
||||
h = self.conv2(h)
|
||||
|
||||
if self.in_channels != self.out_channels:
|
||||
x = self.conv_shortcut(x) if self.use_conv_shortcut else self.nin_shortcut(x)
|
||||
|
||||
return x + h
|
||||
|
||||
|
||||
class LTX2AudioUpsample(nn.Module):
|
||||
def __init__(self, in_channels: int, with_conv: bool, causality_axis: Optional[str] = "height") -> None:
|
||||
super().__init__()
|
||||
self.with_conv = with_conv
|
||||
self.causality_axis = causality_axis
|
||||
if self.with_conv:
|
||||
self.conv = make_conv2d(in_channels, in_channels, kernel_size=3, stride=1, causality_axis=causality_axis)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
||||
if self.with_conv:
|
||||
x = self.conv(x)
|
||||
if self.causality_axis is None or self.causality_axis == "none":
|
||||
pass
|
||||
elif self.causality_axis == "height":
|
||||
x = x[:, :, 1:, :]
|
||||
elif self.causality_axis == "width":
|
||||
x = x[:, :, :, 1:]
|
||||
elif self.causality_axis == "width-compatibility":
|
||||
pass
|
||||
else:
|
||||
raise ValueError(f"Invalid causality_axis: {self.causality_axis}")
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class LTX2AudioPerChannelStatistics(nn.Module):
|
||||
"""
|
||||
Per-channel statistics for normalizing and denormalizing the latent representation. This statics is computed over
|
||||
the entire dataset and stored in model's checkpoint under AudioVAE state_dict
|
||||
"""
|
||||
|
||||
def __init__(self, latent_channels: int = 128) -> None:
|
||||
super().__init__()
|
||||
# Sayak notes: `empty` always causes problems in CI. Should we consider using `torch.ones`?
|
||||
self.register_buffer("std-of-means", torch.empty(latent_channels))
|
||||
self.register_buffer("mean-of-means", torch.empty(latent_channels))
|
||||
|
||||
def denormalize(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return (x * self.get_buffer("std-of-means").to(x)) + self.get_buffer("mean-of-means").to(x)
|
||||
|
||||
def normalize(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return (x - self.get_buffer("mean-of-means").to(x)) / self.get_buffer("std-of-means").to(x)
|
||||
|
||||
|
||||
class LTX2AudioAudioPatchifier:
|
||||
"""
|
||||
Patchifier for spectrogram/audio latents.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
patch_size: int,
|
||||
sample_rate: int = 16000,
|
||||
hop_length: int = 160,
|
||||
audio_latent_downsample_factor: int = 4,
|
||||
is_causal: bool = True,
|
||||
):
|
||||
self.hop_length = hop_length
|
||||
self.sample_rate = sample_rate
|
||||
self.audio_latent_downsample_factor = audio_latent_downsample_factor
|
||||
self.is_causal = is_causal
|
||||
self._patch_size = (1, patch_size, patch_size)
|
||||
|
||||
def patchify(self, audio_latents: torch.Tensor) -> torch.Tensor:
|
||||
batch, channels, time, freq = audio_latents.shape
|
||||
return audio_latents.permute(0, 2, 1, 3).reshape(batch, time, channels * freq)
|
||||
|
||||
def unpatchify(self, audio_latents: torch.Tensor, output_shape: AudioLatentShape) -> torch.Tensor:
|
||||
batch, time, _ = audio_latents.shape
|
||||
channels = output_shape.channels
|
||||
freq = output_shape.mel_bins
|
||||
return audio_latents.view(batch, time, channels, freq).permute(0, 2, 1, 3)
|
||||
|
||||
@property
|
||||
def patch_size(self) -> Tuple[int, int, int]:
|
||||
return self._patch_size
|
||||
|
||||
|
||||
class LTX2AudioDecoder(nn.Module):
|
||||
"""
|
||||
Symmetric decoder that reconstructs audio spectrograms from latent features.
|
||||
|
||||
The decoder mirrors the encoder structure with configurable channel multipliers, attention resolutions, and causal
|
||||
convolutions.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_channels: int,
|
||||
output_channels: int,
|
||||
num_res_blocks: int,
|
||||
attn_resolutions: Set[int],
|
||||
in_channels: int,
|
||||
resolution: int,
|
||||
latent_channels: int,
|
||||
ch_mult: Tuple[int, ...] = (1, 2, 4, 8),
|
||||
norm_type: str = "group",
|
||||
causality_axis: Optional[str] = "width",
|
||||
dropout: float = 0.0,
|
||||
mid_block_add_attention: bool = True,
|
||||
sample_rate: int = 16000,
|
||||
mel_hop_length: int = 160,
|
||||
is_causal: bool = True,
|
||||
mel_bins: Optional[int] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
resolved_causality_axis = _resolve_causality_axis(causality_axis)
|
||||
|
||||
self.per_channel_statistics = LTX2AudioPerChannelStatistics(latent_channels=base_channels)
|
||||
self.sample_rate = sample_rate
|
||||
self.mel_hop_length = mel_hop_length
|
||||
self.is_causal = is_causal
|
||||
self.mel_bins = mel_bins
|
||||
self.patchifier = LTX2AudioAudioPatchifier(
|
||||
patch_size=1,
|
||||
audio_latent_downsample_factor=LATENT_DOWNSAMPLE_FACTOR,
|
||||
sample_rate=sample_rate,
|
||||
hop_length=mel_hop_length,
|
||||
is_causal=is_causal,
|
||||
)
|
||||
|
||||
self.base_channels = base_channels
|
||||
self.temb_ch = 0
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.resolution = resolution
|
||||
self.in_channels = in_channels
|
||||
self.out_ch = output_channels
|
||||
self.give_pre_end = False
|
||||
self.tanh_out = False
|
||||
self.norm_type = norm_type
|
||||
self.latent_channels = latent_channels
|
||||
self.channel_multipliers = ch_mult
|
||||
self.attn_resolutions = attn_resolutions
|
||||
self.causality_axis = resolved_causality_axis
|
||||
|
||||
base_block_channels = base_channels * self.channel_multipliers[-1]
|
||||
base_resolution = resolution // (2 ** (self.num_resolutions - 1))
|
||||
self.z_shape = (1, latent_channels, base_resolution, base_resolution)
|
||||
|
||||
self.conv_in = make_conv2d(
|
||||
latent_channels, base_block_channels, kernel_size=3, stride=1, causality_axis=self.causality_axis
|
||||
)
|
||||
self.non_linearity = nn.SiLU()
|
||||
self.mid = self._build_mid_layers(base_block_channels, dropout, mid_block_add_attention)
|
||||
self.up, final_block_channels = self._build_up_path(
|
||||
initial_block_channels=base_block_channels,
|
||||
dropout=dropout,
|
||||
resamp_with_conv=True,
|
||||
)
|
||||
|
||||
self.norm_out = build_normalization_layer(final_block_channels, normtype=self.norm_type)
|
||||
self.conv_out = make_conv2d(
|
||||
final_block_channels, output_channels, kernel_size=3, stride=1, causality_axis=self.causality_axis
|
||||
)
|
||||
|
||||
def _adjust_output_shape(self, decoded_output: torch.Tensor, target_shape: AudioLatentShape) -> torch.Tensor:
|
||||
_, _, current_time, current_freq = decoded_output.shape
|
||||
target_channels = target_shape.channels
|
||||
target_time = target_shape.frames
|
||||
target_freq = target_shape.mel_bins
|
||||
|
||||
decoded_output = decoded_output[
|
||||
:, :target_channels, : min(current_time, target_time), : min(current_freq, target_freq)
|
||||
]
|
||||
|
||||
time_padding_needed = target_time - decoded_output.shape[2]
|
||||
freq_padding_needed = target_freq - decoded_output.shape[3]
|
||||
|
||||
if time_padding_needed > 0 or freq_padding_needed > 0:
|
||||
padding = (
|
||||
0,
|
||||
max(freq_padding_needed, 0),
|
||||
0,
|
||||
max(time_padding_needed, 0),
|
||||
)
|
||||
decoded_output = F.pad(decoded_output, padding)
|
||||
|
||||
decoded_output = decoded_output[:, :target_channels, :target_time, :target_freq]
|
||||
|
||||
return decoded_output
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
latent_shape = AudioLatentShape(
|
||||
batch=sample.shape[0],
|
||||
channels=sample.shape[1],
|
||||
frames=sample.shape[2],
|
||||
mel_bins=sample.shape[3],
|
||||
)
|
||||
|
||||
sample_patched = self.patchifier.patchify(sample)
|
||||
sample_denormalized = self.per_channel_statistics.denormalize(sample_patched)
|
||||
sample = self.patchifier.unpatchify(sample_denormalized, latent_shape)
|
||||
|
||||
target_frames = latent_shape.frames * LATENT_DOWNSAMPLE_FACTOR
|
||||
|
||||
if self.causality_axis is not None:
|
||||
target_frames = max(target_frames - (LATENT_DOWNSAMPLE_FACTOR - 1), 1)
|
||||
|
||||
target_shape = AudioLatentShape(
|
||||
batch=latent_shape.batch,
|
||||
channels=self.out_ch,
|
||||
frames=target_frames,
|
||||
mel_bins=self.mel_bins if self.mel_bins is not None else latent_shape.mel_bins,
|
||||
)
|
||||
|
||||
hidden_features = self.conv_in(sample)
|
||||
hidden_features = self._run_mid_layers(hidden_features)
|
||||
hidden_features = self._run_upsampling_path(hidden_features)
|
||||
decoded_output = self._finalize_output(hidden_features)
|
||||
|
||||
decoded_output = self._adjust_output_shape(decoded_output, target_shape)
|
||||
|
||||
return decoded_output
|
||||
|
||||
def _build_mid_layers(self, channels: int, dropout: float, add_attention: bool) -> nn.Module:
|
||||
mid = nn.Module()
|
||||
mid.block_1 = LTX2AudioResnetBlock(
|
||||
in_channels=channels,
|
||||
out_channels=channels,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
norm_type=self.norm_type,
|
||||
causality_axis=self.causality_axis,
|
||||
)
|
||||
mid.attn_1 = LTX2AudioAttnBlock(channels, norm_type=self.norm_type) if add_attention else nn.Identity()
|
||||
mid.block_2 = LTX2AudioResnetBlock(
|
||||
in_channels=channels,
|
||||
out_channels=channels,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
norm_type=self.norm_type,
|
||||
causality_axis=self.causality_axis,
|
||||
)
|
||||
return mid
|
||||
|
||||
def _build_up_path(
|
||||
self, initial_block_channels: int, dropout: float, resamp_with_conv: bool
|
||||
) -> tuple[nn.ModuleList, int]:
|
||||
up_modules = nn.ModuleList()
|
||||
block_in = initial_block_channels
|
||||
curr_res = self.resolution // (2 ** (self.num_resolutions - 1))
|
||||
|
||||
for level in reversed(range(self.num_resolutions)):
|
||||
stage = nn.Module()
|
||||
stage.block = nn.ModuleList()
|
||||
stage.attn = nn.ModuleList()
|
||||
block_out = self.base_channels * self.channel_multipliers[level]
|
||||
|
||||
for _ in range(self.num_res_blocks + 1):
|
||||
stage.block.append(
|
||||
LTX2AudioResnetBlock(
|
||||
in_channels=block_in,
|
||||
out_channels=block_out,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
norm_type=self.norm_type,
|
||||
causality_axis=self.causality_axis,
|
||||
)
|
||||
)
|
||||
block_in = block_out
|
||||
if self.attn_resolutions:
|
||||
if curr_res in self.attn_resolutions:
|
||||
stage.attn.append(LTX2AudioAttnBlock(block_in, norm_type=self.norm_type))
|
||||
|
||||
if level != 0:
|
||||
stage.upsample = LTX2AudioUpsample(block_in, resamp_with_conv, causality_axis=self.causality_axis)
|
||||
curr_res *= 2
|
||||
|
||||
up_modules.insert(0, stage)
|
||||
|
||||
return up_modules, block_in
|
||||
|
||||
def _run_mid_layers(self, features: torch.Tensor) -> torch.Tensor:
|
||||
features = self.mid.block_1(features, temb=None)
|
||||
features = self.mid.attn_1(features)
|
||||
return self.mid.block_2(features, temb=None)
|
||||
|
||||
def _run_upsampling_path(self, features: torch.Tensor) -> torch.Tensor:
|
||||
for level in reversed(range(self.num_resolutions)):
|
||||
stage = self.up[level]
|
||||
for block_idx, block in enumerate(stage.block):
|
||||
features = block(features, temb=None)
|
||||
if stage.attn:
|
||||
features = stage.attn[block_idx](features)
|
||||
|
||||
if level != 0 and hasattr(stage, "upsample"):
|
||||
features = stage.upsample(features)
|
||||
|
||||
return features
|
||||
|
||||
def _finalize_output(self, features: torch.Tensor) -> torch.Tensor:
|
||||
if self.give_pre_end:
|
||||
return features
|
||||
|
||||
hidden = self.norm_out(features)
|
||||
hidden = self.non_linearity(hidden)
|
||||
decoded = self.conv_out(hidden)
|
||||
return torch.tanh(decoded) if self.tanh_out else decoded
|
||||
|
||||
|
||||
class AutoencoderKLLTX2Audio(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
r"""
|
||||
LTX2 audio VAE. Currently, only implements the decoder.
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = False
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
base_channels: int = 128,
|
||||
output_channels: int = 2,
|
||||
ch_mult: Tuple[int] = (1, 2, 4),
|
||||
num_res_blocks: int = 2,
|
||||
attn_resolutions: Optional[Tuple[int]] = None,
|
||||
in_channels: int = 2,
|
||||
resolution: int = 256,
|
||||
latent_channels: int = 8,
|
||||
norm_type: str = "pixel",
|
||||
causality_axis: Optional[str] = "height",
|
||||
dropout: float = 0.0,
|
||||
mid_block_add_attention: bool = False,
|
||||
sample_rate: int = 16000,
|
||||
mel_hop_length: int = 160,
|
||||
is_causal: bool = True,
|
||||
mel_bins: Optional[int] = 64,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
resolved_causality_axis = _resolve_causality_axis(causality_axis)
|
||||
attn_resolution_set = set(attn_resolutions) if attn_resolutions else attn_resolutions
|
||||
|
||||
self.decoder = LTX2AudioDecoder(
|
||||
base_channels=base_channels,
|
||||
output_channels=output_channels,
|
||||
ch_mult=ch_mult,
|
||||
num_res_blocks=num_res_blocks,
|
||||
attn_resolutions=attn_resolution_set,
|
||||
in_channels=in_channels,
|
||||
resolution=resolution,
|
||||
latent_channels=latent_channels,
|
||||
norm_type=norm_type,
|
||||
causality_axis=resolved_causality_axis,
|
||||
dropout=dropout,
|
||||
mid_block_add_attention=mid_block_add_attention,
|
||||
sample_rate=sample_rate,
|
||||
mel_hop_length=mel_hop_length,
|
||||
is_causal=is_causal,
|
||||
mel_bins=mel_bins,
|
||||
)
|
||||
|
||||
self.use_slicing = False
|
||||
|
||||
@apply_forward_hook
|
||||
def encode(self, x: torch.Tensor, return_dict: bool = True):
|
||||
raise NotImplementedError("AutoencoderKLLTX2Audio does not implement encoding.")
|
||||
|
||||
def _decode(self, z: torch.Tensor) -> torch.Tensor:
|
||||
return self.decoder(z)
|
||||
|
||||
@apply_forward_hook
|
||||
def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
||||
if self.use_slicing and z.shape[0] > 1:
|
||||
decoded_slices = [self._decode(z_slice) for z_slice in z.split(1)]
|
||||
decoded = torch.cat(decoded_slices)
|
||||
else:
|
||||
decoded = self._decode(z)
|
||||
|
||||
if not return_dict:
|
||||
return (decoded,)
|
||||
|
||||
return DecoderOutput(sample=decoded)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
raise NotImplementedError(
|
||||
"This model doesn't have an encoder yet so we don't implement its `forward()`. Please use `decode()`."
|
||||
)
|
||||
Reference in New Issue
Block a user