mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Add latent upsample pipeline docstring and example
This commit is contained in:
@@ -18,17 +18,68 @@ import torch
|
||||
|
||||
from ...image_processor import PipelineImageInput
|
||||
from ...models import AutoencoderKLLTX2Video
|
||||
from ...utils import deprecate, get_logger
|
||||
from ...utils import deprecate, get_logger, replace_example_docstring
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ...video_processor import VideoProcessor
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from ..ltx.pipeline_output import LTXPipelineOutput
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from .latent_upsampler import LTX2LatentUpsamplerModel
|
||||
|
||||
|
||||
logger = get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
>>> import torch
|
||||
>>> from diffusers import LTX2ImageToVideoPipeline, LTX2
|
||||
>>> from diffusers.utils import load_image
|
||||
>>> from diffusers.pipelines.ltx2.export_utils import encode_video
|
||||
|
||||
>>> pipe = LTX2ImageToVideoPipeline.from_pretrained("Lightricks/LTX-Video-2", torch_dtype=torch.bfloat16)
|
||||
>>> pipe.to("cuda")
|
||||
|
||||
>>> image = load_image(
|
||||
... "https://huggingface.co/datasets/a-r-r-o-w/tiny-meme-dataset-captioned/resolve/main/images/8.png"
|
||||
... )
|
||||
>>> prompt = "A young girl stands calmly in the foreground, looking directly at the camera, as a house fire rages in the background."
|
||||
>>> negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted"
|
||||
|
||||
>>> video, audio = pipe(
|
||||
... image=image,
|
||||
... prompt=prompt,
|
||||
... negative_prompt=negative_prompt,
|
||||
... width=768,
|
||||
... height=512,
|
||||
... num_frames=121,
|
||||
... frame_rate=25.0,
|
||||
... num_inference_steps=40,
|
||||
... guidance_scale=3.0,
|
||||
... output_type="pil",
|
||||
... return_dict=False,
|
||||
... )
|
||||
|
||||
>>> upsample_pipe = LTX2LatentUpsamplePipeline.from_pretrained(
|
||||
... "Lightricks/LTX-Video-2", torch_dtype=torch.bfloat16
|
||||
... )
|
||||
>>> upsample_pipe.to("cuda")
|
||||
|
||||
>>> video = upsample_pipe(
|
||||
... video=video,
|
||||
... width=768,
|
||||
... height=512,
|
||||
... output_type="pil",
|
||||
... return_dict=False,
|
||||
... )[0]
|
||||
|
||||
>>> video = (video * 255).round().astype("uint8")
|
||||
>>> video = torch.from_numpy(video)
|
||||
>>> encode_video(video[0], fps=25.0, audio=audio[0].float().cpu(), output_path="output.mp4")
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
||||
def retrieve_latents(
|
||||
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
||||
@@ -267,6 +318,7 @@ class LTX2LatentUpsamplePipeline(DiffusionPipeline):
|
||||
raise ValueError("`tone_map_compression_ratio` must be in the range [0, 1]")
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
video: Optional[List[PipelineImageInput]] = None,
|
||||
@@ -284,6 +336,57 @@ class LTX2LatentUpsamplePipeline(DiffusionPipeline):
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
video (`List[PipelineImageInput]`, *optional*)
|
||||
The video to be upsampled (such as a LTX 2.0 first stage output). If not supplied, `latents` should be
|
||||
supplied.
|
||||
height (`int`, *optional*, defaults to `512`):
|
||||
The height in pixels of the input video (not the generated video, which will have a larger resolution).
|
||||
width (`int`, *optional*, defaults to `768`):
|
||||
The width in pixels of the input video (not the generated video, which will have a larger resolution).
|
||||
num_frames (`int`, *optional*, defaults to `121`):
|
||||
The number of frames in the input video.
|
||||
spatial_patch_size (`int`, *optional*, defaults to `1`):
|
||||
The spatial patch size of the video latents. Used when `latents` is supplied if unpacking is necessary.
|
||||
temporal_patch_size (`int`, *optional*, defaults to `1`):
|
||||
The temporal patch size of the video latents. Used when `latents` is supplied if unpacking is
|
||||
necessary.
|
||||
latents (`torch.Tensor`, *optional*):
|
||||
Pre-generated video latents. This can be supplied in place of the `video` argument. Can either be a
|
||||
patch sequence of shape `(batch_size, seq_len, hidden_dim)` or a video latent of shape `(batch_size,
|
||||
latent_channels, latent_frames, latent_height, latent_width)`.
|
||||
decode_timestep (`float`, defaults to `0.0`):
|
||||
The timestep at which generated video is decoded.
|
||||
decode_noise_scale (`float`, defaults to `None`):
|
||||
The interpolation factor between random noise and denoised latents at the decode timestep.
|
||||
adain_factor (`float`, *optional*, defaults to `0.0`):
|
||||
Adaptive Instance Normalization (AdaIN) blending factor between the upsampled and original latents.
|
||||
Should be in [-10.0, 10.0]; supplying 0.0 (the default) means that AdaIN is not performed.
|
||||
tone_map_compression_ratio (`float`, *optional*, defaults to `0.0`):
|
||||
The compression strength for tone mapping, which will reduce the dynamic range of the latent values.
|
||||
This is useful for regularizing high-variance latents or for conditioning outputs during generation.
|
||||
Should be in [0, 1], where 0.0 (the default) means tone mapping is not applied and 1.0 corresponds to
|
||||
the full compression effect.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.ltx.LTXPipelineOutput`] instead of a plain tuple.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~pipelines.ltx.LTXPipelineOutput`] or `tuple`:
|
||||
If `return_dict` is `True`, [`~pipelines.ltx.LTXPipelineOutput`] is returned, otherwise a `tuple` is
|
||||
returned where the first element is the upsampled video.
|
||||
"""
|
||||
|
||||
self.check_inputs(
|
||||
video=video,
|
||||
height=height,
|
||||
|
||||
Reference in New Issue
Block a user