1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Add export_utils file for exporting LTX 2.0 videos with audio

This commit is contained in:
Daniel Gu
2026-01-06 02:17:39 +01:00
parent bff989110c
commit cb50cacba5
3 changed files with 140 additions and 0 deletions

View File

@@ -0,0 +1,134 @@
# Copyright 2025 The Lightricks team and The HuggingFace Team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from fractions import Fraction
from typing import Optional
import torch
from ...utils import is_av_available
_CAN_USE_AV = is_av_available()
if _CAN_USE_AV:
import av
else:
raise ImportError(
"PyAV is required to use LTX 2.0 video export utilities. You can install it with `pip install av`"
)
def _prepare_audio_stream(container: av.container.Container, audio_sample_rate: int) -> av.audio.AudioStream:
"""
Prepare the audio stream for writing.
"""
audio_stream = container.add_stream("aac", rate=audio_sample_rate)
audio_stream.codec_context.sample_rate = audio_sample_rate
audio_stream.codec_context.layout = "stereo"
audio_stream.codec_context.time_base = Fraction(1, audio_sample_rate)
return audio_stream
def _resample_audio(
container: av.container.Container, audio_stream: av.audio.AudioStream, frame_in: av.AudioFrame
) -> None:
cc = audio_stream.codec_context
# Use the encoder's format/layout/rate as the *target*
target_format = cc.format or "fltp" # AAC → usually fltp
target_layout = cc.layout or "stereo"
target_rate = cc.sample_rate or frame_in.sample_rate
audio_resampler = av.audio.resampler.AudioResampler(
format=target_format,
layout=target_layout,
rate=target_rate,
)
audio_next_pts = 0
for rframe in audio_resampler.resample(frame_in):
if rframe.pts is None:
rframe.pts = audio_next_pts
audio_next_pts += rframe.samples
rframe.sample_rate = frame_in.sample_rate
container.mux(audio_stream.encode(rframe))
# flush audio encoder
for packet in audio_stream.encode():
container.mux(packet)
def _write_audio(
container: av.container.Container,
audio_stream: av.audio.AudioStream,
samples: torch.Tensor,
audio_sample_rate: int,
) -> None:
if samples.ndim == 1:
samples = samples[:, None]
if samples.shape[1] != 2 and samples.shape[0] == 2:
samples = samples.T
if samples.shape[1] != 2:
raise ValueError(f"Expected samples with 2 channels; got shape {samples.shape}.")
# Convert to int16 packed for ingestion; resampler converts to encoder fmt.
if samples.dtype != torch.int16:
samples = torch.clip(samples, -1.0, 1.0)
samples = (samples * 32767.0).to(torch.int16)
frame_in = av.AudioFrame.from_ndarray(
samples.contiguous().reshape(1, -1).cpu().numpy(),
format="s16",
layout="stereo",
)
frame_in.sample_rate = audio_sample_rate
_resample_audio(container, audio_stream, frame_in)
def encode_video(
video: torch.Tensor, fps: int, audio: Optional[torch.Tensor], audio_sample_rate: Optional[int], output_path: str
) -> None:
video_np = video.cpu().numpy()
_, height, width, _ = video_np.shape
container = av.open(output_path, mode="w")
stream = container.add_stream("libx264", rate=int(fps))
stream.width = width
stream.height = height
stream.pix_fmt = "yuv420p"
if audio is not None:
if audio_sample_rate is None:
raise ValueError("audio_sample_rate is required when audio is provided")
audio_stream = _prepare_audio_stream(container, audio_sample_rate)
for frame_array in video_np:
frame = av.VideoFrame.from_ndarray(frame_array, format="rgb24")
for packet in stream.encode(frame):
container.mux(packet)
# Flush encoder
for packet in stream.encode():
container.mux(packet)
if audio is not None:
_write_audio(container, audio_stream, audio, audio_sample_rate)
container.close()

View File

@@ -66,6 +66,7 @@ from .import_utils import (
is_accelerate_version,
is_aiter_available,
is_aiter_version,
is_av_available,
is_better_profanity_available,
is_bitsandbytes_available,
is_bitsandbytes_version,

View File

@@ -230,6 +230,7 @@ _flash_attn_3_available, _flash_attn_3_version = _is_package_available("flash_at
_aiter_available, _aiter_version = _is_package_available("aiter")
_kornia_available, _kornia_version = _is_package_available("kornia")
_nvidia_modelopt_available, _nvidia_modelopt_version = _is_package_available("modelopt", get_dist_name=True)
_av_available, _av_version = _is_package_available("av")
def is_torch_available():
@@ -420,6 +421,10 @@ def is_kornia_available():
return _kornia_available
def is_av_available():
return _av_available
# docstyle-ignore
FLAX_IMPORT_ERROR = """
{0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the