diff --git a/scripts/convert_ltx2_to_diffusers.py b/scripts/convert_ltx2_to_diffusers.py index 5367113365..2794feffed 100644 --- a/scripts/convert_ltx2_to_diffusers.py +++ b/scripts/convert_ltx2_to_diffusers.py @@ -63,6 +63,8 @@ LTX_2_0_VIDEO_VAE_RENAME_DICT = { "up_blocks.4": "up_blocks.1", "up_blocks.5": "up_blocks.2.upsamplers.0", "up_blocks.6": "up_blocks.2", + "last_time_embedder": "time_embedder", + "last_scale_shift_table": "scale_shift_table", # Common # For all 3D ResNets "res_blocks": "resnets", @@ -372,7 +374,7 @@ def convert_ltx2_connectors(original_state_dict: Dict[str, Any], version: str) - return connectors -def get_ltx2_video_vae_config(version: str) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]: +def get_ltx2_video_vae_config(version: str, timestep_conditioning: bool = False) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]: if version == "test": config = { "model_id": "diffusers-internal-dev/dummy-ltx2", @@ -396,7 +398,7 @@ def get_ltx2_video_vae_config(version: str) -> Tuple[Dict[str, Any], Dict[str, A "downsample_type": ("spatial", "temporal", "spatiotemporal", "spatiotemporal"), "upsample_residual": (True, True, True), "upsample_factor": (2, 2, 2), - "timestep_conditioning": False, + "timestep_conditioning": timestep_conditioning, "patch_size": 4, "patch_size_t": 1, "resnet_norm_eps": 1e-6, @@ -433,7 +435,7 @@ def get_ltx2_video_vae_config(version: str) -> Tuple[Dict[str, Any], Dict[str, A "downsample_type": ("spatial", "temporal", "spatiotemporal", "spatiotemporal"), "upsample_residual": (True, True, True), "upsample_factor": (2, 2, 2), - "timestep_conditioning": False, + "timestep_conditioning": timestep_conditioning, "patch_size": 4, "patch_size_t": 1, "resnet_norm_eps": 1e-6, @@ -450,8 +452,8 @@ def get_ltx2_video_vae_config(version: str) -> Tuple[Dict[str, Any], Dict[str, A return config, rename_dict, special_keys_remap -def convert_ltx2_video_vae(original_state_dict: Dict[str, Any], version: str) -> Dict[str, Any]: - config, rename_dict, special_keys_remap = get_ltx2_video_vae_config(version) +def convert_ltx2_video_vae(original_state_dict: Dict[str, Any], version: str, timestep_conditioning: bool) -> Dict[str, Any]: + config, rename_dict, special_keys_remap = get_ltx2_video_vae_config(version, timestep_conditioning) diffusers_config = config["diffusers_config"] with init_empty_weights(): @@ -717,6 +719,7 @@ def get_args(): help="Latent upsampler filename", ) + parser.add_argument("--timestep_conditioning", action="store_true", help="Whether to add timestep condition to the video VAE model") parser.add_argument("--vae", action="store_true", help="Whether to convert the video VAE model") parser.add_argument("--audio_vae", action="store_true", help="Whether to convert the audio VAE model") parser.add_argument("--dit", action="store_true", help="Whether to convert the DiT model") @@ -786,7 +789,7 @@ def main(args): original_vae_ckpt = load_hub_or_local_checkpoint(filename=args.vae_filename) elif combined_ckpt is not None: original_vae_ckpt = get_model_state_dict_from_combined_ckpt(combined_ckpt, args.vae_prefix) - vae = convert_ltx2_video_vae(original_vae_ckpt, version=args.version) + vae = convert_ltx2_video_vae(original_vae_ckpt, version=args.version, timestep_conditioning=args.timestep_conditioning) if not args.full_pipeline and not args.upsample_pipeline: vae.to(vae_dtype).save_pretrained(os.path.join(args.output_path, "vae")) diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2.py index 7d1cd3ce65..c662d9c167 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2.py @@ -653,6 +653,11 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMix latents: Optional[torch.Tensor] = None, ) -> torch.Tensor: if latents is not None: + if latents.ndim == 5: + # latents are of shape [B, C, F, H, W], need to be packed + latents = self._pack_latents( + latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size + ) return latents.to(device=device, dtype=dtype) height = height // self.vae_spatial_compression_ratio @@ -694,6 +699,9 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMix latent_length = round(duration_s * latents_per_second) if latents is not None: + if latents.ndim == 4: + # latents are of shape [B, C, L, M], need to be packed + latents = self._pack_audio_latents(latents) return latents.to(device=device, dtype=dtype), latent_length # TODO: confirm whether this logic is correct @@ -1097,6 +1105,8 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMix self.transformer_spatial_patch_size, self.transformer_temporal_patch_size, ) + prenorm_latents = latents + prenorm_audio_latents = audio_latents latents = self._denormalize_latents( latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor )