1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

LTX Video 0.9.7 (#11516)

* add upsampling pipeline

* ltx upsample pipeline conversion; pipeline fixes

* make fix-copies

* remove print

* add vae convenience methods

* update

* add tests

* support denoising strength for upscaling & video-to-video

* update docs

* update doc checkpoints

* update docs

* fix

---------

Co-authored-by: Linoy Tsaban <57615435+linoytsaban@users.noreply.github.com>
This commit is contained in:
Aryan
2025-05-13 14:57:03 +05:30
committed by GitHub
parent 8b99f7e157
commit 06fee551e9
10 changed files with 943 additions and 55 deletions

View File

@@ -31,12 +31,103 @@ Available models:
| Model name | Recommended dtype |
|:-------------:|:-----------------:|
| [`LTX Video 0.9.0`](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltx-video-2b-v0.9.safetensors) | `torch.bfloat16` |
| [`LTX Video 0.9.1`](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltx-video-2b-v0.9.1.safetensors) | `torch.bfloat16` |
| [`LTX Video 0.9.5`](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltx-video-2b-v0.9.5.safetensors) | `torch.bfloat16` |
| [`LTX Video 2B 0.9.0`](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltx-video-2b-v0.9.safetensors) | `torch.bfloat16` |
| [`LTX Video 2B 0.9.1`](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltx-video-2b-v0.9.1.safetensors) | `torch.bfloat16` |
| [`LTX Video 2B 0.9.5`](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltx-video-2b-v0.9.5.safetensors) | `torch.bfloat16` |
| [`LTX Video 13B 0.9.7`](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltxv-13b-0.9.7-dev.safetensors) | `torch.bfloat16` |
| [`LTX Video Spatial Upscaler 0.9.7`](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltxv-spatial-upscaler-0.9.7.safetensors) | `torch.bfloat16` |
Note: The recommended dtype is for the transformer component. The VAE and text encoders can be either `torch.float32`, `torch.bfloat16` or `torch.float16` but the recommended dtype is `torch.bfloat16` as used in the original repository.
## Recommended settings for generation
For the best results, it is recommended to follow the guidelines mentioned in the official LTX Video [repository](https://github.com/Lightricks/LTX-Video).
- Some variants of LTX Video are guidance-distilled. For guidance-distilled models, `guidance_scale` must be set to `1.0`. For any other models, `guidance_scale` should be set higher (e.g., `5.0`) for good generation quality.
- For variants with a timestep-aware VAE (LTXV 0.9.1 and above), it is recommended to set `decode_timestep` to `0.05` and `image_cond_noise_scale` to `0.025`.
- For variants that support interpolation between multiple conditioning images and videos (LTXV 0.9.5 and above), it is recommended to use similar looking images/videos for the best results. High divergence between the conditionings may lead to abrupt transitions in the generated video.
## Using LTX Video 13B 0.9.7
LTX Video 0.9.7 comes with a spatial latent upscaler and a 13B parameter transformer. The inference involves generating a low resolution video first, which is very fast, followed by upscaling and refining the generated video.
<!-- TODO(aryan): modify when official checkpoints are available -->
```python
import torch
from diffusers import LTXConditionPipeline, LTXLatentUpsamplePipeline
from diffusers.pipelines.ltx.pipeline_ltx_condition import LTXVideoCondition
from diffusers.utils import export_to_video, load_video
pipe = LTXConditionPipeline.from_pretrained("a-r-r-o-w/LTX-Video-0.9.7-diffusers", torch_dtype=torch.bfloat16)
pipe_upsample = LTXLatentUpsamplePipeline.from_pretrained("a-r-r-o-w/LTX-Video-0.9.7-Latent-Spatial-Upsampler-diffusers", vae=pipe.vae, torch_dtype=torch.bfloat16)
pipe.to("cuda")
pipe_upsample.to("cuda")
pipe.vae.enable_tiling()
def round_to_nearest_resolution_acceptable_by_vae(height, width):
height = height - (height % pipe.vae_temporal_compression_ratio)
width = width - (width % pipe.vae_temporal_compression_ratio)
return height, width
video = load_video(
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cosmos/cosmos-video2world-input-vid.mp4"
)[:21] # Use only the first 21 frames as conditioning
condition1 = LTXVideoCondition(video=video, frame_index=0)
prompt = "The video depicts a winding mountain road covered in snow, with a single vehicle traveling along it. The road is flanked by steep, rocky cliffs and sparse vegetation. The landscape is characterized by rugged terrain and a river visible in the distance. The scene captures the solitude and beauty of a winter drive through a mountainous region."
negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted"
expected_height, expected_width = 768, 1152
downscale_factor = 2 / 3
num_frames = 161
# Part 1. Generate video at smaller resolution
# Text-only conditioning is also supported without the need to pass `conditions`
downscaled_height, downscaled_width = int(expected_height * downscale_factor), int(expected_width * downscale_factor)
downscaled_height, downscaled_width = round_to_nearest_resolution_acceptable_by_vae(downscaled_height, downscaled_width)
latents = pipe(
conditions=[condition1],
prompt=prompt,
negative_prompt=negative_prompt,
width=downscaled_width,
height=downscaled_height,
num_frames=num_frames,
num_inference_steps=30,
generator=torch.Generator().manual_seed(0),
output_type="latent",
).frames
# Part 2. Upscale generated video using latent upsampler with fewer inference steps
# The available latent upsampler upscales the height/width by 2x
upscaled_height, upscaled_width = downscaled_height * 2, downscaled_width * 2
upscaled_latents = pipe_upsample(
latents=latents,
output_type="latent"
).frames
# Part 3. Denoise the upscaled video with few steps to improve texture (optional, but recommended)
video = pipe(
conditions=[condition1],
prompt=prompt,
negative_prompt=negative_prompt,
width=upscaled_width,
height=upscaled_height,
num_frames=num_frames,
denoise_strength=0.4, # Effectively, 4 inference steps out of 10
num_inference_steps=10,
latents=upscaled_latents,
decode_timestep=0.05,
image_cond_noise_scale=0.025,
generator=torch.Generator().manual_seed(0),
output_type="pil",
).frames[0]
# Part 4. Downscale the video to the expected resolution
video = [frame.resize((expected_width, expected_height)) for frame in video]
export_to_video(video, "output.mp4", fps=24)
```
## Loading Single Files
Loading the original LTX Video checkpoints is also possible with [`~ModelMixin.from_single_file`]. We recommend using `from_single_file` for the Lightricks series of models, as they plan to release multiple models in the future in the single file format.
@@ -204,6 +295,12 @@ export_to_video(video, "ship.mp4", fps=24)
- all
- __call__
## LTXLatentUpsamplePipeline
[[autodoc]] LTXLatentUpsamplePipeline
- all
- __call__
## LTXPipelineOutput
[[autodoc]] pipelines.ltx.pipeline_output.LTXPipelineOutput

View File

@@ -7,7 +7,15 @@ from accelerate import init_empty_weights
from safetensors.torch import load_file
from transformers import T5EncoderModel, T5Tokenizer
from diffusers import AutoencoderKLLTXVideo, FlowMatchEulerDiscreteScheduler, LTXPipeline, LTXVideoTransformer3DModel
from diffusers import (
AutoencoderKLLTXVideo,
FlowMatchEulerDiscreteScheduler,
LTXConditionPipeline,
LTXLatentUpsamplePipeline,
LTXPipeline,
LTXVideoTransformer3DModel,
)
from diffusers.pipelines.ltx.modeling_latent_upsampler import LTXLatentUpsamplerModel
def remove_keys_(key: str, state_dict: Dict[str, Any]):
@@ -123,17 +131,10 @@ def update_state_dict_inplace(state_dict: Dict[str, Any], old_key: str, new_key:
state_dict[new_key] = state_dict.pop(old_key)
def convert_transformer(
ckpt_path: str,
dtype: torch.dtype,
version: str = "0.9.0",
):
def convert_transformer(ckpt_path: str, config, dtype: torch.dtype):
PREFIX_KEY = "model.diffusion_model."
original_state_dict = get_state_dict(load_file(ckpt_path))
config = {}
if version == "0.9.5":
config["_use_causal_rope_fix"] = True
with init_empty_weights():
transformer = LTXVideoTransformer3DModel(**config)
@@ -180,8 +181,59 @@ def convert_vae(ckpt_path: str, config, dtype: torch.dtype):
return vae
def convert_spatial_latent_upsampler(ckpt_path: str, config, dtype: torch.dtype):
original_state_dict = get_state_dict(load_file(ckpt_path))
with init_empty_weights():
latent_upsampler = LTXLatentUpsamplerModel(**config)
latent_upsampler.load_state_dict(original_state_dict, strict=True, assign=True)
latent_upsampler.to(dtype)
return latent_upsampler
def get_transformer_config(version: str) -> Dict[str, Any]:
if version == "0.9.7":
config = {
"in_channels": 128,
"out_channels": 128,
"patch_size": 1,
"patch_size_t": 1,
"num_attention_heads": 32,
"attention_head_dim": 128,
"cross_attention_dim": 4096,
"num_layers": 48,
"activation_fn": "gelu-approximate",
"qk_norm": "rms_norm_across_heads",
"norm_elementwise_affine": False,
"norm_eps": 1e-6,
"caption_channels": 4096,
"attention_bias": True,
"attention_out_bias": True,
}
else:
config = {
"in_channels": 128,
"out_channels": 128,
"patch_size": 1,
"patch_size_t": 1,
"num_attention_heads": 32,
"attention_head_dim": 64,
"cross_attention_dim": 2048,
"num_layers": 28,
"activation_fn": "gelu-approximate",
"qk_norm": "rms_norm_across_heads",
"norm_elementwise_affine": False,
"norm_eps": 1e-6,
"caption_channels": 4096,
"attention_bias": True,
"attention_out_bias": True,
}
return config
def get_vae_config(version: str) -> Dict[str, Any]:
if version == "0.9.0":
if version in ["0.9.0"]:
config = {
"in_channels": 3,
"out_channels": 3,
@@ -210,7 +262,7 @@ def get_vae_config(version: str) -> Dict[str, Any]:
"decoder_causal": False,
"timestep_conditioning": False,
}
elif version == "0.9.1":
elif version in ["0.9.1"]:
config = {
"in_channels": 3,
"out_channels": 3,
@@ -240,7 +292,39 @@ def get_vae_config(version: str) -> Dict[str, Any]:
"decoder_causal": False,
}
VAE_KEYS_RENAME_DICT.update(VAE_091_RENAME_DICT)
elif version == "0.9.5":
elif version in ["0.9.5"]:
config = {
"in_channels": 3,
"out_channels": 3,
"latent_channels": 128,
"block_out_channels": (128, 256, 512, 1024, 2048),
"down_block_types": (
"LTXVideo095DownBlock3D",
"LTXVideo095DownBlock3D",
"LTXVideo095DownBlock3D",
"LTXVideo095DownBlock3D",
),
"decoder_block_out_channels": (256, 512, 1024),
"layers_per_block": (4, 6, 6, 2, 2),
"decoder_layers_per_block": (5, 5, 5, 5),
"spatio_temporal_scaling": (True, True, True, True),
"decoder_spatio_temporal_scaling": (True, True, True),
"decoder_inject_noise": (False, False, False, False),
"downsample_type": ("spatial", "temporal", "spatiotemporal", "spatiotemporal"),
"upsample_residual": (True, True, True),
"upsample_factor": (2, 2, 2),
"timestep_conditioning": True,
"patch_size": 4,
"patch_size_t": 1,
"resnet_norm_eps": 1e-6,
"scaling_factor": 1.0,
"encoder_causal": True,
"decoder_causal": False,
"spatial_compression_ratio": 32,
"temporal_compression_ratio": 8,
}
VAE_KEYS_RENAME_DICT.update(VAE_095_RENAME_DICT)
elif version in ["0.9.7"]:
config = {
"in_channels": 3,
"out_channels": 3,
@@ -275,12 +359,33 @@ def get_vae_config(version: str) -> Dict[str, Any]:
return config
def get_spatial_latent_upsampler_config(version: str) -> Dict[str, Any]:
if version == "0.9.7":
config = {
"in_channels": 128,
"mid_channels": 512,
"num_blocks_per_stage": 4,
"dims": 3,
"spatial_upsample": True,
"temporal_upsample": False,
}
else:
raise ValueError(f"Unsupported version: {version}")
return config
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--transformer_ckpt_path", type=str, default=None, help="Path to original transformer checkpoint"
)
parser.add_argument("--vae_ckpt_path", type=str, default=None, help="Path to original vae checkpoint")
parser.add_argument(
"--spatial_latent_upsampler_path",
type=str,
default=None,
help="Path to original spatial latent upsampler checkpoint",
)
parser.add_argument(
"--text_encoder_cache_dir", type=str, default=None, help="Path to text encoder cache directory"
)
@@ -294,7 +399,11 @@ def get_args():
parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
parser.add_argument("--dtype", default="fp32", help="Torch dtype to save the model in.")
parser.add_argument(
"--version", type=str, default="0.9.0", choices=["0.9.0", "0.9.1", "0.9.5"], help="Version of the LTX model"
"--version",
type=str,
default="0.9.0",
choices=["0.9.0", "0.9.1", "0.9.5", "0.9.7"],
help="Version of the LTX model",
)
return parser.parse_args()
@@ -320,11 +429,9 @@ if __name__ == "__main__":
variant = VARIANT_MAPPING[args.dtype]
output_path = Path(args.output_path)
if args.save_pipeline:
assert args.transformer_ckpt_path is not None and args.vae_ckpt_path is not None
if args.transformer_ckpt_path is not None:
transformer: LTXVideoTransformer3DModel = convert_transformer(args.transformer_ckpt_path, dtype)
config = get_transformer_config(args.version)
transformer: LTXVideoTransformer3DModel = convert_transformer(args.transformer_ckpt_path, config, dtype)
if not args.save_pipeline:
transformer.save_pretrained(
output_path / "transformer", safe_serialization=True, max_shard_size="5GB", variant=variant
@@ -336,6 +443,16 @@ if __name__ == "__main__":
if not args.save_pipeline:
vae.save_pretrained(output_path / "vae", safe_serialization=True, max_shard_size="5GB", variant=variant)
if args.spatial_latent_upsampler_path is not None:
config = get_spatial_latent_upsampler_config(args.version)
latent_upsampler: LTXLatentUpsamplerModel = convert_spatial_latent_upsampler(
args.spatial_latent_upsampler_path, config, dtype
)
if not args.save_pipeline:
latent_upsampler.save_pretrained(
output_path / "latent_upsampler", safe_serialization=True, max_shard_size="5GB", variant=variant
)
if args.save_pipeline:
text_encoder_id = "google/t5-v1_1-xxl"
tokenizer = T5Tokenizer.from_pretrained(text_encoder_id, model_max_length=TOKENIZER_MAX_LENGTH)
@@ -348,7 +465,7 @@ if __name__ == "__main__":
for param in text_encoder.parameters():
param.data = param.data.contiguous()
if args.version == "0.9.5":
if args.version in ["0.9.5", "0.9.7"]:
scheduler = FlowMatchEulerDiscreteScheduler(use_dynamic_shifting=False)
else:
scheduler = FlowMatchEulerDiscreteScheduler(
@@ -360,12 +477,40 @@ if __name__ == "__main__":
shift_terminal=0.1,
)
pipe = LTXPipeline(
scheduler=scheduler,
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
transformer=transformer,
)
pipe.save_pretrained(args.output_path, safe_serialization=True, variant=variant, max_shard_size="5GB")
if args.version in ["0.9.0", "0.9.1", "0.9.5"]:
pipe = LTXPipeline(
scheduler=scheduler,
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
transformer=transformer,
)
pipe.save_pretrained(
output_path.as_posix(), safe_serialization=True, variant=variant, max_shard_size="5GB"
)
elif args.version in ["0.9.7"]:
pipe = LTXConditionPipeline(
scheduler=scheduler,
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
transformer=transformer,
)
pipe_upsample = LTXLatentUpsamplePipeline(
vae=vae,
latent_upsampler=latent_upsampler,
)
pipe.save_pretrained(
(output_path / "ltx_pipeline").as_posix(),
safe_serialization=True,
variant=variant,
max_shard_size="5GB",
)
pipe_upsample.save_pretrained(
(output_path / "ltx_upsample_pipeline").as_posix(),
safe_serialization=True,
variant=variant,
max_shard_size="5GB",
)
else:
raise ValueError(f"Unsupported version: {args.version}")

View File

@@ -420,6 +420,7 @@ else:
"LEditsPPPipelineStableDiffusionXL",
"LTXConditionPipeline",
"LTXImageToVideoPipeline",
"LTXLatentUpsamplePipeline",
"LTXPipeline",
"Lumina2Pipeline",
"Lumina2Text2ImgPipeline",
@@ -1003,6 +1004,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
LEditsPPPipelineStableDiffusionXL,
LTXConditionPipeline,
LTXImageToVideoPipeline,
LTXLatentUpsamplePipeline,
LTXPipeline,
Lumina2Pipeline,
Lumina2Text2ImgPipeline,

View File

@@ -268,7 +268,12 @@ else:
]
)
_import_structure["latte"] = ["LattePipeline"]
_import_structure["ltx"] = ["LTXPipeline", "LTXImageToVideoPipeline", "LTXConditionPipeline"]
_import_structure["ltx"] = [
"LTXPipeline",
"LTXImageToVideoPipeline",
"LTXConditionPipeline",
"LTXLatentUpsamplePipeline",
]
_import_structure["lumina"] = ["LuminaPipeline", "LuminaText2ImgPipeline"]
_import_structure["lumina2"] = ["Lumina2Pipeline", "Lumina2Text2ImgPipeline"]
_import_structure["marigold"].extend(
@@ -637,7 +642,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
LEditsPPPipelineStableDiffusion,
LEditsPPPipelineStableDiffusionXL,
)
from .ltx import LTXConditionPipeline, LTXImageToVideoPipeline, LTXPipeline
from .ltx import LTXConditionPipeline, LTXImageToVideoPipeline, LTXLatentUpsamplePipeline, LTXPipeline
from .lumina import LuminaPipeline, LuminaText2ImgPipeline
from .lumina2 import Lumina2Pipeline, Lumina2Text2ImgPipeline
from .marigold import (

View File

@@ -22,9 +22,11 @@ except OptionalDependencyNotAvailable:
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["modeling_latent_upsampler"] = ["LTXLatentUpsamplerModel"]
_import_structure["pipeline_ltx"] = ["LTXPipeline"]
_import_structure["pipeline_ltx_condition"] = ["LTXConditionPipeline"]
_import_structure["pipeline_ltx_image2video"] = ["LTXImageToVideoPipeline"]
_import_structure["pipeline_ltx_latent_upsample"] = ["LTXLatentUpsamplePipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
@@ -34,9 +36,11 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import *
else:
from .modeling_latent_upsampler import LTXLatentUpsamplerModel
from .pipeline_ltx import LTXPipeline
from .pipeline_ltx_condition import LTXConditionPipeline
from .pipeline_ltx_image2video import LTXImageToVideoPipeline
from .pipeline_ltx_latent_upsample import LTXLatentUpsamplePipeline
else:
import sys

View File

@@ -0,0 +1,188 @@
# Copyright 2025 Lightricks and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
import torch
from ...configuration_utils import ConfigMixin, register_to_config
from ...models.modeling_utils import ModelMixin
class ResBlock(torch.nn.Module):
def __init__(self, channels: int, mid_channels: Optional[int] = None, dims: int = 3):
super().__init__()
if mid_channels is None:
mid_channels = channels
Conv = torch.nn.Conv2d if dims == 2 else torch.nn.Conv3d
self.conv1 = Conv(channels, mid_channels, kernel_size=3, padding=1)
self.norm1 = torch.nn.GroupNorm(32, mid_channels)
self.conv2 = Conv(mid_channels, channels, kernel_size=3, padding=1)
self.norm2 = torch.nn.GroupNorm(32, channels)
self.activation = torch.nn.SiLU()
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
residual = hidden_states
hidden_states = self.conv1(hidden_states)
hidden_states = self.norm1(hidden_states)
hidden_states = self.activation(hidden_states)
hidden_states = self.conv2(hidden_states)
hidden_states = self.norm2(hidden_states)
hidden_states = self.activation(hidden_states + residual)
return hidden_states
class PixelShuffleND(torch.nn.Module):
def __init__(self, dims, upscale_factors=(2, 2, 2)):
super().__init__()
self.dims = dims
self.upscale_factors = upscale_factors
if dims not in [1, 2, 3]:
raise ValueError("dims must be 1, 2, or 3")
def forward(self, x):
if self.dims == 3:
# spatiotemporal: b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)
return (
x.unflatten(1, (-1, *self.upscale_factors[:3]))
.permute(0, 1, 5, 2, 6, 3, 7, 4)
.flatten(6, 7)
.flatten(4, 5)
.flatten(2, 3)
)
elif self.dims == 2:
# spatial: b (c p1 p2) h w -> b c (h p1) (w p2)
return (
x.unflatten(1, (-1, *self.upscale_factors[:2])).permute(0, 1, 4, 2, 5, 3).flatten(4, 5).flatten(2, 3)
)
elif self.dims == 1:
# temporal: b (c p1) f h w -> b c (f p1) h w
return x.unflatten(1, (-1, *self.upscale_factors[:1])).permute(0, 1, 3, 2, 4, 5).flatten(2, 3)
class LTXLatentUpsamplerModel(ModelMixin, ConfigMixin):
"""
Model to spatially upsample VAE latents.
Args:
in_channels (`int`, defaults to `128`):
Number of channels in the input latent
mid_channels (`int`, defaults to `512`):
Number of channels in the middle layers
num_blocks_per_stage (`int`, defaults to `4`):
Number of ResBlocks to use in each stage (pre/post upsampling)
dims (`int`, defaults to `3`):
Number of dimensions for convolutions (2 or 3)
spatial_upsample (`bool`, defaults to `True`):
Whether to spatially upsample the latent
temporal_upsample (`bool`, defaults to `False`):
Whether to temporally upsample the latent
"""
@register_to_config
def __init__(
self,
in_channels: int = 128,
mid_channels: int = 512,
num_blocks_per_stage: int = 4,
dims: int = 3,
spatial_upsample: bool = True,
temporal_upsample: bool = False,
):
super().__init__()
self.in_channels = in_channels
self.mid_channels = mid_channels
self.num_blocks_per_stage = num_blocks_per_stage
self.dims = dims
self.spatial_upsample = spatial_upsample
self.temporal_upsample = temporal_upsample
ConvNd = torch.nn.Conv2d if dims == 2 else torch.nn.Conv3d
self.initial_conv = ConvNd(in_channels, mid_channels, kernel_size=3, padding=1)
self.initial_norm = torch.nn.GroupNorm(32, mid_channels)
self.initial_activation = torch.nn.SiLU()
self.res_blocks = torch.nn.ModuleList([ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)])
if spatial_upsample and temporal_upsample:
self.upsampler = torch.nn.Sequential(
torch.nn.Conv3d(mid_channels, 8 * mid_channels, kernel_size=3, padding=1),
PixelShuffleND(3),
)
elif spatial_upsample:
self.upsampler = torch.nn.Sequential(
torch.nn.Conv2d(mid_channels, 4 * mid_channels, kernel_size=3, padding=1),
PixelShuffleND(2),
)
elif temporal_upsample:
self.upsampler = torch.nn.Sequential(
torch.nn.Conv3d(mid_channels, 2 * mid_channels, kernel_size=3, padding=1),
PixelShuffleND(1),
)
else:
raise ValueError("Either spatial_upsample or temporal_upsample must be True")
self.post_upsample_res_blocks = torch.nn.ModuleList(
[ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)]
)
self.final_conv = ConvNd(mid_channels, in_channels, kernel_size=3, padding=1)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
batch_size, num_channels, num_frames, height, width = hidden_states.shape
if self.dims == 2:
hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1)
hidden_states = self.initial_conv(hidden_states)
hidden_states = self.initial_norm(hidden_states)
hidden_states = self.initial_activation(hidden_states)
for block in self.res_blocks:
hidden_states = block(hidden_states)
hidden_states = self.upsampler(hidden_states)
for block in self.post_upsample_res_blocks:
hidden_states = block(hidden_states)
hidden_states = self.final_conv(hidden_states)
hidden_states = hidden_states.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4)
else:
hidden_states = self.initial_conv(hidden_states)
hidden_states = self.initial_norm(hidden_states)
hidden_states = self.initial_activation(hidden_states)
for block in self.res_blocks:
hidden_states = block(hidden_states)
if self.temporal_upsample:
hidden_states = self.upsampler(hidden_states)
hidden_states = hidden_states[:, :, 1:, :, :]
else:
hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1)
hidden_states = self.upsampler(hidden_states)
hidden_states = hidden_states.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4)
for block in self.post_upsample_res_blocks:
hidden_states = block(hidden_states)
hidden_states = self.final_conv(hidden_states)
return hidden_states

View File

@@ -430,6 +430,7 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
video,
frame_index,
strength,
denoise_strength,
height,
width,
callback_on_step_end_tensor_inputs=None,
@@ -497,6 +498,9 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
elif isinstance(video, list) and isinstance(strength, list) and len(video) != len(strength):
raise ValueError("If `conditions` is not provided, `video` and `strength` must be of the same length.")
if denoise_strength < 0 or denoise_strength > 1:
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {denoise_strength}")
@staticmethod
def _prepare_video_ids(
batch_size: int,
@@ -649,6 +653,8 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
width: int = 704,
num_frames: int = 161,
num_prefix_latent_frames: int = 2,
sigma: Optional[torch.Tensor] = None,
latents: Optional[torch.Tensor] = None,
generator: Optional[torch.Generator] = None,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
@@ -658,7 +664,18 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
latent_width = width // self.vae_spatial_compression_ratio
shape = (batch_size, num_channels_latents, num_latent_frames, latent_height, latent_width)
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
if latents is not None and sigma is not None:
if latents.shape != shape:
raise ValueError(
f"Latents shape {latents.shape} does not match expected shape {shape}. Please check the input."
)
latents = latents.to(device=device, dtype=dtype)
sigma = sigma.to(device=device, dtype=dtype)
latents = sigma * noise + (1 - sigma) * latents
else:
latents = noise
if len(conditions) > 0:
condition_latent_frames_mask = torch.zeros(
@@ -766,6 +783,13 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
return latents, conditioning_mask, video_ids, extra_conditioning_num_latents
def get_timesteps(self, sigmas, timesteps, num_inference_steps, strength):
num_steps = min(int(num_inference_steps * strength), num_inference_steps)
start_index = max(num_inference_steps - num_steps, 0)
sigmas = sigmas[start_index:]
timesteps = timesteps[start_index:]
return sigmas, timesteps, num_inference_steps - start_index
@property
def guidance_scale(self):
return self._guidance_scale
@@ -799,6 +823,7 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
video: List[PipelineImageInput] = None,
frame_index: Union[int, List[int]] = 0,
strength: Union[float, List[float]] = 1.0,
denoise_strength: float = 1.0,
prompt: Union[str, List[str]] = None,
negative_prompt: Optional[Union[str, List[str]]] = None,
height: int = 512,
@@ -842,6 +867,10 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
generation. If not provided, one has to pass `conditions`.
strength (`float` or `List[float]`, *optional*):
The strength or strengths of the conditioning effect. If not provided, one has to pass `conditions`.
denoise_strength (`float`, defaults to `1.0`):
The strength of the noise added to the latents for editing. Higher strength leads to more noise added
to the latents, therefore leading to more differences between original video and generated video. This
is useful for video-to-video editing.
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
@@ -918,8 +947,6 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
if latents is not None:
raise ValueError("Passing latents is not yet supported.")
# 1. Check inputs. Raise error if not correct
self.check_inputs(
@@ -929,6 +956,7 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
video=video,
frame_index=frame_index,
strength=strength,
denoise_strength=denoise_strength,
height=height,
width=width,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
@@ -977,8 +1005,9 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
strength = [strength] * num_conditions
device = self._execution_device
vae_dtype = self.vae.dtype
# 3. Prepare text embeddings
# 3. Prepare text embeddings & conditioning image/video
(
prompt_embeds,
prompt_attention_mask,
@@ -1000,8 +1029,6 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
vae_dtype = self.vae.dtype
conditioning_tensors = []
is_conditioning_image_or_video = image is not None or video is not None
if is_conditioning_image_or_video:
@@ -1032,7 +1059,25 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
)
conditioning_tensors.append(condition_tensor)
# 4. Prepare latent variables
# 4. Prepare timesteps
latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1
latent_height = height // self.vae_spatial_compression_ratio
latent_width = width // self.vae_spatial_compression_ratio
sigmas = linear_quadratic_schedule(num_inference_steps)
timesteps = sigmas * 1000
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
latent_sigma = None
if denoise_strength < 1:
sigmas, timesteps, num_inference_steps = self.get_timesteps(
sigmas, timesteps, num_inference_steps, denoise_strength
)
latent_sigma = sigmas[:1].repeat(batch_size * num_videos_per_prompt)
self._num_timesteps = len(timesteps)
# 5. Prepare latent variables
num_channels_latents = self.transformer.config.in_channels
latents, conditioning_mask, video_coords, extra_conditioning_num_latents = self.prepare_latents(
conditioning_tensors,
@@ -1043,6 +1088,8 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
height=height,
width=width,
num_frames=num_frames,
sigma=latent_sigma,
latents=latents,
generator=generator,
device=device,
dtype=torch.float32,
@@ -1056,21 +1103,6 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
if self.do_classifier_free_guidance:
video_coords = torch.cat([video_coords, video_coords], dim=0)
# 5. Prepare timesteps
latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1
latent_height = height // self.vae_spatial_compression_ratio
latent_width = width // self.vae_spatial_compression_ratio
sigmas = linear_quadratic_schedule(num_inference_steps)
timesteps = sigmas * 1000
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler,
num_inference_steps,
device,
timesteps=timesteps,
)
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
self._num_timesteps = len(timesteps)
# 6. Denoising loop
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
@@ -1168,7 +1200,7 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
if not self.vae.config.timestep_conditioning:
timestep = None
else:
noise = torch.randn(latents.shape, generator=generator, device=device, dtype=latents.dtype)
noise = randn_tensor(latents.shape, generator=generator, device=device, dtype=latents.dtype)
if not isinstance(decode_timestep, list):
decode_timestep = [decode_timestep] * batch_size
if decode_noise_scale is None:

View File

@@ -0,0 +1,241 @@
# Copyright 2025 Lightricks and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional, Union
import torch
from ...image_processor import PipelineImageInput
from ...models import AutoencoderKLLTXVideo
from ...utils import get_logger
from ...utils.torch_utils import randn_tensor
from ...video_processor import VideoProcessor
from ..pipeline_utils import DiffusionPipeline
from .modeling_latent_upsampler import LTXLatentUpsamplerModel
from .pipeline_output import LTXPipelineOutput
logger = get_logger(__name__) # pylint: disable=invalid-name
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents(
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
):
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
return encoder_output.latent_dist.mode()
elif hasattr(encoder_output, "latents"):
return encoder_output.latents
else:
raise AttributeError("Could not access latents of provided encoder_output")
class LTXLatentUpsamplePipeline(DiffusionPipeline):
def __init__(
self,
vae: AutoencoderKLLTXVideo,
latent_upsampler: LTXLatentUpsamplerModel,
) -> None:
super().__init__()
self.register_modules(vae=vae, latent_upsampler=latent_upsampler)
self.vae_spatial_compression_ratio = (
self.vae.spatial_compression_ratio if getattr(self, "vae", None) is not None else 32
)
self.vae_temporal_compression_ratio = (
self.vae.temporal_compression_ratio if getattr(self, "vae", None) is not None else 8
)
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_spatial_compression_ratio)
def prepare_latents(
self,
video: Optional[torch.Tensor] = None,
batch_size: int = 1,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
generator: Optional[torch.Generator] = None,
latents: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if latents is not None:
return latents.to(device=device, dtype=dtype)
video = video.to(device=device, dtype=self.vae.dtype)
if isinstance(generator, list):
if len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
init_latents = [
retrieve_latents(self.vae.encode(video[i].unsqueeze(0)), generator[i]) for i in range(batch_size)
]
else:
init_latents = [retrieve_latents(self.vae.encode(vid.unsqueeze(0)), generator) for vid in video]
init_latents = torch.cat(init_latents, dim=0).to(dtype)
init_latents = self._normalize_latents(init_latents, self.vae.latents_mean, self.vae.latents_std)
return init_latents
@staticmethod
# Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._normalize_latents
def _normalize_latents(
latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0
) -> torch.Tensor:
# Normalize latents across the channel dimension [B, C, F, H, W]
latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
latents = (latents - latents_mean) * scaling_factor / latents_std
return latents
@staticmethod
# Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._denormalize_latents
def _denormalize_latents(
latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0
) -> torch.Tensor:
# Denormalize latents across the channel dimension [B, C, F, H, W]
latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
latents = latents * latents_std / scaling_factor + latents_mean
return latents
def enable_vae_slicing(self):
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.vae.enable_slicing()
def disable_vae_slicing(self):
r"""
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
computing decoding in one step.
"""
self.vae.disable_slicing()
def enable_vae_tiling(self):
r"""
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
"""
self.vae.enable_tiling()
def disable_vae_tiling(self):
r"""
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
computing decoding in one step.
"""
self.vae.disable_tiling()
def check_inputs(self, video, height, width, latents):
if height % self.vae_spatial_compression_ratio != 0 or width % self.vae_spatial_compression_ratio != 0:
raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.")
if video is not None and latents is not None:
raise ValueError("Only one of `video` or `latents` can be provided.")
if video is None and latents is None:
raise ValueError("One of `video` or `latents` has to be provided.")
@torch.no_grad()
def __call__(
self,
video: Optional[List[PipelineImageInput]] = None,
height: int = 512,
width: int = 704,
latents: Optional[torch.Tensor] = None,
decode_timestep: Union[float, List[float]] = 0.0,
decode_noise_scale: Optional[Union[float, List[float]]] = None,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
):
self.check_inputs(
video=video,
height=height,
width=width,
latents=latents,
)
if video is not None:
# Batched video input is not yet tested/supported. TODO: take a look later
batch_size = 1
else:
batch_size = latents.shape[0]
device = self._execution_device
if video is not None:
num_frames = len(video)
if num_frames % self.vae_temporal_compression_ratio != 1:
num_frames = (
num_frames // self.vae_temporal_compression_ratio * self.vae_temporal_compression_ratio + 1
)
video = video[:num_frames]
logger.warning(
f"Video length expected to be of the form `k * {self.vae_temporal_compression_ratio} + 1` but is {len(video)}. Truncating to {num_frames} frames."
)
video = self.video_processor.preprocess_video(video, height=height, width=width)
video = video.to(device=device, dtype=torch.float32)
latents = self.prepare_latents(
video=video,
batch_size=batch_size,
dtype=torch.float32,
device=device,
generator=generator,
latents=latents,
)
latents = self._denormalize_latents(
latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor
)
latents = latents.to(self.latent_upsampler.dtype)
latents = self.latent_upsampler(latents)
if output_type == "latent":
latents = self._normalize_latents(
latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor
)
video = latents
else:
if not self.vae.config.timestep_conditioning:
timestep = None
else:
noise = randn_tensor(latents.shape, generator=generator, device=device, dtype=latents.dtype)
if not isinstance(decode_timestep, list):
decode_timestep = [decode_timestep] * batch_size
if decode_noise_scale is None:
decode_noise_scale = decode_timestep
elif not isinstance(decode_noise_scale, list):
decode_noise_scale = [decode_noise_scale] * batch_size
timestep = torch.tensor(decode_timestep, device=device, dtype=latents.dtype)
decode_noise_scale = torch.tensor(decode_noise_scale, device=device, dtype=latents.dtype)[
:, None, None, None, None
]
latents = (1 - decode_noise_scale) * latents + decode_noise_scale * noise
video = self.vae.decode(latents, timestep, return_dict=False)[0]
video = self.video_processor.postprocess_video(video, output_type=output_type)
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (video,)
return LTXPipelineOutput(frames=video)

View File

@@ -1307,6 +1307,21 @@ class LTXImageToVideoPipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
class LTXLatentUpsamplePipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class LTXPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]

View File

@@ -0,0 +1,159 @@
# Copyright 2025 The HuggingFace Team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import torch
from diffusers import AutoencoderKLLTXVideo, LTXLatentUpsamplePipeline
from diffusers.pipelines.ltx.modeling_latent_upsampler import LTXLatentUpsamplerModel
from diffusers.utils.testing_utils import enable_full_determinism
from ..test_pipelines_common import PipelineTesterMixin, to_np
enable_full_determinism()
class LTXLatentUpsamplePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = LTXLatentUpsamplePipeline
params = {"video", "generator"}
batch_params = {"video", "generator"}
required_optional_params = frozenset(["generator", "latents", "return_dict"])
test_xformers_attention = False
supports_dduf = False
def get_dummy_components(self):
torch.manual_seed(0)
vae = AutoencoderKLLTXVideo(
in_channels=3,
out_channels=3,
latent_channels=8,
block_out_channels=(8, 8, 8, 8),
decoder_block_out_channels=(8, 8, 8, 8),
layers_per_block=(1, 1, 1, 1, 1),
decoder_layers_per_block=(1, 1, 1, 1, 1),
spatio_temporal_scaling=(True, True, False, False),
decoder_spatio_temporal_scaling=(True, True, False, False),
decoder_inject_noise=(False, False, False, False, False),
upsample_residual=(False, False, False, False),
upsample_factor=(1, 1, 1, 1),
timestep_conditioning=False,
patch_size=1,
patch_size_t=1,
encoder_causal=True,
decoder_causal=False,
)
vae.use_framewise_encoding = False
vae.use_framewise_decoding = False
torch.manual_seed(0)
latent_upsampler = LTXLatentUpsamplerModel(
in_channels=8,
mid_channels=32,
num_blocks_per_stage=1,
dims=3,
spatial_upsample=True,
temporal_upsample=False,
)
components = {
"vae": vae,
"latent_upsampler": latent_upsampler,
}
return components
def get_dummy_inputs(self, device, seed=0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device=device).manual_seed(seed)
video = torch.randn((5, 3, 32, 32), generator=generator, device=device)
inputs = {
"video": video,
"generator": generator,
"height": 16,
"width": 16,
"output_type": "pt",
}
return inputs
def test_inference(self):
device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
video = pipe(**inputs).frames
generated_video = video[0]
self.assertEqual(generated_video.shape, (5, 3, 32, 32))
expected_video = torch.randn(5, 3, 32, 32)
max_diff = np.abs(generated_video - expected_video).max()
self.assertLessEqual(max_diff, 1e10)
def test_vae_tiling(self, expected_diff_max: float = 0.25):
generator_device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to("cpu")
pipe.set_progress_bar_config(disable=None)
# Without tiling
inputs = self.get_dummy_inputs(generator_device)
inputs["height"] = inputs["width"] = 128
output_without_tiling = pipe(**inputs)[0]
# With tiling
pipe.vae.enable_tiling(
tile_sample_min_height=96,
tile_sample_min_width=96,
tile_sample_stride_height=64,
tile_sample_stride_width=64,
)
inputs = self.get_dummy_inputs(generator_device)
inputs["height"] = inputs["width"] = 128
output_with_tiling = pipe(**inputs)[0]
self.assertLess(
(to_np(output_without_tiling) - to_np(output_with_tiling)).max(),
expected_diff_max,
"VAE tiling should not affect the inference results",
)
@unittest.skip("Test is not applicable.")
def test_callback_inputs(self):
pass
@unittest.skip("Test is not applicable.")
def test_attention_slicing_forward_pass(
self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3
):
pass
@unittest.skip("Test is not applicable.")
def test_inference_batch_consistent(self):
pass
@unittest.skip("Test is not applicable.")
def test_inference_batch_single_identical(self):
pass