From 8f1ddb1b1ea6ffdd6a2f2202e07e61319fa3593f Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Wed, 7 Jan 2026 01:58:25 +0100 Subject: [PATCH] Get latent upsampler working with video latents --- scripts/ltx2_test_latent_upsampler.py | 55 ++++++++++++++++--- .../ltx2/pipeline_ltx2_latent_upsample.py | 34 ++++++++++++ 2 files changed, 80 insertions(+), 9 deletions(-) diff --git a/scripts/ltx2_test_latent_upsampler.py b/scripts/ltx2_test_latent_upsampler.py index 745f2c8d1f..c485e6ce3f 100644 --- a/scripts/ltx2_test_latent_upsampler.py +++ b/scripts/ltx2_test_latent_upsampler.py @@ -53,6 +53,7 @@ def parse_args(): parser.add_argument("--dtype", type=str, default="bf16") parser.add_argument("--cpu_offload", action="store_true") parser.add_argument("--vae_tiling", action="store_true") + parser.add_argument("--use_video_latents", action="store_true") parser.add_argument( "--output_dir", @@ -83,6 +84,10 @@ def main(args): image = load_image(args.image_path) + first_stage_output_type = "pil" + if args.use_video_latents: + first_stage_output_type = "latent" + video, audio = pipeline( image=image, prompt=args.prompt, @@ -94,15 +99,39 @@ def main(args): num_inference_steps=args.num_inference_steps, guidance_scale=args.guidance_scale, generator=torch.Generator(device=args.device).manual_seed(args.seed), - output_type="pil", + output_type=first_stage_output_type, return_dict=False, ) + if args.use_video_latents: + # Manually convert the audio latents to a waveform + audio = audio.to(pipeline.audio_vae.dtype) + audio = pipeline._denormalize_audio_latents( + audio, pipeline.audio_vae.latents_mean, pipeline.audio_vae.latents_std + ) + + sampling_rate = pipeline.audio_sampling_rate + hop_length = pipeline.audio_hop_length + audio_vae_temporal_scale = pipeline.audio_vae_temporal_compression_ratio + duration_s = args.num_frames / args.frame_rate + latents_per_second = ( + float(sampling_rate) / float(hop_length) / float(audio_vae_temporal_scale) + ) + audio_latent_frames = int(duration_s * latents_per_second) + latent_mel_bins = pipeline.audio_vae.config.mel_bins // pipeline.audio_vae_mel_compression_ratio + audio = pipeline._unpack_audio_latents(audio, audio_latent_frames, latent_mel_bins) + + audio = pipeline.audio_vae.decode(audio, return_dict=False)[0] + audio = pipeline.vocoder(audio) + + # Get some pipeline configs for upsampling + spatial_patch_size = pipeline.transformer_spatial_patch_size + temporal_patch_size = pipeline.transformer_temporal_patch_size + # upsample_pipeline = LTX2LatentUpsamplePipeline.from_pretrained( # args.model_id, revision=args.revision, torch_dtype=args.dtype, # ) output_sampling_rate = pipeline.vocoder.config.output_sampling_rate - pipeline.to(device="cpu") del pipeline # Otherwise there might be an OOM error? torch.cuda.empty_cache() gc.collect() @@ -124,13 +153,21 @@ def main(args): if args.vae_tiling: upsample_pipeline.enable_vae_tiling() - video = upsample_pipeline( - video=video, - height=args.height, - width=args.width, - output_type="np", - return_dict=False, - )[0] + upsample_kwargs = { + "height": args.height, + "width": args.width, + "output_type": "np", + "return_dict": False, + } + if args.use_video_latents: + upsample_kwargs["latents"] = video + upsample_kwargs["num_frames"] = args.num_frames + upsample_kwargs["spatial_patch_size"] = spatial_patch_size + upsample_kwargs["temporal_patch_size"] = temporal_patch_size + else: + upsample_kwargs["video"] = video + + video = upsample_pipeline(**upsample_kwargs)[0] # Convert video to uint8 (but keep as NumPy array) video = (video * 255).round().astype("uint8") diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2_latent_upsample.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2_latent_upsample.py index 97c9c90b60..8f35829233 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2_latent_upsample.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2_latent_upsample.py @@ -67,12 +67,25 @@ class LTX2LatentUpsamplePipeline(DiffusionPipeline): self, video: Optional[torch.Tensor] = None, batch_size: int = 1, + num_frames: int = 121, + height: int = 512, + width: int = 768, + spatial_patch_size: int = 1, + temporal_patch_size: int = 1, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, generator: Optional[torch.Generator] = None, latents: Optional[torch.Tensor] = None, ) -> torch.Tensor: if latents is not None: + if latents.ndim == 3: + # Convert token seq [B, S, D] to latent video [B, C, F, H, W] + latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1 + latent_height = height // self.vae_spatial_compression_ratio + latent_width = width // self.vae_spatial_compression_ratio + latents = self._unpack_latents( + latents, latent_num_frames, latent_height, latent_width, spatial_patch_size, temporal_patch_size + ) return latents.to(device=device, dtype=dtype) video = video.to(device=device, dtype=self.vae.dtype) @@ -175,6 +188,19 @@ class LTX2LatentUpsamplePipeline(DiffusionPipeline): latents = latents * latents_std / scaling_factor + latents_mean return latents + @staticmethod + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._unpack_latents + def _unpack_latents( + latents: torch.Tensor, num_frames: int, height: int, width: int, patch_size: int = 1, patch_size_t: int = 1 + ) -> torch.Tensor: + # Packed latents of shape [B, S, D] (S is the effective video sequence length, D is the effective feature dimensions) + # are unpacked and reshaped into a video tensor of shape [B, C, F, H, W]. This is the inverse operation of + # what happens in the `_pack_latents` method. + batch_size = latents.size(0) + latents = latents.reshape(batch_size, num_frames, height, width, -1, patch_size_t, patch_size, patch_size) + latents = latents.permute(0, 4, 1, 5, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(2, 3) + return latents + def enable_vae_slicing(self): r""" Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to @@ -246,6 +272,9 @@ class LTX2LatentUpsamplePipeline(DiffusionPipeline): video: Optional[List[PipelineImageInput]] = None, height: int = 512, width: int = 768, + num_frames: int = 121, + spatial_patch_size: int = 1, + temporal_patch_size: int = 1, latents: Optional[torch.Tensor] = None, decode_timestep: Union[float, List[float]] = 0.0, decode_noise_scale: Optional[Union[float, List[float]]] = None, @@ -286,6 +315,11 @@ class LTX2LatentUpsamplePipeline(DiffusionPipeline): latents = self.prepare_latents( video=video, batch_size=batch_size, + num_frames=num_frames, + height=height, + width=width, + spatial_patch_size=spatial_patch_size, + temporal_patch_size=temporal_patch_size, dtype=torch.float32, device=device, generator=generator,