From 32df138fef2d9b2a685ec90f8297ef5755705ef5 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Wed, 7 Jan 2026 08:03:41 +0100 Subject: [PATCH] Add latent upsample pipeline docstring and example --- .../ltx2/pipeline_ltx2_latent_upsample.py | 107 +++++++++++++++++- 1 file changed, 105 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2_latent_upsample.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2_latent_upsample.py index 8f35829233..fb3d298b82 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2_latent_upsample.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2_latent_upsample.py @@ -18,17 +18,68 @@ import torch from ...image_processor import PipelineImageInput from ...models import AutoencoderKLLTX2Video -from ...utils import deprecate, get_logger +from ...utils import deprecate, get_logger, replace_example_docstring from ...utils.torch_utils import randn_tensor from ...video_processor import VideoProcessor -from ..pipeline_utils import DiffusionPipeline from ..ltx.pipeline_output import LTXPipelineOutput +from ..pipeline_utils import DiffusionPipeline from .latent_upsampler import LTX2LatentUpsamplerModel logger = get_logger(__name__) # pylint: disable=invalid-name +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import LTX2ImageToVideoPipeline, LTX2 + >>> from diffusers.utils import load_image + >>> from diffusers.pipelines.ltx2.export_utils import encode_video + + >>> pipe = LTX2ImageToVideoPipeline.from_pretrained("Lightricks/LTX-Video-2", torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + + >>> image = load_image( + ... "https://huggingface.co/datasets/a-r-r-o-w/tiny-meme-dataset-captioned/resolve/main/images/8.png" + ... ) + >>> prompt = "A young girl stands calmly in the foreground, looking directly at the camera, as a house fire rages in the background." + >>> negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted" + + >>> video, audio = pipe( + ... image=image, + ... prompt=prompt, + ... negative_prompt=negative_prompt, + ... width=768, + ... height=512, + ... num_frames=121, + ... frame_rate=25.0, + ... num_inference_steps=40, + ... guidance_scale=3.0, + ... output_type="pil", + ... return_dict=False, + ... ) + + >>> upsample_pipe = LTX2LatentUpsamplePipeline.from_pretrained( + ... "Lightricks/LTX-Video-2", torch_dtype=torch.bfloat16 + ... ) + >>> upsample_pipe.to("cuda") + + >>> video = upsample_pipe( + ... video=video, + ... width=768, + ... height=512, + ... output_type="pil", + ... return_dict=False, + ... )[0] + + >>> video = (video * 255).round().astype("uint8") + >>> video = torch.from_numpy(video) + >>> encode_video(video[0], fps=25.0, audio=audio[0].float().cpu(), output_path="output.mp4") + ``` +""" + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents def retrieve_latents( encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" @@ -267,6 +318,7 @@ class LTX2LatentUpsamplePipeline(DiffusionPipeline): raise ValueError("`tone_map_compression_ratio` must be in the range [0, 1]") @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, video: Optional[List[PipelineImageInput]] = None, @@ -284,6 +336,57 @@ class LTX2LatentUpsamplePipeline(DiffusionPipeline): output_type: Optional[str] = "pil", return_dict: bool = True, ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + video (`List[PipelineImageInput]`, *optional*) + The video to be upsampled (such as a LTX 2.0 first stage output). If not supplied, `latents` should be + supplied. + height (`int`, *optional*, defaults to `512`): + The height in pixels of the input video (not the generated video, which will have a larger resolution). + width (`int`, *optional*, defaults to `768`): + The width in pixels of the input video (not the generated video, which will have a larger resolution). + num_frames (`int`, *optional*, defaults to `121`): + The number of frames in the input video. + spatial_patch_size (`int`, *optional*, defaults to `1`): + The spatial patch size of the video latents. Used when `latents` is supplied if unpacking is necessary. + temporal_patch_size (`int`, *optional*, defaults to `1`): + The temporal patch size of the video latents. Used when `latents` is supplied if unpacking is + necessary. + latents (`torch.Tensor`, *optional*): + Pre-generated video latents. This can be supplied in place of the `video` argument. Can either be a + patch sequence of shape `(batch_size, seq_len, hidden_dim)` or a video latent of shape `(batch_size, + latent_channels, latent_frames, latent_height, latent_width)`. + decode_timestep (`float`, defaults to `0.0`): + The timestep at which generated video is decoded. + decode_noise_scale (`float`, defaults to `None`): + The interpolation factor between random noise and denoised latents at the decode timestep. + adain_factor (`float`, *optional*, defaults to `0.0`): + Adaptive Instance Normalization (AdaIN) blending factor between the upsampled and original latents. + Should be in [-10.0, 10.0]; supplying 0.0 (the default) means that AdaIN is not performed. + tone_map_compression_ratio (`float`, *optional*, defaults to `0.0`): + The compression strength for tone mapping, which will reduce the dynamic range of the latent values. + This is useful for regularizing high-variance latents or for conditioning outputs during generation. + Should be in [0, 1], where 0.0 (the default) means tone mapping is not applied and 1.0 corresponds to + the full compression effect. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ltx.LTXPipelineOutput`] instead of a plain tuple. + + Examples: + + Returns: + [`~pipelines.ltx.LTXPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.ltx.LTXPipelineOutput`] is returned, otherwise a `tuple` is + returned where the first element is the upsampled video. + """ + self.check_inputs( video=video, height=height,