mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
LTX2 distilled checkpoint support (#12934)
* add constants for distill sigmas values and allow ltx pipeline to pass in sigmas * add time conditioning conversion and token packing for latents * make style & quality * remove prenorm * add sigma param to ltx2 i2v * fix copies and add pack latents to i2v * Apply suggestions from code review Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> * Infer latent dims if latents/audio_latents is supplied * add note for predefined sigmas * run make style and quality * revert distill timesteps & set original_state_dict_repo_idd to default None * add latent normalize * add create noised state, delete last sigmas * remove normalize step in latent upsample pipeline and move it to ltx2 pipeline * add create noise latent to i2v pipeline * fix copies * parse none value in weight conversion script * explicit shape handling * Apply suggestions from code review Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> * make style * add two stage inference tests * add ltx2 documentation * update i2v expected_audio_slice * Apply suggestions from code review Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> * Apply suggestion from @dg845 Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> * Update ltx2.md to remove one-stage example Removed one-stage generation example code and added comments for noise scale in two-stage generation. --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> Co-authored-by: Daniel Gu <dgu8957@gmail.com>
This commit is contained in:
@@ -24,6 +24,179 @@ You can find all the original LTX-Video checkpoints under the [Lightricks](https
|
||||
|
||||
The original codebase for LTX-2 can be found [here](https://github.com/Lightricks/LTX-2).
|
||||
|
||||
## Two-stages Generation
|
||||
Recommended pipeline to achieve production quality generation, this pipeline is composed of two stages:
|
||||
|
||||
- Stage 1: Generate a video at the target resolution using diffusion sampling with classifier-free guidance (CFG). This stage produces a coherent low-noise video sequence that respects the text/image conditioning.
|
||||
- Stage 2: Upsample the Stage 1 output by 2 and refine details using a distilled LoRA model to improve fidelity and visual quality. Stage 2 may apply lighter CFG to preserve the structure from Stage 1 while enhancing texture and sharpness.
|
||||
|
||||
Sample usage of text-to-video two stages pipeline
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import FlowMatchEulerDiscreteScheduler
|
||||
from diffusers.pipelines.ltx2 import LTX2Pipeline, LTX2LatentUpsamplePipeline
|
||||
from diffusers.pipelines.ltx2.latent_upsampler import LTX2LatentUpsamplerModel
|
||||
from diffusers.pipelines.ltx2.utils import STAGE_2_DISTILLED_SIGMA_VALUES
|
||||
from diffusers.pipelines.ltx2.export_utils import encode_video
|
||||
|
||||
device = "cuda:0"
|
||||
width = 768
|
||||
height = 512
|
||||
|
||||
pipe = LTX2Pipeline.from_pretrained(
|
||||
"Lightricks/LTX-2", torch_dtype=torch.bfloat16
|
||||
)
|
||||
pipe.enable_sequential_cpu_offload(device=device)
|
||||
|
||||
prompt = "A beautiful sunset over the ocean"
|
||||
negative_prompt = "shaky, glitchy, low quality, worst quality, deformed, distorted, disfigured, motion smear, motion artifacts, fused fingers, bad anatomy, weird hand, ugly, transition, static."
|
||||
|
||||
# Stage 1 default (non-distilled) inference
|
||||
frame_rate = 24.0
|
||||
video_latent, audio_latent = pipe(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
width=width,
|
||||
height=height,
|
||||
num_frames=121,
|
||||
frame_rate=frame_rate,
|
||||
num_inference_steps=40,
|
||||
sigmas=None,
|
||||
guidance_scale=4.0,
|
||||
output_type="latent",
|
||||
return_dict=False,
|
||||
)
|
||||
|
||||
latent_upsampler = LTX2LatentUpsamplerModel.from_pretrained(
|
||||
"Lightricks/LTX-2",
|
||||
subfolder="latent_upsampler",
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
upsample_pipe = LTX2LatentUpsamplePipeline(vae=pipe.vae, latent_upsampler=latent_upsampler)
|
||||
upsample_pipe.enable_model_cpu_offload(device=device)
|
||||
upscaled_video_latent = upsample_pipe(
|
||||
latents=video_latent,
|
||||
output_type="latent",
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
# Load Stage 2 distilled LoRA
|
||||
pipe.load_lora_weights(
|
||||
"Lightricks/LTX-2", adapter_name="stage_2_distilled", weight_name="ltx-2-19b-distilled-lora-384.safetensors"
|
||||
)
|
||||
pipe.set_adapters("stage_2_distilled", 1.0)
|
||||
# VAE tiling is usually necessary to avoid OOM error when VAE decoding
|
||||
pipe.vae.enable_tiling()
|
||||
# Change scheduler to use Stage 2 distilled sigmas as is
|
||||
new_scheduler = FlowMatchEulerDiscreteScheduler.from_config(
|
||||
pipe.scheduler.config, use_dynamic_shifting=False, shift_terminal=None
|
||||
)
|
||||
pipe.scheduler = new_scheduler
|
||||
# Stage 2 inference with distilled LoRA and sigmas
|
||||
video, audio = pipe(
|
||||
latents=upscaled_video_latent,
|
||||
audio_latents=audio_latent,
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
num_inference_steps=3,
|
||||
noise_scale=STAGE_2_DISTILLED_SIGMA_VALUES[0], # renoise with first sigma value https://github.com/Lightricks/LTX-2/blob/main/packages/ltx-pipelines/src/ltx_pipelines/ti2vid_two_stages.py#L218
|
||||
sigmas=STAGE_2_DISTILLED_SIGMA_VALUES,
|
||||
guidance_scale=1.0,
|
||||
output_type="np",
|
||||
return_dict=False,
|
||||
)
|
||||
video = (video * 255).round().astype("uint8")
|
||||
video = torch.from_numpy(video)
|
||||
|
||||
encode_video(
|
||||
video[0],
|
||||
fps=frame_rate,
|
||||
audio=audio[0].float().cpu(),
|
||||
audio_sample_rate=pipe.vocoder.config.output_sampling_rate,
|
||||
output_path="ltx2_lora_distilled_sample.mp4",
|
||||
)
|
||||
```
|
||||
|
||||
## Distilled checkpoint generation
|
||||
Fastest two-stages generation pipeline using a distilled checkpoint.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers.pipelines.ltx2 import LTX2Pipeline, LTX2LatentUpsamplePipeline
|
||||
from diffusers.pipelines.ltx2.latent_upsampler import LTX2LatentUpsamplerModel
|
||||
from diffusers.pipelines.ltx2.utils import DISTILLED_SIGMA_VALUES, STAGE_2_DISTILLED_SIGMA_VALUES
|
||||
from diffusers.pipelines.ltx2.export_utils import encode_video
|
||||
|
||||
device = "cuda"
|
||||
width = 768
|
||||
height = 512
|
||||
random_seed = 42
|
||||
generator = torch.Generator(device).manual_seed(random_seed)
|
||||
model_path = "rootonchair/LTX-2-19b-distilled"
|
||||
|
||||
pipe = LTX2Pipeline.from_pretrained(
|
||||
model_path, torch_dtype=torch.bfloat16
|
||||
)
|
||||
pipe.enable_sequential_cpu_offload(device=device)
|
||||
|
||||
prompt = "A beautiful sunset over the ocean"
|
||||
negative_prompt = "shaky, glitchy, low quality, worst quality, deformed, distorted, disfigured, motion smear, motion artifacts, fused fingers, bad anatomy, weird hand, ugly, transition, static."
|
||||
|
||||
frame_rate = 24.0
|
||||
video_latent, audio_latent = pipe(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
width=width,
|
||||
height=height,
|
||||
num_frames=121,
|
||||
frame_rate=frame_rate,
|
||||
num_inference_steps=8,
|
||||
sigmas=DISTILLED_SIGMA_VALUES,
|
||||
guidance_scale=1.0,
|
||||
generator=generator,
|
||||
output_type="latent",
|
||||
return_dict=False,
|
||||
)
|
||||
|
||||
latent_upsampler = LTX2LatentUpsamplerModel.from_pretrained(
|
||||
model_path,
|
||||
subfolder="latent_upsampler",
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
upsample_pipe = LTX2LatentUpsamplePipeline(vae=pipe.vae, latent_upsampler=latent_upsampler)
|
||||
upsample_pipe.enable_model_cpu_offload(device=device)
|
||||
upscaled_video_latent = upsample_pipe(
|
||||
latents=video_latent,
|
||||
output_type="latent",
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
video, audio = pipe(
|
||||
latents=upscaled_video_latent,
|
||||
audio_latents=audio_latent,
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
num_inference_steps=3,
|
||||
noise_scale=STAGE_2_DISTILLED_SIGMA_VALUES[0], # renoise with first sigma value https://github.com/Lightricks/LTX-2/blob/main/packages/ltx-pipelines/src/ltx_pipelines/distilled.py#L178
|
||||
sigmas=STAGE_2_DISTILLED_SIGMA_VALUES,
|
||||
generator=generator,
|
||||
guidance_scale=1.0,
|
||||
output_type="np",
|
||||
return_dict=False,
|
||||
)
|
||||
video = (video * 255).round().astype("uint8")
|
||||
video = torch.from_numpy(video)
|
||||
|
||||
encode_video(
|
||||
video[0],
|
||||
fps=frame_rate,
|
||||
audio=audio[0].float().cpu(),
|
||||
audio_sample_rate=pipe.vocoder.config.output_sampling_rate,
|
||||
output_path="ltx2_distilled_sample.mp4",
|
||||
)
|
||||
```
|
||||
|
||||
## LTX2Pipeline
|
||||
|
||||
[[autodoc]] LTX2Pipeline
|
||||
|
||||
@@ -63,6 +63,8 @@ LTX_2_0_VIDEO_VAE_RENAME_DICT = {
|
||||
"up_blocks.4": "up_blocks.1",
|
||||
"up_blocks.5": "up_blocks.2.upsamplers.0",
|
||||
"up_blocks.6": "up_blocks.2",
|
||||
"last_time_embedder": "time_embedder",
|
||||
"last_scale_shift_table": "scale_shift_table",
|
||||
# Common
|
||||
# For all 3D ResNets
|
||||
"res_blocks": "resnets",
|
||||
@@ -372,7 +374,9 @@ def convert_ltx2_connectors(original_state_dict: Dict[str, Any], version: str) -
|
||||
return connectors
|
||||
|
||||
|
||||
def get_ltx2_video_vae_config(version: str) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
|
||||
def get_ltx2_video_vae_config(
|
||||
version: str, timestep_conditioning: bool = False
|
||||
) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
|
||||
if version == "test":
|
||||
config = {
|
||||
"model_id": "diffusers-internal-dev/dummy-ltx2",
|
||||
@@ -396,7 +400,7 @@ def get_ltx2_video_vae_config(version: str) -> Tuple[Dict[str, Any], Dict[str, A
|
||||
"downsample_type": ("spatial", "temporal", "spatiotemporal", "spatiotemporal"),
|
||||
"upsample_residual": (True, True, True),
|
||||
"upsample_factor": (2, 2, 2),
|
||||
"timestep_conditioning": False,
|
||||
"timestep_conditioning": timestep_conditioning,
|
||||
"patch_size": 4,
|
||||
"patch_size_t": 1,
|
||||
"resnet_norm_eps": 1e-6,
|
||||
@@ -433,7 +437,7 @@ def get_ltx2_video_vae_config(version: str) -> Tuple[Dict[str, Any], Dict[str, A
|
||||
"downsample_type": ("spatial", "temporal", "spatiotemporal", "spatiotemporal"),
|
||||
"upsample_residual": (True, True, True),
|
||||
"upsample_factor": (2, 2, 2),
|
||||
"timestep_conditioning": False,
|
||||
"timestep_conditioning": timestep_conditioning,
|
||||
"patch_size": 4,
|
||||
"patch_size_t": 1,
|
||||
"resnet_norm_eps": 1e-6,
|
||||
@@ -450,8 +454,10 @@ def get_ltx2_video_vae_config(version: str) -> Tuple[Dict[str, Any], Dict[str, A
|
||||
return config, rename_dict, special_keys_remap
|
||||
|
||||
|
||||
def convert_ltx2_video_vae(original_state_dict: Dict[str, Any], version: str) -> Dict[str, Any]:
|
||||
config, rename_dict, special_keys_remap = get_ltx2_video_vae_config(version)
|
||||
def convert_ltx2_video_vae(
|
||||
original_state_dict: Dict[str, Any], version: str, timestep_conditioning: bool
|
||||
) -> Dict[str, Any]:
|
||||
config, rename_dict, special_keys_remap = get_ltx2_video_vae_config(version, timestep_conditioning)
|
||||
diffusers_config = config["diffusers_config"]
|
||||
|
||||
with init_empty_weights():
|
||||
@@ -659,10 +665,15 @@ def get_model_state_dict_from_combined_ckpt(combined_ckpt: Dict[str, Any], prefi
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
def none_or_str(value: str):
|
||||
if isinstance(value, str) and value.lower() == "none":
|
||||
return None
|
||||
return value
|
||||
|
||||
parser.add_argument(
|
||||
"--original_state_dict_repo_id",
|
||||
default="Lightricks/LTX-2",
|
||||
type=str,
|
||||
type=none_or_str,
|
||||
help="HF Hub repo id with LTX 2.0 checkpoint",
|
||||
)
|
||||
parser.add_argument(
|
||||
@@ -682,7 +693,7 @@ def get_args():
|
||||
parser.add_argument(
|
||||
"--combined_filename",
|
||||
default="ltx-2-19b-dev.safetensors",
|
||||
type=str,
|
||||
type=none_or_str,
|
||||
help="Filename for combined checkpoint with all LTX 2.0 models (VAE, DiT, etc.)",
|
||||
)
|
||||
parser.add_argument("--vae_prefix", default="vae.", type=str)
|
||||
@@ -701,22 +712,25 @@ def get_args():
|
||||
parser.add_argument(
|
||||
"--text_encoder_model_id",
|
||||
default="google/gemma-3-12b-it-qat-q4_0-unquantized",
|
||||
type=str,
|
||||
type=none_or_str,
|
||||
help="HF Hub id for the LTX 2.0 base text encoder model",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tokenizer_id",
|
||||
default="google/gemma-3-12b-it-qat-q4_0-unquantized",
|
||||
type=str,
|
||||
type=none_or_str,
|
||||
help="HF Hub id for the LTX 2.0 text tokenizer",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--latent_upsampler_filename",
|
||||
default="ltx-2-spatial-upscaler-x2-1.0.safetensors",
|
||||
type=str,
|
||||
type=none_or_str,
|
||||
help="Latent upsampler filename",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--timestep_conditioning", action="store_true", help="Whether to add timestep condition to the video VAE model"
|
||||
)
|
||||
parser.add_argument("--vae", action="store_true", help="Whether to convert the video VAE model")
|
||||
parser.add_argument("--audio_vae", action="store_true", help="Whether to convert the audio VAE model")
|
||||
parser.add_argument("--dit", action="store_true", help="Whether to convert the DiT model")
|
||||
@@ -786,7 +800,9 @@ def main(args):
|
||||
original_vae_ckpt = load_hub_or_local_checkpoint(filename=args.vae_filename)
|
||||
elif combined_ckpt is not None:
|
||||
original_vae_ckpt = get_model_state_dict_from_combined_ckpt(combined_ckpt, args.vae_prefix)
|
||||
vae = convert_ltx2_video_vae(original_vae_ckpt, version=args.version)
|
||||
vae = convert_ltx2_video_vae(
|
||||
original_vae_ckpt, version=args.version, timestep_conditioning=args.timestep_conditioning
|
||||
)
|
||||
if not args.full_pipeline and not args.upsample_pipeline:
|
||||
vae.to(vae_dtype).save_pretrained(os.path.join(args.output_path, "vae"))
|
||||
|
||||
|
||||
@@ -743,8 +743,8 @@ class AutoencoderKLLTX2Audio(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
|
||||
# Per-channel statistics for normalizing and denormalizing the latent representation. This statics is computed over
|
||||
# the entire dataset and stored in model's checkpoint under AudioVAE state_dict
|
||||
latents_std = torch.zeros((base_channels,))
|
||||
latents_mean = torch.ones((base_channels,))
|
||||
latents_std = torch.ones((base_channels,))
|
||||
latents_mean = torch.zeros((base_channels,))
|
||||
self.register_buffer("latents_mean", latents_mean, persistent=True)
|
||||
self.register_buffer("latents_std", latents_std, persistent=True)
|
||||
|
||||
|
||||
@@ -584,6 +584,17 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin):
|
||||
latents = latents.permute(0, 4, 1, 5, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(2, 3)
|
||||
return latents
|
||||
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.ltx2.pipeline_ltx2_image2video.LTX2ImageToVideoPipeline._normalize_latents
|
||||
def _normalize_latents(
|
||||
latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0
|
||||
) -> torch.Tensor:
|
||||
# Normalize latents across the channel dimension [B, C, F, H, W]
|
||||
latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
|
||||
latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
|
||||
latents = (latents - latents_mean) * scaling_factor / latents_std
|
||||
return latents
|
||||
|
||||
@staticmethod
|
||||
def _denormalize_latents(
|
||||
latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0
|
||||
@@ -594,12 +605,26 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin):
|
||||
latents = latents * latents_std / scaling_factor + latents_mean
|
||||
return latents
|
||||
|
||||
@staticmethod
|
||||
def _normalize_audio_latents(latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor):
|
||||
latents_mean = latents_mean.to(latents.device, latents.dtype)
|
||||
latents_std = latents_std.to(latents.device, latents.dtype)
|
||||
return (latents - latents_mean) / latents_std
|
||||
|
||||
@staticmethod
|
||||
def _denormalize_audio_latents(latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor):
|
||||
latents_mean = latents_mean.to(latents.device, latents.dtype)
|
||||
latents_std = latents_std.to(latents.device, latents.dtype)
|
||||
return (latents * latents_std) + latents_mean
|
||||
|
||||
@staticmethod
|
||||
def _create_noised_state(
|
||||
latents: torch.Tensor, noise_scale: Union[float, torch.Tensor], generator: Optional[torch.Generator] = None
|
||||
):
|
||||
noise = randn_tensor(latents.shape, generator=generator, device=latents.device, dtype=latents.dtype)
|
||||
noised_latents = noise_scale * noise + (1 - noise_scale) * latents
|
||||
return noised_latents
|
||||
|
||||
@staticmethod
|
||||
def _pack_audio_latents(
|
||||
latents: torch.Tensor, patch_size: Optional[int] = None, patch_size_t: Optional[int] = None
|
||||
@@ -647,12 +672,26 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin):
|
||||
height: int = 512,
|
||||
width: int = 768,
|
||||
num_frames: int = 121,
|
||||
noise_scale: float = 0.0,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if latents is not None:
|
||||
if latents.ndim == 5:
|
||||
latents = self._normalize_latents(
|
||||
latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor
|
||||
)
|
||||
# latents are of shape [B, C, F, H, W], need to be packed
|
||||
latents = self._pack_latents(
|
||||
latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size
|
||||
)
|
||||
if latents.ndim != 3:
|
||||
raise ValueError(
|
||||
f"Provided `latents` tensor has shape {latents.shape}, but the expected shape is [batch_size, num_seq, num_features]."
|
||||
)
|
||||
latents = self._create_noised_state(latents, noise_scale, generator)
|
||||
return latents.to(device=device, dtype=dtype)
|
||||
|
||||
height = height // self.vae_spatial_compression_ratio
|
||||
@@ -677,29 +716,30 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin):
|
||||
self,
|
||||
batch_size: int = 1,
|
||||
num_channels_latents: int = 8,
|
||||
audio_latent_length: int = 1, # 1 is just a dummy value
|
||||
num_mel_bins: int = 64,
|
||||
num_frames: int = 121,
|
||||
frame_rate: float = 25.0,
|
||||
sampling_rate: int = 16000,
|
||||
hop_length: int = 160,
|
||||
noise_scale: float = 0.0,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
duration_s = num_frames / frame_rate
|
||||
latents_per_second = (
|
||||
float(sampling_rate) / float(hop_length) / float(self.audio_vae_temporal_compression_ratio)
|
||||
)
|
||||
latent_length = round(duration_s * latents_per_second)
|
||||
|
||||
if latents is not None:
|
||||
return latents.to(device=device, dtype=dtype), latent_length
|
||||
if latents.ndim == 4:
|
||||
# latents are of shape [B, C, L, M], need to be packed
|
||||
latents = self._pack_audio_latents(latents)
|
||||
if latents.ndim != 3:
|
||||
raise ValueError(
|
||||
f"Provided `latents` tensor has shape {latents.shape}, but the expected shape is [batch_size, num_seq, num_features]."
|
||||
)
|
||||
latents = self._normalize_audio_latents(latents, self.audio_vae.latents_mean, self.audio_vae.latents_std)
|
||||
latents = self._create_noised_state(latents, noise_scale, generator)
|
||||
return latents.to(device=device, dtype=dtype)
|
||||
|
||||
# TODO: confirm whether this logic is correct
|
||||
latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio
|
||||
|
||||
shape = (batch_size, num_channels_latents, latent_length, latent_mel_bins)
|
||||
shape = (batch_size, num_channels_latents, audio_latent_length, latent_mel_bins)
|
||||
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
@@ -709,7 +749,7 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin):
|
||||
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
latents = self._pack_audio_latents(latents)
|
||||
return latents, latent_length
|
||||
return latents
|
||||
|
||||
@property
|
||||
def guidance_scale(self):
|
||||
@@ -750,9 +790,11 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin):
|
||||
num_frames: int = 121,
|
||||
frame_rate: float = 24.0,
|
||||
num_inference_steps: int = 40,
|
||||
sigmas: Optional[List[float]] = None,
|
||||
timesteps: List[int] = None,
|
||||
guidance_scale: float = 4.0,
|
||||
guidance_rescale: float = 0.0,
|
||||
noise_scale: float = 0.0,
|
||||
num_videos_per_prompt: Optional[int] = 1,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
@@ -788,6 +830,10 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin):
|
||||
num_inference_steps (`int`, *optional*, defaults to 40):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
sigmas (`List[float]`, *optional*):
|
||||
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
|
||||
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
|
||||
will be used.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
|
||||
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
|
||||
@@ -804,6 +850,9 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin):
|
||||
[Common Diffusion Noise Schedules and Sample Steps are
|
||||
Flawed](https://huggingface.co/papers/2305.08891). Guidance rescale factor should fix overexposure when
|
||||
using zero terminal SNR.
|
||||
noise_scale (`float`, *optional*, defaults to `0.0`):
|
||||
The interpolation factor between random noise and denoised latents at each timestep. Applying noise to
|
||||
the `latents` and `audio_latents` before continue denoising.
|
||||
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of videos to generate per prompt.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
@@ -922,6 +971,21 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin):
|
||||
latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1
|
||||
latent_height = height // self.vae_spatial_compression_ratio
|
||||
latent_width = width // self.vae_spatial_compression_ratio
|
||||
if latents is not None:
|
||||
if latents.ndim == 5:
|
||||
logger.info(
|
||||
"Got latents of shape [batch_size, latent_dim, latent_frames, latent_height, latent_width], `latent_num_frames`, `latent_height`, `latent_width` will be inferred."
|
||||
)
|
||||
_, _, latent_num_frames, latent_height, latent_width = latents.shape # [B, C, F, H, W]
|
||||
elif latents.ndim == 3:
|
||||
logger.warning(
|
||||
f"You have supplied packed `latents` of shape {latents.shape}, so the latent dims cannot be"
|
||||
f" inferred. Make sure the supplied `height`, `width`, and `num_frames` are correct."
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Provided `latents` tensor has shape {latents.shape}, but the expected shape is either [batch_size, seq_len, num_features] or [batch_size, latent_dim, latent_frames, latent_height, latent_width]."
|
||||
)
|
||||
video_sequence_length = latent_num_frames * latent_height * latent_width
|
||||
|
||||
num_channels_latents = self.transformer.config.in_channels
|
||||
@@ -931,26 +995,45 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin):
|
||||
height,
|
||||
width,
|
||||
num_frames,
|
||||
noise_scale,
|
||||
torch.float32,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
)
|
||||
|
||||
duration_s = num_frames / frame_rate
|
||||
audio_latents_per_second = (
|
||||
self.audio_sampling_rate / self.audio_hop_length / float(self.audio_vae_temporal_compression_ratio)
|
||||
)
|
||||
audio_num_frames = round(duration_s * audio_latents_per_second)
|
||||
if audio_latents is not None:
|
||||
if audio_latents.ndim == 4:
|
||||
logger.info(
|
||||
"Got audio_latents of shape [batch_size, num_channels, audio_length, mel_bins], `audio_num_frames` will be inferred."
|
||||
)
|
||||
_, _, audio_num_frames, _ = audio_latents.shape # [B, C, L, M]
|
||||
elif audio_latents.ndim == 3:
|
||||
logger.warning(
|
||||
f"You have supplied packed `audio_latents` of shape {audio_latents.shape}, so the latent dims"
|
||||
f" cannot be inferred. Make sure the supplied `num_frames` and `frame_rate` are correct."
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Provided `audio_latents` tensor has shape {audio_latents.shape}, but the expected shape is either [batch_size, seq_len, num_features] or [batch_size, num_channels, audio_length, mel_bins]."
|
||||
)
|
||||
|
||||
num_mel_bins = self.audio_vae.config.mel_bins if getattr(self, "audio_vae", None) is not None else 64
|
||||
latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio
|
||||
|
||||
num_channels_latents_audio = (
|
||||
self.audio_vae.config.latent_channels if getattr(self, "audio_vae", None) is not None else 8
|
||||
)
|
||||
audio_latents, audio_num_frames = self.prepare_audio_latents(
|
||||
audio_latents = self.prepare_audio_latents(
|
||||
batch_size * num_videos_per_prompt,
|
||||
num_channels_latents=num_channels_latents_audio,
|
||||
audio_latent_length=audio_num_frames,
|
||||
num_mel_bins=num_mel_bins,
|
||||
num_frames=num_frames, # Video frames, audio frames will be calculated from this
|
||||
frame_rate=frame_rate,
|
||||
sampling_rate=self.audio_sampling_rate,
|
||||
hop_length=self.audio_hop_length,
|
||||
noise_scale=noise_scale,
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
generator=generator,
|
||||
@@ -958,7 +1041,7 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin):
|
||||
)
|
||||
|
||||
# 5. Prepare timesteps
|
||||
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
|
||||
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
|
||||
mu = calculate_shift(
|
||||
video_sequence_length,
|
||||
self.scheduler.config.get("base_image_seq_len", 1024),
|
||||
|
||||
@@ -614,6 +614,15 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL
|
||||
latents = latents * latents_std / scaling_factor + latents_mean
|
||||
return latents
|
||||
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._create_noised_state
|
||||
def _create_noised_state(
|
||||
latents: torch.Tensor, noise_scale: Union[float, torch.Tensor], generator: Optional[torch.Generator] = None
|
||||
):
|
||||
noise = randn_tensor(latents.shape, generator=generator, device=latents.device, dtype=latents.dtype)
|
||||
noised_latents = noise_scale * noise + (1 - noise_scale) * latents
|
||||
return noised_latents
|
||||
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._pack_audio_latents
|
||||
def _pack_audio_latents(
|
||||
@@ -656,6 +665,13 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL
|
||||
latents = latents.unflatten(2, (-1, num_mel_bins)).transpose(1, 2)
|
||||
return latents
|
||||
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._normalize_audio_latents
|
||||
def _normalize_audio_latents(latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor):
|
||||
latents_mean = latents_mean.to(latents.device, latents.dtype)
|
||||
latents_std = latents_std.to(latents.device, latents.dtype)
|
||||
return (latents - latents_mean) / latents_std
|
||||
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._denormalize_audio_latents
|
||||
def _denormalize_audio_latents(latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor):
|
||||
@@ -671,6 +687,7 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL
|
||||
height: int = 512,
|
||||
width: int = 704,
|
||||
num_frames: int = 161,
|
||||
noise_scale: float = 0.0,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
@@ -686,6 +703,15 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL
|
||||
if latents is not None:
|
||||
conditioning_mask = latents.new_zeros(mask_shape)
|
||||
conditioning_mask[:, :, 0] = 1.0
|
||||
if latents.ndim == 5:
|
||||
latents = self._normalize_latents(
|
||||
latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor
|
||||
)
|
||||
latents = self._create_noised_state(latents, noise_scale * (1 - conditioning_mask), generator)
|
||||
# latents are of shape [B, C, F, H, W], need to be packed
|
||||
latents = self._pack_latents(
|
||||
latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size
|
||||
)
|
||||
conditioning_mask = self._pack_latents(
|
||||
conditioning_mask, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size
|
||||
).squeeze(-1)
|
||||
@@ -737,29 +763,30 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL
|
||||
self,
|
||||
batch_size: int = 1,
|
||||
num_channels_latents: int = 8,
|
||||
audio_latent_length: int = 1, # 1 is just a dummy value
|
||||
num_mel_bins: int = 64,
|
||||
num_frames: int = 121,
|
||||
frame_rate: float = 25.0,
|
||||
sampling_rate: int = 16000,
|
||||
hop_length: int = 160,
|
||||
noise_scale: float = 0.0,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
duration_s = num_frames / frame_rate
|
||||
latents_per_second = (
|
||||
float(sampling_rate) / float(hop_length) / float(self.audio_vae_temporal_compression_ratio)
|
||||
)
|
||||
latent_length = round(duration_s * latents_per_second)
|
||||
|
||||
if latents is not None:
|
||||
return latents.to(device=device, dtype=dtype), latent_length
|
||||
if latents.ndim == 4:
|
||||
# latents are of shape [B, C, L, M], need to be packed
|
||||
latents = self._pack_audio_latents(latents)
|
||||
if latents.ndim != 3:
|
||||
raise ValueError(
|
||||
f"Provided `latents` tensor has shape {latents.shape}, but the expected shape is [batch_size, num_seq, num_features]."
|
||||
)
|
||||
latents = self._normalize_audio_latents(latents, self.audio_vae.latents_mean, self.audio_vae.latents_std)
|
||||
latents = self._create_noised_state(latents, noise_scale, generator)
|
||||
return latents.to(device=device, dtype=dtype)
|
||||
|
||||
# TODO: confirm whether this logic is correct
|
||||
latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio
|
||||
|
||||
shape = (batch_size, num_channels_latents, latent_length, latent_mel_bins)
|
||||
shape = (batch_size, num_channels_latents, audio_latent_length, latent_mel_bins)
|
||||
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
@@ -769,7 +796,7 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL
|
||||
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
latents = self._pack_audio_latents(latents)
|
||||
return latents, latent_length
|
||||
return latents
|
||||
|
||||
@property
|
||||
def guidance_scale(self):
|
||||
@@ -811,9 +838,11 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL
|
||||
num_frames: int = 121,
|
||||
frame_rate: float = 24.0,
|
||||
num_inference_steps: int = 40,
|
||||
sigmas: Optional[List[float]] = None,
|
||||
timesteps: List[int] = None,
|
||||
guidance_scale: float = 4.0,
|
||||
guidance_rescale: float = 0.0,
|
||||
noise_scale: float = 0.0,
|
||||
num_videos_per_prompt: Optional[int] = 1,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
@@ -851,6 +880,10 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL
|
||||
num_inference_steps (`int`, *optional*, defaults to 40):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
sigmas (`List[float]`, *optional*):
|
||||
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
|
||||
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
|
||||
will be used.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
|
||||
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
|
||||
@@ -867,6 +900,9 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL
|
||||
[Common Diffusion Noise Schedules and Sample Steps are
|
||||
Flawed](https://huggingface.co/papers/2305.08891). Guidance rescale factor should fix overexposure when
|
||||
using zero terminal SNR.
|
||||
noise_scale (`float`, *optional*, defaults to `0.0`):
|
||||
The interpolation factor between random noise and denoised latents at each timestep. Applying noise to
|
||||
the `latents` and `audio_latents` before continue denoising.
|
||||
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of videos to generate per prompt.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
@@ -982,6 +1018,26 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL
|
||||
)
|
||||
|
||||
# 4. Prepare latent variables
|
||||
latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1
|
||||
latent_height = height // self.vae_spatial_compression_ratio
|
||||
latent_width = width // self.vae_spatial_compression_ratio
|
||||
if latents is not None:
|
||||
if latents.ndim == 5:
|
||||
logger.info(
|
||||
"Got latents of shape [batch_size, latent_dim, latent_frames, latent_height, latent_width], `latent_num_frames`, `latent_height`, `latent_width` will be inferred."
|
||||
)
|
||||
_, _, latent_num_frames, latent_height, latent_width = latents.shape # [B, C, F, H, W]
|
||||
elif latents.ndim == 3:
|
||||
logger.warning(
|
||||
f"You have supplied packed `latents` of shape {latents.shape}, so the latent dims cannot be"
|
||||
f" inferred. Make sure the supplied `height`, `width`, and `num_frames` are correct."
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Provided `latents` tensor has shape {latents.shape}, but the expected shape is either [batch_size, seq_len, num_features] or [batch_size, latent_dim, latent_frames, latent_height, latent_width]."
|
||||
)
|
||||
video_sequence_length = latent_num_frames * latent_height * latent_width
|
||||
|
||||
if latents is None:
|
||||
image = self.video_processor.preprocess(image, height=height, width=width)
|
||||
image = image.to(device=device, dtype=prompt_embeds.dtype)
|
||||
@@ -994,6 +1050,7 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL
|
||||
height,
|
||||
width,
|
||||
num_frames,
|
||||
noise_scale,
|
||||
torch.float32,
|
||||
device,
|
||||
generator,
|
||||
@@ -1002,20 +1059,38 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL
|
||||
if self.do_classifier_free_guidance:
|
||||
conditioning_mask = torch.cat([conditioning_mask, conditioning_mask])
|
||||
|
||||
duration_s = num_frames / frame_rate
|
||||
audio_latents_per_second = (
|
||||
self.audio_sampling_rate / self.audio_hop_length / float(self.audio_vae_temporal_compression_ratio)
|
||||
)
|
||||
audio_num_frames = round(duration_s * audio_latents_per_second)
|
||||
if audio_latents is not None:
|
||||
if audio_latents.ndim == 4:
|
||||
logger.info(
|
||||
"Got audio_latents of shape [batch_size, num_channels, audio_length, mel_bins], `audio_num_frames` will be inferred."
|
||||
)
|
||||
_, _, audio_num_frames, _ = audio_latents.shape # [B, C, L, M]
|
||||
elif audio_latents.ndim == 3:
|
||||
logger.warning(
|
||||
f"You have supplied packed `audio_latents` of shape {audio_latents.shape}, so the latent dims"
|
||||
f" cannot be inferred. Make sure the supplied `num_frames` and `frame_rate` are correct."
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Provided `audio_latents` tensor has shape {audio_latents.shape}, but the expected shape is either [batch_size, seq_len, num_features] or [batch_size, num_channels, audio_length, mel_bins]."
|
||||
)
|
||||
|
||||
num_mel_bins = self.audio_vae.config.mel_bins if getattr(self, "audio_vae", None) is not None else 64
|
||||
latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio
|
||||
|
||||
num_channels_latents_audio = (
|
||||
self.audio_vae.config.latent_channels if getattr(self, "audio_vae", None) is not None else 8
|
||||
)
|
||||
audio_latents, audio_num_frames = self.prepare_audio_latents(
|
||||
audio_latents = self.prepare_audio_latents(
|
||||
batch_size * num_videos_per_prompt,
|
||||
num_channels_latents=num_channels_latents_audio,
|
||||
audio_latent_length=audio_num_frames,
|
||||
num_mel_bins=num_mel_bins,
|
||||
num_frames=num_frames, # Video frames, audio frames will be calculated from this
|
||||
frame_rate=frame_rate,
|
||||
sampling_rate=self.audio_sampling_rate,
|
||||
hop_length=self.audio_hop_length,
|
||||
noise_scale=noise_scale,
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
generator=generator,
|
||||
@@ -1023,12 +1098,7 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL
|
||||
)
|
||||
|
||||
# 5. Prepare timesteps
|
||||
latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1
|
||||
latent_height = height // self.vae_spatial_compression_ratio
|
||||
latent_width = width // self.vae_spatial_compression_ratio
|
||||
video_sequence_length = latent_num_frames * latent_height * latent_width
|
||||
|
||||
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
|
||||
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
|
||||
mu = calculate_shift(
|
||||
video_sequence_length,
|
||||
self.scheduler.config.get("base_image_seq_len", 1024),
|
||||
|
||||
@@ -228,17 +228,6 @@ class LTX2LatentUpsamplePipeline(DiffusionPipeline):
|
||||
filtered = latents * scales
|
||||
return filtered
|
||||
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.ltx2.pipeline_ltx2_image2video.LTX2ImageToVideoPipeline._normalize_latents
|
||||
def _normalize_latents(
|
||||
latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0
|
||||
) -> torch.Tensor:
|
||||
# Normalize latents across the channel dimension [B, C, F, H, W]
|
||||
latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
|
||||
latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
|
||||
latents = (latents - latents_mean) * scaling_factor / latents_std
|
||||
return latents
|
||||
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._denormalize_latents
|
||||
def _denormalize_latents(
|
||||
@@ -408,9 +397,6 @@ class LTX2LatentUpsamplePipeline(DiffusionPipeline):
|
||||
latents = self.tone_map_latents(latents, tone_map_compression_ratio)
|
||||
|
||||
if output_type == "latent":
|
||||
latents = self._normalize_latents(
|
||||
latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor
|
||||
)
|
||||
video = latents
|
||||
else:
|
||||
if not self.vae.config.timestep_conditioning:
|
||||
|
||||
6
src/diffusers/pipelines/ltx2/utils.py
Normal file
6
src/diffusers/pipelines/ltx2/utils.py
Normal file
@@ -0,0 +1,6 @@
|
||||
# Pre-trained sigma values for distilled model are taken from
|
||||
# https://github.com/Lightricks/LTX-2/blob/main/packages/ltx-pipelines/src/ltx_pipelines/utils/constants.py
|
||||
DISTILLED_SIGMA_VALUES = [1.0, 0.99375, 0.9875, 0.98125, 0.975, 0.909375, 0.725, 0.421875]
|
||||
|
||||
# Reduced schedule for super-resolution stage 2 (subset of distilled values)
|
||||
STAGE_2_DISTILLED_SIGMA_VALUES = [0.909375, 0.725, 0.421875]
|
||||
@@ -222,7 +222,57 @@ class LTX2PipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
)
|
||||
expected_audio_slice = torch.tensor(
|
||||
[
|
||||
0.0236, 0.0499, 0.1230, 0.1094, 0.1713, 0.1044, 0.1729, 0.1009, 0.0672, -0.0069, 0.0688, 0.0097, 0.0808, 0.1231, 0.0986, 0.0739
|
||||
0.0263, 0.0528, 0.1217, 0.1104, 0.1632, 0.1072, 0.1789, 0.0949, 0.0672, -0.0069, 0.0688, 0.0097, 0.0808, 0.1231, 0.0986, 0.0739
|
||||
]
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
video = video.flatten()
|
||||
audio = audio.flatten()
|
||||
generated_video_slice = torch.cat([video[:8], video[-8:]])
|
||||
generated_audio_slice = torch.cat([audio[:8], audio[-8:]])
|
||||
|
||||
assert torch.allclose(expected_video_slice, generated_video_slice, atol=1e-4, rtol=1e-4)
|
||||
assert torch.allclose(expected_audio_slice, generated_audio_slice, atol=1e-4, rtol=1e-4)
|
||||
|
||||
def test_two_stages_inference(self):
|
||||
device = "cpu"
|
||||
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.to(device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
inputs["output_type"] = "latent"
|
||||
first_stage_output = pipe(**inputs)
|
||||
video_latent = first_stage_output.frames
|
||||
audio_latent = first_stage_output.audio
|
||||
|
||||
self.assertEqual(video_latent.shape, (1, 4, 3, 16, 16))
|
||||
self.assertEqual(audio_latent.shape, (1, 2, 5, 2))
|
||||
self.assertEqual(audio_latent.shape[1], components["vocoder"].config.out_channels)
|
||||
|
||||
inputs["latents"] = video_latent
|
||||
inputs["audio_latents"] = audio_latent
|
||||
inputs["output_type"] = "pt"
|
||||
second_stage_output = pipe(**inputs)
|
||||
video = second_stage_output.frames
|
||||
audio = second_stage_output.audio
|
||||
|
||||
self.assertEqual(video.shape, (1, 5, 3, 32, 32))
|
||||
self.assertEqual(audio.shape[0], 1)
|
||||
self.assertEqual(audio.shape[1], components["vocoder"].config.out_channels)
|
||||
|
||||
# fmt: off
|
||||
expected_video_slice = torch.tensor(
|
||||
[
|
||||
0.5514, 0.5943, 0.4260, 0.5971, 0.4306, 0.6369, 0.3124, 0.6964, 0.5419, 0.2412, 0.3882, 0.4504, 0.1941, 0.3404, 0.6037, 0.2464
|
||||
]
|
||||
)
|
||||
expected_audio_slice = torch.tensor(
|
||||
[
|
||||
0.0252, 0.0526, 0.1211, 0.1119, 0.1638, 0.1042, 0.1776, 0.0948, 0.0672, -0.0069, 0.0688, 0.0097, 0.0808, 0.1231, 0.0986, 0.0739
|
||||
]
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
@@ -224,7 +224,57 @@ class LTX2ImageToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
)
|
||||
expected_audio_slice = torch.tensor(
|
||||
[
|
||||
0.0236, 0.0499, 0.1230, 0.1094, 0.1713, 0.1044, 0.1729, 0.1009, 0.0672, -0.0069, 0.0688, 0.0097, 0.0808, 0.1231, 0.0986, 0.0739
|
||||
0.0294, 0.0498, 0.1269, 0.1135, 0.1639, 0.1116, 0.1730, 0.0931, 0.0672, -0.0069, 0.0688, 0.0097, 0.0808, 0.1231, 0.0986, 0.0739
|
||||
]
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
video = video.flatten()
|
||||
audio = audio.flatten()
|
||||
generated_video_slice = torch.cat([video[:8], video[-8:]])
|
||||
generated_audio_slice = torch.cat([audio[:8], audio[-8:]])
|
||||
|
||||
assert torch.allclose(expected_video_slice, generated_video_slice, atol=1e-4, rtol=1e-4)
|
||||
assert torch.allclose(expected_audio_slice, generated_audio_slice, atol=1e-4, rtol=1e-4)
|
||||
|
||||
def test_two_stages_inference(self):
|
||||
device = "cpu"
|
||||
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.to(device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
inputs["output_type"] = "latent"
|
||||
first_stage_output = pipe(**inputs)
|
||||
video_latent = first_stage_output.frames
|
||||
audio_latent = first_stage_output.audio
|
||||
|
||||
self.assertEqual(video_latent.shape, (1, 4, 3, 16, 16))
|
||||
self.assertEqual(audio_latent.shape, (1, 2, 5, 2))
|
||||
self.assertEqual(audio_latent.shape[1], components["vocoder"].config.out_channels)
|
||||
|
||||
inputs["latents"] = video_latent
|
||||
inputs["audio_latents"] = audio_latent
|
||||
inputs["output_type"] = "pt"
|
||||
second_stage_output = pipe(**inputs)
|
||||
video = second_stage_output.frames
|
||||
audio = second_stage_output.audio
|
||||
|
||||
self.assertEqual(video.shape, (1, 5, 3, 32, 32))
|
||||
self.assertEqual(audio.shape[0], 1)
|
||||
self.assertEqual(audio.shape[1], components["vocoder"].config.out_channels)
|
||||
|
||||
# fmt: off
|
||||
expected_video_slice = torch.tensor(
|
||||
[
|
||||
0.2665, 0.6915, 0.2939, 0.6767, 0.2552, 0.6215, 0.1765, 0.6248, 0.2800, 0.2356, 0.3480, 0.5395, 0.3190, 0.4128, 0.4784, 0.4086
|
||||
]
|
||||
)
|
||||
expected_audio_slice = torch.tensor(
|
||||
[
|
||||
0.0273, 0.0490, 0.1253, 0.1129, 0.1655, 0.1057, 0.1707, 0.0943, 0.0672, -0.0069, 0.0688, 0.0097, 0.0808, 0.1231, 0.0986, 0.0739
|
||||
]
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
Reference in New Issue
Block a user