1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Add script to test full LTX2Pipeline T2V inference

This commit is contained in:
Daniel Gu
2025-12-23 08:43:37 +01:00
parent 595f485ad8
commit 3bf736979f

View File

@@ -0,0 +1,213 @@
import argparse
import os
from fractions import Fraction
from typing import Optional
import av # Needs to be installed separately (`pip install av`)
import torch
from diffusers import LTX2Pipeline
# Video export functions copied from original LTX 2.0 code
def _prepare_audio_stream(container: av.container.Container, audio_sample_rate: int) -> av.audio.AudioStream:
"""
Prepare the audio stream for writing.
"""
audio_stream = container.add_stream("aac", rate=audio_sample_rate)
audio_stream.codec_context.sample_rate = audio_sample_rate
audio_stream.codec_context.layout = "stereo"
audio_stream.codec_context.time_base = Fraction(1, audio_sample_rate)
return audio_stream
def _resample_audio(
container: av.container.Container, audio_stream: av.audio.AudioStream, frame_in: av.AudioFrame
) -> None:
cc = audio_stream.codec_context
# Use the encoder's format/layout/rate as the *target*
target_format = cc.format or "fltp" # AAC → usually fltp
target_layout = cc.layout or "stereo"
target_rate = cc.sample_rate or frame_in.sample_rate
audio_resampler = av.audio.resampler.AudioResampler(
format=target_format,
layout=target_layout,
rate=target_rate,
)
audio_next_pts = 0
for rframe in audio_resampler.resample(frame_in):
if rframe.pts is None:
rframe.pts = audio_next_pts
audio_next_pts += rframe.samples
rframe.sample_rate = frame_in.sample_rate
container.mux(audio_stream.encode(rframe))
# flush audio encoder
for packet in audio_stream.encode():
container.mux(packet)
def _write_audio(
container: av.container.Container,
audio_stream: av.audio.AudioStream,
samples: torch.Tensor,
audio_sample_rate: int,
) -> None:
if samples.ndim == 1:
samples = samples[:, None]
if samples.shape[1] != 2 and samples.shape[0] == 2:
samples = samples.T
if samples.shape[1] != 2:
raise ValueError(f"Expected samples with 2 channels; got shape {samples.shape}.")
# Convert to int16 packed for ingestion; resampler converts to encoder fmt.
if samples.dtype != torch.int16:
samples = torch.clip(samples, -1.0, 1.0)
samples = (samples * 32767.0).to(torch.int16)
frame_in = av.AudioFrame.from_ndarray(
samples.contiguous().reshape(1, -1).cpu().numpy(),
format="s16",
layout="stereo",
)
frame_in.sample_rate = audio_sample_rate
_resample_audio(container, audio_stream, frame_in)
def encode_video(
video: torch.Tensor, fps: int, audio: Optional[torch.Tensor], audio_sample_rate: Optional[int], output_path: str
) -> None:
video_np = video.cpu().numpy()
_, height, width, _ = video_np.shape
container = av.open(output_path, mode="w")
stream = container.add_stream("libx264", rate=int(fps))
stream.width = width
stream.height = height
stream.pix_fmt = "yuv420p"
if audio is not None:
if audio_sample_rate is None:
raise ValueError("audio_sample_rate is required when audio is provided")
audio_stream = _prepare_audio_stream(container, audio_sample_rate)
for frame_array in video_np:
frame = av.VideoFrame.from_ndarray(frame_array, format="rgb24")
for packet in stream.encode(frame):
container.mux(packet)
# Flush encoder
for packet in stream.encode():
container.mux(packet)
if audio is not None:
_write_audio(container, audio_stream, audio, audio_sample_rate)
container.close()
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--model_id", type=str, default="diffusers-internal-dev/new-ltx-model")
parser.add_argument("--revision", type=str, default="main")
parser.add_argument(
"--prompt",
type=str,
default="A video of a dog dancing to energetic electronic dance music",
)
parser.add_argument(
"--negative_prompt",
type=str,
default=(
"blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, "
"grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, "
"deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, "
"wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of "
"field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent "
"lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny "
"valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, "
"mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, "
"off-sync audio,incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward "
"pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, "
"inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts."
),
)
parser.add_argument("--num_inference_steps", type=int, default=40)
parser.add_argument("--height", type=int, default=512)
parser.add_argument("--width", type=int, default=768)
parser.add_argument("--num_frames", type=int, default=121)
parser.add_argument("--frame_rate", type=float, default=25.0)
parser.add_argument("--guidance_scale", type=float, default=3.0)
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--device", type=str, default="cuda:0")
parser.add_argument("--dtype", type=str, default="bf16")
parser.add_argument(
"--output_dir",
type=str,
default="/home/daniel_gu/samples",
help="Output directory for generated video",
)
parser.add_argument(
"--output_filename",
type=str,
default="ltx2_sample_video.mp4",
help="Filename of the exported generated video",
)
args = parser.parse_args()
args.dtype = torch.bfloat16 if args.dtype == "bf16" else torch.float32
return args
def main(args):
pipeline = LTX2Pipeline.from_pretrained(
args.model_id,
revision=args.revision,
torch_dtype=args.dtype,
)
pipeline.to(device=args.device)
video, audio = pipeline(
prompt=args.prompt,
negative_prompt=args.negative_prompt,
height=args.height,
width=args.width,
num_frames=args.num_frames,
frame_rate=args.frame_rate,
num_inference_steps=args.num_inference_steps,
guidance_scale=args.guidance_scale,
generator=torch.Generator(device=args.device).manual_seed(args.seed),
output_type="np",
)
# Convert video to uint8 (but keep as NumPy array)
video = (video * 255).round().astype("uint8")
video = torch.from_numpy(video)
# video should already be frames first, reshape to channels-last (we want shape to be (*, F, H , W, C))
video = video.permute(0, 1, 3, 4, 2)
encode_video(
video[0],
fps=args.frame_rate,
audio=audio[0].float().cpu(),
audio_sample_rate=pipeline.vocoder.config.output_sampling_rate, # should be 24000
output_path=os.path.join(args.output_dir, args.output_filename),
)
if __name__ == '__main__':
args = parse_args()
main(args)