1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Add Audio VAE logic to T2V pipeline

This commit is contained in:
Daniel Gu
2025-12-23 03:51:22 +01:00
parent ae3b6e7cc2
commit 54bfc5d617
2 changed files with 49 additions and 22 deletions

View File

@@ -605,6 +605,10 @@ class AutoencoderKLLTX2Audio(ModelMixin, AutoencoderMixin, ConfigMixin):
mel_bins=mel_bins,
)
# TODO: calculate programmatically instead of hardcoding
self.temporal_compression_ratio = LATENT_DOWNSAMPLE_FACTOR # 4
# TODO: confirm whether the mel compression ratio below is correct
self.mel_compression_ratio = LATENT_DOWNSAMPLE_FACTOR
self.use_slicing = False
@apply_forward_hook

View File

@@ -21,7 +21,7 @@ from transformers import Gemma3ForConditionalGeneration, GemmaTokenizer, GemmaTo
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...loaders import FromSingleFileMixin, LTXVideoLoraLoaderMixin
from ...models.autoencoders import AutoencoderKLLTX2Video
from ...models.autoencoders import AutoencoderKLLTX2Audio, AutoencoderKLLTX2Video
from ...models.transformers import LTX2VideoTransformer3DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import is_torch_xla_available, logging, replace_example_docstring
@@ -201,7 +201,7 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMix
self,
scheduler: FlowMatchEulerDiscreteScheduler,
vae: AutoencoderKLLTX2Video,
audio_vae: AutoencoderKLLTX2Video,
audio_vae: AutoencoderKLLTX2Audio,
text_encoder: LTX2AudioVisualTextEncoder,
tokenizer: Union[GemmaTokenizer, GemmaTokenizerFast],
transformer: LTX2VideoTransformer3DModel,
@@ -225,6 +225,13 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMix
self.vae_temporal_compression_ratio = (
self.vae.temporal_compression_ratio if getattr(self, "vae", None) is not None else 8
)
# TODO: check whether the MEL compression ratio logic here is corrct
self.audio_vae_mel_compression_ratio = (
self.audio_vae.mel_compression_ratio if getattr(self, "audio_vae", None) is not None else 4
)
self.audio_vae_temporal_compression_ratio = (
self.audio_vae.temporal_compression_ratio if getattr(self, "audio_vae", None) is not None else 4
)
self.transformer_spatial_patch_size = (
self.transformer.config.patch_size if getattr(self, "transformer", None) is not None else 1
)
@@ -232,6 +239,13 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMix
self.transformer.config.patch_size_t if getattr(self, "transformer") is not None else 1
)
self.audio_sampling_rate = (
self.audio_vae.config.sample_rate if getattr(self, "audio_vae", None) is not None else 16000
)
self.audio_hop_length = (
self.audio_vae.config.mel_hop_length if getattr(self, "audio_vae", None) is not None else 160
)
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_spatial_compression_ratio)
self.tokenizer_max_length = (
self.tokenizer.model_max_length if getattr(self, "tokenizer", None) is not None else 1024
@@ -487,9 +501,9 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMix
if patch_size is not None and patch_size_t is not None:
# Packs the latents into a patch sequence of shape [B, L // p_t * M // p, C * p_t * p] (a ndim=3 tnesor).
# dim=1 is the effective audio sequence length and dim=2 is the effective audio input feature size.
batch_size, num_channels, latent_length, num_mel_bins = latents.shape
batch_size, num_channels, latent_length, latent_mel_bins = latents.shape
post_patch_latent_length = latent_length / patch_size_t
post_patch_mel_bins = num_mel_bins / patch_size
post_patch_mel_bins = latent_mel_bins / patch_size
latents = latents.reshape(
batch_size, -1, post_patch_latent_length, patch_size_t, post_patch_mel_bins, patch_size
)
@@ -556,12 +570,11 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMix
self,
batch_size: int = 1,
num_channels_latents: int = 8,
num_mel_bins: int = 16,
num_mel_bins: int = 64,
num_frames: int = 121,
frame_rate: float = 25.0,
sampling_rate: int = 16000,
hop_length: int = 160,
audio_latent_scale_factor: int = 4,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
generator: Optional[torch.Generator] = None,
@@ -571,10 +584,13 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMix
return latents.to(device=device, dtype=dtype)
duration_s = num_frames / frame_rate
latents_per_second = float(sampling_rate) / float(hop_length) / float(audio_latent_scale_factor)
latents_per_second = float(sampling_rate) / float(hop_length) / float(self.audio_vae_temporal_compression_ratio)
latent_length = int(duration_s * latents_per_second)
shape = (batch_size, num_channels_latents, latent_length, num_mel_bins)
# TODO: confirm whether this logic is correct
latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio
shape = (batch_size, num_channels_latents, latent_length, latent_mel_bins)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
@@ -792,6 +808,11 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMix
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
# 4. Prepare latent variables
latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1
latent_height = height // self.vae_spatial_compression_ratio
latent_width = width // self.vae_spatial_compression_ratio
video_sequence_length = latent_num_frames * latent_height * latent_width
num_channels_latents = self.transformer.config.in_channels
latents = self.prepare_latents(
batch_size * num_videos_per_prompt,
@@ -805,15 +826,20 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMix
latents,
)
num_mel_bins = self.audio_vae.config.mel_bins if getattr(self, "audio_vae", None) is not None else 64
latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio
num_channels_latents_audio = (
self.audio_vae.config.latent_channels if getattr(self, "audio_vae", None) is not None else 8
)
audio_latents, audio_num_frames = self.prepare_audio_latents(
batch_size * num_videos_per_prompt,
num_channels_latents=8, # TODO: get from audio VAE
num_mel_bins=16, # TODO: get from audio VAE
num_channels_latents=num_channels_latents_audio,
num_mel_bins=num_mel_bins,
num_frames=num_frames, # Video frames, audio frames will be calculated from this
frame_rate=frame_rate,
sampling_rate=self.transformer.config.audio_sampling_rate,
hop_length=self.transformer.config.audio_hop_length,
audio_latent_scale_factor=4, # TODO: get from audio VAE
sampling_rate=self.audio_sampling_rate,
hop_length=self.audio_hop_length,
dtype=torch.float32,
device=device,
generator=generator,
@@ -821,10 +847,6 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMix
)
# 5. Prepare timesteps
latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1
latent_height = height // self.vae_spatial_compression_ratio
latent_width = width // self.vae_spatial_compression_ratio
video_sequence_length = latent_num_frames * latent_height * latent_width
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
mu = calculate_shift(
video_sequence_length,
@@ -964,10 +986,11 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMix
video = self.vae.decode(latents, timestep, return_dict=False)[0]
video = self.video_processor.postprocess_video(video, output_type=output_type)
# TODO: get num_mel_bins from audio VAE or vocoder?
audio_latents = self._unpack_audio_latents(audio_latents, audio_num_frames, num_mel_bins=16)
# TODO: apply audio VAE decoder
audio = self.vocoder(audio_latents)
audio_latents = self._unpack_audio_latents(audio_latents, audio_num_frames, num_mel_bins=latent_mel_bins)
# NOTE: currently, unlike the video VAE, we denormalize the audio latents inside the audio VAE decoder's
# decode method
generated_mel_spectrograms = self.audio_vae.decode(audio_latents, return_dict=False)[0]
waveforms = self.vocoder(generated_mel_spectrograms)
# Offload all models
self.maybe_free_model_hooks()
@@ -975,4 +998,4 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMix
if not return_dict:
return (video, audio)
return LTX2PipelineOutput(frames=video, audio=audio)
return LTX2PipelineOutput(frames=video, audio=waveforms)