1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Merge pull request #2 from huggingface/ltx-2-video-vae

LTX 2.0 Video VAE Implementation
This commit is contained in:
dg845
2025-12-19 16:36:38 -08:00
committed by GitHub
7 changed files with 1979 additions and 3 deletions

View File

@@ -8,8 +8,9 @@ import torch
from accelerate import init_empty_weights
from huggingface_hub import hf_hub_download
from diffusers import LTX2VideoTransformer3DModel
from diffusers import AutoencoderKLLTX2Video, LTX2VideoTransformer3DModel
from diffusers.utils.import_utils import is_accelerate_available
from diffusers.pipelines.ltx2.vocoder import LTX2Vocoder
CTX = init_empty_weights if is_accelerate_available() else nullcontext
@@ -35,6 +36,39 @@ LTX_2_0_TRANSFORMER_KEYS_RENAME_DICT = {
"k_norm": "norm_k",
}
LTX_2_0_VIDEO_VAE_RENAME_DICT = {
# Encoder
"down_blocks.0": "down_blocks.0",
"down_blocks.1": "down_blocks.0.downsamplers.0",
"down_blocks.2": "down_blocks.1",
"down_blocks.3": "down_blocks.1.downsamplers.0",
"down_blocks.4": "down_blocks.2",
"down_blocks.5": "down_blocks.2.downsamplers.0",
"down_blocks.6": "down_blocks.3",
"down_blocks.7": "down_blocks.3.downsamplers.0",
"down_blocks.8": "mid_block",
# Decoder
"up_blocks.0": "mid_block",
"up_blocks.1": "up_blocks.0.upsamplers.0",
"up_blocks.2": "up_blocks.0",
"up_blocks.3": "up_blocks.1.upsamplers.0",
"up_blocks.4": "up_blocks.1",
"up_blocks.5": "up_blocks.2.upsamplers.0",
"up_blocks.6": "up_blocks.2",
# Common
# For all 3D ResNets
"res_blocks": "resnets",
"per_channel_statistics.mean-of-means": "latents_mean",
"per_channel_statistics.std-of-means": "latents_std",
}
LTX_2_0_VOCODER_RENAME_DICT = {
"ups": "upsamplers",
"resblocks": "resnets",
"conv_pre": "conv_in",
"conv_post": "conv_out",
}
def update_state_dict_inplace(state_dict: Dict[str, Any], old_key: str, new_key: str) -> None:
state_dict[new_key] = state_dict.pop(old_key)
@@ -68,6 +102,13 @@ LTX_2_0_TRANSFORMER_SPECIAL_KEYS_REMAP = {
"adaln_single": convert_ltx2_transformer_adaln_single,
}
LTX_2_0_VAE_SPECIAL_KEYS_REMAP = {
"per_channel_statistics.channel": remove_keys_inplace,
"per_channel_statistics.mean-of-stds": remove_keys_inplace,
}
LTX_2_0_VOCODER_SPECIAL_KEYS_REMAP = {}
def get_ltx2_transformer_config(version: str) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
if version == "test":
@@ -180,6 +221,157 @@ def convert_ltx2_transformer(original_state_dict: Dict[str, Any], version: str)
return transformer
def get_ltx2_video_vae_config(version: str) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
if version == "test":
config = {
"model_id": "diffusers-internal-dev/dummy-ltx2",
"diffusers_config": {
"in_channels": 3,
"out_channels": 3,
"latent_channels": 128,
"block_out_channels": (256, 512, 1024, 2048),
"down_block_types": (
"LTX2VideoDownBlock3D",
"LTX2VideoDownBlock3D",
"LTX2VideoDownBlock3D",
"LTX2VideoDownBlock3D",
),
"decoder_block_out_channels": (256, 512, 1024),
"layers_per_block": (4, 6, 6, 2, 2),
"decoder_layers_per_block": (5, 5, 5, 5),
"spatio_temporal_scaling": (True, True, True, True),
"decoder_spatio_temporal_scaling": (True, True, True),
"decoder_inject_noise": (False, False, False, False),
"downsample_type": ("spatial", "temporal", "spatiotemporal", "spatiotemporal"),
"upsample_residual": (True, True, True),
"upsample_factor": (2, 2, 2),
"timestep_conditioning": False,
"patch_size": 4,
"patch_size_t": 1,
"resnet_norm_eps": 1e-6,
"encoder_causal": True,
"decoder_causal": False,
"encoder_spatial_padding_mode": "zeros",
"decoder_spatial_padding_mode": "reflect",
"spatial_compression_ratio": 32,
"temporal_compression_ratio": 8,
},
}
rename_dict = LTX_2_0_VIDEO_VAE_RENAME_DICT
special_keys_remap = LTX_2_0_VAE_SPECIAL_KEYS_REMAP
elif version == "2.0":
config = {
"model_id": "diffusers-internal-dev/dummy-ltx2",
"diffusers_config": {
"in_channels": 3,
"out_channels": 3,
"latent_channels": 128,
"block_out_channels": (256, 512, 1024, 2048),
"down_block_types": (
"LTX2VideoDownBlock3D",
"LTX2VideoDownBlock3D",
"LTX2VideoDownBlock3D",
"LTX2VideoDownBlock3D",
),
"decoder_block_out_channels": (256, 512, 1024),
"layers_per_block": (4, 6, 6, 2, 2),
"decoder_layers_per_block": (5, 5, 5, 5),
"spatio_temporal_scaling": (True, True, True, True),
"decoder_spatio_temporal_scaling": (True, True, True),
"decoder_inject_noise": (False, False, False, False),
"downsample_type": ("spatial", "temporal", "spatiotemporal", "spatiotemporal"),
"upsample_residual": (True, True, True),
"upsample_factor": (2, 2, 2),
"timestep_conditioning": False,
"patch_size": 4,
"patch_size_t": 1,
"resnet_norm_eps": 1e-6,
"encoder_causal": True,
"decoder_causal": False,
"encoder_spatial_padding_mode": "zeros",
"decoder_spatial_padding_mode": "reflect",
"spatial_compression_ratio": 32,
"temporal_compression_ratio": 8,
},
}
rename_dict = LTX_2_0_VIDEO_VAE_RENAME_DICT
special_keys_remap = LTX_2_0_VAE_SPECIAL_KEYS_REMAP
return config, rename_dict, special_keys_remap
def convert_ltx2_video_vae(original_state_dict: Dict[str, Any], version: str) -> Dict[str, Any]:
config, rename_dict, special_keys_remap = get_ltx2_video_vae_config(version)
diffusers_config = config["diffusers_config"]
with init_empty_weights():
vae = AutoencoderKLLTX2Video.from_config(diffusers_config)
# Handle official code --> diffusers key remapping via the remap dict
for key in list(original_state_dict.keys()):
new_key = key[:]
for replace_key, rename_key in rename_dict.items():
new_key = new_key.replace(replace_key, rename_key)
update_state_dict_inplace(original_state_dict, key, new_key)
# Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in
# special_keys_remap
for key in list(original_state_dict.keys()):
for special_key, handler_fn_inplace in special_keys_remap.items():
if special_key not in key:
continue
handler_fn_inplace(key, original_state_dict)
vae.load_state_dict(original_state_dict, strict=True, assign=True)
return vae
def get_ltx2_vocoder_config(version: str) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
if version == "2.0":
config = {
"model_id": "diffusers-internal-dev/new-ltx-model",
"diffusers_config": {
"in_channels": 128,
"hidden_channels": 1024,
"out_channels": 2,
"upsample_kernel_sizes": [16, 15, 8, 4, 4],
"upsample_factors": [6, 5, 2, 2, 2],
"resnet_kernel_sizes": [3, 7, 11],
"resnet_dilations": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
"leaky_relu_negative_slope": 0.1,
"output_sampling_rate": 24000,
}
}
rename_dict = LTX_2_0_VOCODER_RENAME_DICT
special_keys_remap = LTX_2_0_VOCODER_SPECIAL_KEYS_REMAP
return config, rename_dict, special_keys_remap
def convert_ltx2_vocoder(original_state_dict: Dict[str, Any], version: str) -> Dict[str, Any]:
config, rename_dict, special_keys_remap = get_ltx2_vocoder_config(version)
diffusers_config = config["diffusers_config"]
with init_empty_weights():
vocoder = LTX2Vocoder.from_config(diffusers_config)
# Handle official code --> diffusers key remapping via the remap dict
for key in list(original_state_dict.keys()):
new_key = key[:]
for replace_key, rename_key in rename_dict.items():
new_key = new_key.replace(replace_key, rename_key)
update_state_dict_inplace(original_state_dict, key, new_key)
# Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in
# special_keys_remap
for key in list(original_state_dict.keys()):
for special_key, handler_fn_inplace in special_keys_remap.items():
if special_key not in key:
continue
handler_fn_inplace(key, original_state_dict)
vocoder.load_state_dict(original_state_dict, strict=True, assign=True)
return vocoder
def load_original_checkpoint(args, filename: Optional[str]) -> Dict[str, Any]:
if args.original_state_dict_repo_id is not None:
ckpt_path = hf_hub_download(repo_id=args.original_state_dict_repo_id, filename=filename)
@@ -312,7 +504,13 @@ def main(args):
combined_ckpt = load_original_checkpoint(args, filename=args.combined_filename)
if args.vae or args.full_pipeline:
pass
if args.vae_filename is not None:
original_vae_ckpt = load_hub_or_local_checkpoint(filename=args.vae_filename)
elif combined_ckpt is not None:
original_vae_ckpt = get_model_state_dict_from_combined_ckpt(combined_ckpt, args.vae_prefix)
vae = convert_ltx2_video_vae(original_vae_ckpt, version=args.version)
if not args.full_pipeline:
vae.to(vae_dtype).save_pretrained(os.path.join(args.output_path, "vae"))
if args.audio_vae or args.full_pipeline:
pass
@@ -327,7 +525,13 @@ def main(args):
transformer.to(dit_dtype).save_pretrained(os.path.join(args.output_path, "transformer"))
if args.vocoder or args.full_pipeline:
pass
if args.vocoder_filename is not None:
original_vocoder_ckpt = load_hub_or_local_checkpoint(filename=args.vocoder_filename)
elif combined_ckpt is not None:
original_vocoder_ckpt = get_model_state_dict_from_combined_ckpt(combined_ckpt, args.vocoder_prefix)
vocoder = convert_ltx2_vocoder(original_vocoder_ckpt, version=args.version)
if not args.full_pipeline:
vocoder.to(vocoder_dtype).save_pretrained(os.path.join(args.output_path, "vocoder"))
if args.full_pipeline:
pass

View File

@@ -194,6 +194,7 @@ else:
"AutoencoderKLHunyuanVideo",
"AutoencoderKLHunyuanVideo15",
"AutoencoderKLLTXVideo",
"AutoencoderKLLTX2Video",
"AutoencoderKLMagvit",
"AutoencoderKLMochi",
"AutoencoderKLQwenImage",
@@ -928,6 +929,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
AutoencoderKLHunyuanVideo,
AutoencoderKLHunyuanVideo15,
AutoencoderKLLTXVideo,
AutoencoderKLLTX2Video,
AutoencoderKLMagvit,
AutoencoderKLMochi,
AutoencoderKLQwenImage,

View File

@@ -41,6 +41,7 @@ if is_torch_available():
_import_structure["autoencoders.autoencoder_kl_hunyuanimage_refiner"] = ["AutoencoderKLHunyuanImageRefiner"]
_import_structure["autoencoders.autoencoder_kl_hunyuanvideo15"] = ["AutoencoderKLHunyuanVideo15"]
_import_structure["autoencoders.autoencoder_kl_ltx"] = ["AutoencoderKLLTXVideo"]
_import_structure["autoencoders.autoencoder_kl_ltx2"] = ["AutoencoderKLLTX2Video"]
_import_structure["autoencoders.autoencoder_kl_magvit"] = ["AutoencoderKLMagvit"]
_import_structure["autoencoders.autoencoder_kl_mochi"] = ["AutoencoderKLMochi"]
_import_structure["autoencoders.autoencoder_kl_qwenimage"] = ["AutoencoderKLQwenImage"]
@@ -153,6 +154,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
AutoencoderKLHunyuanVideo,
AutoencoderKLHunyuanVideo15,
AutoencoderKLLTXVideo,
AutoencoderKLLTX2Video,
AutoencoderKLMagvit,
AutoencoderKLMochi,
AutoencoderKLQwenImage,

View File

@@ -10,6 +10,7 @@ from .autoencoder_kl_hunyuanimage import AutoencoderKLHunyuanImage
from .autoencoder_kl_hunyuanimage_refiner import AutoencoderKLHunyuanImageRefiner
from .autoencoder_kl_hunyuanvideo15 import AutoencoderKLHunyuanVideo15
from .autoencoder_kl_ltx import AutoencoderKLLTXVideo
from .autoencoder_kl_ltx2 import AutoencoderKLLTX2Video
from .autoencoder_kl_magvit import AutoencoderKLMagvit
from .autoencoder_kl_mochi import AutoencoderKLMochi
from .autoencoder_kl_qwenimage import AutoencoderKLQwenImage

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,173 @@
import math
from typing import List, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...models.modeling_utils import ModelMixin
class ResBlock(nn.Module):
def __init__(
self,
channels: int,
kernel_size: int = 3,
stride: int = 1,
dilations: Tuple[int, ...] = (1, 3, 5),
leaky_relu_negative_slope: float = 0.1,
padding_mode: str = "same",
):
super().__init__()
self.dilations = dilations
self.negative_slope = leaky_relu_negative_slope
self.convs1 = nn.ModuleList(
[
nn.Conv1d(
channels,
channels,
kernel_size,
stride=stride,
dilation=dilation,
padding=padding_mode
)
for dilation in dilations
]
)
self.convs2 = nn.ModuleList(
[
nn.Conv1d(
channels,
channels,
kernel_size,
stride=stride,
dilation=1,
padding=padding_mode
)
for _ in range(len(dilations))
]
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
for conv1, conv2 in zip(self.convs1, self.convs2):
xt = F.leaky_relu(x, negative_slope=self.negative_slope)
xt = conv1(xt)
xt = F.leaky_relu(xt, negative_slope=self.negative_slope)
xt = conv2(xt)
x = x + xt
return x
class LTX2Vocoder(ModelMixin, ConfigMixin):
r"""
LTX 2.0 vocoder for converting generated mel spectrograms back to audio waveforms.
"""
@register_to_config
def __init__(
self,
in_channels: int = 128,
hidden_channels: int = 1024,
out_channels: int = 2,
upsample_kernel_sizes: List[int] = [16, 15, 8, 4, 4],
upsample_factors: List[int] = [6, 5, 2, 2, 2],
resnet_kernel_sizes: List[int] = [3, 7, 11],
resnet_dilations: List[List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
leaky_relu_negative_slope: float = 0.1,
output_sampling_rate: int = 24000,
):
super().__init__()
self.num_upsample_layers = len(upsample_kernel_sizes)
self.resnets_per_upsample = len(resnet_kernel_sizes)
self.out_channels = out_channels
self.total_upsample_factor = math.prod(upsample_factors)
self.negative_slope = leaky_relu_negative_slope
if self.num_upsample_layers != len(upsample_factors):
raise ValueError(
f"`upsample_kernel_sizes` and `upsample_factors` should be lists of the same length but are length"
f" {self.num_upsample_layers} and {len(upsample_factors)}, respectively."
)
if self.resnets_per_upsample != len(resnet_dilations):
raise ValueError(
f"`resnet_kernel_sizes` and `resnet_dilations` should be lists of the same length but are length"
f" {len(self.resnets_per_upsample)} and {len(resnet_dilations)}, respectively."
)
self.conv_in = nn.Conv1d(in_channels, hidden_channels, kernel_size=7, stride=1, padding=3)
self.upsamplers = nn.ModuleList()
self.resnets = nn.ModuleList()
input_channels = hidden_channels
for i, (stride, kernel_size) in enumerate(zip(upsample_factors, upsample_kernel_sizes)):
output_channels = input_channels // 2
self.upsamplers.append(
nn.ConvTranspose1d(
input_channels, # hidden_channels // (2 ** i)
output_channels, # hidden_channels // (2 ** (i + 1))
kernel_size,
stride=stride,
padding=(kernel_size - stride) // 2,
)
)
for kernel_size, dilations in zip(resnet_kernel_sizes, resnet_dilations):
self.resnets.append(
ResBlock(
output_channels,
kernel_size,
dilations=dilations,
leaky_relu_negative_slope=leaky_relu_negative_slope,
)
)
input_channels = output_channels
self.conv_out = nn.Conv1d(output_channels, out_channels, 7, stride=1, padding=3)
def forward(self, hidden_states: torch.Tensor, time_last: bool = False) -> torch.Tensor:
r"""
Forward pass of the vocoder.
Args:
hidden_states (`torch.Tensor`):
Input Mel spectrogram tensor of shape `(batch_size, num_channels, time, num_mel_bins)` if `time_last`
is `False` (the default) or shape `(batch_size, num_channels, num_mel_bins, time)` if `time_last` is
`True`.
time_last (`bool`, *optional*, defaults to `False`):
Whether the last dimension of the input is the time/frame dimension or the Mel bins dimension.
Returns:
`torch.Tensor`:
Audio waveform tensor of shape (batch_size, out_channels, audio_length)
"""
# Ensure that the time/frame dimension is last
if not time_last:
hidden_states = hidden_states.transpose(2, 3)
# Combine channels and frequency (mel bins) dimensions
hidden_states = hidden_states.flatten(1, 2)
hidden_states = self.conv_in(hidden_states)
for i in range(self.num_upsample_layers):
hidden_states = F.leaky_relu(hidden_states, negative_slope=self.negative_slope)
hidden_states = self.upsamplers[i](hidden_states)
# Run all resnets in parallel on hidden_states
start = i * self.resnets_per_upsample
end = (i + 1) * self.resnets_per_upsample
resnet_outputs = torch.stack([self.resnets[j](hidden_states) for j in range(start, end)], dim=0)
hidden_states = torch.mean(resnet_outputs, dim=0)
# NOTE: unlike the first leaky ReLU, this leaky ReLU is set to use the default F.leaky_relu negative slope of
# 0.01 (whereas the others usually use a slope of 0.1). Not sure if this is intended
hidden_states = F.leaky_relu(hidden_states, negative_slope=0.01)
hidden_states = self.conv_out(hidden_states)
hidden_states = torch.tanh(hidden_states)
return hidden_states

View File

@@ -0,0 +1,105 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import torch
from diffusers import AutoencoderKLLTX2Video
from ...testing_utils import (
enable_full_determinism,
floats_tensor,
torch_device,
)
from ..test_modeling_common import ModelTesterMixin
from .testing_utils import AutoencoderTesterMixin
enable_full_determinism()
class AutoencoderKLLTX2VideoTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
model_class = AutoencoderKLLTX2Video
main_input_name = "sample"
base_precision = 1e-2
def get_autoencoder_kl_ltx_video_config(self):
return {
"in_channels": 3,
"out_channels": 3,
"latent_channels": 8,
"block_out_channels": (8, 8, 8, 8),
"decoder_block_out_channels": (16, 32, 64),
"layers_per_block": (1, 1, 1, 1, 1),
"decoder_layers_per_block": (1, 1, 1, 1),
"spatio_temporal_scaling": (True, True, True, True),
"decoder_spatio_temporal_scaling": (True, True, True),
"decoder_inject_noise": (False, False, False, False),
"downsample_type": ("spatial", "temporal", "spatiotemporal", "spatiotemporal"),
"upsample_residual": (True, True, True),
"upsample_factor": (2, 2, 2),
"timestep_conditioning": False,
"patch_size": 1,
"patch_size_t": 1,
"encoder_causal": True,
"decoder_causal": False,
"encoder_spatial_padding_mode": "zeros",
# Full model uses `reflect` but this does not have deterministic backward implementation, so use `zeros`
"decoder_spatial_padding_mode": "zeros",
}
@property
def dummy_input(self):
batch_size = 2
num_frames = 9
num_channels = 3
sizes = (16, 16)
image = floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device)
input_dict = {"sample": image}
return input_dict
@property
def input_shape(self):
return (3, 9, 16, 16)
@property
def output_shape(self):
return (3, 9, 16, 16)
def prepare_init_args_and_inputs_for_common(self):
init_dict = self.get_autoencoder_kl_ltx_video_config()
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def test_gradient_checkpointing_is_applied(self):
expected_set = {
"LTX2VideoEncoder3d",
"LTX2VideoDecoder3d",
"LTX2VideoDownBlock3D",
"LTX2VideoMidBlock3d",
"LTX2VideoUpBlock3d",
}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
@unittest.skip("Unsupported test.")
def test_outputs_equivalence(self):
pass
@unittest.skip("AutoencoderKLLTXVideo does not support `norm_num_groups` because it does not use GroupNorm.")
def test_forward_with_norm_groups(self):
pass