mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
add time conditioning conversion and token packing for latents
This commit is contained in:
@@ -63,6 +63,8 @@ LTX_2_0_VIDEO_VAE_RENAME_DICT = {
|
||||
"up_blocks.4": "up_blocks.1",
|
||||
"up_blocks.5": "up_blocks.2.upsamplers.0",
|
||||
"up_blocks.6": "up_blocks.2",
|
||||
"last_time_embedder": "time_embedder",
|
||||
"last_scale_shift_table": "scale_shift_table",
|
||||
# Common
|
||||
# For all 3D ResNets
|
||||
"res_blocks": "resnets",
|
||||
@@ -372,7 +374,7 @@ def convert_ltx2_connectors(original_state_dict: Dict[str, Any], version: str) -
|
||||
return connectors
|
||||
|
||||
|
||||
def get_ltx2_video_vae_config(version: str) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
|
||||
def get_ltx2_video_vae_config(version: str, timestep_conditioning: bool = False) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
|
||||
if version == "test":
|
||||
config = {
|
||||
"model_id": "diffusers-internal-dev/dummy-ltx2",
|
||||
@@ -396,7 +398,7 @@ def get_ltx2_video_vae_config(version: str) -> Tuple[Dict[str, Any], Dict[str, A
|
||||
"downsample_type": ("spatial", "temporal", "spatiotemporal", "spatiotemporal"),
|
||||
"upsample_residual": (True, True, True),
|
||||
"upsample_factor": (2, 2, 2),
|
||||
"timestep_conditioning": False,
|
||||
"timestep_conditioning": timestep_conditioning,
|
||||
"patch_size": 4,
|
||||
"patch_size_t": 1,
|
||||
"resnet_norm_eps": 1e-6,
|
||||
@@ -433,7 +435,7 @@ def get_ltx2_video_vae_config(version: str) -> Tuple[Dict[str, Any], Dict[str, A
|
||||
"downsample_type": ("spatial", "temporal", "spatiotemporal", "spatiotemporal"),
|
||||
"upsample_residual": (True, True, True),
|
||||
"upsample_factor": (2, 2, 2),
|
||||
"timestep_conditioning": False,
|
||||
"timestep_conditioning": timestep_conditioning,
|
||||
"patch_size": 4,
|
||||
"patch_size_t": 1,
|
||||
"resnet_norm_eps": 1e-6,
|
||||
@@ -450,8 +452,8 @@ def get_ltx2_video_vae_config(version: str) -> Tuple[Dict[str, Any], Dict[str, A
|
||||
return config, rename_dict, special_keys_remap
|
||||
|
||||
|
||||
def convert_ltx2_video_vae(original_state_dict: Dict[str, Any], version: str) -> Dict[str, Any]:
|
||||
config, rename_dict, special_keys_remap = get_ltx2_video_vae_config(version)
|
||||
def convert_ltx2_video_vae(original_state_dict: Dict[str, Any], version: str, timestep_conditioning: bool) -> Dict[str, Any]:
|
||||
config, rename_dict, special_keys_remap = get_ltx2_video_vae_config(version, timestep_conditioning)
|
||||
diffusers_config = config["diffusers_config"]
|
||||
|
||||
with init_empty_weights():
|
||||
@@ -717,6 +719,7 @@ def get_args():
|
||||
help="Latent upsampler filename",
|
||||
)
|
||||
|
||||
parser.add_argument("--timestep_conditioning", action="store_true", help="Whether to add timestep condition to the video VAE model")
|
||||
parser.add_argument("--vae", action="store_true", help="Whether to convert the video VAE model")
|
||||
parser.add_argument("--audio_vae", action="store_true", help="Whether to convert the audio VAE model")
|
||||
parser.add_argument("--dit", action="store_true", help="Whether to convert the DiT model")
|
||||
@@ -786,7 +789,7 @@ def main(args):
|
||||
original_vae_ckpt = load_hub_or_local_checkpoint(filename=args.vae_filename)
|
||||
elif combined_ckpt is not None:
|
||||
original_vae_ckpt = get_model_state_dict_from_combined_ckpt(combined_ckpt, args.vae_prefix)
|
||||
vae = convert_ltx2_video_vae(original_vae_ckpt, version=args.version)
|
||||
vae = convert_ltx2_video_vae(original_vae_ckpt, version=args.version, timestep_conditioning=args.timestep_conditioning)
|
||||
if not args.full_pipeline and not args.upsample_pipeline:
|
||||
vae.to(vae_dtype).save_pretrained(os.path.join(args.output_path, "vae"))
|
||||
|
||||
|
||||
@@ -653,6 +653,11 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMix
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if latents is not None:
|
||||
if latents.ndim == 5:
|
||||
# latents are of shape [B, C, F, H, W], need to be packed
|
||||
latents = self._pack_latents(
|
||||
latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size
|
||||
)
|
||||
return latents.to(device=device, dtype=dtype)
|
||||
|
||||
height = height // self.vae_spatial_compression_ratio
|
||||
@@ -694,6 +699,9 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMix
|
||||
latent_length = round(duration_s * latents_per_second)
|
||||
|
||||
if latents is not None:
|
||||
if latents.ndim == 4:
|
||||
# latents are of shape [B, C, L, M], need to be packed
|
||||
latents = self._pack_audio_latents(latents)
|
||||
return latents.to(device=device, dtype=dtype), latent_length
|
||||
|
||||
# TODO: confirm whether this logic is correct
|
||||
@@ -1097,6 +1105,8 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMix
|
||||
self.transformer_spatial_patch_size,
|
||||
self.transformer_temporal_patch_size,
|
||||
)
|
||||
prenorm_latents = latents
|
||||
prenorm_audio_latents = audio_latents
|
||||
latents = self._denormalize_latents(
|
||||
latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user