1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Get latent upsampler working with video latents

This commit is contained in:
Daniel Gu
2026-01-07 01:58:25 +01:00
parent 245d056c7d
commit 8f1ddb1b1e
2 changed files with 80 additions and 9 deletions

View File

@@ -53,6 +53,7 @@ def parse_args():
parser.add_argument("--dtype", type=str, default="bf16")
parser.add_argument("--cpu_offload", action="store_true")
parser.add_argument("--vae_tiling", action="store_true")
parser.add_argument("--use_video_latents", action="store_true")
parser.add_argument(
"--output_dir",
@@ -83,6 +84,10 @@ def main(args):
image = load_image(args.image_path)
first_stage_output_type = "pil"
if args.use_video_latents:
first_stage_output_type = "latent"
video, audio = pipeline(
image=image,
prompt=args.prompt,
@@ -94,15 +99,39 @@ def main(args):
num_inference_steps=args.num_inference_steps,
guidance_scale=args.guidance_scale,
generator=torch.Generator(device=args.device).manual_seed(args.seed),
output_type="pil",
output_type=first_stage_output_type,
return_dict=False,
)
if args.use_video_latents:
# Manually convert the audio latents to a waveform
audio = audio.to(pipeline.audio_vae.dtype)
audio = pipeline._denormalize_audio_latents(
audio, pipeline.audio_vae.latents_mean, pipeline.audio_vae.latents_std
)
sampling_rate = pipeline.audio_sampling_rate
hop_length = pipeline.audio_hop_length
audio_vae_temporal_scale = pipeline.audio_vae_temporal_compression_ratio
duration_s = args.num_frames / args.frame_rate
latents_per_second = (
float(sampling_rate) / float(hop_length) / float(audio_vae_temporal_scale)
)
audio_latent_frames = int(duration_s * latents_per_second)
latent_mel_bins = pipeline.audio_vae.config.mel_bins // pipeline.audio_vae_mel_compression_ratio
audio = pipeline._unpack_audio_latents(audio, audio_latent_frames, latent_mel_bins)
audio = pipeline.audio_vae.decode(audio, return_dict=False)[0]
audio = pipeline.vocoder(audio)
# Get some pipeline configs for upsampling
spatial_patch_size = pipeline.transformer_spatial_patch_size
temporal_patch_size = pipeline.transformer_temporal_patch_size
# upsample_pipeline = LTX2LatentUpsamplePipeline.from_pretrained(
# args.model_id, revision=args.revision, torch_dtype=args.dtype,
# )
output_sampling_rate = pipeline.vocoder.config.output_sampling_rate
pipeline.to(device="cpu")
del pipeline # Otherwise there might be an OOM error?
torch.cuda.empty_cache()
gc.collect()
@@ -124,13 +153,21 @@ def main(args):
if args.vae_tiling:
upsample_pipeline.enable_vae_tiling()
video = upsample_pipeline(
video=video,
height=args.height,
width=args.width,
output_type="np",
return_dict=False,
)[0]
upsample_kwargs = {
"height": args.height,
"width": args.width,
"output_type": "np",
"return_dict": False,
}
if args.use_video_latents:
upsample_kwargs["latents"] = video
upsample_kwargs["num_frames"] = args.num_frames
upsample_kwargs["spatial_patch_size"] = spatial_patch_size
upsample_kwargs["temporal_patch_size"] = temporal_patch_size
else:
upsample_kwargs["video"] = video
video = upsample_pipeline(**upsample_kwargs)[0]
# Convert video to uint8 (but keep as NumPy array)
video = (video * 255).round().astype("uint8")

View File

@@ -67,12 +67,25 @@ class LTX2LatentUpsamplePipeline(DiffusionPipeline):
self,
video: Optional[torch.Tensor] = None,
batch_size: int = 1,
num_frames: int = 121,
height: int = 512,
width: int = 768,
spatial_patch_size: int = 1,
temporal_patch_size: int = 1,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
generator: Optional[torch.Generator] = None,
latents: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if latents is not None:
if latents.ndim == 3:
# Convert token seq [B, S, D] to latent video [B, C, F, H, W]
latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1
latent_height = height // self.vae_spatial_compression_ratio
latent_width = width // self.vae_spatial_compression_ratio
latents = self._unpack_latents(
latents, latent_num_frames, latent_height, latent_width, spatial_patch_size, temporal_patch_size
)
return latents.to(device=device, dtype=dtype)
video = video.to(device=device, dtype=self.vae.dtype)
@@ -175,6 +188,19 @@ class LTX2LatentUpsamplePipeline(DiffusionPipeline):
latents = latents * latents_std / scaling_factor + latents_mean
return latents
@staticmethod
# Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._unpack_latents
def _unpack_latents(
latents: torch.Tensor, num_frames: int, height: int, width: int, patch_size: int = 1, patch_size_t: int = 1
) -> torch.Tensor:
# Packed latents of shape [B, S, D] (S is the effective video sequence length, D is the effective feature dimensions)
# are unpacked and reshaped into a video tensor of shape [B, C, F, H, W]. This is the inverse operation of
# what happens in the `_pack_latents` method.
batch_size = latents.size(0)
latents = latents.reshape(batch_size, num_frames, height, width, -1, patch_size_t, patch_size, patch_size)
latents = latents.permute(0, 4, 1, 5, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(2, 3)
return latents
def enable_vae_slicing(self):
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
@@ -246,6 +272,9 @@ class LTX2LatentUpsamplePipeline(DiffusionPipeline):
video: Optional[List[PipelineImageInput]] = None,
height: int = 512,
width: int = 768,
num_frames: int = 121,
spatial_patch_size: int = 1,
temporal_patch_size: int = 1,
latents: Optional[torch.Tensor] = None,
decode_timestep: Union[float, List[float]] = 0.0,
decode_noise_scale: Optional[Union[float, List[float]]] = None,
@@ -286,6 +315,11 @@ class LTX2LatentUpsamplePipeline(DiffusionPipeline):
latents = self.prepare_latents(
video=video,
batch_size=batch_size,
num_frames=num_frames,
height=height,
width=width,
spatial_patch_size=spatial_patch_size,
temporal_patch_size=temporal_patch_size,
dtype=torch.float32,
device=device,
generator=generator,