mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Add Audio VAE logic to T2V pipeline
This commit is contained in:
@@ -605,6 +605,10 @@ class AutoencoderKLLTX2Audio(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
mel_bins=mel_bins,
|
||||
)
|
||||
|
||||
# TODO: calculate programmatically instead of hardcoding
|
||||
self.temporal_compression_ratio = LATENT_DOWNSAMPLE_FACTOR # 4
|
||||
# TODO: confirm whether the mel compression ratio below is correct
|
||||
self.mel_compression_ratio = LATENT_DOWNSAMPLE_FACTOR
|
||||
self.use_slicing = False
|
||||
|
||||
@apply_forward_hook
|
||||
|
||||
@@ -21,7 +21,7 @@ from transformers import Gemma3ForConditionalGeneration, GemmaTokenizer, GemmaTo
|
||||
|
||||
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
|
||||
from ...loaders import FromSingleFileMixin, LTXVideoLoraLoaderMixin
|
||||
from ...models.autoencoders import AutoencoderKLLTX2Video
|
||||
from ...models.autoencoders import AutoencoderKLLTX2Audio, AutoencoderKLLTX2Video
|
||||
from ...models.transformers import LTX2VideoTransformer3DModel
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from ...utils import is_torch_xla_available, logging, replace_example_docstring
|
||||
@@ -201,7 +201,7 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMix
|
||||
self,
|
||||
scheduler: FlowMatchEulerDiscreteScheduler,
|
||||
vae: AutoencoderKLLTX2Video,
|
||||
audio_vae: AutoencoderKLLTX2Video,
|
||||
audio_vae: AutoencoderKLLTX2Audio,
|
||||
text_encoder: LTX2AudioVisualTextEncoder,
|
||||
tokenizer: Union[GemmaTokenizer, GemmaTokenizerFast],
|
||||
transformer: LTX2VideoTransformer3DModel,
|
||||
@@ -225,6 +225,13 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMix
|
||||
self.vae_temporal_compression_ratio = (
|
||||
self.vae.temporal_compression_ratio if getattr(self, "vae", None) is not None else 8
|
||||
)
|
||||
# TODO: check whether the MEL compression ratio logic here is corrct
|
||||
self.audio_vae_mel_compression_ratio = (
|
||||
self.audio_vae.mel_compression_ratio if getattr(self, "audio_vae", None) is not None else 4
|
||||
)
|
||||
self.audio_vae_temporal_compression_ratio = (
|
||||
self.audio_vae.temporal_compression_ratio if getattr(self, "audio_vae", None) is not None else 4
|
||||
)
|
||||
self.transformer_spatial_patch_size = (
|
||||
self.transformer.config.patch_size if getattr(self, "transformer", None) is not None else 1
|
||||
)
|
||||
@@ -232,6 +239,13 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMix
|
||||
self.transformer.config.patch_size_t if getattr(self, "transformer") is not None else 1
|
||||
)
|
||||
|
||||
self.audio_sampling_rate = (
|
||||
self.audio_vae.config.sample_rate if getattr(self, "audio_vae", None) is not None else 16000
|
||||
)
|
||||
self.audio_hop_length = (
|
||||
self.audio_vae.config.mel_hop_length if getattr(self, "audio_vae", None) is not None else 160
|
||||
)
|
||||
|
||||
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_spatial_compression_ratio)
|
||||
self.tokenizer_max_length = (
|
||||
self.tokenizer.model_max_length if getattr(self, "tokenizer", None) is not None else 1024
|
||||
@@ -487,9 +501,9 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMix
|
||||
if patch_size is not None and patch_size_t is not None:
|
||||
# Packs the latents into a patch sequence of shape [B, L // p_t * M // p, C * p_t * p] (a ndim=3 tnesor).
|
||||
# dim=1 is the effective audio sequence length and dim=2 is the effective audio input feature size.
|
||||
batch_size, num_channels, latent_length, num_mel_bins = latents.shape
|
||||
batch_size, num_channels, latent_length, latent_mel_bins = latents.shape
|
||||
post_patch_latent_length = latent_length / patch_size_t
|
||||
post_patch_mel_bins = num_mel_bins / patch_size
|
||||
post_patch_mel_bins = latent_mel_bins / patch_size
|
||||
latents = latents.reshape(
|
||||
batch_size, -1, post_patch_latent_length, patch_size_t, post_patch_mel_bins, patch_size
|
||||
)
|
||||
@@ -556,12 +570,11 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMix
|
||||
self,
|
||||
batch_size: int = 1,
|
||||
num_channels_latents: int = 8,
|
||||
num_mel_bins: int = 16,
|
||||
num_mel_bins: int = 64,
|
||||
num_frames: int = 121,
|
||||
frame_rate: float = 25.0,
|
||||
sampling_rate: int = 16000,
|
||||
hop_length: int = 160,
|
||||
audio_latent_scale_factor: int = 4,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
@@ -571,10 +584,13 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMix
|
||||
return latents.to(device=device, dtype=dtype)
|
||||
|
||||
duration_s = num_frames / frame_rate
|
||||
latents_per_second = float(sampling_rate) / float(hop_length) / float(audio_latent_scale_factor)
|
||||
latents_per_second = float(sampling_rate) / float(hop_length) / float(self.audio_vae_temporal_compression_ratio)
|
||||
latent_length = int(duration_s * latents_per_second)
|
||||
|
||||
shape = (batch_size, num_channels_latents, latent_length, num_mel_bins)
|
||||
# TODO: confirm whether this logic is correct
|
||||
latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio
|
||||
|
||||
shape = (batch_size, num_channels_latents, latent_length, latent_mel_bins)
|
||||
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
@@ -792,6 +808,11 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMix
|
||||
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
|
||||
|
||||
# 4. Prepare latent variables
|
||||
latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1
|
||||
latent_height = height // self.vae_spatial_compression_ratio
|
||||
latent_width = width // self.vae_spatial_compression_ratio
|
||||
video_sequence_length = latent_num_frames * latent_height * latent_width
|
||||
|
||||
num_channels_latents = self.transformer.config.in_channels
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_videos_per_prompt,
|
||||
@@ -805,15 +826,20 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMix
|
||||
latents,
|
||||
)
|
||||
|
||||
num_mel_bins = self.audio_vae.config.mel_bins if getattr(self, "audio_vae", None) is not None else 64
|
||||
latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio
|
||||
|
||||
num_channels_latents_audio = (
|
||||
self.audio_vae.config.latent_channels if getattr(self, "audio_vae", None) is not None else 8
|
||||
)
|
||||
audio_latents, audio_num_frames = self.prepare_audio_latents(
|
||||
batch_size * num_videos_per_prompt,
|
||||
num_channels_latents=8, # TODO: get from audio VAE
|
||||
num_mel_bins=16, # TODO: get from audio VAE
|
||||
num_channels_latents=num_channels_latents_audio,
|
||||
num_mel_bins=num_mel_bins,
|
||||
num_frames=num_frames, # Video frames, audio frames will be calculated from this
|
||||
frame_rate=frame_rate,
|
||||
sampling_rate=self.transformer.config.audio_sampling_rate,
|
||||
hop_length=self.transformer.config.audio_hop_length,
|
||||
audio_latent_scale_factor=4, # TODO: get from audio VAE
|
||||
sampling_rate=self.audio_sampling_rate,
|
||||
hop_length=self.audio_hop_length,
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
generator=generator,
|
||||
@@ -821,10 +847,6 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMix
|
||||
)
|
||||
|
||||
# 5. Prepare timesteps
|
||||
latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1
|
||||
latent_height = height // self.vae_spatial_compression_ratio
|
||||
latent_width = width // self.vae_spatial_compression_ratio
|
||||
video_sequence_length = latent_num_frames * latent_height * latent_width
|
||||
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
|
||||
mu = calculate_shift(
|
||||
video_sequence_length,
|
||||
@@ -964,10 +986,11 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMix
|
||||
video = self.vae.decode(latents, timestep, return_dict=False)[0]
|
||||
video = self.video_processor.postprocess_video(video, output_type=output_type)
|
||||
|
||||
# TODO: get num_mel_bins from audio VAE or vocoder?
|
||||
audio_latents = self._unpack_audio_latents(audio_latents, audio_num_frames, num_mel_bins=16)
|
||||
# TODO: apply audio VAE decoder
|
||||
audio = self.vocoder(audio_latents)
|
||||
audio_latents = self._unpack_audio_latents(audio_latents, audio_num_frames, num_mel_bins=latent_mel_bins)
|
||||
# NOTE: currently, unlike the video VAE, we denormalize the audio latents inside the audio VAE decoder's
|
||||
# decode method
|
||||
generated_mel_spectrograms = self.audio_vae.decode(audio_latents, return_dict=False)[0]
|
||||
waveforms = self.vocoder(generated_mel_spectrograms)
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
@@ -975,4 +998,4 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMix
|
||||
if not return_dict:
|
||||
return (video, audio)
|
||||
|
||||
return LTX2PipelineOutput(frames=video, audio=audio)
|
||||
return LTX2PipelineOutput(frames=video, audio=waveforms)
|
||||
|
||||
Reference in New Issue
Block a user