mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Support LTX 2.0 audio VAE encoder
This commit is contained in:
@@ -148,10 +148,7 @@ LTX_2_0_VAE_SPECIAL_KEYS_REMAP = {
|
||||
"per_channel_statistics.mean-of-stds": remove_keys_inplace,
|
||||
}
|
||||
|
||||
LTX_2_0_AUDIO_VAE_SPECIAL_KEYS_REMAP = {
|
||||
"encoder": remove_keys_inplace,
|
||||
"per_channel_statistics": convert_ltx2_audio_vae_per_channel_statistics,
|
||||
}
|
||||
LTX_2_0_AUDIO_VAE_SPECIAL_KEYS_REMAP = {}
|
||||
|
||||
LTX_2_0_VOCODER_SPECIAL_KEYS_REMAP = {}
|
||||
|
||||
@@ -499,6 +496,7 @@ def get_ltx2_audio_vae_config(version: str) -> Tuple[Dict[str, Any], Dict[str, A
|
||||
"mel_hop_length": 160,
|
||||
"is_causal": True,
|
||||
"mel_bins": 64,
|
||||
"double_z": True,
|
||||
},
|
||||
}
|
||||
rename_dict = LTX_2_0_AUDIO_VAE_RENAME_DICT
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Optional, Set, Tuple, Union
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -21,8 +21,9 @@ import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...utils.accelerate_utils import apply_forward_hook
|
||||
from ..modeling_outputs import AutoencoderKLOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from .vae import AutoencoderMixin, DecoderOutput
|
||||
from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution
|
||||
|
||||
|
||||
LATENT_DOWNSAMPLE_FACTOR = 4
|
||||
@@ -219,6 +220,40 @@ class LTX2AudioResnetBlock(nn.Module):
|
||||
return x + h
|
||||
|
||||
|
||||
class LTX2AudioDownsample(nn.Module):
|
||||
def __init__(self, in_channels: int, with_conv: bool, causality_axis: Optional[str] = "height") -> None:
|
||||
super().__init__()
|
||||
self.with_conv = with_conv
|
||||
self.causality_axis = causality_axis
|
||||
|
||||
if self.with_conv:
|
||||
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
if self.with_conv:
|
||||
# Padding tuple is in the order: (left, right, top, bottom).
|
||||
if self.causality_axis == "none":
|
||||
pad = (0, 1, 0, 1)
|
||||
elif self.causality_axis == "width":
|
||||
pad = (2, 0, 0, 1)
|
||||
elif self.causality_axis == "height":
|
||||
pad = (0, 1, 2, 0)
|
||||
elif self.causality_axis == "width-compatibility":
|
||||
pad = (1, 0, 0, 1)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid `causality_axis` {self.causality_axis}; supported values are `none`, `width`, `height`,"
|
||||
f" and `width-compatibility`."
|
||||
)
|
||||
|
||||
x = F.pad(x, pad, mode="constant", value=0)
|
||||
x = self.conv(x)
|
||||
else:
|
||||
# with_conv=False implies that causality_axis is "none"
|
||||
x = F.avg_pool2d(x, kernel_size=2, stride=2)
|
||||
return x
|
||||
|
||||
|
||||
class LTX2AudioUpsample(nn.Module):
|
||||
def __init__(self, in_channels: int, with_conv: bool, causality_axis: Optional[str] = "height") -> None:
|
||||
super().__init__()
|
||||
@@ -282,6 +317,156 @@ class LTX2AudioAudioPatchifier:
|
||||
return self._patch_size
|
||||
|
||||
|
||||
class LTX2AudioEncoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
base_channels: int = 128,
|
||||
output_channels: int = 1,
|
||||
num_res_blocks: int = 2,
|
||||
attn_resolutions: Optional[Tuple[int, ...]] = None,
|
||||
in_channels: int = 2,
|
||||
resolution: int = 256,
|
||||
latent_channels: int = 8,
|
||||
ch_mult: Tuple[int, ...] = (1, 2, 4),
|
||||
norm_type: str = "group",
|
||||
causality_axis: Optional[str] = "width",
|
||||
dropout: float = 0.0,
|
||||
mid_block_add_attention: bool = False,
|
||||
sample_rate: int = 16000,
|
||||
mel_hop_length: int = 160,
|
||||
is_causal: bool = True,
|
||||
mel_bins: Optional[int] = 64,
|
||||
double_z: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.sample_rate = sample_rate
|
||||
self.mel_hop_length = mel_hop_length
|
||||
self.is_causal = is_causal
|
||||
self.mel_bins = mel_bins
|
||||
|
||||
self.base_channels = base_channels
|
||||
self.temb_ch = 0
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.resolution = resolution
|
||||
self.in_channels = in_channels
|
||||
self.out_ch = output_channels
|
||||
self.give_pre_end = False
|
||||
self.tanh_out = False
|
||||
self.norm_type = norm_type
|
||||
self.latent_channels = latent_channels
|
||||
self.channel_multipliers = ch_mult
|
||||
self.attn_resolutions = attn_resolutions
|
||||
self.causality_axis = causality_axis
|
||||
|
||||
base_block_channels = base_channels
|
||||
base_resolution = resolution
|
||||
self.z_shape = (1, latent_channels, base_resolution, base_resolution)
|
||||
|
||||
if self.causality_axis is not None:
|
||||
self.conv_in = LTX2AudioCausalConv2d(
|
||||
in_channels, base_block_channels, kernel_size=3, stride=1, causality_axis=self.causality_axis
|
||||
)
|
||||
else:
|
||||
self.conv_in = nn.Conv2d(in_channels, base_block_channels, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
self.down = nn.ModuleList()
|
||||
block_in = base_block_channels
|
||||
curr_res = self.resolution
|
||||
|
||||
for level in range(self.num_resolutions):
|
||||
stage = nn.Module()
|
||||
stage.block = nn.ModuleList()
|
||||
stage.attn = nn.ModuleList()
|
||||
block_out = self.base_channels * self.channel_multipliers[level]
|
||||
|
||||
for _ in range(self.num_res_blocks):
|
||||
stage.block.append(
|
||||
LTX2AudioResnetBlock(
|
||||
in_channels=block_in,
|
||||
out_channels=block_out,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
norm_type=self.norm_type,
|
||||
causality_axis=self.causality_axis,
|
||||
)
|
||||
)
|
||||
block_in = block_out
|
||||
if self.attn_resolutions:
|
||||
if curr_res in self.attn_resolutions:
|
||||
stage.attn.append(LTX2AudioAttnBlock(block_in, norm_type=self.norm_type))
|
||||
|
||||
if level != self.num_resolutions - 1:
|
||||
stage.downsample = LTX2AudioDownsample(block_in, True, causality_axis=self.causality_axis)
|
||||
curr_res = curr_res // 2
|
||||
|
||||
self.down.append(stage)
|
||||
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = LTX2AudioResnetBlock(
|
||||
in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
norm_type=self.norm_type,
|
||||
causality_axis=self.causality_axis,
|
||||
)
|
||||
if mid_block_add_attention:
|
||||
self.mid.attn_1 = LTX2AudioAttnBlock(block_in, norm_type=self.norm_type)
|
||||
else:
|
||||
self.mid.attn_1 = nn.Identity()
|
||||
self.mid.block_2 = LTX2AudioResnetBlock(
|
||||
in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
norm_type=self.norm_type,
|
||||
causality_axis=self.causality_axis,
|
||||
)
|
||||
|
||||
final_block_channels = block_in
|
||||
z_channels = 2 * latent_channels if double_z else latent_channels
|
||||
if self.norm_type == "group":
|
||||
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=final_block_channels, eps=1e-6, affine=True)
|
||||
elif self.norm_type == "pixel":
|
||||
self.norm_out = LTX2AudioPixelNorm(dim=1, eps=1e-6)
|
||||
else:
|
||||
raise ValueError(f"Invalid normalization type: {self.norm_type}")
|
||||
self.non_linearity = nn.SiLU()
|
||||
|
||||
if self.causality_axis is not None:
|
||||
self.conv_out = LTX2AudioCausalConv2d(
|
||||
final_block_channels, z_channels, kernel_size=3, stride=1, causality_axis=self.causality_axis
|
||||
)
|
||||
else:
|
||||
self.conv_out = nn.Conv2d(final_block_channels, z_channels, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
# hidden_states expected shape: (batch_size, channels, time, num_mel_bins)
|
||||
hidden_states = self.conv_in(hidden_states)
|
||||
|
||||
for level in range(self.num_resolutions):
|
||||
stage = self.down[level]
|
||||
for block_idx, block in enumerate(stage.block):
|
||||
hidden_states = block(hidden_states, temb=None)
|
||||
if stage.attn:
|
||||
hidden_states = stage.attn[block_idx](hidden_states)
|
||||
|
||||
if level != self.num_resolutions - 1 and hasattr(stage, "downsample"):
|
||||
hidden_states = stage.downsample(hidden_states)
|
||||
|
||||
hidden_states = self.mid.block_1(hidden_states, temb=None)
|
||||
hidden_states = self.mid.attn_1(hidden_states)
|
||||
hidden_states = self.mid.block_2(hidden_states, temb=None)
|
||||
|
||||
hidden_states = self.norm_out(hidden_states)
|
||||
hidden_states = self.non_linearity(hidden_states)
|
||||
hidden_states = self.conv_out(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class LTX2AudioDecoder(nn.Module):
|
||||
"""
|
||||
Symmetric decoder that reconstructs audio spectrograms from latent features.
|
||||
@@ -292,22 +477,22 @@ class LTX2AudioDecoder(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_channels: int,
|
||||
output_channels: int,
|
||||
num_res_blocks: int,
|
||||
attn_resolutions: Set[int],
|
||||
in_channels: int,
|
||||
resolution: int,
|
||||
latent_channels: int,
|
||||
ch_mult: Tuple[int, ...] = (1, 2, 4, 8),
|
||||
base_channels: int = 128,
|
||||
output_channels: int = 1,
|
||||
num_res_blocks: int = 2,
|
||||
attn_resolutions: Optional[Tuple[int, ...]] = None,
|
||||
in_channels: int = 2,
|
||||
resolution: int = 256,
|
||||
latent_channels: int = 8,
|
||||
ch_mult: Tuple[int, ...] = (1, 2, 4),
|
||||
norm_type: str = "group",
|
||||
causality_axis: Optional[str] = "width",
|
||||
dropout: float = 0.0,
|
||||
mid_block_add_attention: bool = True,
|
||||
mid_block_add_attention: bool = False,
|
||||
sample_rate: int = 16000,
|
||||
mel_hop_length: int = 160,
|
||||
is_causal: bool = True,
|
||||
mel_bins: Optional[int] = None,
|
||||
mel_bins: Optional[int] = 64,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@@ -493,9 +678,9 @@ class AutoencoderKLLTX2Audio(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
self,
|
||||
base_channels: int = 128,
|
||||
output_channels: int = 2,
|
||||
ch_mult: Tuple[int] = (1, 2, 4),
|
||||
ch_mult: Tuple[int, ...] = (1, 2, 4),
|
||||
num_res_blocks: int = 2,
|
||||
attn_resolutions: Optional[Tuple[int]] = None,
|
||||
attn_resolutions: Optional[Tuple[int, ...]] = None,
|
||||
in_channels: int = 2,
|
||||
resolution: int = 256,
|
||||
latent_channels: int = 8,
|
||||
@@ -507,6 +692,7 @@ class AutoencoderKLLTX2Audio(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
mel_hop_length: int = 160,
|
||||
is_causal: bool = True,
|
||||
mel_bins: Optional[int] = 64,
|
||||
double_z: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@@ -516,6 +702,26 @@ class AutoencoderKLLTX2Audio(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
|
||||
attn_resolution_set = set(attn_resolutions) if attn_resolutions else attn_resolutions
|
||||
|
||||
self.encoder = LTX2AudioEncoder(
|
||||
base_channels=base_channels,
|
||||
output_channels=output_channels,
|
||||
ch_mult=ch_mult,
|
||||
num_res_blocks=num_res_blocks,
|
||||
attn_resolutions=attn_resolution_set,
|
||||
in_channels=in_channels,
|
||||
resolution=resolution,
|
||||
latent_channels=latent_channels,
|
||||
norm_type=norm_type,
|
||||
causality_axis=causality_axis,
|
||||
dropout=dropout,
|
||||
mid_block_add_attention=mid_block_add_attention,
|
||||
sample_rate=sample_rate,
|
||||
mel_hop_length=mel_hop_length,
|
||||
is_causal=is_causal,
|
||||
mel_bins=mel_bins,
|
||||
double_z=double_z,
|
||||
)
|
||||
|
||||
self.decoder = LTX2AudioDecoder(
|
||||
base_channels=base_channels,
|
||||
output_channels=output_channels,
|
||||
@@ -548,9 +754,21 @@ class AutoencoderKLLTX2Audio(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
self.mel_compression_ratio = LATENT_DOWNSAMPLE_FACTOR
|
||||
self.use_slicing = False
|
||||
|
||||
def _encode(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.encoder(x)
|
||||
|
||||
@apply_forward_hook
|
||||
def encode(self, x: torch.Tensor, return_dict: bool = True):
|
||||
raise NotImplementedError("AutoencoderKLLTX2Audio does not implement encoding.")
|
||||
if self.use_slicing and x.shape[0] > 1:
|
||||
encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
|
||||
h = torch.cat(encoded_slices)
|
||||
else:
|
||||
h = self._encode(x)
|
||||
posterior = DiagonalGaussianDistribution(h)
|
||||
|
||||
if not return_dict:
|
||||
return (posterior,)
|
||||
return AutoencoderKLOutput(latent_dist=posterior)
|
||||
|
||||
def _decode(self, z: torch.Tensor) -> torch.Tensor:
|
||||
return self.decoder(z)
|
||||
@@ -568,7 +786,20 @@ class AutoencoderKLLTX2Audio(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
|
||||
return DecoderOutput(sample=decoded)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
raise NotImplementedError(
|
||||
"This model doesn't have an encoder yet so we don't implement its `forward()`. Please use `decode()`."
|
||||
)
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.Tensor,
|
||||
sample_posterior: bool = False,
|
||||
return_dict: bool = True,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
) -> Union[DecoderOutput, torch.Tensor]:
|
||||
posterior = self.encode(sample).latent_dist
|
||||
if sample_posterior:
|
||||
z = posterior.sample(generator=generator)
|
||||
else:
|
||||
z = posterior.mode()
|
||||
print(f"z shape: {z.shape}")
|
||||
dec = self.decode(z)
|
||||
if not return_dict:
|
||||
return (dec.sample,)
|
||||
return dec
|
||||
|
||||
@@ -0,0 +1,88 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
from diffusers import AutoencoderKLLTX2Audio
|
||||
|
||||
from ...testing_utils import (
|
||||
floats_tensor,
|
||||
torch_device,
|
||||
)
|
||||
from ..test_modeling_common import ModelTesterMixin
|
||||
from .testing_utils import AutoencoderTesterMixin
|
||||
|
||||
|
||||
class AutoencoderKLLTX2AudioTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
|
||||
model_class = AutoencoderKLLTX2Audio
|
||||
main_input_name = "sample"
|
||||
base_precision = 1e-2
|
||||
|
||||
def get_autoencoder_kl_ltx_video_config(self):
|
||||
return {
|
||||
"in_channels": 2, # stereo,
|
||||
"output_channels": 2,
|
||||
"latent_channels": 4,
|
||||
"base_channels": 16,
|
||||
"ch_mult": (1, 2, 4),
|
||||
"resolution": 16,
|
||||
"attn_resolutions": None,
|
||||
"num_res_blocks": 2,
|
||||
"norm_type": "pixel",
|
||||
"causality_axis": "height",
|
||||
"mid_block_add_attention": False,
|
||||
"sample_rate": 16000,
|
||||
"mel_hop_length": 160,
|
||||
"mel_bins": 16,
|
||||
"is_causal": True,
|
||||
"double_z": True,
|
||||
}
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
batch_size = 2
|
||||
num_channels = 2
|
||||
num_frames = 8
|
||||
num_mel_bins = 16
|
||||
|
||||
spectrogram = floats_tensor((batch_size, num_channels, num_frames, num_mel_bins)).to(torch_device)
|
||||
|
||||
input_dict = {"sample": spectrogram}
|
||||
return input_dict
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (2, 5, 16)
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (2, 5, 16)
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = self.get_autoencoder_kl_ltx_video_config()
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
# Overriding as output shape is not the same as input shape for LTX 2.0 audio VAE
|
||||
def test_output(self):
|
||||
super().test_output(expected_output_shape=(2, 2, 5, 16))
|
||||
|
||||
@unittest.skip("Unsupported test.")
|
||||
def test_outputs_equivalence(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("AutoencoderKLLTXAudio does not support `norm_num_groups` because it does not use GroupNorm.")
|
||||
def test_forward_with_norm_groups(self):
|
||||
pass
|
||||
Reference in New Issue
Block a user