From 280e34781457a774652fab1c0400c431a73d6546 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 30 Dec 2025 08:05:56 +0530 Subject: [PATCH 1/2] Refactor Audio VAE to be simpler and remove helpers (#7) * remove resolve causality axes stuff. * remove a bunch of helpers. * remove adjust output shape helper. * remove the use of audiolatentshape. * move normalization and patchify out of pipeline. * fix * up * up * Remove unpatchify and patchify ops before audio latents denormalization (#9) --------- Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> --- scripts/convert_ltx2_to_diffusers.py | 5 +- scripts/test_ltx2_audio_conversion.py | 15 + .../autoencoders/autoencoder_kl_ltx2_audio.py | 362 ++++++++---------- src/diffusers/pipelines/ltx2/pipeline_ltx2.py | 13 +- 4 files changed, 179 insertions(+), 216 deletions(-) diff --git a/scripts/convert_ltx2_to_diffusers.py b/scripts/convert_ltx2_to_diffusers.py index 479a569817..d1384c1dca 100644 --- a/scripts/convert_ltx2_to_diffusers.py +++ b/scripts/convert_ltx2_to_diffusers.py @@ -70,7 +70,10 @@ LTX_2_0_VIDEO_VAE_RENAME_DICT = { "per_channel_statistics.std-of-means": "latents_std", } -LTX_2_0_AUDIO_VAE_RENAME_DICT = {} +LTX_2_0_AUDIO_VAE_RENAME_DICT = { + "per_channel_statistics.mean-of-means": "latents_mean", + "per_channel_statistics.std-of-means": "latents_std", +} LTX_2_0_VOCODER_RENAME_DICT = { "ups": "upsamplers", diff --git a/scripts/test_ltx2_audio_conversion.py b/scripts/test_ltx2_audio_conversion.py index 8d07a6f9b1..a6ba16ed9e 100644 --- a/scripts/test_ltx2_audio_conversion.py +++ b/scripts/test_ltx2_audio_conversion.py @@ -22,6 +22,9 @@ def convert_state_dict(state_dict: dict) -> dict: if new_key.startswith("decoder."): new_key = new_key[len("decoder.") :] converted[f"decoder.{new_key}"] = value + + converted["latents_mean"] = converted.pop("decoder.per_channel_statistics.mean-of-means") + converted["latents_std"] = converted.pop("decoder.per_channel_statistics.std-of-means") return converted @@ -87,9 +90,21 @@ def main() -> None: latent_width, device=device, dtype=dtype, + generator=torch.Generator(device).manual_seed(42) ) original_out = original_decoder(dummy) + + from diffusers.pipelines.ltx2.pipeline_ltx2 import LTX2Pipeline + + _, a_channels, a_time, a_freq = dummy.shape + dummy = dummy.permute(0, 2, 1, 3).reshape(-1, a_time, a_channels * a_freq) + dummy = LTX2Pipeline._denormalize_audio_latents( + dummy, + diffusers_model.latents_mean, + diffusers_model.latents_std, + ) + dummy = dummy.view(-1, a_time, a_channels, a_freq).permute(0, 2, 1, 3) diffusers_out = diffusers_model.decode(dummy).sample torch.testing.assert_close(diffusers_out, original_out, rtol=1e-4, atol=1e-4) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_ltx2_audio.py b/src/diffusers/models/autoencoders/autoencoder_kl_ltx2_audio.py index 90ddf2aa6e..8cdcfa1a74 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_ltx2_audio.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_ltx2_audio.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from collections import namedtuple from typing import Optional, Set, Tuple, Union import torch @@ -27,57 +26,6 @@ from .vae import AutoencoderMixin, DecoderOutput LATENT_DOWNSAMPLE_FACTOR = 4 -SUPPORTED_CAUSAL_AXES = {"none", "width", "height", "width-compatibility"} - - -AudioLatentShape = namedtuple( - "AudioLatentShape", - [ - "batch", - "channels", - "frames", - "mel_bins", - ], -) - - -def _resolve_causality_axis(causality_axis: Optional[str] = None) -> Optional[str]: - normalized = "none" if causality_axis is None else str(causality_axis).lower() - if normalized not in SUPPORTED_CAUSAL_AXES: - raise NotImplementedError( - f"Unsupported causality_axis '{causality_axis}'. Supported: {sorted(SUPPORTED_CAUSAL_AXES)}" - ) - return None if normalized == "none" else normalized - - -def make_conv2d( - in_channels: int, - out_channels: int, - kernel_size: Union[int, Tuple[int, int]], - stride: int = 1, - padding: Optional[Tuple[int, int, int, int]] = None, - dilation: int = 1, - groups: int = 1, - bias: bool = True, - causality_axis: Optional[str] = None, -) -> nn.Module: - if causality_axis is not None: - return LTX2AudioCausalConv2d( - in_channels, out_channels, kernel_size, stride, dilation, groups, bias, causality_axis - ) - if padding is None: - padding = kernel_size // 2 if isinstance(kernel_size, int) else tuple(k // 2 for k in kernel_size) - - return nn.Conv2d( - in_channels, - out_channels, - kernel_size, - stride, - padding, - dilation, - groups, - bias, - ) class LTX2AudioCausalConv2d(nn.Module): @@ -147,14 +95,6 @@ class LTX2AudioPixelNorm(nn.Module): return x / rms -def build_normalization_layer(in_channels: int, *, num_groups: int = 32, normtype: str = "group") -> nn.Module: - if normtype == "group": - return nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) - if normtype == "pixel": - return LTX2AudioPixelNorm(dim=1, eps=1e-6) - raise ValueError(f"Invalid normalization type: {normtype}") - - class LTX2AudioAttnBlock(nn.Module): def __init__( self, @@ -164,7 +104,12 @@ class LTX2AudioAttnBlock(nn.Module): super().__init__() self.in_channels = in_channels - self.norm = build_normalization_layer(in_channels, normtype=norm_type) + if norm_type == "group": + self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + elif norm_type == "pixel": + self.norm = LTX2AudioPixelNorm(dim=1, eps=1e-6) + else: + raise ValueError(f"Invalid normalization type: {norm_type}") self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) @@ -211,23 +156,49 @@ class LTX2AudioResnetBlock(nn.Module): self.out_channels = out_channels self.use_conv_shortcut = conv_shortcut - self.norm1 = build_normalization_layer(in_channels, normtype=norm_type) + if norm_type == "group": + self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + elif norm_type == "pixel": + self.norm1 = LTX2AudioPixelNorm(dim=1, eps=1e-6) + else: + raise ValueError(f"Invalid normalization type: {norm_type}") self.non_linearity = nn.SiLU() - self.conv1 = make_conv2d(in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis) + if causality_axis is not None: + self.conv1 = LTX2AudioCausalConv2d( + in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis + ) + else: + self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) if temb_channels > 0: self.temb_proj = nn.Linear(temb_channels, out_channels) - self.norm2 = build_normalization_layer(out_channels, normtype=norm_type) + if norm_type == "group": + self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True) + elif norm_type == "pixel": + self.norm2 = LTX2AudioPixelNorm(dim=1, eps=1e-6) + else: + raise ValueError(f"Invalid normalization type: {norm_type}") self.dropout = nn.Dropout(dropout) - self.conv2 = make_conv2d(out_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis) + if causality_axis is not None: + self.conv2 = LTX2AudioCausalConv2d( + out_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis + ) + else: + self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) if self.in_channels != self.out_channels: if self.use_conv_shortcut: - self.conv_shortcut = make_conv2d( - in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis - ) + if causality_axis is not None: + self.conv_shortcut = LTX2AudioCausalConv2d( + in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis + ) + else: + self.conv_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) else: - self.nin_shortcut = make_conv2d( - in_channels, out_channels, kernel_size=1, stride=1, causality_axis=causality_axis - ) + if causality_axis is not None: + self.nin_shortcut = LTX2AudioCausalConv2d( + in_channels, out_channels, kernel_size=1, stride=1, causality_axis=causality_axis + ) + else: + self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) def forward(self, x: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor: h = self.norm1(x) @@ -254,7 +225,12 @@ class LTX2AudioUpsample(nn.Module): self.with_conv = with_conv self.causality_axis = causality_axis if self.with_conv: - self.conv = make_conv2d(in_channels, in_channels, kernel_size=3, stride=1, causality_axis=causality_axis) + if causality_axis is not None: + self.conv = LTX2AudioCausalConv2d( + in_channels, in_channels, kernel_size=3, stride=1, causality_axis=causality_axis + ) + else: + self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) def forward(self, x: torch.Tensor) -> torch.Tensor: x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") @@ -273,26 +249,6 @@ class LTX2AudioUpsample(nn.Module): return x - -class LTX2AudioPerChannelStatistics(nn.Module): - """ - Per-channel statistics for normalizing and denormalizing the latent representation. This statics is computed over - the entire dataset and stored in model's checkpoint under AudioVAE state_dict - """ - - def __init__(self, latent_channels: int = 128) -> None: - super().__init__() - # Sayak notes: `empty` always causes problems in CI. Should we consider using `torch.ones`? - self.register_buffer("std-of-means", torch.empty(latent_channels)) - self.register_buffer("mean-of-means", torch.empty(latent_channels)) - - def denormalize(self, x: torch.Tensor) -> torch.Tensor: - return (x * self.get_buffer("std-of-means").to(x)) + self.get_buffer("mean-of-means").to(x) - - def normalize(self, x: torch.Tensor) -> torch.Tensor: - return (x - self.get_buffer("mean-of-means").to(x)) / self.get_buffer("std-of-means").to(x) - - class LTX2AudioAudioPatchifier: """ Patchifier for spectrogram/audio latents. @@ -316,11 +272,9 @@ class LTX2AudioAudioPatchifier: batch, channels, time, freq = audio_latents.shape return audio_latents.permute(0, 2, 1, 3).reshape(batch, time, channels * freq) - def unpatchify(self, audio_latents: torch.Tensor, output_shape: AudioLatentShape) -> torch.Tensor: + def unpatchify(self, audio_latents: torch.Tensor, channels: int, mel_bins: int) -> torch.Tensor: batch, time, _ = audio_latents.shape - channels = output_shape.channels - freq = output_shape.mel_bins - return audio_latents.view(batch, time, channels, freq).permute(0, 2, 1, 3) + return audio_latents.view(batch, time, channels, mel_bins).permute(0, 2, 1, 3) @property def patch_size(self) -> Tuple[int, int, int]: @@ -356,9 +310,6 @@ class LTX2AudioDecoder(nn.Module): ) -> None: super().__init__() - resolved_causality_axis = _resolve_causality_axis(causality_axis) - - self.per_channel_statistics = LTX2AudioPerChannelStatistics(latent_channels=base_channels) self.sample_rate = sample_rate self.mel_hop_length = mel_hop_length self.is_causal = is_causal @@ -384,116 +335,43 @@ class LTX2AudioDecoder(nn.Module): self.latent_channels = latent_channels self.channel_multipliers = ch_mult self.attn_resolutions = attn_resolutions - self.causality_axis = resolved_causality_axis + self.causality_axis = causality_axis base_block_channels = base_channels * self.channel_multipliers[-1] base_resolution = resolution // (2 ** (self.num_resolutions - 1)) self.z_shape = (1, latent_channels, base_resolution, base_resolution) - self.conv_in = make_conv2d( - latent_channels, base_block_channels, kernel_size=3, stride=1, causality_axis=self.causality_axis - ) - self.non_linearity = nn.SiLU() - self.mid = self._build_mid_layers(base_block_channels, dropout, mid_block_add_attention) - self.up, final_block_channels = self._build_up_path( - initial_block_channels=base_block_channels, - dropout=dropout, - resamp_with_conv=True, - ) - - self.norm_out = build_normalization_layer(final_block_channels, normtype=self.norm_type) - self.conv_out = make_conv2d( - final_block_channels, output_channels, kernel_size=3, stride=1, causality_axis=self.causality_axis - ) - - def _adjust_output_shape(self, decoded_output: torch.Tensor, target_shape: AudioLatentShape) -> torch.Tensor: - _, _, current_time, current_freq = decoded_output.shape - target_channels = target_shape.channels - target_time = target_shape.frames - target_freq = target_shape.mel_bins - - decoded_output = decoded_output[ - :, :target_channels, : min(current_time, target_time), : min(current_freq, target_freq) - ] - - time_padding_needed = target_time - decoded_output.shape[2] - freq_padding_needed = target_freq - decoded_output.shape[3] - - if time_padding_needed > 0 or freq_padding_needed > 0: - padding = ( - 0, - max(freq_padding_needed, 0), - 0, - max(time_padding_needed, 0), - ) - decoded_output = F.pad(decoded_output, padding) - - decoded_output = decoded_output[:, :target_channels, :target_time, :target_freq] - - return decoded_output - - def forward( - self, - sample: torch.Tensor, - ) -> torch.Tensor: - latent_shape = AudioLatentShape( - batch=sample.shape[0], - channels=sample.shape[1], - frames=sample.shape[2], - mel_bins=sample.shape[3], - ) - - sample_patched = self.patchifier.patchify(sample) - sample_denormalized = self.per_channel_statistics.denormalize(sample_patched) - sample = self.patchifier.unpatchify(sample_denormalized, latent_shape) - - target_frames = latent_shape.frames * LATENT_DOWNSAMPLE_FACTOR - if self.causality_axis is not None: - target_frames = max(target_frames - (LATENT_DOWNSAMPLE_FACTOR - 1), 1) - - target_shape = AudioLatentShape( - batch=latent_shape.batch, - channels=self.out_ch, - frames=target_frames, - mel_bins=self.mel_bins if self.mel_bins is not None else latent_shape.mel_bins, - ) - - hidden_features = self.conv_in(sample) - hidden_features = self._run_mid_layers(hidden_features) - hidden_features = self._run_upsampling_path(hidden_features) - decoded_output = self._finalize_output(hidden_features) - - decoded_output = self._adjust_output_shape(decoded_output, target_shape) - - return decoded_output - - def _build_mid_layers(self, channels: int, dropout: float, add_attention: bool) -> nn.Module: - mid = nn.Module() - mid.block_1 = LTX2AudioResnetBlock( - in_channels=channels, - out_channels=channels, + self.conv_in = LTX2AudioCausalConv2d( + latent_channels, base_block_channels, kernel_size=3, stride=1, causality_axis=self.causality_axis + ) + else: + self.conv_in = nn.Conv2d(latent_channels, base_block_channels, kernel_size=3, stride=1, padding=1) + self.non_linearity = nn.SiLU() + self.mid = nn.Module() + self.mid.block_1 = LTX2AudioResnetBlock( + in_channels=base_block_channels, + out_channels=base_block_channels, temb_channels=self.temb_ch, dropout=dropout, norm_type=self.norm_type, causality_axis=self.causality_axis, ) - mid.attn_1 = LTX2AudioAttnBlock(channels, norm_type=self.norm_type) if add_attention else nn.Identity() - mid.block_2 = LTX2AudioResnetBlock( - in_channels=channels, - out_channels=channels, + if mid_block_add_attention: + self.mid.attn_1 = LTX2AudioAttnBlock(base_block_channels, norm_type=self.norm_type) + else: + self.mid.attn_1 = nn.Identity() + self.mid.block_2 = LTX2AudioResnetBlock( + in_channels=base_block_channels, + out_channels=base_block_channels, temb_channels=self.temb_ch, dropout=dropout, norm_type=self.norm_type, causality_axis=self.causality_axis, ) - return mid - def _build_up_path( - self, initial_block_channels: int, dropout: float, resamp_with_conv: bool - ) -> tuple[nn.ModuleList, int]: - up_modules = nn.ModuleList() - block_in = initial_block_channels + self.up = nn.ModuleList() + block_in = base_block_channels curr_res = self.resolution // (2 ** (self.num_resolutions - 1)) for level in reversed(range(self.num_resolutions)): @@ -519,39 +397,89 @@ class LTX2AudioDecoder(nn.Module): stage.attn.append(LTX2AudioAttnBlock(block_in, norm_type=self.norm_type)) if level != 0: - stage.upsample = LTX2AudioUpsample(block_in, resamp_with_conv, causality_axis=self.causality_axis) + stage.upsample = LTX2AudioUpsample(block_in, True, causality_axis=self.causality_axis) curr_res *= 2 - up_modules.insert(0, stage) + self.up.insert(0, stage) - return up_modules, block_in + final_block_channels = block_in - def _run_mid_layers(self, features: torch.Tensor) -> torch.Tensor: - features = self.mid.block_1(features, temb=None) - features = self.mid.attn_1(features) - return self.mid.block_2(features, temb=None) + if self.norm_type == "group": + self.norm_out = nn.GroupNorm( + num_groups=32, num_channels=final_block_channels, eps=1e-6, affine=True + ) + elif self.norm_type == "pixel": + self.norm_out = LTX2AudioPixelNorm(dim=1, eps=1e-6) + else: + raise ValueError(f"Invalid normalization type: {self.norm_type}") + + if self.causality_axis is not None: + self.conv_out = LTX2AudioCausalConv2d( + final_block_channels, output_channels, kernel_size=3, stride=1, causality_axis=self.causality_axis + ) + else: + self.conv_out = nn.Conv2d(final_block_channels, output_channels, kernel_size=3, stride=1, padding=1) + + def forward( + self, + sample: torch.Tensor, + ) -> torch.Tensor: + _, _, frames, mel_bins = sample.shape + + target_frames = frames * LATENT_DOWNSAMPLE_FACTOR + + if self.causality_axis is not None: + target_frames = max(target_frames - (LATENT_DOWNSAMPLE_FACTOR - 1), 1) + + target_channels = self.out_ch + target_mel_bins = self.mel_bins if self.mel_bins is not None else mel_bins + + hidden_features = self.conv_in(sample) + hidden_features = self.mid.block_1(hidden_features, temb=None) + hidden_features = self.mid.attn_1(hidden_features) + hidden_features = self.mid.block_2(hidden_features, temb=None) - def _run_upsampling_path(self, features: torch.Tensor) -> torch.Tensor: for level in reversed(range(self.num_resolutions)): stage = self.up[level] for block_idx, block in enumerate(stage.block): - features = block(features, temb=None) + hidden_features = block(hidden_features, temb=None) if stage.attn: - features = stage.attn[block_idx](features) + hidden_features = stage.attn[block_idx](hidden_features) if level != 0 and hasattr(stage, "upsample"): - features = stage.upsample(features) + hidden_features = stage.upsample(hidden_features) - return features - - def _finalize_output(self, features: torch.Tensor) -> torch.Tensor: if self.give_pre_end: - return features + return hidden_features - hidden = self.norm_out(features) + hidden = self.norm_out(hidden_features) hidden = self.non_linearity(hidden) - decoded = self.conv_out(hidden) - return torch.tanh(decoded) if self.tanh_out else decoded + decoded_output = self.conv_out(hidden) + decoded_output = torch.tanh(decoded_output) if self.tanh_out else decoded_output + + _, _, current_time, current_freq = decoded_output.shape + target_time = target_frames + target_freq = target_mel_bins + + decoded_output = decoded_output[ + :, :target_channels, : min(current_time, target_time), : min(current_freq, target_freq) + ] + + time_padding_needed = target_time - decoded_output.shape[2] + freq_padding_needed = target_freq - decoded_output.shape[3] + + if time_padding_needed > 0 or freq_padding_needed > 0: + padding = ( + 0, + max(freq_padding_needed, 0), + 0, + max(time_padding_needed, 0), + ) + decoded_output = F.pad(decoded_output, padding) + + decoded_output = decoded_output[:, :target_channels, :target_time, :target_freq] + + return decoded_output class AutoencoderKLLTX2Audio(ModelMixin, AutoencoderMixin, ConfigMixin): @@ -583,7 +511,10 @@ class AutoencoderKLLTX2Audio(ModelMixin, AutoencoderMixin, ConfigMixin): ) -> None: super().__init__() - resolved_causality_axis = _resolve_causality_axis(causality_axis) + supported_causality_axes = {"none", "width", "height", "width-compatibility"} + if causality_axis not in supported_causality_axes: + raise ValueError(f"{causality_axis=} is not valid. Supported values: {supported_causality_axes}") + attn_resolution_set = set(attn_resolutions) if attn_resolutions else attn_resolutions self.decoder = LTX2AudioDecoder( @@ -596,7 +527,7 @@ class AutoencoderKLLTX2Audio(ModelMixin, AutoencoderMixin, ConfigMixin): resolution=resolution, latent_channels=latent_channels, norm_type=norm_type, - causality_axis=resolved_causality_axis, + causality_axis=causality_axis, dropout=dropout, mid_block_add_attention=mid_block_add_attention, sample_rate=sample_rate, @@ -605,6 +536,13 @@ class AutoencoderKLLTX2Audio(ModelMixin, AutoencoderMixin, ConfigMixin): mel_bins=mel_bins, ) + # Per-channel statistics for normalizing and denormalizing the latent representation. This statics is computed over + # the entire dataset and stored in model's checkpoint under AudioVAE state_dict + latents_std = torch.zeros((base_channels, )) + latents_mean = torch.ones((base_channels, )) + self.register_buffer("latents_mean", latents_mean, persistent=True) + self.register_buffer("latents_std", latents_std, persistent=True) + # TODO: calculate programmatically instead of hardcoding self.temporal_compression_ratio = LATENT_DOWNSAMPLE_FACTOR # 4 # TODO: confirm whether the mel compression ratio below is correct diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2.py index a4ee5cb150..fe31d02ec0 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2.py @@ -516,6 +516,12 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMix latents = latents * latents_std / scaling_factor + latents_mean return latents + @staticmethod + def _denormalize_audio_latents(latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor): + latents_mean = latents_mean.to(latents.device, latents.dtype) + latents_std = latents_std.to(latents.device, latents.dtype) + return (latents * latents_std) + latents_mean + @staticmethod def _pack_audio_latents( latents: torch.Tensor, patch_size: Optional[int] = None, patch_size_t: Optional[int] = None @@ -1038,10 +1044,11 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMix video = self.vae.decode(latents, timestep, return_dict=False)[0] video = self.video_processor.postprocess_video(video, output_type=output_type) - audio_latents = self._unpack_audio_latents(audio_latents, audio_num_frames, num_mel_bins=latent_mel_bins) audio_latents = audio_latents.to(self.audio_vae.dtype) - # NOTE: currently, unlike the video VAE, we denormalize the audio latents inside the audio VAE decoder's - # decode method + audio_latents = self._denormalize_audio_latents( + audio_latents, self.audio_vae.latents_mean, self.audio_vae.latents_std + ) + audio_latents = self._unpack_audio_latents(audio_latents, audio_num_frames, num_mel_bins=latent_mel_bins) generated_mel_spectrograms = self.audio_vae.decode(audio_latents, return_dict=False)[0] audio = self.vocoder(generated_mel_spectrograms) From 46822c43dbe9dde816f467c0ca4aa6fb126f5998 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 30 Dec 2025 09:06:07 +0530 Subject: [PATCH 2/2] Add support for I2V (#8) * start i2v. * up * up * up * up * up * remove uniform strategy code. * remove unneeded code. --- src/diffusers/__init__.py | 2 + .../models/transformers/transformer_ltx2.py | 13 +- src/diffusers/pipelines/__init__.py | 4 +- src/diffusers/pipelines/ltx2/__init__.py | 2 + src/diffusers/pipelines/ltx2/pipeline_ltx2.py | 10 - .../ltx2/pipeline_ltx2_image2video.py | 1138 +++++++++++++++++ 6 files changed, 1152 insertions(+), 17 deletions(-) create mode 100644 src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index ea429c2e41..2e99ea8063 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -538,6 +538,7 @@ else: "LTXLatentUpsamplePipeline", "LTXPipeline", "LTX2Pipeline", + "LTX2ImageToVideoPipeline", "LucyEditPipeline", "Lumina2Pipeline", "Lumina2Text2ImgPipeline", @@ -1245,6 +1246,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: LTXLatentUpsamplePipeline, LTXPipeline, LTX2Pipeline, + LTX2ImageToVideoPipeline, LucyEditPipeline, Lumina2Pipeline, Lumina2Text2ImgPipeline, diff --git a/src/diffusers/models/transformers/transformer_ltx2.py b/src/diffusers/models/transformers/transformer_ltx2.py index 3d2d079608..1f685fdc3a 100644 --- a/src/diffusers/models/transformers/transformer_ltx2.py +++ b/src/diffusers/models/transformers/transformer_ltx2.py @@ -1051,6 +1051,7 @@ class LTX2VideoTransformer3DModel( encoder_hidden_states: torch.Tensor, audio_encoder_hidden_states: torch.Tensor, timestep: torch.LongTensor, + audio_timestep: Optional[torch.LongTensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, audio_encoder_attention_mask: Optional[torch.Tensor] = None, num_frames: Optional[int] = None, @@ -1073,8 +1074,7 @@ class LTX2VideoTransformer3DModel( Input patchified audio latents of shape (batch_size, num_audio_tokens, audio_in_channels). encoder_hidden_states (`torch.Tensor`): Input text embeddings of shape TODO. - timesteps (`torch.Tensor`): - Timestep information of shape (batch_size, num_train_timesteps). + TODO for the rest. Returns: `AudioVisualModelOutput` or `tuple`: @@ -1097,6 +1097,9 @@ class LTX2VideoTransformer3DModel( "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." ) + # Determine timestep for audio. + audio_timestep = audio_timestep if audio_timestep is not None else timestep + # convert encoder_attention_mask to a bias the same way we do for attention_mask if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0 @@ -1143,7 +1146,7 @@ class LTX2VideoTransformer3DModel( embedded_timestep = embedded_timestep.view(batch_size, -1, embedded_timestep.size(-1)) temb_audio, audio_embedded_timestep = self.audio_time_embed( - timestep.flatten(), + audio_timestep.flatten(), batch_size=batch_size, hidden_dtype=audio_hidden_states.dtype, ) @@ -1165,12 +1168,12 @@ class LTX2VideoTransformer3DModel( video_cross_attn_a2v_gate = video_cross_attn_a2v_gate.view(batch_size, -1, video_cross_attn_a2v_gate.shape[-1]) audio_cross_attn_scale_shift, _ = self.av_cross_attn_audio_scale_shift( - timestep.flatten(), + audio_timestep.flatten(), batch_size=batch_size, hidden_dtype=audio_hidden_states.dtype, ) audio_cross_attn_v2a_gate, _ = self.av_cross_attn_audio_v2a_gate( - timestep.flatten() * timestep_cross_attn_gate_scale_factor, + audio_timestep.flatten() * timestep_cross_attn_gate_scale_factor, batch_size=batch_size, hidden_dtype=audio_hidden_states.dtype, ) diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index ef9430043b..eaf444d5ec 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -288,7 +288,7 @@ else: "LTXConditionPipeline", "LTXLatentUpsamplePipeline", ] - _import_structure["ltx2"] = ["LTX2Pipeline"] + _import_structure["ltx2"] = ["LTX2Pipeline", "LTX2ImageToVideoPipeline"] _import_structure["lumina"] = ["LuminaPipeline", "LuminaText2ImgPipeline"] _import_structure["lumina2"] = ["Lumina2Pipeline", "Lumina2Text2ImgPipeline"] _import_structure["lucy"] = ["LucyEditPipeline"] @@ -720,7 +720,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: LEditsPPPipelineStableDiffusionXL, ) from .ltx import LTXConditionPipeline, LTXImageToVideoPipeline, LTXLatentUpsamplePipeline, LTXPipeline - from .ltx2 import LTX2Pipeline + from .ltx2 import LTX2Pipeline, LTX2ImageToVideoPipeline from .lucy import LucyEditPipeline from .lumina import LuminaPipeline, LuminaText2ImgPipeline from .lumina2 import Lumina2Pipeline, Lumina2Text2ImgPipeline diff --git a/src/diffusers/pipelines/ltx2/__init__.py b/src/diffusers/pipelines/ltx2/__init__.py index d23123089f..a97c836e0c 100644 --- a/src/diffusers/pipelines/ltx2/__init__.py +++ b/src/diffusers/pipelines/ltx2/__init__.py @@ -23,6 +23,7 @@ except OptionalDependencyNotAvailable: _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: _import_structure["pipeline_ltx2"] = ["LTX2Pipeline"] + _import_structure["pipeline_ltx2_image2video"] = ["LTX2ImageToVideoPipeline"] _import_structure["text_encoder"] = ["LTX2AudioVisualTextEncoder"] _import_structure["vocoder"] = ["LTX2Vocoder"] @@ -35,6 +36,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: from ...utils.dummy_torch_and_transformers_objects import * else: from .pipeline_ltx2 import LTX2Pipeline + from .pipeline_ltx2_image2video import LTX2ImageToVideoPipeline from .text_encoder import LTX2AudioVisualTextEncoder from .vocoder import LTX2Vocoder diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2.py index fe31d02ec0..2617e5cacb 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2.py @@ -496,16 +496,6 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMix latents = latents.permute(0, 4, 1, 5, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(2, 3) return latents - @staticmethod - def _normalize_latents( - latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0 - ) -> torch.Tensor: - # Normalize latents across the channel dimension [B, C, F, H, W] - latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) - latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) - latents = (latents - latents_mean) * scaling_factor / latents_std - return latents - @staticmethod def _denormalize_latents( latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0 diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py new file mode 100644 index 0000000000..9f0755bb31 --- /dev/null +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py @@ -0,0 +1,1138 @@ +# Copyright 2025 Lightricks and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +from typing import Any, Callable, Dict, List, Optional, Union +import inspect +import numpy as np +import torch + +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import PipelineImageInput +from ...utils import is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from .pipeline_output import LTX2PipelineOutput +from ..pipeline_utils import DiffusionPipeline +from .text_encoder import LTX2AudioVisualTextEncoder +from .vocoder import LTX2Vocoder +from ...loaders import FromSingleFileMixin, LTXVideoLoraLoaderMixin +from ...models.autoencoders import AutoencoderKLLTX2Audio, AutoencoderKLLTX2Video +from ...models.transformers import LTX2VideoTransformer3DModel +from transformers import GemmaTokenizer, GemmaTokenizerFast +from ...video_processor import VideoProcessor + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import LTX2ImageToVideoPipeline + >>> from diffusers.utils import export_to_video, load_image + + >>> pipe = LTX2ImageToVideoPipeline.from_pretrained("Lightricks/LTX-Video-2", torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + + >>> image = load_image( + ... "https://huggingface.co/datasets/a-r-r-o-w/tiny-meme-dataset-captioned/resolve/main/images/8.png" + ... ) + >>> prompt = "A young girl stands calmly in the foreground, looking directly at the camera, as a house fire rages in the background." + >>> negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted" + + >>> video = pipe( + ... image=image, + ... prompt=prompt, + ... negative_prompt=negative_prompt, + ... width=704, + ... height=480, + ... num_frames=121, + ... num_inference_steps=40, + ... ).frames[0] + >>> export_to_video(video, "output.mp4", fps=24) + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + +# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + r""" + Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on + Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). + + Args: + noise_cfg (`torch.Tensor`): + The predicted noise tensor for the guided diffusion process. + noise_pred_text (`torch.Tensor`): + The predicted noise tensor for the text-guided diffusion process. + guidance_rescale (`float`, *optional*, defaults to 0.0): + A rescale factor applied to the noise predictions. + + Returns: + noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor. + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixin): + r""" + Pipeline for image-to-video generation. + + Reference: https://github.com/Lightricks/LTX-Video + + TODO + """ + + model_cpu_offload_seq = "text_encoder->transformer->vae->audio_vae->vocoder" + _optional_components = [] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKLLTX2Video, + audio_vae: AutoencoderKLLTX2Audio, + text_encoder: LTX2AudioVisualTextEncoder, + tokenizer: Union[GemmaTokenizer, GemmaTokenizerFast], + transformer: LTX2VideoTransformer3DModel, + vocoder: LTX2Vocoder, + ): + super().__init__() + + self.register_modules( + vae=vae, + audio_vae=audio_vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + vocoder=vocoder, + scheduler=scheduler, + ) + + self.vae_spatial_compression_ratio = ( + self.vae.spatial_compression_ratio if getattr(self, "vae", None) is not None else 32 + ) + self.vae_temporal_compression_ratio = ( + self.vae.temporal_compression_ratio if getattr(self, "vae", None) is not None else 8 + ) + # TODO: check whether the MEL compression ratio logic here is corrct + self.audio_vae_mel_compression_ratio = ( + self.audio_vae.mel_compression_ratio if getattr(self, "audio_vae", None) is not None else 4 + ) + self.audio_vae_temporal_compression_ratio = ( + self.audio_vae.temporal_compression_ratio if getattr(self, "audio_vae", None) is not None else 4 + ) + self.transformer_spatial_patch_size = ( + self.transformer.config.patch_size if getattr(self, "transformer", None) is not None else 1 + ) + self.transformer_temporal_patch_size = ( + self.transformer.config.patch_size_t if getattr(self, "transformer") is not None else 1 + ) + + self.audio_sampling_rate = ( + self.audio_vae.config.sample_rate if getattr(self, "audio_vae", None) is not None else 16000 + ) + self.audio_hop_length = ( + self.audio_vae.config.mel_hop_length if getattr(self, "audio_vae", None) is not None else 160 + ) + + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_spatial_compression_ratio, resample="bilinear") + self.tokenizer_max_length = ( + self.tokenizer.model_max_length if getattr(self, "tokenizer", None) is not None else 1024 + ) + + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._get_gemma_prompt_embeds + def _get_gemma_prompt_embeds( + self, + prompt: Union[str, List[str]], + num_videos_per_prompt: int = 1, + max_sequence_length: int = 1024, + scale_factor: int = 8, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`str` or `torch.device`): + torch device to place the resulting embeddings on + dtype: (`torch.dtype`): + torch dtype to cast the prompt embeds to + max_sequence_length (`int`, defaults to 1024): Maximum sequence length to use for the prompt. + """ + device = device or self._execution_device + dtype = dtype or self.text_encoder.base_text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if getattr(self, "tokenizer", None) is not None: + # Gemma expects left padding for chat-style prompts + self.tokenizer.padding_side = "left" + if self.tokenizer.pad_token is None: + self.tokenizer.pad_token = self.tokenizer.eos_token + + prompt = [p.strip() for p in prompt] + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + prompt_attention_mask = text_inputs.attention_mask + + prompt_embeds, audio_prompt_embeds, prompt_attention_mask = self.text_encoder( + text_input_ids.to(device), + attention_mask=prompt_attention_mask.to(device), + padding_side=self.tokenizer.padding_side, + scale_factor=scale_factor, + ) + prompt_embeds = prompt_embeds.to(dtype=dtype) + audio_prompt_embeds = audio_prompt_embeds.to(dtype=dtype) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + _, audio_seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, audio_seq_len, -1) + + prompt_attention_mask = prompt_attention_mask.view(batch_size, -1) + prompt_attention_mask = prompt_attention_mask.repeat(num_videos_per_prompt, 1) + + return prompt_embeds, audio_prompt_embeds, prompt_attention_mask + + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline.encode_prompt + def encode_prompt( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + audio_prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + negative_audio_prompt_embeds: Optional[torch.Tensor] = None, + prompt_attention_mask: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, + max_sequence_length: int = 1024, + scale_factor: int = 8, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds, audio_prompt_embeds, prompt_attention_mask = self._get_gemma_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + scale_factor=scale_factor, + device=device, + dtype=dtype, + ) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds, negative_audio_prompt_embeds, negative_prompt_attention_mask = self._get_gemma_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + scale_factor=scale_factor, + device=device, + dtype=dtype, + ) + + return prompt_embeds, audio_prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_audio_prompt_embeds, negative_prompt_attention_mask + + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline.check_inputs + def check_inputs( + self, + prompt, + height, + width, + callback_on_step_end_tensor_inputs=None, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_attention_mask=None, + negative_prompt_attention_mask=None, + ): + if height % 32 != 0 or width % 32 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt_embeds is not None and prompt_attention_mask is None: + raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.") + + if negative_prompt_embeds is not None and negative_prompt_attention_mask is None: + raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.") + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + if prompt_attention_mask.shape != negative_prompt_attention_mask.shape: + raise ValueError( + "`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but" + f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`" + f" {negative_prompt_attention_mask.shape}." + ) + + @staticmethod + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._pack_latents + def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1) -> torch.Tensor: + # Unpacked latents of shape are [B, C, F, H, W] are patched into tokens of shape [B, C, F // p_t, p_t, H // p, p, W // p, p]. + # The patch dimensions are then permuted and collapsed into the channel dimension of shape: + # [B, F // p_t * H // p * W // p, C * p_t * p * p] (an ndim=3 tensor). + # dim=0 is the batch size, dim=1 is the effective video sequence length, dim=2 is the effective number of input features + batch_size, num_channels, num_frames, height, width = latents.shape + post_patch_num_frames = num_frames // patch_size_t + post_patch_height = height // patch_size + post_patch_width = width // patch_size + latents = latents.reshape( + batch_size, + -1, + post_patch_num_frames, + patch_size_t, + post_patch_height, + patch_size, + post_patch_width, + patch_size, + ) + latents = latents.permute(0, 2, 4, 6, 1, 3, 5, 7).flatten(4, 7).flatten(1, 3) + return latents + + @staticmethod + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._unpack_latents + def _unpack_latents( + latents: torch.Tensor, num_frames: int, height: int, width: int, patch_size: int = 1, patch_size_t: int = 1 + ) -> torch.Tensor: + # Packed latents of shape [B, S, D] (S is the effective video sequence length, D is the effective feature dimensions) + # are unpacked and reshaped into a video tensor of shape [B, C, F, H, W]. This is the inverse operation of + # what happens in the `_pack_latents` method. + batch_size = latents.size(0) + latents = latents.reshape(batch_size, num_frames, height, width, -1, patch_size_t, patch_size, patch_size) + latents = latents.permute(0, 4, 1, 5, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(2, 3) + return latents + + @staticmethod + def _normalize_latents( + latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0 + ) -> torch.Tensor: + # Normalize latents across the channel dimension [B, C, F, H, W] + latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents = (latents - latents_mean) * scaling_factor / latents_std + return latents + + @staticmethod + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._denormalize_latents + def _denormalize_latents( + latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0 + ) -> torch.Tensor: + # Denormalize latents across the channel dimension [B, C, F, H, W] + latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents = latents * latents_std / scaling_factor + latents_mean + return latents + + @staticmethod + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._pack_audio_latents + def _pack_audio_latents( + latents: torch.Tensor, patch_size: Optional[int] = None, patch_size_t: Optional[int] = None + ) -> torch.Tensor: + # Audio latents shape: [B, C, L, M], where L is the latent audio length and M is the number of mel bins + if patch_size is not None and patch_size_t is not None: + # Packs the latents into a patch sequence of shape [B, L // p_t * M // p, C * p_t * p] (a ndim=3 tnesor). + # dim=1 is the effective audio sequence length and dim=2 is the effective audio input feature size. + batch_size, num_channels, latent_length, latent_mel_bins = latents.shape + post_patch_latent_length = latent_length / patch_size_t + post_patch_mel_bins = latent_mel_bins / patch_size + latents = latents.reshape( + batch_size, -1, post_patch_latent_length, patch_size_t, post_patch_mel_bins, patch_size + ) + latents = latents.permute(0, 2, 4, 1, 3, 5).flatten(3, 5).flatten(1, 2) + else: + # Packs the latents into a patch sequence of shape [B, L, C * M]. This implicitly assumes a (mel) + # patch_size of M (all mel bins constitutes a single patch) and a patch_size_t of 1. + latents = latents.transpose(1, 2).flatten(2, 3) # [B, C, L, M] --> [B, L, C * M] + return latents + + @staticmethod + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._unpack_audio_latents + def _unpack_audio_latents( + latents: torch.Tensor, + latent_length: int, + num_mel_bins: int, + patch_size: Optional[int] = None, + patch_size_t: Optional[int] = None, + ) -> torch.Tensor: + # Unpacks an audio patch sequence of shape [B, S, D] into a latent spectrogram tensor of shape [B, C, L, M], + # where L is the latent audio length and M is the number of mel bins. + if patch_size is not None and patch_size_t is not None: + batch_size = latents.size(0) + latents = latents.reshape(batch_size, latent_length, num_mel_bins, -1, patch_size_t, patch_size) + latents = latents.permute(0, 3, 1, 4, 2, 5).flatten(4, 5).flatten(2, 3) + else: + # Assume [B, S, D] = [B, L, C * M], which implies that patch_size = M and patch_size_t = 1. + latents = latents.unflatten(2, (-1, num_mel_bins)).transpose(1, 2) + return latents + + def prepare_latents( + self, + image: Optional[torch.Tensor] = None, + batch_size: int = 1, + num_channels_latents: int = 128, + height: int = 512, + width: int = 704, + num_frames: int = 161, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + height = height // self.vae_spatial_compression_ratio + width = width // self.vae_spatial_compression_ratio + num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1 + + shape = (batch_size, num_channels_latents, num_frames, height, width) + mask_shape = (batch_size, 1, num_frames, height, width) + + if latents is not None: + conditioning_mask = latents.new_zeros(mask_shape) + conditioning_mask[:, :, 0] = 1.0 + conditioning_mask = self._pack_latents( + conditioning_mask, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size + ).squeeze(-1) + if latents.ndim != 3 or latents.shape[:2] != conditioning_mask.shape: + raise ValueError( + f"Provided `latents` tensor has shape {latents.shape}, but the expected shape is {conditioning_mask.shape + (num_channels_latents,)}." + ) + return latents.to(device=device, dtype=dtype), conditioning_mask + + if isinstance(generator, list): + if len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + init_latents = [ + retrieve_latents(self.vae.encode(image[i].unsqueeze(0).unsqueeze(2)), generator[i], "argmax") + for i in range(batch_size) + ] + else: + init_latents = [ + retrieve_latents(self.vae.encode(img.unsqueeze(0).unsqueeze(2)), generator, "argmax") for img in image + ] + + init_latents = torch.cat(init_latents, dim=0).to(dtype) + init_latents = self._normalize_latents(init_latents, self.vae.latents_mean, self.vae.latents_std) + init_latents = init_latents.repeat(1, 1, num_frames, 1, 1) + + # First condition is image latents and those should be kept clean. + conditioning_mask = torch.zeros(mask_shape, device=device, dtype=dtype) + conditioning_mask[:, :, 0] = 1.0 + + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + # Interpolation. + latents = init_latents * conditioning_mask + noise * (1 - conditioning_mask) + + conditioning_mask = self._pack_latents( + conditioning_mask, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size + ).squeeze(-1) + latents = self._pack_latents( + latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size + ) + + return latents, conditioning_mask + + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline.prepare_audio_latents + def prepare_audio_latents( + self, + batch_size: int = 1, + num_channels_latents: int = 8, + num_mel_bins: int = 64, + num_frames: int = 121, + frame_rate: float = 25.0, + sampling_rate: int = 16000, + hop_length: int = 160, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + duration_s = num_frames / frame_rate + latents_per_second = float(sampling_rate) / float(hop_length) / float(self.audio_vae_temporal_compression_ratio) + latent_length = int(duration_s * latents_per_second) + + if latents is not None: + return latents.to(device=device, dtype=dtype), latent_length + + # TODO: confirm whether this logic is correct + latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio + + shape = (batch_size, num_channels_latents, latent_length, latent_mel_bins) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self._pack_audio_latents(latents) + return latents, latent_length + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def guidance_rescale(self): + return self._guidance_rescale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1.0 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + image: PipelineImageInput = None, + prompt: Union[str, List[str]] = None, + negative_prompt: Optional[Union[str, List[str]]] = None, + height: int = 512, + width: int = 768, + num_frames: int = 121, + frame_rate: float = 25.0, + num_inference_steps: int = 40, + timesteps: List[int] = None, + guidance_scale: float = 3.0, + guidance_rescale: float = 0.0, + num_videos_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + audio_latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + audio_prompt_embeds: Optional[torch.Tensor] = None, + prompt_attention_mask: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + negative_audio_prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, + decode_timestep: Union[float, List[float]] = 0.0, + decode_noise_scale: Optional[Union[float, List[float]]] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 1024, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + image (`PipelineImageInput`): + The input image to condition the generation on. Must be an image, a list of images or a `torch.Tensor`. + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + height (`int`, *optional*, defaults to `512`): + The height in pixels of the generated image. This is set to 480 by default for the best results. + width (`int`, *optional*, defaults to `768`): + The width in pixels of the generated image. This is set to 848 by default for the best results. + num_frames (`int`, *optional*, defaults to `121`): + The number of video frames to generate + frame_rate (`float`, *optional*, defaults to `25.0`): + The frames per second (FPS) of the generated video. + num_inference_steps (`int`, *optional*, defaults to 40): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to `3.0`): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + guidance_rescale (`float`, *optional*, defaults to 0.0): + Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891) `guidance_scale` is defined as `φ` in equation 16. of + [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). Guidance rescale factor should fix overexposure when + using zero terminal SNR. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of videos to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for video + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + audio_latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for audio + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + audio_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings for audio processing. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + prompt_attention_mask (`torch.Tensor`, *optional*): + Pre-generated attention mask for text embeddings. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not + provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. + negative_audio_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings for audio processing. For PixArt-Sigma this negative prompt should be "". If not + provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. + negative_prompt_attention_mask (`torch.FloatTensor`, *optional*): + Pre-generated attention mask for negative text embeddings. + decode_timestep (`float`, defaults to `0.0`): + The timestep at which generated video is decoded. + decode_noise_scale (`float`, defaults to `None`): + The interpolation factor between random noise and denoised latents at the decode timestep. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ltx.LTXPipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int`, *optional*, defaults to `1024`): + Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.ltx.LTXPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.ltx.LTXPipelineOutput`] is returned, otherwise a `tuple` is + returned where the first element is a list with the generated images. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt=prompt, + height=height, + width=width, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + ) + + self._guidance_scale = guidance_scale + self._guidance_rescale = guidance_rescale + self._attention_kwargs = attention_kwargs + self._interrupt = False + self._current_timestep = None + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. Prepare text embeddings + ( + prompt_embeds, + audio_prompt_embeds, + prompt_attention_mask, + negative_prompt_embeds, + negative_audio_prompt_embeds, + negative_prompt_attention_mask, + ) = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + audio_prompt_embeds=audio_prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + negative_audio_prompt_embeds=negative_audio_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + max_sequence_length=max_sequence_length, + device=device, + ) + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + audio_prompt_embeds = torch.cat([negative_audio_prompt_embeds, audio_prompt_embeds], dim=0) + prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) + + # 4. Prepare latent variables + if latents is None: + image = self.video_processor.preprocess(image, height=height, width=width) + image = image.to(device=device, dtype=prompt_embeds.dtype) + + num_channels_latents = self.transformer.config.in_channels + latents, conditioning_mask = self.prepare_latents( + image, + batch_size * num_videos_per_prompt, + num_channels_latents, + height, + width, + num_frames, + torch.float32, + device, + generator, + latents, + ) + if self.do_classifier_free_guidance: + conditioning_mask = torch.cat([conditioning_mask, conditioning_mask]) + + num_mel_bins = self.audio_vae.config.mel_bins if getattr(self, "audio_vae", None) is not None else 64 + latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio + + num_channels_latents_audio = ( + self.audio_vae.config.latent_channels if getattr(self, "audio_vae", None) is not None else 8 + ) + audio_latents, audio_num_frames = self.prepare_audio_latents( + batch_size * num_videos_per_prompt, + num_channels_latents=num_channels_latents_audio, + num_mel_bins=num_mel_bins, + num_frames=num_frames, # Video frames, audio frames will be calculated from this + frame_rate=frame_rate, + sampling_rate=self.audio_sampling_rate, + hop_length=self.audio_hop_length, + dtype=torch.float32, + device=device, + generator=generator, + latents=audio_latents, + ) + + # 5. Prepare timesteps + latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1 + latent_height = height // self.vae_spatial_compression_ratio + latent_width = width // self.vae_spatial_compression_ratio + video_sequence_length = latent_num_frames * latent_height * latent_width + + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) + mu = calculate_shift( + video_sequence_length, + self.scheduler.config.get("base_image_seq_len", 1024), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.95), + self.scheduler.config.get("max_shift", 2.05), + ) + + # For now, duplicate the scheduler for use with the audio latents + audio_scheduler = copy.deepcopy(self.scheduler) + _, _ = retrieve_timesteps( + audio_scheduler, + num_inference_steps, + device, + timesteps, + sigmas=sigmas, + mu=mu, + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + timesteps, + sigmas=sigmas, + mu=mu, + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # 6. Prepare micro-conditions + rope_interpolation_scale = ( + self.vae_temporal_compression_ratio / frame_rate, + self.vae_spatial_compression_ratio, + self.vae_spatial_compression_ratio, + ) + # Pre-compute video and audio positional ids as they will be the same at each step of the denoising loop + video_coords = self.transformer.rope.prepare_video_coords( + latents.shape[0], latent_num_frames, latent_height, latent_width, latents.device, fps=frame_rate + ) + audio_coords = self.transformer.audio_rope.prepare_audio_coords( + audio_latents.shape[0], audio_num_frames, audio_latents.device, fps=frame_rate + ) + + # 7. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + latent_model_input = latent_model_input.to(prompt_embeds.dtype) + audio_latent_model_input = torch.cat([audio_latents] * 2) if self.do_classifier_free_guidance else audio_latents + audio_latent_model_input = audio_latent_model_input.to(prompt_embeds.dtype) + + timestep = t.expand(latent_model_input.shape[0]) + video_timestep = timestep.unsqueeze(-1) * (1 - conditioning_mask) + + with self.transformer.cache_context("cond_uncond"): + noise_pred_video, noise_pred_audio = self.transformer( + hidden_states=latent_model_input, + audio_hidden_states=audio_latent_model_input, + encoder_hidden_states=prompt_embeds, + audio_encoder_hidden_states=audio_prompt_embeds, + timestep=video_timestep, + audio_timestep=timestep, + encoder_attention_mask=prompt_attention_mask, + audio_encoder_attention_mask=prompt_attention_mask, + num_frames=latent_num_frames, + height=latent_height, + width=latent_width, + fps=frame_rate, + audio_num_frames=audio_num_frames, + video_coords=video_coords, + audio_coords=audio_coords, + # rope_interpolation_scale=rope_interpolation_scale, + attention_kwargs=attention_kwargs, + return_dict=False, + ) + noise_pred_video = noise_pred_video.float() + noise_pred_audio = noise_pred_audio.float() + + if self.do_classifier_free_guidance: + noise_pred_video_uncond, noise_pred_video_text = noise_pred_video.chunk(2) + noise_pred_video = noise_pred_video_uncond + self.guidance_scale * (noise_pred_video_text - noise_pred_video_uncond) + + noise_pred_audio_uncond, noise_pred_audio_text = noise_pred_audio.chunk(2) + noise_pred_audio = noise_pred_audio_uncond + self.guidance_scale * (noise_pred_audio_text - noise_pred_audio_uncond) + + if self.guidance_rescale > 0: + # Based on 3.4. in https://huggingface.co/papers/2305.08891 + noise_pred_video = rescale_noise_cfg( + noise_pred_video, noise_pred_video_text, guidance_rescale=self.guidance_rescale + ) + noise_pred_audio = rescale_noise_cfg( + noise_pred_audio, noise_pred_audio_text, guidance_rescale=self.guidance_rescale + ) + + # compute the previous noisy sample x_t -> x_t-1 + noise_pred_video = self._unpack_latents( + noise_pred_video, + latent_num_frames, + latent_height, + latent_width, + self.transformer_spatial_patch_size, + self.transformer_temporal_patch_size, + ) + latents = self._unpack_latents( + latents, + latent_num_frames, + latent_height, + latent_width, + self.transformer_spatial_patch_size, + self.transformer_temporal_patch_size, + ) + + noise_pred_video = noise_pred_video[:, :, 1:] + noise_latents = latents[:, :, 1:] + pred_latents = self.scheduler.step(noise_pred_video, t, noise_latents, return_dict=False)[0] + + latents = torch.cat([latents[:, :, :1], pred_latents], dim=2) + latents = self._pack_latents( + latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size + ) + + # NOTE: for now duplicate scheduler for audio latents in case self.scheduler sets internal state in + # the step method (such as _step_index) + audio_latents = audio_scheduler.step(noise_pred_audio, t, audio_latents, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if output_type == "latent": + video = latents + audio = audio_latents + else: + latents = self._unpack_latents( + latents, + latent_num_frames, + latent_height, + latent_width, + self.transformer_spatial_patch_size, + self.transformer_temporal_patch_size, + ) + latents = self._denormalize_latents( + latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor + ) + latents = latents.to(prompt_embeds.dtype) + + if not self.vae.config.timestep_conditioning: + timestep = None + else: + noise = randn_tensor(latents.shape, generator=generator, device=device, dtype=latents.dtype) + if not isinstance(decode_timestep, list): + decode_timestep = [decode_timestep] * batch_size + if decode_noise_scale is None: + decode_noise_scale = decode_timestep + elif not isinstance(decode_noise_scale, list): + decode_noise_scale = [decode_noise_scale] * batch_size + + timestep = torch.tensor(decode_timestep, device=device, dtype=latents.dtype) + decode_noise_scale = torch.tensor(decode_noise_scale, device=device, dtype=latents.dtype)[ + :, None, None, None, None + ] + latents = (1 - decode_noise_scale) * latents + decode_noise_scale * noise + + latents = latents.to(self.vae.dtype) + video = self.vae.decode(latents, timestep, return_dict=False)[0] + video = self.video_processor.postprocess_video(video, output_type=output_type) + + audio_latents = self._unpack_audio_latents(audio_latents, audio_num_frames, num_mel_bins=latent_mel_bins) + audio_latents = audio_latents.to(self.audio_vae.dtype) + # NOTE: currently, unlike the video VAE, we denormalize the audio latents inside the audio VAE decoder's + # decode method + generated_mel_spectrograms = self.audio_vae.decode(audio_latents, return_dict=False)[0] + audio = self.vocoder(generated_mel_spectrograms) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video, audio) + + return LTX2PipelineOutput(frames=video, audio=audio)