1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Add test script for LTX 2.0 latent upsampling

This commit is contained in:
Daniel Gu
2026-01-06 05:58:31 +01:00
parent 84c0b2fb84
commit a7d6916afc

View File

@@ -0,0 +1,147 @@
import argparse
import gc
import os
import torch
from diffusers import AutoencoderKLLTX2Video
from diffusers.utils import load_image
from diffusers.pipelines.ltx2 import LTX2ImageToVideoPipeline, LTX2LatentUpsamplePipeline, LTX2LatentUpsamplerModel
from diffusers.pipelines.ltx2.export_utils import encode_video
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--model_id", type=str, default="diffusers-internal-dev/new-ltx-model")
parser.add_argument("--revision", type=str, default="main")
parser.add_argument("--image_path", required=True, type=str)
parser.add_argument(
"--prompt",
type=str,
default=(
"An astronaut hatches from a fragile egg on the surface of the Moon, the shell cracking and peeling apart "
"in gentle low-gravity motion. Fine lunar dust lifts and drifts outward with each movement, floating in "
"slow arcs before settling back onto the ground. The astronaut pushes free in a deliberate, weightless "
"motion, small fragments of the egg tumbling and spinning through the air. In the background, the deep "
"darkness of space subtly shifts as stars glide with the camera's movement, emphasizing vast depth and "
"scale. The camera performs a smooth, cinematic slow push-in, with natural parallax between the foreground "
"dust, the astronaut, and the distant starfield. Ultra-realistic detail, physically accurate low-gravity "
"motion, cinematic lighting, and a breath-taking, movie-like shot."
),
)
parser.add_argument(
"--negative_prompt",
type=str,
default=(
"shaky, glitchy, low quality, worst quality, deformed, distorted, disfigured, motion smear, motion "
"artifacts, fused fingers, bad anatomy, weird hand, ugly, transition, static."
),
)
parser.add_argument("--num_inference_steps", type=int, default=40)
parser.add_argument("--height", type=int, default=512)
parser.add_argument("--width", type=int, default=768)
parser.add_argument("--num_frames", type=int, default=121)
parser.add_argument("--frame_rate", type=float, default=25.0)
parser.add_argument("--guidance_scale", type=float, default=3.0)
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--apply_scheduler_fix", action="store_true")
parser.add_argument("--device", type=str, default="cuda:0")
parser.add_argument("--dtype", type=str, default="bf16")
parser.add_argument("--cpu_offload", action="store_true")
parser.add_argument(
"--output_dir",
type=str,
default="samples",
help="Output directory for generated video",
)
parser.add_argument(
"--output_filename",
type=str,
default="ltx2_i2v_video_upsampled.mp4",
help="Filename of the exported generated video",
)
args = parser.parse_args()
args.dtype = torch.bfloat16 if args.dtype == "bf16" else torch.float32
return args
def main(args):
pipeline = LTX2ImageToVideoPipeline.from_pretrained(
args.model_id, revision=args.revision, torch_dtype=args.dtype,
)
if args.cpu_offload:
pipeline.enable_model_cpu_offload()
else:
pipeline.to(device=args.device)
image = load_image(args.image_path)
video, audio = pipeline(
image=image,
prompt=args.prompt,
negative_prompt=args.negative_prompt,
height=args.height,
width=args.width,
num_frames=args.num_frames,
frame_rate=args.frame_rate,
num_inference_steps=args.num_inference_steps,
guidance_scale=args.guidance_scale,
generator=torch.Generator(device=args.device).manual_seed(args.seed),
output_type="pil",
return_dict=False,
)
# upsample_pipeline = LTX2LatentUpsamplePipeline.from_pretrained(
# args.model_id, revision=args.revision, torch_dtype=args.dtype,
# )
output_sampling_rate = pipeline.vocoder.config.output_sampling_rate
pipeline.to(device="cpu")
del pipeline # Otherwise there might be an OOM error?
torch.cuda.empty_cache()
gc.collect()
vae = AutoencoderKLLTX2Video.from_pretrained(
args.model_id,
subfolder="vae",
revision=args.revision,
torch_dtype=args.dtype,
)
latent_upsampler = LTX2LatentUpsamplerModel.from_pretrained(
args.model_id,
subfolder="latent_upsampler",
revision=args.revision,
torch_dtype=args.dtype,
)
upsample_pipeline = LTX2LatentUpsamplePipeline(vae=vae, latent_upsampler=latent_upsampler)
upsample_pipeline.to(device=args.device)
video = upsample_pipeline(
video=video,
height=args.height,
width=args.width,
output_type="np",
return_dict=False,
)[0]
# Convert video to uint8 (but keep as NumPy array)
video = (video * 255).round().astype("uint8")
video = torch.from_numpy(video)
encode_video(
video[0],
fps=args.frame_rate,
audio=audio[0].float().cpu(),
audio_sample_rate=output_sampling_rate, # should be 24000
output_path=os.path.join(args.output_dir, args.output_filename),
)
if __name__ == '__main__':
args = parse_args()
main(args)