1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Conversion script for LTX 2.0 Audio VAE Decoder

This commit is contained in:
Daniel Gu
2025-12-23 02:48:08 +01:00
parent 5f7e43d17f
commit d303e2a6ff

View File

@@ -8,7 +8,7 @@ import torch
from accelerate import init_empty_weights
from huggingface_hub import hf_hub_download
from diffusers import AutoencoderKLLTX2Video, LTX2VideoTransformer3DModel
from diffusers import AutoencoderKLLTX2Audio, AutoencoderKLLTX2Video, LTX2VideoTransformer3DModel
from diffusers.utils.import_utils import is_accelerate_available
from diffusers.pipelines.ltx2.vocoder import LTX2Vocoder
@@ -62,6 +62,8 @@ LTX_2_0_VIDEO_VAE_RENAME_DICT = {
"per_channel_statistics.std-of-means": "latents_std",
}
LTX_2_0_AUDIO_VAE_RENAME_DICT = {}
LTX_2_0_VOCODER_RENAME_DICT = {
"ups": "upsamplers",
"resblocks": "resnets",
@@ -96,6 +98,15 @@ def convert_ltx2_transformer_adaln_single(key: str, state_dict: Dict[str, Any])
return
def convert_ltx2_audio_vae_per_channel_statistics(key: str, state_dict: Dict[str, Any]) -> None:
if key.startswith("per_channel_statistics"):
new_key = ".".join(["decoder", key])
param = state_dict.pop(key)
state_dict[new_key] = param
return
LTX_2_0_TRANSFORMER_SPECIAL_KEYS_REMAP = {
"video_embeddings_connector": remove_keys_inplace,
"audio_embeddings_connector": remove_keys_inplace,
@@ -107,6 +118,11 @@ LTX_2_0_VAE_SPECIAL_KEYS_REMAP = {
"per_channel_statistics.mean-of-stds": remove_keys_inplace,
}
LTX_2_0_AUDIO_VAE_SPECIAL_KEYS_REMAP = {
"encoder": remove_keys_inplace,
"per_channel_statistics": convert_ltx2_audio_vae_per_channel_statistics,
}
LTX_2_0_VOCODER_SPECIAL_KEYS_REMAP = {}
@@ -325,6 +341,60 @@ def convert_ltx2_video_vae(original_state_dict: Dict[str, Any], version: str) ->
return vae
def get_ltx2_audio_vae_config(version: str) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
if version == "2.0":
config = {
"model_id": "diffusers-internal-dev/new-ltx-model",
"diffusers_config": {
"base_channels": 128,
"output_channels": 2,
"ch_mult": (1, 2, 4),
"num_res_blocks": 2,
"attn_resolutions": None,
"in_channels": 2,
"resolution": 256,
"latent_channels": 8,
"norm_type": "pixel",
"causality_axis": "height",
"dropout": 0.0,
"mid_block_add_attention": False,
"sample_rate": 16000,
"mel_hop_length": 160,
"is_causal": True,
"mel_bins": 64,
},
}
rename_dict = LTX_2_0_AUDIO_VAE_RENAME_DICT
special_keys_remap = LTX_2_0_AUDIO_VAE_SPECIAL_KEYS_REMAP
return config, rename_dict, special_keys_remap
def convert_ltx2_audio_vae(original_state_dict: Dict[str, Any], version: str) -> Dict[str, Any]:
config, rename_dict, special_keys_remap = get_ltx2_audio_vae_config(version)
diffusers_config = config["diffusers_config"]
with init_empty_weights():
vae = AutoencoderKLLTX2Audio.from_config(diffusers_config)
# Handle official code --> diffusers key remapping via the remap dict
for key in list(original_state_dict.keys()):
new_key = key[:]
for replace_key, rename_key in rename_dict.items():
new_key = new_key.replace(replace_key, rename_key)
update_state_dict_inplace(original_state_dict, key, new_key)
# Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in
# special_keys_remap
for key in list(original_state_dict.keys()):
for special_key, handler_fn_inplace in special_keys_remap.items():
if special_key not in key:
continue
handler_fn_inplace(key, original_state_dict)
vae.load_state_dict(original_state_dict, strict=True, assign=True)
return vae
def get_ltx2_vocoder_config(version: str) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
if version == "2.0":
config = {
@@ -513,7 +583,13 @@ def main(args):
vae.to(vae_dtype).save_pretrained(os.path.join(args.output_path, "vae"))
if args.audio_vae or args.full_pipeline:
pass
if args.audio_vae_filename is not None:
original_audio_vae_ckpt = load_hub_or_local_checkpoint(filename=args.audio_vae_filename)
elif combined_ckpt is not None:
original_audio_vae_ckpt = get_model_state_dict_from_combined_ckpt(combined_ckpt, args.audio_vae_prefix)
audio_vae = convert_ltx2_audio_vae(original_audio_vae_ckpt, version=args.version)
if not args.full_pipeline:
audio_vae.to(audio_vae_dtype).save_pretrained(os.path.join(args.output_path, "audio_vae"))
if args.dit or args.full_pipeline:
if args.dit_filename is not None: