diff --git a/scripts/convert_ltx2_to_diffusers.py b/scripts/convert_ltx2_to_diffusers.py index f2e879c065..eb130a3549 100644 --- a/scripts/convert_ltx2_to_diffusers.py +++ b/scripts/convert_ltx2_to_diffusers.py @@ -8,7 +8,7 @@ import torch from accelerate import init_empty_weights from huggingface_hub import hf_hub_download -from diffusers import AutoencoderKLLTX2Video, LTX2VideoTransformer3DModel +from diffusers import AutoencoderKLLTX2Audio, AutoencoderKLLTX2Video, LTX2VideoTransformer3DModel from diffusers.utils.import_utils import is_accelerate_available from diffusers.pipelines.ltx2.vocoder import LTX2Vocoder @@ -62,6 +62,8 @@ LTX_2_0_VIDEO_VAE_RENAME_DICT = { "per_channel_statistics.std-of-means": "latents_std", } +LTX_2_0_AUDIO_VAE_RENAME_DICT = {} + LTX_2_0_VOCODER_RENAME_DICT = { "ups": "upsamplers", "resblocks": "resnets", @@ -96,6 +98,15 @@ def convert_ltx2_transformer_adaln_single(key: str, state_dict: Dict[str, Any]) return +def convert_ltx2_audio_vae_per_channel_statistics(key: str, state_dict: Dict[str, Any]) -> None: + if key.startswith("per_channel_statistics"): + new_key = ".".join(["decoder", key]) + param = state_dict.pop(key) + state_dict[new_key] = param + + return + + LTX_2_0_TRANSFORMER_SPECIAL_KEYS_REMAP = { "video_embeddings_connector": remove_keys_inplace, "audio_embeddings_connector": remove_keys_inplace, @@ -107,6 +118,11 @@ LTX_2_0_VAE_SPECIAL_KEYS_REMAP = { "per_channel_statistics.mean-of-stds": remove_keys_inplace, } +LTX_2_0_AUDIO_VAE_SPECIAL_KEYS_REMAP = { + "encoder": remove_keys_inplace, + "per_channel_statistics": convert_ltx2_audio_vae_per_channel_statistics, +} + LTX_2_0_VOCODER_SPECIAL_KEYS_REMAP = {} @@ -325,6 +341,60 @@ def convert_ltx2_video_vae(original_state_dict: Dict[str, Any], version: str) -> return vae +def get_ltx2_audio_vae_config(version: str) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]: + if version == "2.0": + config = { + "model_id": "diffusers-internal-dev/new-ltx-model", + "diffusers_config": { + "base_channels": 128, + "output_channels": 2, + "ch_mult": (1, 2, 4), + "num_res_blocks": 2, + "attn_resolutions": None, + "in_channels": 2, + "resolution": 256, + "latent_channels": 8, + "norm_type": "pixel", + "causality_axis": "height", + "dropout": 0.0, + "mid_block_add_attention": False, + "sample_rate": 16000, + "mel_hop_length": 160, + "is_causal": True, + "mel_bins": 64, + }, + } + rename_dict = LTX_2_0_AUDIO_VAE_RENAME_DICT + special_keys_remap = LTX_2_0_AUDIO_VAE_SPECIAL_KEYS_REMAP + return config, rename_dict, special_keys_remap + + +def convert_ltx2_audio_vae(original_state_dict: Dict[str, Any], version: str) -> Dict[str, Any]: + config, rename_dict, special_keys_remap = get_ltx2_audio_vae_config(version) + diffusers_config = config["diffusers_config"] + + with init_empty_weights(): + vae = AutoencoderKLLTX2Audio.from_config(diffusers_config) + + # Handle official code --> diffusers key remapping via the remap dict + for key in list(original_state_dict.keys()): + new_key = key[:] + for replace_key, rename_key in rename_dict.items(): + new_key = new_key.replace(replace_key, rename_key) + update_state_dict_inplace(original_state_dict, key, new_key) + + # Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in + # special_keys_remap + for key in list(original_state_dict.keys()): + for special_key, handler_fn_inplace in special_keys_remap.items(): + if special_key not in key: + continue + handler_fn_inplace(key, original_state_dict) + + vae.load_state_dict(original_state_dict, strict=True, assign=True) + return vae + + def get_ltx2_vocoder_config(version: str) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]: if version == "2.0": config = { @@ -513,7 +583,13 @@ def main(args): vae.to(vae_dtype).save_pretrained(os.path.join(args.output_path, "vae")) if args.audio_vae or args.full_pipeline: - pass + if args.audio_vae_filename is not None: + original_audio_vae_ckpt = load_hub_or_local_checkpoint(filename=args.audio_vae_filename) + elif combined_ckpt is not None: + original_audio_vae_ckpt = get_model_state_dict_from_combined_ckpt(combined_ckpt, args.audio_vae_prefix) + audio_vae = convert_ltx2_audio_vae(original_audio_vae_ckpt, version=args.version) + if not args.full_pipeline: + audio_vae.to(audio_vae_dtype).save_pretrained(os.path.join(args.output_path, "audio_vae")) if args.dit or args.full_pipeline: if args.dit_filename is not None: