diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_ltx2_audio.py b/src/diffusers/models/autoencoders/autoencoder_kl_ltx2_audio.py index e3c0ef2c3d..90ddf2aa6e 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_ltx2_audio.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_ltx2_audio.py @@ -605,6 +605,10 @@ class AutoencoderKLLTX2Audio(ModelMixin, AutoencoderMixin, ConfigMixin): mel_bins=mel_bins, ) + # TODO: calculate programmatically instead of hardcoding + self.temporal_compression_ratio = LATENT_DOWNSAMPLE_FACTOR # 4 + # TODO: confirm whether the mel compression ratio below is correct + self.mel_compression_ratio = LATENT_DOWNSAMPLE_FACTOR self.use_slicing = False @apply_forward_hook diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2.py index 9373b21401..99160a38be 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2.py @@ -21,7 +21,7 @@ from transformers import Gemma3ForConditionalGeneration, GemmaTokenizer, GemmaTo from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...loaders import FromSingleFileMixin, LTXVideoLoraLoaderMixin -from ...models.autoencoders import AutoencoderKLLTX2Video +from ...models.autoencoders import AutoencoderKLLTX2Audio, AutoencoderKLLTX2Video from ...models.transformers import LTX2VideoTransformer3DModel from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import is_torch_xla_available, logging, replace_example_docstring @@ -201,7 +201,7 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMix self, scheduler: FlowMatchEulerDiscreteScheduler, vae: AutoencoderKLLTX2Video, - audio_vae: AutoencoderKLLTX2Video, + audio_vae: AutoencoderKLLTX2Audio, text_encoder: LTX2AudioVisualTextEncoder, tokenizer: Union[GemmaTokenizer, GemmaTokenizerFast], transformer: LTX2VideoTransformer3DModel, @@ -225,6 +225,13 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMix self.vae_temporal_compression_ratio = ( self.vae.temporal_compression_ratio if getattr(self, "vae", None) is not None else 8 ) + # TODO: check whether the MEL compression ratio logic here is corrct + self.audio_vae_mel_compression_ratio = ( + self.audio_vae.mel_compression_ratio if getattr(self, "audio_vae", None) is not None else 4 + ) + self.audio_vae_temporal_compression_ratio = ( + self.audio_vae.temporal_compression_ratio if getattr(self, "audio_vae", None) is not None else 4 + ) self.transformer_spatial_patch_size = ( self.transformer.config.patch_size if getattr(self, "transformer", None) is not None else 1 ) @@ -232,6 +239,13 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMix self.transformer.config.patch_size_t if getattr(self, "transformer") is not None else 1 ) + self.audio_sampling_rate = ( + self.audio_vae.config.sample_rate if getattr(self, "audio_vae", None) is not None else 16000 + ) + self.audio_hop_length = ( + self.audio_vae.config.mel_hop_length if getattr(self, "audio_vae", None) is not None else 160 + ) + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_spatial_compression_ratio) self.tokenizer_max_length = ( self.tokenizer.model_max_length if getattr(self, "tokenizer", None) is not None else 1024 @@ -487,9 +501,9 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMix if patch_size is not None and patch_size_t is not None: # Packs the latents into a patch sequence of shape [B, L // p_t * M // p, C * p_t * p] (a ndim=3 tnesor). # dim=1 is the effective audio sequence length and dim=2 is the effective audio input feature size. - batch_size, num_channels, latent_length, num_mel_bins = latents.shape + batch_size, num_channels, latent_length, latent_mel_bins = latents.shape post_patch_latent_length = latent_length / patch_size_t - post_patch_mel_bins = num_mel_bins / patch_size + post_patch_mel_bins = latent_mel_bins / patch_size latents = latents.reshape( batch_size, -1, post_patch_latent_length, patch_size_t, post_patch_mel_bins, patch_size ) @@ -556,12 +570,11 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMix self, batch_size: int = 1, num_channels_latents: int = 8, - num_mel_bins: int = 16, + num_mel_bins: int = 64, num_frames: int = 121, frame_rate: float = 25.0, sampling_rate: int = 16000, hop_length: int = 160, - audio_latent_scale_factor: int = 4, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, generator: Optional[torch.Generator] = None, @@ -571,10 +584,13 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMix return latents.to(device=device, dtype=dtype) duration_s = num_frames / frame_rate - latents_per_second = float(sampling_rate) / float(hop_length) / float(audio_latent_scale_factor) + latents_per_second = float(sampling_rate) / float(hop_length) / float(self.audio_vae_temporal_compression_ratio) latent_length = int(duration_s * latents_per_second) - shape = (batch_size, num_channels_latents, latent_length, num_mel_bins) + # TODO: confirm whether this logic is correct + latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio + + shape = (batch_size, num_channels_latents, latent_length, latent_mel_bins) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( @@ -792,6 +808,11 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMix prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) # 4. Prepare latent variables + latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1 + latent_height = height // self.vae_spatial_compression_ratio + latent_width = width // self.vae_spatial_compression_ratio + video_sequence_length = latent_num_frames * latent_height * latent_width + num_channels_latents = self.transformer.config.in_channels latents = self.prepare_latents( batch_size * num_videos_per_prompt, @@ -805,15 +826,20 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMix latents, ) + num_mel_bins = self.audio_vae.config.mel_bins if getattr(self, "audio_vae", None) is not None else 64 + latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio + + num_channels_latents_audio = ( + self.audio_vae.config.latent_channels if getattr(self, "audio_vae", None) is not None else 8 + ) audio_latents, audio_num_frames = self.prepare_audio_latents( batch_size * num_videos_per_prompt, - num_channels_latents=8, # TODO: get from audio VAE - num_mel_bins=16, # TODO: get from audio VAE + num_channels_latents=num_channels_latents_audio, + num_mel_bins=num_mel_bins, num_frames=num_frames, # Video frames, audio frames will be calculated from this frame_rate=frame_rate, - sampling_rate=self.transformer.config.audio_sampling_rate, - hop_length=self.transformer.config.audio_hop_length, - audio_latent_scale_factor=4, # TODO: get from audio VAE + sampling_rate=self.audio_sampling_rate, + hop_length=self.audio_hop_length, dtype=torch.float32, device=device, generator=generator, @@ -821,10 +847,6 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMix ) # 5. Prepare timesteps - latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1 - latent_height = height // self.vae_spatial_compression_ratio - latent_width = width // self.vae_spatial_compression_ratio - video_sequence_length = latent_num_frames * latent_height * latent_width sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) mu = calculate_shift( video_sequence_length, @@ -964,10 +986,11 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMix video = self.vae.decode(latents, timestep, return_dict=False)[0] video = self.video_processor.postprocess_video(video, output_type=output_type) - # TODO: get num_mel_bins from audio VAE or vocoder? - audio_latents = self._unpack_audio_latents(audio_latents, audio_num_frames, num_mel_bins=16) - # TODO: apply audio VAE decoder - audio = self.vocoder(audio_latents) + audio_latents = self._unpack_audio_latents(audio_latents, audio_num_frames, num_mel_bins=latent_mel_bins) + # NOTE: currently, unlike the video VAE, we denormalize the audio latents inside the audio VAE decoder's + # decode method + generated_mel_spectrograms = self.audio_vae.decode(audio_latents, return_dict=False)[0] + waveforms = self.vocoder(generated_mel_spectrograms) # Offload all models self.maybe_free_model_hooks() @@ -975,4 +998,4 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMix if not return_dict: return (video, audio) - return LTX2PipelineOutput(frames=video, audio=audio) + return LTX2PipelineOutput(frames=video, audio=waveforms)