diff --git a/docs/source/en/api/pipelines/ltx2.md b/docs/source/en/api/pipelines/ltx2.md index 4c6860daf0..24776b4230 100644 --- a/docs/source/en/api/pipelines/ltx2.md +++ b/docs/source/en/api/pipelines/ltx2.md @@ -24,6 +24,179 @@ You can find all the original LTX-Video checkpoints under the [Lightricks](https The original codebase for LTX-2 can be found [here](https://github.com/Lightricks/LTX-2). +## Two-stages Generation +Recommended pipeline to achieve production quality generation, this pipeline is composed of two stages: + +- Stage 1: Generate a video at the target resolution using diffusion sampling with classifier-free guidance (CFG). This stage produces a coherent low-noise video sequence that respects the text/image conditioning. +- Stage 2: Upsample the Stage 1 output by 2 and refine details using a distilled LoRA model to improve fidelity and visual quality. Stage 2 may apply lighter CFG to preserve the structure from Stage 1 while enhancing texture and sharpness. + +Sample usage of text-to-video two stages pipeline + +```py +import torch +from diffusers import FlowMatchEulerDiscreteScheduler +from diffusers.pipelines.ltx2 import LTX2Pipeline, LTX2LatentUpsamplePipeline +from diffusers.pipelines.ltx2.latent_upsampler import LTX2LatentUpsamplerModel +from diffusers.pipelines.ltx2.utils import STAGE_2_DISTILLED_SIGMA_VALUES +from diffusers.pipelines.ltx2.export_utils import encode_video + +device = "cuda:0" +width = 768 +height = 512 + +pipe = LTX2Pipeline.from_pretrained( + "Lightricks/LTX-2", torch_dtype=torch.bfloat16 +) +pipe.enable_sequential_cpu_offload(device=device) + +prompt = "A beautiful sunset over the ocean" +negative_prompt = "shaky, glitchy, low quality, worst quality, deformed, distorted, disfigured, motion smear, motion artifacts, fused fingers, bad anatomy, weird hand, ugly, transition, static." + +# Stage 1 default (non-distilled) inference +frame_rate = 24.0 +video_latent, audio_latent = pipe( + prompt=prompt, + negative_prompt=negative_prompt, + width=width, + height=height, + num_frames=121, + frame_rate=frame_rate, + num_inference_steps=40, + sigmas=None, + guidance_scale=4.0, + output_type="latent", + return_dict=False, +) + +latent_upsampler = LTX2LatentUpsamplerModel.from_pretrained( + "Lightricks/LTX-2", + subfolder="latent_upsampler", + torch_dtype=torch.bfloat16, +) +upsample_pipe = LTX2LatentUpsamplePipeline(vae=pipe.vae, latent_upsampler=latent_upsampler) +upsample_pipe.enable_model_cpu_offload(device=device) +upscaled_video_latent = upsample_pipe( + latents=video_latent, + output_type="latent", + return_dict=False, +)[0] + +# Load Stage 2 distilled LoRA +pipe.load_lora_weights( + "Lightricks/LTX-2", adapter_name="stage_2_distilled", weight_name="ltx-2-19b-distilled-lora-384.safetensors" +) +pipe.set_adapters("stage_2_distilled", 1.0) +# VAE tiling is usually necessary to avoid OOM error when VAE decoding +pipe.vae.enable_tiling() +# Change scheduler to use Stage 2 distilled sigmas as is +new_scheduler = FlowMatchEulerDiscreteScheduler.from_config( + pipe.scheduler.config, use_dynamic_shifting=False, shift_terminal=None +) +pipe.scheduler = new_scheduler +# Stage 2 inference with distilled LoRA and sigmas +video, audio = pipe( + latents=upscaled_video_latent, + audio_latents=audio_latent, + prompt=prompt, + negative_prompt=negative_prompt, + num_inference_steps=3, + noise_scale=STAGE_2_DISTILLED_SIGMA_VALUES[0], # renoise with first sigma value https://github.com/Lightricks/LTX-2/blob/main/packages/ltx-pipelines/src/ltx_pipelines/ti2vid_two_stages.py#L218 + sigmas=STAGE_2_DISTILLED_SIGMA_VALUES, + guidance_scale=1.0, + output_type="np", + return_dict=False, +) +video = (video * 255).round().astype("uint8") +video = torch.from_numpy(video) + +encode_video( + video[0], + fps=frame_rate, + audio=audio[0].float().cpu(), + audio_sample_rate=pipe.vocoder.config.output_sampling_rate, + output_path="ltx2_lora_distilled_sample.mp4", +) +``` + +## Distilled checkpoint generation +Fastest two-stages generation pipeline using a distilled checkpoint. + +```py +import torch +from diffusers.pipelines.ltx2 import LTX2Pipeline, LTX2LatentUpsamplePipeline +from diffusers.pipelines.ltx2.latent_upsampler import LTX2LatentUpsamplerModel +from diffusers.pipelines.ltx2.utils import DISTILLED_SIGMA_VALUES, STAGE_2_DISTILLED_SIGMA_VALUES +from diffusers.pipelines.ltx2.export_utils import encode_video + +device = "cuda" +width = 768 +height = 512 +random_seed = 42 +generator = torch.Generator(device).manual_seed(random_seed) +model_path = "rootonchair/LTX-2-19b-distilled" + +pipe = LTX2Pipeline.from_pretrained( + model_path, torch_dtype=torch.bfloat16 +) +pipe.enable_sequential_cpu_offload(device=device) + +prompt = "A beautiful sunset over the ocean" +negative_prompt = "shaky, glitchy, low quality, worst quality, deformed, distorted, disfigured, motion smear, motion artifacts, fused fingers, bad anatomy, weird hand, ugly, transition, static." + +frame_rate = 24.0 +video_latent, audio_latent = pipe( + prompt=prompt, + negative_prompt=negative_prompt, + width=width, + height=height, + num_frames=121, + frame_rate=frame_rate, + num_inference_steps=8, + sigmas=DISTILLED_SIGMA_VALUES, + guidance_scale=1.0, + generator=generator, + output_type="latent", + return_dict=False, +) + +latent_upsampler = LTX2LatentUpsamplerModel.from_pretrained( + model_path, + subfolder="latent_upsampler", + torch_dtype=torch.bfloat16, +) +upsample_pipe = LTX2LatentUpsamplePipeline(vae=pipe.vae, latent_upsampler=latent_upsampler) +upsample_pipe.enable_model_cpu_offload(device=device) +upscaled_video_latent = upsample_pipe( + latents=video_latent, + output_type="latent", + return_dict=False, +)[0] + +video, audio = pipe( + latents=upscaled_video_latent, + audio_latents=audio_latent, + prompt=prompt, + negative_prompt=negative_prompt, + num_inference_steps=3, + noise_scale=STAGE_2_DISTILLED_SIGMA_VALUES[0], # renoise with first sigma value https://github.com/Lightricks/LTX-2/blob/main/packages/ltx-pipelines/src/ltx_pipelines/distilled.py#L178 + sigmas=STAGE_2_DISTILLED_SIGMA_VALUES, + generator=generator, + guidance_scale=1.0, + output_type="np", + return_dict=False, +) +video = (video * 255).round().astype("uint8") +video = torch.from_numpy(video) + +encode_video( + video[0], + fps=frame_rate, + audio=audio[0].float().cpu(), + audio_sample_rate=pipe.vocoder.config.output_sampling_rate, + output_path="ltx2_distilled_sample.mp4", +) +``` + ## LTX2Pipeline [[autodoc]] LTX2Pipeline diff --git a/scripts/convert_ltx2_to_diffusers.py b/scripts/convert_ltx2_to_diffusers.py index 5367113365..88a9c8700a 100644 --- a/scripts/convert_ltx2_to_diffusers.py +++ b/scripts/convert_ltx2_to_diffusers.py @@ -63,6 +63,8 @@ LTX_2_0_VIDEO_VAE_RENAME_DICT = { "up_blocks.4": "up_blocks.1", "up_blocks.5": "up_blocks.2.upsamplers.0", "up_blocks.6": "up_blocks.2", + "last_time_embedder": "time_embedder", + "last_scale_shift_table": "scale_shift_table", # Common # For all 3D ResNets "res_blocks": "resnets", @@ -372,7 +374,9 @@ def convert_ltx2_connectors(original_state_dict: Dict[str, Any], version: str) - return connectors -def get_ltx2_video_vae_config(version: str) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]: +def get_ltx2_video_vae_config( + version: str, timestep_conditioning: bool = False +) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]: if version == "test": config = { "model_id": "diffusers-internal-dev/dummy-ltx2", @@ -396,7 +400,7 @@ def get_ltx2_video_vae_config(version: str) -> Tuple[Dict[str, Any], Dict[str, A "downsample_type": ("spatial", "temporal", "spatiotemporal", "spatiotemporal"), "upsample_residual": (True, True, True), "upsample_factor": (2, 2, 2), - "timestep_conditioning": False, + "timestep_conditioning": timestep_conditioning, "patch_size": 4, "patch_size_t": 1, "resnet_norm_eps": 1e-6, @@ -433,7 +437,7 @@ def get_ltx2_video_vae_config(version: str) -> Tuple[Dict[str, Any], Dict[str, A "downsample_type": ("spatial", "temporal", "spatiotemporal", "spatiotemporal"), "upsample_residual": (True, True, True), "upsample_factor": (2, 2, 2), - "timestep_conditioning": False, + "timestep_conditioning": timestep_conditioning, "patch_size": 4, "patch_size_t": 1, "resnet_norm_eps": 1e-6, @@ -450,8 +454,10 @@ def get_ltx2_video_vae_config(version: str) -> Tuple[Dict[str, Any], Dict[str, A return config, rename_dict, special_keys_remap -def convert_ltx2_video_vae(original_state_dict: Dict[str, Any], version: str) -> Dict[str, Any]: - config, rename_dict, special_keys_remap = get_ltx2_video_vae_config(version) +def convert_ltx2_video_vae( + original_state_dict: Dict[str, Any], version: str, timestep_conditioning: bool +) -> Dict[str, Any]: + config, rename_dict, special_keys_remap = get_ltx2_video_vae_config(version, timestep_conditioning) diffusers_config = config["diffusers_config"] with init_empty_weights(): @@ -659,10 +665,15 @@ def get_model_state_dict_from_combined_ckpt(combined_ckpt: Dict[str, Any], prefi def get_args(): parser = argparse.ArgumentParser() + def none_or_str(value: str): + if isinstance(value, str) and value.lower() == "none": + return None + return value + parser.add_argument( "--original_state_dict_repo_id", default="Lightricks/LTX-2", - type=str, + type=none_or_str, help="HF Hub repo id with LTX 2.0 checkpoint", ) parser.add_argument( @@ -682,7 +693,7 @@ def get_args(): parser.add_argument( "--combined_filename", default="ltx-2-19b-dev.safetensors", - type=str, + type=none_or_str, help="Filename for combined checkpoint with all LTX 2.0 models (VAE, DiT, etc.)", ) parser.add_argument("--vae_prefix", default="vae.", type=str) @@ -701,22 +712,25 @@ def get_args(): parser.add_argument( "--text_encoder_model_id", default="google/gemma-3-12b-it-qat-q4_0-unquantized", - type=str, + type=none_or_str, help="HF Hub id for the LTX 2.0 base text encoder model", ) parser.add_argument( "--tokenizer_id", default="google/gemma-3-12b-it-qat-q4_0-unquantized", - type=str, + type=none_or_str, help="HF Hub id for the LTX 2.0 text tokenizer", ) parser.add_argument( "--latent_upsampler_filename", default="ltx-2-spatial-upscaler-x2-1.0.safetensors", - type=str, + type=none_or_str, help="Latent upsampler filename", ) + parser.add_argument( + "--timestep_conditioning", action="store_true", help="Whether to add timestep condition to the video VAE model" + ) parser.add_argument("--vae", action="store_true", help="Whether to convert the video VAE model") parser.add_argument("--audio_vae", action="store_true", help="Whether to convert the audio VAE model") parser.add_argument("--dit", action="store_true", help="Whether to convert the DiT model") @@ -786,7 +800,9 @@ def main(args): original_vae_ckpt = load_hub_or_local_checkpoint(filename=args.vae_filename) elif combined_ckpt is not None: original_vae_ckpt = get_model_state_dict_from_combined_ckpt(combined_ckpt, args.vae_prefix) - vae = convert_ltx2_video_vae(original_vae_ckpt, version=args.version) + vae = convert_ltx2_video_vae( + original_vae_ckpt, version=args.version, timestep_conditioning=args.timestep_conditioning + ) if not args.full_pipeline and not args.upsample_pipeline: vae.to(vae_dtype).save_pretrained(os.path.join(args.output_path, "vae")) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_ltx2_audio.py b/src/diffusers/models/autoencoders/autoencoder_kl_ltx2_audio.py index 6c9c7dce3d..b29629a29f 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_ltx2_audio.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_ltx2_audio.py @@ -743,8 +743,8 @@ class AutoencoderKLLTX2Audio(ModelMixin, AutoencoderMixin, ConfigMixin): # Per-channel statistics for normalizing and denormalizing the latent representation. This statics is computed over # the entire dataset and stored in model's checkpoint under AudioVAE state_dict - latents_std = torch.zeros((base_channels,)) - latents_mean = torch.ones((base_channels,)) + latents_std = torch.ones((base_channels,)) + latents_mean = torch.zeros((base_channels,)) self.register_buffer("latents_mean", latents_mean, persistent=True) self.register_buffer("latents_std", latents_std, persistent=True) diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2.py index 9cf8479263..a92a7a2c88 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2.py @@ -584,6 +584,17 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin): latents = latents.permute(0, 4, 1, 5, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(2, 3) return latents + @staticmethod + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2_image2video.LTX2ImageToVideoPipeline._normalize_latents + def _normalize_latents( + latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0 + ) -> torch.Tensor: + # Normalize latents across the channel dimension [B, C, F, H, W] + latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents = (latents - latents_mean) * scaling_factor / latents_std + return latents + @staticmethod def _denormalize_latents( latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0 @@ -594,12 +605,26 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin): latents = latents * latents_std / scaling_factor + latents_mean return latents + @staticmethod + def _normalize_audio_latents(latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor): + latents_mean = latents_mean.to(latents.device, latents.dtype) + latents_std = latents_std.to(latents.device, latents.dtype) + return (latents - latents_mean) / latents_std + @staticmethod def _denormalize_audio_latents(latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor): latents_mean = latents_mean.to(latents.device, latents.dtype) latents_std = latents_std.to(latents.device, latents.dtype) return (latents * latents_std) + latents_mean + @staticmethod + def _create_noised_state( + latents: torch.Tensor, noise_scale: Union[float, torch.Tensor], generator: Optional[torch.Generator] = None + ): + noise = randn_tensor(latents.shape, generator=generator, device=latents.device, dtype=latents.dtype) + noised_latents = noise_scale * noise + (1 - noise_scale) * latents + return noised_latents + @staticmethod def _pack_audio_latents( latents: torch.Tensor, patch_size: Optional[int] = None, patch_size_t: Optional[int] = None @@ -647,12 +672,26 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin): height: int = 512, width: int = 768, num_frames: int = 121, + noise_scale: float = 0.0, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, generator: Optional[torch.Generator] = None, latents: Optional[torch.Tensor] = None, ) -> torch.Tensor: if latents is not None: + if latents.ndim == 5: + latents = self._normalize_latents( + latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor + ) + # latents are of shape [B, C, F, H, W], need to be packed + latents = self._pack_latents( + latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size + ) + if latents.ndim != 3: + raise ValueError( + f"Provided `latents` tensor has shape {latents.shape}, but the expected shape is [batch_size, num_seq, num_features]." + ) + latents = self._create_noised_state(latents, noise_scale, generator) return latents.to(device=device, dtype=dtype) height = height // self.vae_spatial_compression_ratio @@ -677,29 +716,30 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin): self, batch_size: int = 1, num_channels_latents: int = 8, + audio_latent_length: int = 1, # 1 is just a dummy value num_mel_bins: int = 64, - num_frames: int = 121, - frame_rate: float = 25.0, - sampling_rate: int = 16000, - hop_length: int = 160, + noise_scale: float = 0.0, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, generator: Optional[torch.Generator] = None, latents: Optional[torch.Tensor] = None, ) -> torch.Tensor: - duration_s = num_frames / frame_rate - latents_per_second = ( - float(sampling_rate) / float(hop_length) / float(self.audio_vae_temporal_compression_ratio) - ) - latent_length = round(duration_s * latents_per_second) - if latents is not None: - return latents.to(device=device, dtype=dtype), latent_length + if latents.ndim == 4: + # latents are of shape [B, C, L, M], need to be packed + latents = self._pack_audio_latents(latents) + if latents.ndim != 3: + raise ValueError( + f"Provided `latents` tensor has shape {latents.shape}, but the expected shape is [batch_size, num_seq, num_features]." + ) + latents = self._normalize_audio_latents(latents, self.audio_vae.latents_mean, self.audio_vae.latents_std) + latents = self._create_noised_state(latents, noise_scale, generator) + return latents.to(device=device, dtype=dtype) # TODO: confirm whether this logic is correct latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio - shape = (batch_size, num_channels_latents, latent_length, latent_mel_bins) + shape = (batch_size, num_channels_latents, audio_latent_length, latent_mel_bins) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( @@ -709,7 +749,7 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin): latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) latents = self._pack_audio_latents(latents) - return latents, latent_length + return latents @property def guidance_scale(self): @@ -750,9 +790,11 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin): num_frames: int = 121, frame_rate: float = 24.0, num_inference_steps: int = 40, + sigmas: Optional[List[float]] = None, timesteps: List[int] = None, guidance_scale: float = 4.0, guidance_rescale: float = 0.0, + noise_scale: float = 0.0, num_videos_per_prompt: Optional[int] = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, @@ -788,6 +830,10 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin): num_inference_steps (`int`, *optional*, defaults to 40): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. timesteps (`List[int]`, *optional*): Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is @@ -804,6 +850,9 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin): [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891). Guidance rescale factor should fix overexposure when using zero terminal SNR. + noise_scale (`float`, *optional*, defaults to `0.0`): + The interpolation factor between random noise and denoised latents at each timestep. Applying noise to + the `latents` and `audio_latents` before continue denoising. num_videos_per_prompt (`int`, *optional*, defaults to 1): The number of videos to generate per prompt. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): @@ -922,6 +971,21 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin): latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1 latent_height = height // self.vae_spatial_compression_ratio latent_width = width // self.vae_spatial_compression_ratio + if latents is not None: + if latents.ndim == 5: + logger.info( + "Got latents of shape [batch_size, latent_dim, latent_frames, latent_height, latent_width], `latent_num_frames`, `latent_height`, `latent_width` will be inferred." + ) + _, _, latent_num_frames, latent_height, latent_width = latents.shape # [B, C, F, H, W] + elif latents.ndim == 3: + logger.warning( + f"You have supplied packed `latents` of shape {latents.shape}, so the latent dims cannot be" + f" inferred. Make sure the supplied `height`, `width`, and `num_frames` are correct." + ) + else: + raise ValueError( + f"Provided `latents` tensor has shape {latents.shape}, but the expected shape is either [batch_size, seq_len, num_features] or [batch_size, latent_dim, latent_frames, latent_height, latent_width]." + ) video_sequence_length = latent_num_frames * latent_height * latent_width num_channels_latents = self.transformer.config.in_channels @@ -931,26 +995,45 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin): height, width, num_frames, + noise_scale, torch.float32, device, generator, latents, ) + duration_s = num_frames / frame_rate + audio_latents_per_second = ( + self.audio_sampling_rate / self.audio_hop_length / float(self.audio_vae_temporal_compression_ratio) + ) + audio_num_frames = round(duration_s * audio_latents_per_second) + if audio_latents is not None: + if audio_latents.ndim == 4: + logger.info( + "Got audio_latents of shape [batch_size, num_channels, audio_length, mel_bins], `audio_num_frames` will be inferred." + ) + _, _, audio_num_frames, _ = audio_latents.shape # [B, C, L, M] + elif audio_latents.ndim == 3: + logger.warning( + f"You have supplied packed `audio_latents` of shape {audio_latents.shape}, so the latent dims" + f" cannot be inferred. Make sure the supplied `num_frames` and `frame_rate` are correct." + ) + else: + raise ValueError( + f"Provided `audio_latents` tensor has shape {audio_latents.shape}, but the expected shape is either [batch_size, seq_len, num_features] or [batch_size, num_channels, audio_length, mel_bins]." + ) + num_mel_bins = self.audio_vae.config.mel_bins if getattr(self, "audio_vae", None) is not None else 64 latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio - num_channels_latents_audio = ( self.audio_vae.config.latent_channels if getattr(self, "audio_vae", None) is not None else 8 ) - audio_latents, audio_num_frames = self.prepare_audio_latents( + audio_latents = self.prepare_audio_latents( batch_size * num_videos_per_prompt, num_channels_latents=num_channels_latents_audio, + audio_latent_length=audio_num_frames, num_mel_bins=num_mel_bins, - num_frames=num_frames, # Video frames, audio frames will be calculated from this - frame_rate=frame_rate, - sampling_rate=self.audio_sampling_rate, - hop_length=self.audio_hop_length, + noise_scale=noise_scale, dtype=torch.float32, device=device, generator=generator, @@ -958,7 +1041,7 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin): ) # 5. Prepare timesteps - sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas mu = calculate_shift( video_sequence_length, self.scheduler.config.get("base_image_seq_len", 1024), diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py index b1711e2831..04d7ee89c5 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py @@ -614,6 +614,15 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL latents = latents * latents_std / scaling_factor + latents_mean return latents + @staticmethod + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._create_noised_state + def _create_noised_state( + latents: torch.Tensor, noise_scale: Union[float, torch.Tensor], generator: Optional[torch.Generator] = None + ): + noise = randn_tensor(latents.shape, generator=generator, device=latents.device, dtype=latents.dtype) + noised_latents = noise_scale * noise + (1 - noise_scale) * latents + return noised_latents + @staticmethod # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._pack_audio_latents def _pack_audio_latents( @@ -656,6 +665,13 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL latents = latents.unflatten(2, (-1, num_mel_bins)).transpose(1, 2) return latents + @staticmethod + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._normalize_audio_latents + def _normalize_audio_latents(latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor): + latents_mean = latents_mean.to(latents.device, latents.dtype) + latents_std = latents_std.to(latents.device, latents.dtype) + return (latents - latents_mean) / latents_std + @staticmethod # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._denormalize_audio_latents def _denormalize_audio_latents(latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor): @@ -671,6 +687,7 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL height: int = 512, width: int = 704, num_frames: int = 161, + noise_scale: float = 0.0, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, generator: Optional[torch.Generator] = None, @@ -686,6 +703,15 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL if latents is not None: conditioning_mask = latents.new_zeros(mask_shape) conditioning_mask[:, :, 0] = 1.0 + if latents.ndim == 5: + latents = self._normalize_latents( + latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor + ) + latents = self._create_noised_state(latents, noise_scale * (1 - conditioning_mask), generator) + # latents are of shape [B, C, F, H, W], need to be packed + latents = self._pack_latents( + latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size + ) conditioning_mask = self._pack_latents( conditioning_mask, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size ).squeeze(-1) @@ -737,29 +763,30 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL self, batch_size: int = 1, num_channels_latents: int = 8, + audio_latent_length: int = 1, # 1 is just a dummy value num_mel_bins: int = 64, - num_frames: int = 121, - frame_rate: float = 25.0, - sampling_rate: int = 16000, - hop_length: int = 160, + noise_scale: float = 0.0, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, generator: Optional[torch.Generator] = None, latents: Optional[torch.Tensor] = None, ) -> torch.Tensor: - duration_s = num_frames / frame_rate - latents_per_second = ( - float(sampling_rate) / float(hop_length) / float(self.audio_vae_temporal_compression_ratio) - ) - latent_length = round(duration_s * latents_per_second) - if latents is not None: - return latents.to(device=device, dtype=dtype), latent_length + if latents.ndim == 4: + # latents are of shape [B, C, L, M], need to be packed + latents = self._pack_audio_latents(latents) + if latents.ndim != 3: + raise ValueError( + f"Provided `latents` tensor has shape {latents.shape}, but the expected shape is [batch_size, num_seq, num_features]." + ) + latents = self._normalize_audio_latents(latents, self.audio_vae.latents_mean, self.audio_vae.latents_std) + latents = self._create_noised_state(latents, noise_scale, generator) + return latents.to(device=device, dtype=dtype) # TODO: confirm whether this logic is correct latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio - shape = (batch_size, num_channels_latents, latent_length, latent_mel_bins) + shape = (batch_size, num_channels_latents, audio_latent_length, latent_mel_bins) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( @@ -769,7 +796,7 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) latents = self._pack_audio_latents(latents) - return latents, latent_length + return latents @property def guidance_scale(self): @@ -811,9 +838,11 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL num_frames: int = 121, frame_rate: float = 24.0, num_inference_steps: int = 40, + sigmas: Optional[List[float]] = None, timesteps: List[int] = None, guidance_scale: float = 4.0, guidance_rescale: float = 0.0, + noise_scale: float = 0.0, num_videos_per_prompt: Optional[int] = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, @@ -851,6 +880,10 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL num_inference_steps (`int`, *optional*, defaults to 40): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. timesteps (`List[int]`, *optional*): Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is @@ -867,6 +900,9 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891). Guidance rescale factor should fix overexposure when using zero terminal SNR. + noise_scale (`float`, *optional*, defaults to `0.0`): + The interpolation factor between random noise and denoised latents at each timestep. Applying noise to + the `latents` and `audio_latents` before continue denoising. num_videos_per_prompt (`int`, *optional*, defaults to 1): The number of videos to generate per prompt. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): @@ -982,6 +1018,26 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL ) # 4. Prepare latent variables + latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1 + latent_height = height // self.vae_spatial_compression_ratio + latent_width = width // self.vae_spatial_compression_ratio + if latents is not None: + if latents.ndim == 5: + logger.info( + "Got latents of shape [batch_size, latent_dim, latent_frames, latent_height, latent_width], `latent_num_frames`, `latent_height`, `latent_width` will be inferred." + ) + _, _, latent_num_frames, latent_height, latent_width = latents.shape # [B, C, F, H, W] + elif latents.ndim == 3: + logger.warning( + f"You have supplied packed `latents` of shape {latents.shape}, so the latent dims cannot be" + f" inferred. Make sure the supplied `height`, `width`, and `num_frames` are correct." + ) + else: + raise ValueError( + f"Provided `latents` tensor has shape {latents.shape}, but the expected shape is either [batch_size, seq_len, num_features] or [batch_size, latent_dim, latent_frames, latent_height, latent_width]." + ) + video_sequence_length = latent_num_frames * latent_height * latent_width + if latents is None: image = self.video_processor.preprocess(image, height=height, width=width) image = image.to(device=device, dtype=prompt_embeds.dtype) @@ -994,6 +1050,7 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL height, width, num_frames, + noise_scale, torch.float32, device, generator, @@ -1002,20 +1059,38 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL if self.do_classifier_free_guidance: conditioning_mask = torch.cat([conditioning_mask, conditioning_mask]) + duration_s = num_frames / frame_rate + audio_latents_per_second = ( + self.audio_sampling_rate / self.audio_hop_length / float(self.audio_vae_temporal_compression_ratio) + ) + audio_num_frames = round(duration_s * audio_latents_per_second) + if audio_latents is not None: + if audio_latents.ndim == 4: + logger.info( + "Got audio_latents of shape [batch_size, num_channels, audio_length, mel_bins], `audio_num_frames` will be inferred." + ) + _, _, audio_num_frames, _ = audio_latents.shape # [B, C, L, M] + elif audio_latents.ndim == 3: + logger.warning( + f"You have supplied packed `audio_latents` of shape {audio_latents.shape}, so the latent dims" + f" cannot be inferred. Make sure the supplied `num_frames` and `frame_rate` are correct." + ) + else: + raise ValueError( + f"Provided `audio_latents` tensor has shape {audio_latents.shape}, but the expected shape is either [batch_size, seq_len, num_features] or [batch_size, num_channels, audio_length, mel_bins]." + ) + num_mel_bins = self.audio_vae.config.mel_bins if getattr(self, "audio_vae", None) is not None else 64 latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio - num_channels_latents_audio = ( self.audio_vae.config.latent_channels if getattr(self, "audio_vae", None) is not None else 8 ) - audio_latents, audio_num_frames = self.prepare_audio_latents( + audio_latents = self.prepare_audio_latents( batch_size * num_videos_per_prompt, num_channels_latents=num_channels_latents_audio, + audio_latent_length=audio_num_frames, num_mel_bins=num_mel_bins, - num_frames=num_frames, # Video frames, audio frames will be calculated from this - frame_rate=frame_rate, - sampling_rate=self.audio_sampling_rate, - hop_length=self.audio_hop_length, + noise_scale=noise_scale, dtype=torch.float32, device=device, generator=generator, @@ -1023,12 +1098,7 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL ) # 5. Prepare timesteps - latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1 - latent_height = height // self.vae_spatial_compression_ratio - latent_width = width // self.vae_spatial_compression_ratio - video_sequence_length = latent_num_frames * latent_height * latent_width - - sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas mu = calculate_shift( video_sequence_length, self.scheduler.config.get("base_image_seq_len", 1024), diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2_latent_upsample.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2_latent_upsample.py index a44c40b043..340efd10f2 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2_latent_upsample.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2_latent_upsample.py @@ -228,17 +228,6 @@ class LTX2LatentUpsamplePipeline(DiffusionPipeline): filtered = latents * scales return filtered - @staticmethod - # Copied from diffusers.pipelines.ltx2.pipeline_ltx2_image2video.LTX2ImageToVideoPipeline._normalize_latents - def _normalize_latents( - latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0 - ) -> torch.Tensor: - # Normalize latents across the channel dimension [B, C, F, H, W] - latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) - latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) - latents = (latents - latents_mean) * scaling_factor / latents_std - return latents - @staticmethod # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._denormalize_latents def _denormalize_latents( @@ -408,9 +397,6 @@ class LTX2LatentUpsamplePipeline(DiffusionPipeline): latents = self.tone_map_latents(latents, tone_map_compression_ratio) if output_type == "latent": - latents = self._normalize_latents( - latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor - ) video = latents else: if not self.vae.config.timestep_conditioning: diff --git a/src/diffusers/pipelines/ltx2/utils.py b/src/diffusers/pipelines/ltx2/utils.py new file mode 100644 index 0000000000..f80469817f --- /dev/null +++ b/src/diffusers/pipelines/ltx2/utils.py @@ -0,0 +1,6 @@ +# Pre-trained sigma values for distilled model are taken from +# https://github.com/Lightricks/LTX-2/blob/main/packages/ltx-pipelines/src/ltx_pipelines/utils/constants.py +DISTILLED_SIGMA_VALUES = [1.0, 0.99375, 0.9875, 0.98125, 0.975, 0.909375, 0.725, 0.421875] + +# Reduced schedule for super-resolution stage 2 (subset of distilled values) +STAGE_2_DISTILLED_SIGMA_VALUES = [0.909375, 0.725, 0.421875] diff --git a/tests/pipelines/ltx2/test_ltx2.py b/tests/pipelines/ltx2/test_ltx2.py index 6ffc237250..7d1a3bfc99 100644 --- a/tests/pipelines/ltx2/test_ltx2.py +++ b/tests/pipelines/ltx2/test_ltx2.py @@ -222,7 +222,57 @@ class LTX2PipelineFastTests(PipelineTesterMixin, unittest.TestCase): ) expected_audio_slice = torch.tensor( [ - 0.0236, 0.0499, 0.1230, 0.1094, 0.1713, 0.1044, 0.1729, 0.1009, 0.0672, -0.0069, 0.0688, 0.0097, 0.0808, 0.1231, 0.0986, 0.0739 + 0.0263, 0.0528, 0.1217, 0.1104, 0.1632, 0.1072, 0.1789, 0.0949, 0.0672, -0.0069, 0.0688, 0.0097, 0.0808, 0.1231, 0.0986, 0.0739 + ] + ) + # fmt: on + + video = video.flatten() + audio = audio.flatten() + generated_video_slice = torch.cat([video[:8], video[-8:]]) + generated_audio_slice = torch.cat([audio[:8], audio[-8:]]) + + assert torch.allclose(expected_video_slice, generated_video_slice, atol=1e-4, rtol=1e-4) + assert torch.allclose(expected_audio_slice, generated_audio_slice, atol=1e-4, rtol=1e-4) + + def test_two_stages_inference(self): + device = "cpu" + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + inputs["output_type"] = "latent" + first_stage_output = pipe(**inputs) + video_latent = first_stage_output.frames + audio_latent = first_stage_output.audio + + self.assertEqual(video_latent.shape, (1, 4, 3, 16, 16)) + self.assertEqual(audio_latent.shape, (1, 2, 5, 2)) + self.assertEqual(audio_latent.shape[1], components["vocoder"].config.out_channels) + + inputs["latents"] = video_latent + inputs["audio_latents"] = audio_latent + inputs["output_type"] = "pt" + second_stage_output = pipe(**inputs) + video = second_stage_output.frames + audio = second_stage_output.audio + + self.assertEqual(video.shape, (1, 5, 3, 32, 32)) + self.assertEqual(audio.shape[0], 1) + self.assertEqual(audio.shape[1], components["vocoder"].config.out_channels) + + # fmt: off + expected_video_slice = torch.tensor( + [ + 0.5514, 0.5943, 0.4260, 0.5971, 0.4306, 0.6369, 0.3124, 0.6964, 0.5419, 0.2412, 0.3882, 0.4504, 0.1941, 0.3404, 0.6037, 0.2464 + ] + ) + expected_audio_slice = torch.tensor( + [ + 0.0252, 0.0526, 0.1211, 0.1119, 0.1638, 0.1042, 0.1776, 0.0948, 0.0672, -0.0069, 0.0688, 0.0097, 0.0808, 0.1231, 0.0986, 0.0739 ] ) # fmt: on diff --git a/tests/pipelines/ltx2/test_ltx2_image2video.py b/tests/pipelines/ltx2/test_ltx2_image2video.py index 1edae9c0e0..3653e1cfc5 100644 --- a/tests/pipelines/ltx2/test_ltx2_image2video.py +++ b/tests/pipelines/ltx2/test_ltx2_image2video.py @@ -224,7 +224,57 @@ class LTX2ImageToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ) expected_audio_slice = torch.tensor( [ - 0.0236, 0.0499, 0.1230, 0.1094, 0.1713, 0.1044, 0.1729, 0.1009, 0.0672, -0.0069, 0.0688, 0.0097, 0.0808, 0.1231, 0.0986, 0.0739 + 0.0294, 0.0498, 0.1269, 0.1135, 0.1639, 0.1116, 0.1730, 0.0931, 0.0672, -0.0069, 0.0688, 0.0097, 0.0808, 0.1231, 0.0986, 0.0739 + ] + ) + # fmt: on + + video = video.flatten() + audio = audio.flatten() + generated_video_slice = torch.cat([video[:8], video[-8:]]) + generated_audio_slice = torch.cat([audio[:8], audio[-8:]]) + + assert torch.allclose(expected_video_slice, generated_video_slice, atol=1e-4, rtol=1e-4) + assert torch.allclose(expected_audio_slice, generated_audio_slice, atol=1e-4, rtol=1e-4) + + def test_two_stages_inference(self): + device = "cpu" + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + inputs["output_type"] = "latent" + first_stage_output = pipe(**inputs) + video_latent = first_stage_output.frames + audio_latent = first_stage_output.audio + + self.assertEqual(video_latent.shape, (1, 4, 3, 16, 16)) + self.assertEqual(audio_latent.shape, (1, 2, 5, 2)) + self.assertEqual(audio_latent.shape[1], components["vocoder"].config.out_channels) + + inputs["latents"] = video_latent + inputs["audio_latents"] = audio_latent + inputs["output_type"] = "pt" + second_stage_output = pipe(**inputs) + video = second_stage_output.frames + audio = second_stage_output.audio + + self.assertEqual(video.shape, (1, 5, 3, 32, 32)) + self.assertEqual(audio.shape[0], 1) + self.assertEqual(audio.shape[1], components["vocoder"].config.out_channels) + + # fmt: off + expected_video_slice = torch.tensor( + [ + 0.2665, 0.6915, 0.2939, 0.6767, 0.2552, 0.6215, 0.1765, 0.6248, 0.2800, 0.2356, 0.3480, 0.5395, 0.3190, 0.4128, 0.4784, 0.4086 + ] + ) + expected_audio_slice = torch.tensor( + [ + 0.0273, 0.0490, 0.1253, 0.1129, 0.1655, 0.1057, 0.1707, 0.0943, 0.0672, -0.0069, 0.0688, 0.0097, 0.0808, 0.1231, 0.0986, 0.0739 ] ) # fmt: on