mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
LTX Video 0.9.7 (#11516)
* add upsampling pipeline * ltx upsample pipeline conversion; pipeline fixes * make fix-copies * remove print * add vae convenience methods * update * add tests * support denoising strength for upscaling & video-to-video * update docs * update doc checkpoints * update docs * fix --------- Co-authored-by: Linoy Tsaban <57615435+linoytsaban@users.noreply.github.com>
This commit is contained in:
@@ -31,12 +31,103 @@ Available models:
|
||||
|
||||
| Model name | Recommended dtype |
|
||||
|:-------------:|:-----------------:|
|
||||
| [`LTX Video 0.9.0`](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltx-video-2b-v0.9.safetensors) | `torch.bfloat16` |
|
||||
| [`LTX Video 0.9.1`](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltx-video-2b-v0.9.1.safetensors) | `torch.bfloat16` |
|
||||
| [`LTX Video 0.9.5`](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltx-video-2b-v0.9.5.safetensors) | `torch.bfloat16` |
|
||||
| [`LTX Video 2B 0.9.0`](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltx-video-2b-v0.9.safetensors) | `torch.bfloat16` |
|
||||
| [`LTX Video 2B 0.9.1`](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltx-video-2b-v0.9.1.safetensors) | `torch.bfloat16` |
|
||||
| [`LTX Video 2B 0.9.5`](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltx-video-2b-v0.9.5.safetensors) | `torch.bfloat16` |
|
||||
| [`LTX Video 13B 0.9.7`](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltxv-13b-0.9.7-dev.safetensors) | `torch.bfloat16` |
|
||||
| [`LTX Video Spatial Upscaler 0.9.7`](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltxv-spatial-upscaler-0.9.7.safetensors) | `torch.bfloat16` |
|
||||
|
||||
Note: The recommended dtype is for the transformer component. The VAE and text encoders can be either `torch.float32`, `torch.bfloat16` or `torch.float16` but the recommended dtype is `torch.bfloat16` as used in the original repository.
|
||||
|
||||
## Recommended settings for generation
|
||||
|
||||
For the best results, it is recommended to follow the guidelines mentioned in the official LTX Video [repository](https://github.com/Lightricks/LTX-Video).
|
||||
|
||||
- Some variants of LTX Video are guidance-distilled. For guidance-distilled models, `guidance_scale` must be set to `1.0`. For any other models, `guidance_scale` should be set higher (e.g., `5.0`) for good generation quality.
|
||||
- For variants with a timestep-aware VAE (LTXV 0.9.1 and above), it is recommended to set `decode_timestep` to `0.05` and `image_cond_noise_scale` to `0.025`.
|
||||
- For variants that support interpolation between multiple conditioning images and videos (LTXV 0.9.5 and above), it is recommended to use similar looking images/videos for the best results. High divergence between the conditionings may lead to abrupt transitions in the generated video.
|
||||
|
||||
## Using LTX Video 13B 0.9.7
|
||||
|
||||
LTX Video 0.9.7 comes with a spatial latent upscaler and a 13B parameter transformer. The inference involves generating a low resolution video first, which is very fast, followed by upscaling and refining the generated video.
|
||||
|
||||
<!-- TODO(aryan): modify when official checkpoints are available -->
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import LTXConditionPipeline, LTXLatentUpsamplePipeline
|
||||
from diffusers.pipelines.ltx.pipeline_ltx_condition import LTXVideoCondition
|
||||
from diffusers.utils import export_to_video, load_video
|
||||
|
||||
pipe = LTXConditionPipeline.from_pretrained("a-r-r-o-w/LTX-Video-0.9.7-diffusers", torch_dtype=torch.bfloat16)
|
||||
pipe_upsample = LTXLatentUpsamplePipeline.from_pretrained("a-r-r-o-w/LTX-Video-0.9.7-Latent-Spatial-Upsampler-diffusers", vae=pipe.vae, torch_dtype=torch.bfloat16)
|
||||
pipe.to("cuda")
|
||||
pipe_upsample.to("cuda")
|
||||
pipe.vae.enable_tiling()
|
||||
|
||||
def round_to_nearest_resolution_acceptable_by_vae(height, width):
|
||||
height = height - (height % pipe.vae_temporal_compression_ratio)
|
||||
width = width - (width % pipe.vae_temporal_compression_ratio)
|
||||
return height, width
|
||||
|
||||
video = load_video(
|
||||
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cosmos/cosmos-video2world-input-vid.mp4"
|
||||
)[:21] # Use only the first 21 frames as conditioning
|
||||
condition1 = LTXVideoCondition(video=video, frame_index=0)
|
||||
|
||||
prompt = "The video depicts a winding mountain road covered in snow, with a single vehicle traveling along it. The road is flanked by steep, rocky cliffs and sparse vegetation. The landscape is characterized by rugged terrain and a river visible in the distance. The scene captures the solitude and beauty of a winter drive through a mountainous region."
|
||||
negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted"
|
||||
expected_height, expected_width = 768, 1152
|
||||
downscale_factor = 2 / 3
|
||||
num_frames = 161
|
||||
|
||||
# Part 1. Generate video at smaller resolution
|
||||
# Text-only conditioning is also supported without the need to pass `conditions`
|
||||
downscaled_height, downscaled_width = int(expected_height * downscale_factor), int(expected_width * downscale_factor)
|
||||
downscaled_height, downscaled_width = round_to_nearest_resolution_acceptable_by_vae(downscaled_height, downscaled_width)
|
||||
latents = pipe(
|
||||
conditions=[condition1],
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
width=downscaled_width,
|
||||
height=downscaled_height,
|
||||
num_frames=num_frames,
|
||||
num_inference_steps=30,
|
||||
generator=torch.Generator().manual_seed(0),
|
||||
output_type="latent",
|
||||
).frames
|
||||
|
||||
# Part 2. Upscale generated video using latent upsampler with fewer inference steps
|
||||
# The available latent upsampler upscales the height/width by 2x
|
||||
upscaled_height, upscaled_width = downscaled_height * 2, downscaled_width * 2
|
||||
upscaled_latents = pipe_upsample(
|
||||
latents=latents,
|
||||
output_type="latent"
|
||||
).frames
|
||||
|
||||
# Part 3. Denoise the upscaled video with few steps to improve texture (optional, but recommended)
|
||||
video = pipe(
|
||||
conditions=[condition1],
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
width=upscaled_width,
|
||||
height=upscaled_height,
|
||||
num_frames=num_frames,
|
||||
denoise_strength=0.4, # Effectively, 4 inference steps out of 10
|
||||
num_inference_steps=10,
|
||||
latents=upscaled_latents,
|
||||
decode_timestep=0.05,
|
||||
image_cond_noise_scale=0.025,
|
||||
generator=torch.Generator().manual_seed(0),
|
||||
output_type="pil",
|
||||
).frames[0]
|
||||
|
||||
# Part 4. Downscale the video to the expected resolution
|
||||
video = [frame.resize((expected_width, expected_height)) for frame in video]
|
||||
|
||||
export_to_video(video, "output.mp4", fps=24)
|
||||
```
|
||||
|
||||
## Loading Single Files
|
||||
|
||||
Loading the original LTX Video checkpoints is also possible with [`~ModelMixin.from_single_file`]. We recommend using `from_single_file` for the Lightricks series of models, as they plan to release multiple models in the future in the single file format.
|
||||
@@ -204,6 +295,12 @@ export_to_video(video, "ship.mp4", fps=24)
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## LTXLatentUpsamplePipeline
|
||||
|
||||
[[autodoc]] LTXLatentUpsamplePipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## LTXPipelineOutput
|
||||
|
||||
[[autodoc]] pipelines.ltx.pipeline_output.LTXPipelineOutput
|
||||
|
||||
@@ -7,7 +7,15 @@ from accelerate import init_empty_weights
|
||||
from safetensors.torch import load_file
|
||||
from transformers import T5EncoderModel, T5Tokenizer
|
||||
|
||||
from diffusers import AutoencoderKLLTXVideo, FlowMatchEulerDiscreteScheduler, LTXPipeline, LTXVideoTransformer3DModel
|
||||
from diffusers import (
|
||||
AutoencoderKLLTXVideo,
|
||||
FlowMatchEulerDiscreteScheduler,
|
||||
LTXConditionPipeline,
|
||||
LTXLatentUpsamplePipeline,
|
||||
LTXPipeline,
|
||||
LTXVideoTransformer3DModel,
|
||||
)
|
||||
from diffusers.pipelines.ltx.modeling_latent_upsampler import LTXLatentUpsamplerModel
|
||||
|
||||
|
||||
def remove_keys_(key: str, state_dict: Dict[str, Any]):
|
||||
@@ -123,17 +131,10 @@ def update_state_dict_inplace(state_dict: Dict[str, Any], old_key: str, new_key:
|
||||
state_dict[new_key] = state_dict.pop(old_key)
|
||||
|
||||
|
||||
def convert_transformer(
|
||||
ckpt_path: str,
|
||||
dtype: torch.dtype,
|
||||
version: str = "0.9.0",
|
||||
):
|
||||
def convert_transformer(ckpt_path: str, config, dtype: torch.dtype):
|
||||
PREFIX_KEY = "model.diffusion_model."
|
||||
|
||||
original_state_dict = get_state_dict(load_file(ckpt_path))
|
||||
config = {}
|
||||
if version == "0.9.5":
|
||||
config["_use_causal_rope_fix"] = True
|
||||
with init_empty_weights():
|
||||
transformer = LTXVideoTransformer3DModel(**config)
|
||||
|
||||
@@ -180,8 +181,59 @@ def convert_vae(ckpt_path: str, config, dtype: torch.dtype):
|
||||
return vae
|
||||
|
||||
|
||||
def convert_spatial_latent_upsampler(ckpt_path: str, config, dtype: torch.dtype):
|
||||
original_state_dict = get_state_dict(load_file(ckpt_path))
|
||||
|
||||
with init_empty_weights():
|
||||
latent_upsampler = LTXLatentUpsamplerModel(**config)
|
||||
|
||||
latent_upsampler.load_state_dict(original_state_dict, strict=True, assign=True)
|
||||
latent_upsampler.to(dtype)
|
||||
return latent_upsampler
|
||||
|
||||
|
||||
def get_transformer_config(version: str) -> Dict[str, Any]:
|
||||
if version == "0.9.7":
|
||||
config = {
|
||||
"in_channels": 128,
|
||||
"out_channels": 128,
|
||||
"patch_size": 1,
|
||||
"patch_size_t": 1,
|
||||
"num_attention_heads": 32,
|
||||
"attention_head_dim": 128,
|
||||
"cross_attention_dim": 4096,
|
||||
"num_layers": 48,
|
||||
"activation_fn": "gelu-approximate",
|
||||
"qk_norm": "rms_norm_across_heads",
|
||||
"norm_elementwise_affine": False,
|
||||
"norm_eps": 1e-6,
|
||||
"caption_channels": 4096,
|
||||
"attention_bias": True,
|
||||
"attention_out_bias": True,
|
||||
}
|
||||
else:
|
||||
config = {
|
||||
"in_channels": 128,
|
||||
"out_channels": 128,
|
||||
"patch_size": 1,
|
||||
"patch_size_t": 1,
|
||||
"num_attention_heads": 32,
|
||||
"attention_head_dim": 64,
|
||||
"cross_attention_dim": 2048,
|
||||
"num_layers": 28,
|
||||
"activation_fn": "gelu-approximate",
|
||||
"qk_norm": "rms_norm_across_heads",
|
||||
"norm_elementwise_affine": False,
|
||||
"norm_eps": 1e-6,
|
||||
"caption_channels": 4096,
|
||||
"attention_bias": True,
|
||||
"attention_out_bias": True,
|
||||
}
|
||||
return config
|
||||
|
||||
|
||||
def get_vae_config(version: str) -> Dict[str, Any]:
|
||||
if version == "0.9.0":
|
||||
if version in ["0.9.0"]:
|
||||
config = {
|
||||
"in_channels": 3,
|
||||
"out_channels": 3,
|
||||
@@ -210,7 +262,7 @@ def get_vae_config(version: str) -> Dict[str, Any]:
|
||||
"decoder_causal": False,
|
||||
"timestep_conditioning": False,
|
||||
}
|
||||
elif version == "0.9.1":
|
||||
elif version in ["0.9.1"]:
|
||||
config = {
|
||||
"in_channels": 3,
|
||||
"out_channels": 3,
|
||||
@@ -240,7 +292,39 @@ def get_vae_config(version: str) -> Dict[str, Any]:
|
||||
"decoder_causal": False,
|
||||
}
|
||||
VAE_KEYS_RENAME_DICT.update(VAE_091_RENAME_DICT)
|
||||
elif version == "0.9.5":
|
||||
elif version in ["0.9.5"]:
|
||||
config = {
|
||||
"in_channels": 3,
|
||||
"out_channels": 3,
|
||||
"latent_channels": 128,
|
||||
"block_out_channels": (128, 256, 512, 1024, 2048),
|
||||
"down_block_types": (
|
||||
"LTXVideo095DownBlock3D",
|
||||
"LTXVideo095DownBlock3D",
|
||||
"LTXVideo095DownBlock3D",
|
||||
"LTXVideo095DownBlock3D",
|
||||
),
|
||||
"decoder_block_out_channels": (256, 512, 1024),
|
||||
"layers_per_block": (4, 6, 6, 2, 2),
|
||||
"decoder_layers_per_block": (5, 5, 5, 5),
|
||||
"spatio_temporal_scaling": (True, True, True, True),
|
||||
"decoder_spatio_temporal_scaling": (True, True, True),
|
||||
"decoder_inject_noise": (False, False, False, False),
|
||||
"downsample_type": ("spatial", "temporal", "spatiotemporal", "spatiotemporal"),
|
||||
"upsample_residual": (True, True, True),
|
||||
"upsample_factor": (2, 2, 2),
|
||||
"timestep_conditioning": True,
|
||||
"patch_size": 4,
|
||||
"patch_size_t": 1,
|
||||
"resnet_norm_eps": 1e-6,
|
||||
"scaling_factor": 1.0,
|
||||
"encoder_causal": True,
|
||||
"decoder_causal": False,
|
||||
"spatial_compression_ratio": 32,
|
||||
"temporal_compression_ratio": 8,
|
||||
}
|
||||
VAE_KEYS_RENAME_DICT.update(VAE_095_RENAME_DICT)
|
||||
elif version in ["0.9.7"]:
|
||||
config = {
|
||||
"in_channels": 3,
|
||||
"out_channels": 3,
|
||||
@@ -275,12 +359,33 @@ def get_vae_config(version: str) -> Dict[str, Any]:
|
||||
return config
|
||||
|
||||
|
||||
def get_spatial_latent_upsampler_config(version: str) -> Dict[str, Any]:
|
||||
if version == "0.9.7":
|
||||
config = {
|
||||
"in_channels": 128,
|
||||
"mid_channels": 512,
|
||||
"num_blocks_per_stage": 4,
|
||||
"dims": 3,
|
||||
"spatial_upsample": True,
|
||||
"temporal_upsample": False,
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"Unsupported version: {version}")
|
||||
return config
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--transformer_ckpt_path", type=str, default=None, help="Path to original transformer checkpoint"
|
||||
)
|
||||
parser.add_argument("--vae_ckpt_path", type=str, default=None, help="Path to original vae checkpoint")
|
||||
parser.add_argument(
|
||||
"--spatial_latent_upsampler_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to original spatial latent upsampler checkpoint",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--text_encoder_cache_dir", type=str, default=None, help="Path to text encoder cache directory"
|
||||
)
|
||||
@@ -294,7 +399,11 @@ def get_args():
|
||||
parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
|
||||
parser.add_argument("--dtype", default="fp32", help="Torch dtype to save the model in.")
|
||||
parser.add_argument(
|
||||
"--version", type=str, default="0.9.0", choices=["0.9.0", "0.9.1", "0.9.5"], help="Version of the LTX model"
|
||||
"--version",
|
||||
type=str,
|
||||
default="0.9.0",
|
||||
choices=["0.9.0", "0.9.1", "0.9.5", "0.9.7"],
|
||||
help="Version of the LTX model",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
@@ -320,11 +429,9 @@ if __name__ == "__main__":
|
||||
variant = VARIANT_MAPPING[args.dtype]
|
||||
output_path = Path(args.output_path)
|
||||
|
||||
if args.save_pipeline:
|
||||
assert args.transformer_ckpt_path is not None and args.vae_ckpt_path is not None
|
||||
|
||||
if args.transformer_ckpt_path is not None:
|
||||
transformer: LTXVideoTransformer3DModel = convert_transformer(args.transformer_ckpt_path, dtype)
|
||||
config = get_transformer_config(args.version)
|
||||
transformer: LTXVideoTransformer3DModel = convert_transformer(args.transformer_ckpt_path, config, dtype)
|
||||
if not args.save_pipeline:
|
||||
transformer.save_pretrained(
|
||||
output_path / "transformer", safe_serialization=True, max_shard_size="5GB", variant=variant
|
||||
@@ -336,6 +443,16 @@ if __name__ == "__main__":
|
||||
if not args.save_pipeline:
|
||||
vae.save_pretrained(output_path / "vae", safe_serialization=True, max_shard_size="5GB", variant=variant)
|
||||
|
||||
if args.spatial_latent_upsampler_path is not None:
|
||||
config = get_spatial_latent_upsampler_config(args.version)
|
||||
latent_upsampler: LTXLatentUpsamplerModel = convert_spatial_latent_upsampler(
|
||||
args.spatial_latent_upsampler_path, config, dtype
|
||||
)
|
||||
if not args.save_pipeline:
|
||||
latent_upsampler.save_pretrained(
|
||||
output_path / "latent_upsampler", safe_serialization=True, max_shard_size="5GB", variant=variant
|
||||
)
|
||||
|
||||
if args.save_pipeline:
|
||||
text_encoder_id = "google/t5-v1_1-xxl"
|
||||
tokenizer = T5Tokenizer.from_pretrained(text_encoder_id, model_max_length=TOKENIZER_MAX_LENGTH)
|
||||
@@ -348,7 +465,7 @@ if __name__ == "__main__":
|
||||
for param in text_encoder.parameters():
|
||||
param.data = param.data.contiguous()
|
||||
|
||||
if args.version == "0.9.5":
|
||||
if args.version in ["0.9.5", "0.9.7"]:
|
||||
scheduler = FlowMatchEulerDiscreteScheduler(use_dynamic_shifting=False)
|
||||
else:
|
||||
scheduler = FlowMatchEulerDiscreteScheduler(
|
||||
@@ -360,12 +477,40 @@ if __name__ == "__main__":
|
||||
shift_terminal=0.1,
|
||||
)
|
||||
|
||||
pipe = LTXPipeline(
|
||||
scheduler=scheduler,
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
transformer=transformer,
|
||||
)
|
||||
|
||||
pipe.save_pretrained(args.output_path, safe_serialization=True, variant=variant, max_shard_size="5GB")
|
||||
if args.version in ["0.9.0", "0.9.1", "0.9.5"]:
|
||||
pipe = LTXPipeline(
|
||||
scheduler=scheduler,
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
transformer=transformer,
|
||||
)
|
||||
pipe.save_pretrained(
|
||||
output_path.as_posix(), safe_serialization=True, variant=variant, max_shard_size="5GB"
|
||||
)
|
||||
elif args.version in ["0.9.7"]:
|
||||
pipe = LTXConditionPipeline(
|
||||
scheduler=scheduler,
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
transformer=transformer,
|
||||
)
|
||||
pipe_upsample = LTXLatentUpsamplePipeline(
|
||||
vae=vae,
|
||||
latent_upsampler=latent_upsampler,
|
||||
)
|
||||
pipe.save_pretrained(
|
||||
(output_path / "ltx_pipeline").as_posix(),
|
||||
safe_serialization=True,
|
||||
variant=variant,
|
||||
max_shard_size="5GB",
|
||||
)
|
||||
pipe_upsample.save_pretrained(
|
||||
(output_path / "ltx_upsample_pipeline").as_posix(),
|
||||
safe_serialization=True,
|
||||
variant=variant,
|
||||
max_shard_size="5GB",
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported version: {args.version}")
|
||||
|
||||
@@ -420,6 +420,7 @@ else:
|
||||
"LEditsPPPipelineStableDiffusionXL",
|
||||
"LTXConditionPipeline",
|
||||
"LTXImageToVideoPipeline",
|
||||
"LTXLatentUpsamplePipeline",
|
||||
"LTXPipeline",
|
||||
"Lumina2Pipeline",
|
||||
"Lumina2Text2ImgPipeline",
|
||||
@@ -1003,6 +1004,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
LEditsPPPipelineStableDiffusionXL,
|
||||
LTXConditionPipeline,
|
||||
LTXImageToVideoPipeline,
|
||||
LTXLatentUpsamplePipeline,
|
||||
LTXPipeline,
|
||||
Lumina2Pipeline,
|
||||
Lumina2Text2ImgPipeline,
|
||||
|
||||
@@ -268,7 +268,12 @@ else:
|
||||
]
|
||||
)
|
||||
_import_structure["latte"] = ["LattePipeline"]
|
||||
_import_structure["ltx"] = ["LTXPipeline", "LTXImageToVideoPipeline", "LTXConditionPipeline"]
|
||||
_import_structure["ltx"] = [
|
||||
"LTXPipeline",
|
||||
"LTXImageToVideoPipeline",
|
||||
"LTXConditionPipeline",
|
||||
"LTXLatentUpsamplePipeline",
|
||||
]
|
||||
_import_structure["lumina"] = ["LuminaPipeline", "LuminaText2ImgPipeline"]
|
||||
_import_structure["lumina2"] = ["Lumina2Pipeline", "Lumina2Text2ImgPipeline"]
|
||||
_import_structure["marigold"].extend(
|
||||
@@ -637,7 +642,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
LEditsPPPipelineStableDiffusion,
|
||||
LEditsPPPipelineStableDiffusionXL,
|
||||
)
|
||||
from .ltx import LTXConditionPipeline, LTXImageToVideoPipeline, LTXPipeline
|
||||
from .ltx import LTXConditionPipeline, LTXImageToVideoPipeline, LTXLatentUpsamplePipeline, LTXPipeline
|
||||
from .lumina import LuminaPipeline, LuminaText2ImgPipeline
|
||||
from .lumina2 import Lumina2Pipeline, Lumina2Text2ImgPipeline
|
||||
from .marigold import (
|
||||
|
||||
@@ -22,9 +22,11 @@ except OptionalDependencyNotAvailable:
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["modeling_latent_upsampler"] = ["LTXLatentUpsamplerModel"]
|
||||
_import_structure["pipeline_ltx"] = ["LTXPipeline"]
|
||||
_import_structure["pipeline_ltx_condition"] = ["LTXConditionPipeline"]
|
||||
_import_structure["pipeline_ltx_image2video"] = ["LTXImageToVideoPipeline"]
|
||||
_import_structure["pipeline_ltx_latent_upsample"] = ["LTXLatentUpsamplePipeline"]
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
try:
|
||||
@@ -34,9 +36,11 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import *
|
||||
else:
|
||||
from .modeling_latent_upsampler import LTXLatentUpsamplerModel
|
||||
from .pipeline_ltx import LTXPipeline
|
||||
from .pipeline_ltx_condition import LTXConditionPipeline
|
||||
from .pipeline_ltx_image2video import LTXImageToVideoPipeline
|
||||
from .pipeline_ltx_latent_upsample import LTXLatentUpsamplePipeline
|
||||
|
||||
else:
|
||||
import sys
|
||||
|
||||
188
src/diffusers/pipelines/ltx/modeling_latent_upsampler.py
Normal file
188
src/diffusers/pipelines/ltx/modeling_latent_upsampler.py
Normal file
@@ -0,0 +1,188 @@
|
||||
# Copyright 2025 Lightricks and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...models.modeling_utils import ModelMixin
|
||||
|
||||
|
||||
class ResBlock(torch.nn.Module):
|
||||
def __init__(self, channels: int, mid_channels: Optional[int] = None, dims: int = 3):
|
||||
super().__init__()
|
||||
if mid_channels is None:
|
||||
mid_channels = channels
|
||||
|
||||
Conv = torch.nn.Conv2d if dims == 2 else torch.nn.Conv3d
|
||||
|
||||
self.conv1 = Conv(channels, mid_channels, kernel_size=3, padding=1)
|
||||
self.norm1 = torch.nn.GroupNorm(32, mid_channels)
|
||||
self.conv2 = Conv(mid_channels, channels, kernel_size=3, padding=1)
|
||||
self.norm2 = torch.nn.GroupNorm(32, channels)
|
||||
self.activation = torch.nn.SiLU()
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
residual = hidden_states
|
||||
hidden_states = self.conv1(hidden_states)
|
||||
hidden_states = self.norm1(hidden_states)
|
||||
hidden_states = self.activation(hidden_states)
|
||||
hidden_states = self.conv2(hidden_states)
|
||||
hidden_states = self.norm2(hidden_states)
|
||||
hidden_states = self.activation(hidden_states + residual)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class PixelShuffleND(torch.nn.Module):
|
||||
def __init__(self, dims, upscale_factors=(2, 2, 2)):
|
||||
super().__init__()
|
||||
|
||||
self.dims = dims
|
||||
self.upscale_factors = upscale_factors
|
||||
|
||||
if dims not in [1, 2, 3]:
|
||||
raise ValueError("dims must be 1, 2, or 3")
|
||||
|
||||
def forward(self, x):
|
||||
if self.dims == 3:
|
||||
# spatiotemporal: b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)
|
||||
return (
|
||||
x.unflatten(1, (-1, *self.upscale_factors[:3]))
|
||||
.permute(0, 1, 5, 2, 6, 3, 7, 4)
|
||||
.flatten(6, 7)
|
||||
.flatten(4, 5)
|
||||
.flatten(2, 3)
|
||||
)
|
||||
elif self.dims == 2:
|
||||
# spatial: b (c p1 p2) h w -> b c (h p1) (w p2)
|
||||
return (
|
||||
x.unflatten(1, (-1, *self.upscale_factors[:2])).permute(0, 1, 4, 2, 5, 3).flatten(4, 5).flatten(2, 3)
|
||||
)
|
||||
elif self.dims == 1:
|
||||
# temporal: b (c p1) f h w -> b c (f p1) h w
|
||||
return x.unflatten(1, (-1, *self.upscale_factors[:1])).permute(0, 1, 3, 2, 4, 5).flatten(2, 3)
|
||||
|
||||
|
||||
class LTXLatentUpsamplerModel(ModelMixin, ConfigMixin):
|
||||
"""
|
||||
Model to spatially upsample VAE latents.
|
||||
|
||||
Args:
|
||||
in_channels (`int`, defaults to `128`):
|
||||
Number of channels in the input latent
|
||||
mid_channels (`int`, defaults to `512`):
|
||||
Number of channels in the middle layers
|
||||
num_blocks_per_stage (`int`, defaults to `4`):
|
||||
Number of ResBlocks to use in each stage (pre/post upsampling)
|
||||
dims (`int`, defaults to `3`):
|
||||
Number of dimensions for convolutions (2 or 3)
|
||||
spatial_upsample (`bool`, defaults to `True`):
|
||||
Whether to spatially upsample the latent
|
||||
temporal_upsample (`bool`, defaults to `False`):
|
||||
Whether to temporally upsample the latent
|
||||
"""
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 128,
|
||||
mid_channels: int = 512,
|
||||
num_blocks_per_stage: int = 4,
|
||||
dims: int = 3,
|
||||
spatial_upsample: bool = True,
|
||||
temporal_upsample: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.mid_channels = mid_channels
|
||||
self.num_blocks_per_stage = num_blocks_per_stage
|
||||
self.dims = dims
|
||||
self.spatial_upsample = spatial_upsample
|
||||
self.temporal_upsample = temporal_upsample
|
||||
|
||||
ConvNd = torch.nn.Conv2d if dims == 2 else torch.nn.Conv3d
|
||||
|
||||
self.initial_conv = ConvNd(in_channels, mid_channels, kernel_size=3, padding=1)
|
||||
self.initial_norm = torch.nn.GroupNorm(32, mid_channels)
|
||||
self.initial_activation = torch.nn.SiLU()
|
||||
|
||||
self.res_blocks = torch.nn.ModuleList([ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)])
|
||||
|
||||
if spatial_upsample and temporal_upsample:
|
||||
self.upsampler = torch.nn.Sequential(
|
||||
torch.nn.Conv3d(mid_channels, 8 * mid_channels, kernel_size=3, padding=1),
|
||||
PixelShuffleND(3),
|
||||
)
|
||||
elif spatial_upsample:
|
||||
self.upsampler = torch.nn.Sequential(
|
||||
torch.nn.Conv2d(mid_channels, 4 * mid_channels, kernel_size=3, padding=1),
|
||||
PixelShuffleND(2),
|
||||
)
|
||||
elif temporal_upsample:
|
||||
self.upsampler = torch.nn.Sequential(
|
||||
torch.nn.Conv3d(mid_channels, 2 * mid_channels, kernel_size=3, padding=1),
|
||||
PixelShuffleND(1),
|
||||
)
|
||||
else:
|
||||
raise ValueError("Either spatial_upsample or temporal_upsample must be True")
|
||||
|
||||
self.post_upsample_res_blocks = torch.nn.ModuleList(
|
||||
[ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)]
|
||||
)
|
||||
|
||||
self.final_conv = ConvNd(mid_channels, in_channels, kernel_size=3, padding=1)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
batch_size, num_channels, num_frames, height, width = hidden_states.shape
|
||||
|
||||
if self.dims == 2:
|
||||
hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1)
|
||||
hidden_states = self.initial_conv(hidden_states)
|
||||
hidden_states = self.initial_norm(hidden_states)
|
||||
hidden_states = self.initial_activation(hidden_states)
|
||||
|
||||
for block in self.res_blocks:
|
||||
hidden_states = block(hidden_states)
|
||||
|
||||
hidden_states = self.upsampler(hidden_states)
|
||||
|
||||
for block in self.post_upsample_res_blocks:
|
||||
hidden_states = block(hidden_states)
|
||||
|
||||
hidden_states = self.final_conv(hidden_states)
|
||||
hidden_states = hidden_states.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4)
|
||||
else:
|
||||
hidden_states = self.initial_conv(hidden_states)
|
||||
hidden_states = self.initial_norm(hidden_states)
|
||||
hidden_states = self.initial_activation(hidden_states)
|
||||
|
||||
for block in self.res_blocks:
|
||||
hidden_states = block(hidden_states)
|
||||
|
||||
if self.temporal_upsample:
|
||||
hidden_states = self.upsampler(hidden_states)
|
||||
hidden_states = hidden_states[:, :, 1:, :, :]
|
||||
else:
|
||||
hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1)
|
||||
hidden_states = self.upsampler(hidden_states)
|
||||
hidden_states = hidden_states.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4)
|
||||
|
||||
for block in self.post_upsample_res_blocks:
|
||||
hidden_states = block(hidden_states)
|
||||
|
||||
hidden_states = self.final_conv(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
@@ -430,6 +430,7 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
|
||||
video,
|
||||
frame_index,
|
||||
strength,
|
||||
denoise_strength,
|
||||
height,
|
||||
width,
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
@@ -497,6 +498,9 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
|
||||
elif isinstance(video, list) and isinstance(strength, list) and len(video) != len(strength):
|
||||
raise ValueError("If `conditions` is not provided, `video` and `strength` must be of the same length.")
|
||||
|
||||
if denoise_strength < 0 or denoise_strength > 1:
|
||||
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {denoise_strength}")
|
||||
|
||||
@staticmethod
|
||||
def _prepare_video_ids(
|
||||
batch_size: int,
|
||||
@@ -649,6 +653,8 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
|
||||
width: int = 704,
|
||||
num_frames: int = 161,
|
||||
num_prefix_latent_frames: int = 2,
|
||||
sigma: Optional[torch.Tensor] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
@@ -658,7 +664,18 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
|
||||
latent_width = width // self.vae_spatial_compression_ratio
|
||||
|
||||
shape = (batch_size, num_channels_latents, num_latent_frames, latent_height, latent_width)
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
|
||||
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
if latents is not None and sigma is not None:
|
||||
if latents.shape != shape:
|
||||
raise ValueError(
|
||||
f"Latents shape {latents.shape} does not match expected shape {shape}. Please check the input."
|
||||
)
|
||||
latents = latents.to(device=device, dtype=dtype)
|
||||
sigma = sigma.to(device=device, dtype=dtype)
|
||||
latents = sigma * noise + (1 - sigma) * latents
|
||||
else:
|
||||
latents = noise
|
||||
|
||||
if len(conditions) > 0:
|
||||
condition_latent_frames_mask = torch.zeros(
|
||||
@@ -766,6 +783,13 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
|
||||
|
||||
return latents, conditioning_mask, video_ids, extra_conditioning_num_latents
|
||||
|
||||
def get_timesteps(self, sigmas, timesteps, num_inference_steps, strength):
|
||||
num_steps = min(int(num_inference_steps * strength), num_inference_steps)
|
||||
start_index = max(num_inference_steps - num_steps, 0)
|
||||
sigmas = sigmas[start_index:]
|
||||
timesteps = timesteps[start_index:]
|
||||
return sigmas, timesteps, num_inference_steps - start_index
|
||||
|
||||
@property
|
||||
def guidance_scale(self):
|
||||
return self._guidance_scale
|
||||
@@ -799,6 +823,7 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
|
||||
video: List[PipelineImageInput] = None,
|
||||
frame_index: Union[int, List[int]] = 0,
|
||||
strength: Union[float, List[float]] = 1.0,
|
||||
denoise_strength: float = 1.0,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
height: int = 512,
|
||||
@@ -842,6 +867,10 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
|
||||
generation. If not provided, one has to pass `conditions`.
|
||||
strength (`float` or `List[float]`, *optional*):
|
||||
The strength or strengths of the conditioning effect. If not provided, one has to pass `conditions`.
|
||||
denoise_strength (`float`, defaults to `1.0`):
|
||||
The strength of the noise added to the latents for editing. Higher strength leads to more noise added
|
||||
to the latents, therefore leading to more differences between original video and generated video. This
|
||||
is useful for video-to-video editing.
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
||||
instead.
|
||||
@@ -918,8 +947,6 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
|
||||
|
||||
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
||||
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
||||
if latents is not None:
|
||||
raise ValueError("Passing latents is not yet supported.")
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(
|
||||
@@ -929,6 +956,7 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
|
||||
video=video,
|
||||
frame_index=frame_index,
|
||||
strength=strength,
|
||||
denoise_strength=denoise_strength,
|
||||
height=height,
|
||||
width=width,
|
||||
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
||||
@@ -977,8 +1005,9 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
|
||||
strength = [strength] * num_conditions
|
||||
|
||||
device = self._execution_device
|
||||
vae_dtype = self.vae.dtype
|
||||
|
||||
# 3. Prepare text embeddings
|
||||
# 3. Prepare text embeddings & conditioning image/video
|
||||
(
|
||||
prompt_embeds,
|
||||
prompt_attention_mask,
|
||||
@@ -1000,8 +1029,6 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
||||
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
|
||||
|
||||
vae_dtype = self.vae.dtype
|
||||
|
||||
conditioning_tensors = []
|
||||
is_conditioning_image_or_video = image is not None or video is not None
|
||||
if is_conditioning_image_or_video:
|
||||
@@ -1032,7 +1059,25 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
|
||||
)
|
||||
conditioning_tensors.append(condition_tensor)
|
||||
|
||||
# 4. Prepare latent variables
|
||||
# 4. Prepare timesteps
|
||||
latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1
|
||||
latent_height = height // self.vae_spatial_compression_ratio
|
||||
latent_width = width // self.vae_spatial_compression_ratio
|
||||
sigmas = linear_quadratic_schedule(num_inference_steps)
|
||||
timesteps = sigmas * 1000
|
||||
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||
|
||||
latent_sigma = None
|
||||
if denoise_strength < 1:
|
||||
sigmas, timesteps, num_inference_steps = self.get_timesteps(
|
||||
sigmas, timesteps, num_inference_steps, denoise_strength
|
||||
)
|
||||
latent_sigma = sigmas[:1].repeat(batch_size * num_videos_per_prompt)
|
||||
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
# 5. Prepare latent variables
|
||||
num_channels_latents = self.transformer.config.in_channels
|
||||
latents, conditioning_mask, video_coords, extra_conditioning_num_latents = self.prepare_latents(
|
||||
conditioning_tensors,
|
||||
@@ -1043,6 +1088,8 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
|
||||
height=height,
|
||||
width=width,
|
||||
num_frames=num_frames,
|
||||
sigma=latent_sigma,
|
||||
latents=latents,
|
||||
generator=generator,
|
||||
device=device,
|
||||
dtype=torch.float32,
|
||||
@@ -1056,21 +1103,6 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
|
||||
if self.do_classifier_free_guidance:
|
||||
video_coords = torch.cat([video_coords, video_coords], dim=0)
|
||||
|
||||
# 5. Prepare timesteps
|
||||
latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1
|
||||
latent_height = height // self.vae_spatial_compression_ratio
|
||||
latent_width = width // self.vae_spatial_compression_ratio
|
||||
sigmas = linear_quadratic_schedule(num_inference_steps)
|
||||
timesteps = sigmas * 1000
|
||||
timesteps, num_inference_steps = retrieve_timesteps(
|
||||
self.scheduler,
|
||||
num_inference_steps,
|
||||
device,
|
||||
timesteps=timesteps,
|
||||
)
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
# 6. Denoising loop
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
@@ -1168,7 +1200,7 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
|
||||
if not self.vae.config.timestep_conditioning:
|
||||
timestep = None
|
||||
else:
|
||||
noise = torch.randn(latents.shape, generator=generator, device=device, dtype=latents.dtype)
|
||||
noise = randn_tensor(latents.shape, generator=generator, device=device, dtype=latents.dtype)
|
||||
if not isinstance(decode_timestep, list):
|
||||
decode_timestep = [decode_timestep] * batch_size
|
||||
if decode_noise_scale is None:
|
||||
|
||||
241
src/diffusers/pipelines/ltx/pipeline_ltx_latent_upsample.py
Normal file
241
src/diffusers/pipelines/ltx/pipeline_ltx_latent_upsample.py
Normal file
@@ -0,0 +1,241 @@
|
||||
# Copyright 2025 Lightricks and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from ...image_processor import PipelineImageInput
|
||||
from ...models import AutoencoderKLLTXVideo
|
||||
from ...utils import get_logger
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ...video_processor import VideoProcessor
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from .modeling_latent_upsampler import LTXLatentUpsamplerModel
|
||||
from .pipeline_output import LTXPipelineOutput
|
||||
|
||||
|
||||
logger = get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
||||
def retrieve_latents(
|
||||
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
||||
):
|
||||
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
||||
return encoder_output.latent_dist.sample(generator)
|
||||
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
||||
return encoder_output.latent_dist.mode()
|
||||
elif hasattr(encoder_output, "latents"):
|
||||
return encoder_output.latents
|
||||
else:
|
||||
raise AttributeError("Could not access latents of provided encoder_output")
|
||||
|
||||
|
||||
class LTXLatentUpsamplePipeline(DiffusionPipeline):
|
||||
def __init__(
|
||||
self,
|
||||
vae: AutoencoderKLLTXVideo,
|
||||
latent_upsampler: LTXLatentUpsamplerModel,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(vae=vae, latent_upsampler=latent_upsampler)
|
||||
|
||||
self.vae_spatial_compression_ratio = (
|
||||
self.vae.spatial_compression_ratio if getattr(self, "vae", None) is not None else 32
|
||||
)
|
||||
self.vae_temporal_compression_ratio = (
|
||||
self.vae.temporal_compression_ratio if getattr(self, "vae", None) is not None else 8
|
||||
)
|
||||
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_spatial_compression_ratio)
|
||||
|
||||
def prepare_latents(
|
||||
self,
|
||||
video: Optional[torch.Tensor] = None,
|
||||
batch_size: int = 1,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if latents is not None:
|
||||
return latents.to(device=device, dtype=dtype)
|
||||
|
||||
video = video.to(device=device, dtype=self.vae.dtype)
|
||||
if isinstance(generator, list):
|
||||
if len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
init_latents = [
|
||||
retrieve_latents(self.vae.encode(video[i].unsqueeze(0)), generator[i]) for i in range(batch_size)
|
||||
]
|
||||
else:
|
||||
init_latents = [retrieve_latents(self.vae.encode(vid.unsqueeze(0)), generator) for vid in video]
|
||||
|
||||
init_latents = torch.cat(init_latents, dim=0).to(dtype)
|
||||
init_latents = self._normalize_latents(init_latents, self.vae.latents_mean, self.vae.latents_std)
|
||||
return init_latents
|
||||
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._normalize_latents
|
||||
def _normalize_latents(
|
||||
latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0
|
||||
) -> torch.Tensor:
|
||||
# Normalize latents across the channel dimension [B, C, F, H, W]
|
||||
latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
|
||||
latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
|
||||
latents = (latents - latents_mean) * scaling_factor / latents_std
|
||||
return latents
|
||||
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._denormalize_latents
|
||||
def _denormalize_latents(
|
||||
latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0
|
||||
) -> torch.Tensor:
|
||||
# Denormalize latents across the channel dimension [B, C, F, H, W]
|
||||
latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
|
||||
latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
|
||||
latents = latents * latents_std / scaling_factor + latents_mean
|
||||
return latents
|
||||
|
||||
def enable_vae_slicing(self):
|
||||
r"""
|
||||
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
|
||||
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
|
||||
"""
|
||||
self.vae.enable_slicing()
|
||||
|
||||
def disable_vae_slicing(self):
|
||||
r"""
|
||||
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
|
||||
computing decoding in one step.
|
||||
"""
|
||||
self.vae.disable_slicing()
|
||||
|
||||
def enable_vae_tiling(self):
|
||||
r"""
|
||||
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
|
||||
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
|
||||
processing larger images.
|
||||
"""
|
||||
self.vae.enable_tiling()
|
||||
|
||||
def disable_vae_tiling(self):
|
||||
r"""
|
||||
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
|
||||
computing decoding in one step.
|
||||
"""
|
||||
self.vae.disable_tiling()
|
||||
|
||||
def check_inputs(self, video, height, width, latents):
|
||||
if height % self.vae_spatial_compression_ratio != 0 or width % self.vae_spatial_compression_ratio != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.")
|
||||
|
||||
if video is not None and latents is not None:
|
||||
raise ValueError("Only one of `video` or `latents` can be provided.")
|
||||
if video is None and latents is None:
|
||||
raise ValueError("One of `video` or `latents` has to be provided.")
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
video: Optional[List[PipelineImageInput]] = None,
|
||||
height: int = 512,
|
||||
width: int = 704,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
decode_timestep: Union[float, List[float]] = 0.0,
|
||||
decode_noise_scale: Optional[Union[float, List[float]]] = None,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
):
|
||||
self.check_inputs(
|
||||
video=video,
|
||||
height=height,
|
||||
width=width,
|
||||
latents=latents,
|
||||
)
|
||||
|
||||
if video is not None:
|
||||
# Batched video input is not yet tested/supported. TODO: take a look later
|
||||
batch_size = 1
|
||||
else:
|
||||
batch_size = latents.shape[0]
|
||||
device = self._execution_device
|
||||
|
||||
if video is not None:
|
||||
num_frames = len(video)
|
||||
if num_frames % self.vae_temporal_compression_ratio != 1:
|
||||
num_frames = (
|
||||
num_frames // self.vae_temporal_compression_ratio * self.vae_temporal_compression_ratio + 1
|
||||
)
|
||||
video = video[:num_frames]
|
||||
logger.warning(
|
||||
f"Video length expected to be of the form `k * {self.vae_temporal_compression_ratio} + 1` but is {len(video)}. Truncating to {num_frames} frames."
|
||||
)
|
||||
video = self.video_processor.preprocess_video(video, height=height, width=width)
|
||||
video = video.to(device=device, dtype=torch.float32)
|
||||
|
||||
latents = self.prepare_latents(
|
||||
video=video,
|
||||
batch_size=batch_size,
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
generator=generator,
|
||||
latents=latents,
|
||||
)
|
||||
|
||||
latents = self._denormalize_latents(
|
||||
latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor
|
||||
)
|
||||
latents = latents.to(self.latent_upsampler.dtype)
|
||||
latents = self.latent_upsampler(latents)
|
||||
|
||||
if output_type == "latent":
|
||||
latents = self._normalize_latents(
|
||||
latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor
|
||||
)
|
||||
video = latents
|
||||
else:
|
||||
if not self.vae.config.timestep_conditioning:
|
||||
timestep = None
|
||||
else:
|
||||
noise = randn_tensor(latents.shape, generator=generator, device=device, dtype=latents.dtype)
|
||||
if not isinstance(decode_timestep, list):
|
||||
decode_timestep = [decode_timestep] * batch_size
|
||||
if decode_noise_scale is None:
|
||||
decode_noise_scale = decode_timestep
|
||||
elif not isinstance(decode_noise_scale, list):
|
||||
decode_noise_scale = [decode_noise_scale] * batch_size
|
||||
|
||||
timestep = torch.tensor(decode_timestep, device=device, dtype=latents.dtype)
|
||||
decode_noise_scale = torch.tensor(decode_noise_scale, device=device, dtype=latents.dtype)[
|
||||
:, None, None, None, None
|
||||
]
|
||||
latents = (1 - decode_noise_scale) * latents + decode_noise_scale * noise
|
||||
|
||||
video = self.vae.decode(latents, timestep, return_dict=False)[0]
|
||||
video = self.video_processor.postprocess_video(video, output_type=output_type)
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return (video,)
|
||||
|
||||
return LTXPipelineOutput(frames=video)
|
||||
@@ -1307,6 +1307,21 @@ class LTXImageToVideoPipeline(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class LTXLatentUpsamplePipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class LTXPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
|
||||
159
tests/pipelines/ltx/test_ltx_latent_upsample.py
Normal file
159
tests/pipelines/ltx/test_ltx_latent_upsample.py
Normal file
@@ -0,0 +1,159 @@
|
||||
# Copyright 2025 The HuggingFace Team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from diffusers import AutoencoderKLLTXVideo, LTXLatentUpsamplePipeline
|
||||
from diffusers.pipelines.ltx.modeling_latent_upsampler import LTXLatentUpsamplerModel
|
||||
from diffusers.utils.testing_utils import enable_full_determinism
|
||||
|
||||
from ..test_pipelines_common import PipelineTesterMixin, to_np
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class LTXLatentUpsamplePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
pipeline_class = LTXLatentUpsamplePipeline
|
||||
params = {"video", "generator"}
|
||||
batch_params = {"video", "generator"}
|
||||
required_optional_params = frozenset(["generator", "latents", "return_dict"])
|
||||
test_xformers_attention = False
|
||||
supports_dduf = False
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
vae = AutoencoderKLLTXVideo(
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
latent_channels=8,
|
||||
block_out_channels=(8, 8, 8, 8),
|
||||
decoder_block_out_channels=(8, 8, 8, 8),
|
||||
layers_per_block=(1, 1, 1, 1, 1),
|
||||
decoder_layers_per_block=(1, 1, 1, 1, 1),
|
||||
spatio_temporal_scaling=(True, True, False, False),
|
||||
decoder_spatio_temporal_scaling=(True, True, False, False),
|
||||
decoder_inject_noise=(False, False, False, False, False),
|
||||
upsample_residual=(False, False, False, False),
|
||||
upsample_factor=(1, 1, 1, 1),
|
||||
timestep_conditioning=False,
|
||||
patch_size=1,
|
||||
patch_size_t=1,
|
||||
encoder_causal=True,
|
||||
decoder_causal=False,
|
||||
)
|
||||
vae.use_framewise_encoding = False
|
||||
vae.use_framewise_decoding = False
|
||||
|
||||
torch.manual_seed(0)
|
||||
latent_upsampler = LTXLatentUpsamplerModel(
|
||||
in_channels=8,
|
||||
mid_channels=32,
|
||||
num_blocks_per_stage=1,
|
||||
dims=3,
|
||||
spatial_upsample=True,
|
||||
temporal_upsample=False,
|
||||
)
|
||||
|
||||
components = {
|
||||
"vae": vae,
|
||||
"latent_upsampler": latent_upsampler,
|
||||
}
|
||||
return components
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device=device).manual_seed(seed)
|
||||
|
||||
video = torch.randn((5, 3, 32, 32), generator=generator, device=device)
|
||||
|
||||
inputs = {
|
||||
"video": video,
|
||||
"generator": generator,
|
||||
"height": 16,
|
||||
"width": 16,
|
||||
"output_type": "pt",
|
||||
}
|
||||
|
||||
return inputs
|
||||
|
||||
def test_inference(self):
|
||||
device = "cpu"
|
||||
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.to(device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
video = pipe(**inputs).frames
|
||||
generated_video = video[0]
|
||||
|
||||
self.assertEqual(generated_video.shape, (5, 3, 32, 32))
|
||||
expected_video = torch.randn(5, 3, 32, 32)
|
||||
max_diff = np.abs(generated_video - expected_video).max()
|
||||
self.assertLessEqual(max_diff, 1e10)
|
||||
|
||||
def test_vae_tiling(self, expected_diff_max: float = 0.25):
|
||||
generator_device = "cpu"
|
||||
components = self.get_dummy_components()
|
||||
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.to("cpu")
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
# Without tiling
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
inputs["height"] = inputs["width"] = 128
|
||||
output_without_tiling = pipe(**inputs)[0]
|
||||
|
||||
# With tiling
|
||||
pipe.vae.enable_tiling(
|
||||
tile_sample_min_height=96,
|
||||
tile_sample_min_width=96,
|
||||
tile_sample_stride_height=64,
|
||||
tile_sample_stride_width=64,
|
||||
)
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
inputs["height"] = inputs["width"] = 128
|
||||
output_with_tiling = pipe(**inputs)[0]
|
||||
|
||||
self.assertLess(
|
||||
(to_np(output_without_tiling) - to_np(output_with_tiling)).max(),
|
||||
expected_diff_max,
|
||||
"VAE tiling should not affect the inference results",
|
||||
)
|
||||
|
||||
@unittest.skip("Test is not applicable.")
|
||||
def test_callback_inputs(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Test is not applicable.")
|
||||
def test_attention_slicing_forward_pass(
|
||||
self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3
|
||||
):
|
||||
pass
|
||||
|
||||
@unittest.skip("Test is not applicable.")
|
||||
def test_inference_batch_consistent(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Test is not applicable.")
|
||||
def test_inference_batch_single_identical(self):
|
||||
pass
|
||||
Reference in New Issue
Block a user