1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

add time conditioning conversion and token packing for latents

This commit is contained in:
Pham Hong Vinh
2026-01-11 22:13:58 +07:00
parent 3d78f9d17d
commit 9c754a46aa
2 changed files with 19 additions and 6 deletions

View File

@@ -63,6 +63,8 @@ LTX_2_0_VIDEO_VAE_RENAME_DICT = {
"up_blocks.4": "up_blocks.1",
"up_blocks.5": "up_blocks.2.upsamplers.0",
"up_blocks.6": "up_blocks.2",
"last_time_embedder": "time_embedder",
"last_scale_shift_table": "scale_shift_table",
# Common
# For all 3D ResNets
"res_blocks": "resnets",
@@ -372,7 +374,7 @@ def convert_ltx2_connectors(original_state_dict: Dict[str, Any], version: str) -
return connectors
def get_ltx2_video_vae_config(version: str) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
def get_ltx2_video_vae_config(version: str, timestep_conditioning: bool = False) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
if version == "test":
config = {
"model_id": "diffusers-internal-dev/dummy-ltx2",
@@ -396,7 +398,7 @@ def get_ltx2_video_vae_config(version: str) -> Tuple[Dict[str, Any], Dict[str, A
"downsample_type": ("spatial", "temporal", "spatiotemporal", "spatiotemporal"),
"upsample_residual": (True, True, True),
"upsample_factor": (2, 2, 2),
"timestep_conditioning": False,
"timestep_conditioning": timestep_conditioning,
"patch_size": 4,
"patch_size_t": 1,
"resnet_norm_eps": 1e-6,
@@ -433,7 +435,7 @@ def get_ltx2_video_vae_config(version: str) -> Tuple[Dict[str, Any], Dict[str, A
"downsample_type": ("spatial", "temporal", "spatiotemporal", "spatiotemporal"),
"upsample_residual": (True, True, True),
"upsample_factor": (2, 2, 2),
"timestep_conditioning": False,
"timestep_conditioning": timestep_conditioning,
"patch_size": 4,
"patch_size_t": 1,
"resnet_norm_eps": 1e-6,
@@ -450,8 +452,8 @@ def get_ltx2_video_vae_config(version: str) -> Tuple[Dict[str, Any], Dict[str, A
return config, rename_dict, special_keys_remap
def convert_ltx2_video_vae(original_state_dict: Dict[str, Any], version: str) -> Dict[str, Any]:
config, rename_dict, special_keys_remap = get_ltx2_video_vae_config(version)
def convert_ltx2_video_vae(original_state_dict: Dict[str, Any], version: str, timestep_conditioning: bool) -> Dict[str, Any]:
config, rename_dict, special_keys_remap = get_ltx2_video_vae_config(version, timestep_conditioning)
diffusers_config = config["diffusers_config"]
with init_empty_weights():
@@ -717,6 +719,7 @@ def get_args():
help="Latent upsampler filename",
)
parser.add_argument("--timestep_conditioning", action="store_true", help="Whether to add timestep condition to the video VAE model")
parser.add_argument("--vae", action="store_true", help="Whether to convert the video VAE model")
parser.add_argument("--audio_vae", action="store_true", help="Whether to convert the audio VAE model")
parser.add_argument("--dit", action="store_true", help="Whether to convert the DiT model")
@@ -786,7 +789,7 @@ def main(args):
original_vae_ckpt = load_hub_or_local_checkpoint(filename=args.vae_filename)
elif combined_ckpt is not None:
original_vae_ckpt = get_model_state_dict_from_combined_ckpt(combined_ckpt, args.vae_prefix)
vae = convert_ltx2_video_vae(original_vae_ckpt, version=args.version)
vae = convert_ltx2_video_vae(original_vae_ckpt, version=args.version, timestep_conditioning=args.timestep_conditioning)
if not args.full_pipeline and not args.upsample_pipeline:
vae.to(vae_dtype).save_pretrained(os.path.join(args.output_path, "vae"))

View File

@@ -653,6 +653,11 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMix
latents: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if latents is not None:
if latents.ndim == 5:
# latents are of shape [B, C, F, H, W], need to be packed
latents = self._pack_latents(
latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size
)
return latents.to(device=device, dtype=dtype)
height = height // self.vae_spatial_compression_ratio
@@ -694,6 +699,9 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMix
latent_length = round(duration_s * latents_per_second)
if latents is not None:
if latents.ndim == 4:
# latents are of shape [B, C, L, M], need to be packed
latents = self._pack_audio_latents(latents)
return latents.to(device=device, dtype=dtype), latent_length
# TODO: confirm whether this logic is correct
@@ -1097,6 +1105,8 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMix
self.transformer_spatial_patch_size,
self.transformer_temporal_patch_size,
)
prenorm_latents = latents
prenorm_audio_latents = audio_latents
latents = self._denormalize_latents(
latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor
)