diff --git a/scripts/ltx2_test_latent_upsampler.py b/scripts/ltx2_test_latent_upsampler.py index ace0646e1d..6b2e088f23 100644 --- a/scripts/ltx2_test_latent_upsampler.py +++ b/scripts/ltx2_test_latent_upsampler.py @@ -108,19 +108,6 @@ def main(args): if args.use_video_latents: # Manually convert the audio latents to a waveform audio = audio.to(pipeline.audio_vae.dtype) - audio = pipeline._denormalize_audio_latents( - audio, pipeline.audio_vae.latents_mean, pipeline.audio_vae.latents_std - ) - - sampling_rate = pipeline.audio_sampling_rate - hop_length = pipeline.audio_hop_length - audio_vae_temporal_scale = pipeline.audio_vae_temporal_compression_ratio - duration_s = args.num_frames / args.frame_rate - latents_per_second = float(sampling_rate) / float(hop_length) / float(audio_vae_temporal_scale) - audio_latent_frames = int(duration_s * latents_per_second) - latent_mel_bins = pipeline.audio_vae.config.mel_bins // pipeline.audio_vae_mel_compression_ratio - audio = pipeline._unpack_audio_latents(audio, audio_latent_frames, latent_mel_bins) - audio = pipeline.audio_vae.decode(audio, return_dict=False)[0] audio = pipeline.vocoder(audio) diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2.py index 7cbcca67d2..af06b396cc 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2.py @@ -1070,21 +1070,27 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMix if XLA_AVAILABLE: xm.mark_step() + latents = self._unpack_latents( + latents, + latent_num_frames, + latent_height, + latent_width, + self.transformer_spatial_patch_size, + self.transformer_temporal_patch_size, + ) + latents = self._denormalize_latents( + latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor + ) + + audio_latents = self._denormalize_audio_latents( + audio_latents, self.audio_vae.latents_mean, self.audio_vae.latents_std + ) + audio_latents = self._unpack_audio_latents(audio_latents, audio_num_frames, num_mel_bins=latent_mel_bins) + if output_type == "latent": video = latents audio = audio_latents else: - latents = self._unpack_latents( - latents, - latent_num_frames, - latent_height, - latent_width, - self.transformer_spatial_patch_size, - self.transformer_temporal_patch_size, - ) - latents = self._denormalize_latents( - latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor - ) latents = latents.to(prompt_embeds.dtype) if not self.vae.config.timestep_conditioning: @@ -1109,10 +1115,6 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMix video = self.video_processor.postprocess_video(video, output_type=output_type) audio_latents = audio_latents.to(self.audio_vae.dtype) - audio_latents = self._denormalize_audio_latents( - audio_latents, self.audio_vae.latents_mean, self.audio_vae.latents_std - ) - audio_latents = self._unpack_audio_latents(audio_latents, audio_num_frames, num_mel_bins=latent_mel_bins) generated_mel_spectrograms = self.audio_vae.decode(audio_latents, return_dict=False)[0] audio = self.vocoder(generated_mel_spectrograms) diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py index 0a707806ce..339d5533d5 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py @@ -1166,21 +1166,27 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL if XLA_AVAILABLE: xm.mark_step() + latents = self._unpack_latents( + latents, + latent_num_frames, + latent_height, + latent_width, + self.transformer_spatial_patch_size, + self.transformer_temporal_patch_size, + ) + latents = self._denormalize_latents( + latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor + ) + + audio_latents = self._denormalize_audio_latents( + audio_latents, self.audio_vae.latents_mean, self.audio_vae.latents_std + ) + audio_latents = self._unpack_audio_latents(audio_latents, audio_num_frames, num_mel_bins=latent_mel_bins) + if output_type == "latent": video = latents audio = audio_latents else: - latents = self._unpack_latents( - latents, - latent_num_frames, - latent_height, - latent_width, - self.transformer_spatial_patch_size, - self.transformer_temporal_patch_size, - ) - latents = self._denormalize_latents( - latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor - ) latents = latents.to(prompt_embeds.dtype) if not self.vae.config.timestep_conditioning: @@ -1205,10 +1211,6 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL video = self.video_processor.postprocess_video(video, output_type=output_type) audio_latents = audio_latents.to(self.audio_vae.dtype) - audio_latents = self._denormalize_audio_latents( - audio_latents, self.audio_vae.latents_mean, self.audio_vae.latents_std - ) - audio_latents = self._unpack_audio_latents(audio_latents, audio_num_frames, num_mel_bins=latent_mel_bins) generated_mel_spectrograms = self.audio_vae.decode(audio_latents, return_dict=False)[0] audio = self.vocoder(generated_mel_spectrograms) diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2_latent_upsample.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2_latent_upsample.py index a144c8b15b..680740b07b 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2_latent_upsample.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2_latent_upsample.py @@ -154,7 +154,8 @@ class LTX2LatentUpsamplePipeline(DiffusionPipeline): init_latents = [retrieve_latents(self.vae.encode(vid.unsqueeze(0)), generator) for vid in video] init_latents = torch.cat(init_latents, dim=0).to(dtype) - init_latents = self._normalize_latents(init_latents, self.vae.latents_mean, self.vae.latents_std) + # NOTE: latent upsampler operates on the unnormalized latents, so don't normalize here + # init_latents = self._normalize_latents(init_latents, self.vae.latents_mean, self.vae.latents_std) return init_latents def adain_filter_latent(self, latents: torch.Tensor, reference_latents: torch.Tensor, factor: float = 1.0): @@ -275,6 +276,7 @@ class LTX2LatentUpsamplePipeline(DiffusionPipeline): spatial_patch_size: int = 1, temporal_patch_size: int = 1, latents: Optional[torch.Tensor] = None, + latents_normalized: bool = False, decode_timestep: Union[float, List[float]] = 0.0, decode_noise_scale: Optional[Union[float, List[float]]] = None, adain_factor: float = 0.0, @@ -305,6 +307,9 @@ class LTX2LatentUpsamplePipeline(DiffusionPipeline): Pre-generated video latents. This can be supplied in place of the `video` argument. Can either be a patch sequence of shape `(batch_size, seq_len, hidden_dim)` or a video latent of shape `(batch_size, latent_channels, latent_frames, latent_height, latent_width)`. + latents_normalized (`bool`, *optional*, defaults to `False`) + If `latents` are supplied, whether the `latents` are normalized using the VAE latent mean and std. If + `True`, the `latents` will be denormalized before being supplied to the latent upsampler. decode_timestep (`float`, defaults to `0.0`): The timestep at which generated video is decoded. decode_noise_scale (`float`, defaults to `None`): @@ -362,6 +367,7 @@ class LTX2LatentUpsamplePipeline(DiffusionPipeline): video = self.video_processor.preprocess_video(video, height=height, width=width) video = video.to(device=device, dtype=torch.float32) + latents_supplied = latents is not None latents = self.prepare_latents( video=video, batch_size=batch_size, @@ -376,9 +382,10 @@ class LTX2LatentUpsamplePipeline(DiffusionPipeline): latents=latents, ) - latents = self._denormalize_latents( - latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor - ) + if latents_supplied and latents_normalized: + latents = self._denormalize_latents( + latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor + ) latents = latents.to(self.latent_upsampler.dtype) latents_upsampled = self.latent_upsampler(latents)