1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Add latent upsample pipeline docstring and example

This commit is contained in:
Daniel Gu
2026-01-07 08:03:41 +01:00
parent 0637b549a0
commit 32df138fef

View File

@@ -18,17 +18,68 @@ import torch
from ...image_processor import PipelineImageInput
from ...models import AutoencoderKLLTX2Video
from ...utils import deprecate, get_logger
from ...utils import deprecate, get_logger, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ...video_processor import VideoProcessor
from ..pipeline_utils import DiffusionPipeline
from ..ltx.pipeline_output import LTXPipelineOutput
from ..pipeline_utils import DiffusionPipeline
from .latent_upsampler import LTX2LatentUpsamplerModel
logger = get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```py
>>> import torch
>>> from diffusers import LTX2ImageToVideoPipeline, LTX2
>>> from diffusers.utils import load_image
>>> from diffusers.pipelines.ltx2.export_utils import encode_video
>>> pipe = LTX2ImageToVideoPipeline.from_pretrained("Lightricks/LTX-Video-2", torch_dtype=torch.bfloat16)
>>> pipe.to("cuda")
>>> image = load_image(
... "https://huggingface.co/datasets/a-r-r-o-w/tiny-meme-dataset-captioned/resolve/main/images/8.png"
... )
>>> prompt = "A young girl stands calmly in the foreground, looking directly at the camera, as a house fire rages in the background."
>>> negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted"
>>> video, audio = pipe(
... image=image,
... prompt=prompt,
... negative_prompt=negative_prompt,
... width=768,
... height=512,
... num_frames=121,
... frame_rate=25.0,
... num_inference_steps=40,
... guidance_scale=3.0,
... output_type="pil",
... return_dict=False,
... )
>>> upsample_pipe = LTX2LatentUpsamplePipeline.from_pretrained(
... "Lightricks/LTX-Video-2", torch_dtype=torch.bfloat16
... )
>>> upsample_pipe.to("cuda")
>>> video = upsample_pipe(
... video=video,
... width=768,
... height=512,
... output_type="pil",
... return_dict=False,
... )[0]
>>> video = (video * 255).round().astype("uint8")
>>> video = torch.from_numpy(video)
>>> encode_video(video[0], fps=25.0, audio=audio[0].float().cpu(), output_path="output.mp4")
```
"""
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents(
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
@@ -267,6 +318,7 @@ class LTX2LatentUpsamplePipeline(DiffusionPipeline):
raise ValueError("`tone_map_compression_ratio` must be in the range [0, 1]")
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
video: Optional[List[PipelineImageInput]] = None,
@@ -284,6 +336,57 @@ class LTX2LatentUpsamplePipeline(DiffusionPipeline):
output_type: Optional[str] = "pil",
return_dict: bool = True,
):
r"""
Function invoked when calling the pipeline for generation.
Args:
video (`List[PipelineImageInput]`, *optional*)
The video to be upsampled (such as a LTX 2.0 first stage output). If not supplied, `latents` should be
supplied.
height (`int`, *optional*, defaults to `512`):
The height in pixels of the input video (not the generated video, which will have a larger resolution).
width (`int`, *optional*, defaults to `768`):
The width in pixels of the input video (not the generated video, which will have a larger resolution).
num_frames (`int`, *optional*, defaults to `121`):
The number of frames in the input video.
spatial_patch_size (`int`, *optional*, defaults to `1`):
The spatial patch size of the video latents. Used when `latents` is supplied if unpacking is necessary.
temporal_patch_size (`int`, *optional*, defaults to `1`):
The temporal patch size of the video latents. Used when `latents` is supplied if unpacking is
necessary.
latents (`torch.Tensor`, *optional*):
Pre-generated video latents. This can be supplied in place of the `video` argument. Can either be a
patch sequence of shape `(batch_size, seq_len, hidden_dim)` or a video latent of shape `(batch_size,
latent_channels, latent_frames, latent_height, latent_width)`.
decode_timestep (`float`, defaults to `0.0`):
The timestep at which generated video is decoded.
decode_noise_scale (`float`, defaults to `None`):
The interpolation factor between random noise and denoised latents at the decode timestep.
adain_factor (`float`, *optional*, defaults to `0.0`):
Adaptive Instance Normalization (AdaIN) blending factor between the upsampled and original latents.
Should be in [-10.0, 10.0]; supplying 0.0 (the default) means that AdaIN is not performed.
tone_map_compression_ratio (`float`, *optional*, defaults to `0.0`):
The compression strength for tone mapping, which will reduce the dynamic range of the latent values.
This is useful for regularizing high-variance latents or for conditioning outputs during generation.
Should be in [0, 1], where 0.0 (the default) means tone mapping is not applied and 1.0 corresponds to
the full compression effect.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.ltx.LTXPipelineOutput`] instead of a plain tuple.
Examples:
Returns:
[`~pipelines.ltx.LTXPipelineOutput`] or `tuple`:
If `return_dict` is `True`, [`~pipelines.ltx.LTXPipelineOutput`] is returned, otherwise a `tuple` is
returned where the first element is the upsampled video.
"""
self.check_inputs(
video=video,
height=height,