mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Conversion script for LTX 2.0 Audio VAE Decoder
This commit is contained in:
@@ -8,7 +8,7 @@ import torch
|
||||
from accelerate import init_empty_weights
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
from diffusers import AutoencoderKLLTX2Video, LTX2VideoTransformer3DModel
|
||||
from diffusers import AutoencoderKLLTX2Audio, AutoencoderKLLTX2Video, LTX2VideoTransformer3DModel
|
||||
from diffusers.utils.import_utils import is_accelerate_available
|
||||
from diffusers.pipelines.ltx2.vocoder import LTX2Vocoder
|
||||
|
||||
@@ -62,6 +62,8 @@ LTX_2_0_VIDEO_VAE_RENAME_DICT = {
|
||||
"per_channel_statistics.std-of-means": "latents_std",
|
||||
}
|
||||
|
||||
LTX_2_0_AUDIO_VAE_RENAME_DICT = {}
|
||||
|
||||
LTX_2_0_VOCODER_RENAME_DICT = {
|
||||
"ups": "upsamplers",
|
||||
"resblocks": "resnets",
|
||||
@@ -96,6 +98,15 @@ def convert_ltx2_transformer_adaln_single(key: str, state_dict: Dict[str, Any])
|
||||
return
|
||||
|
||||
|
||||
def convert_ltx2_audio_vae_per_channel_statistics(key: str, state_dict: Dict[str, Any]) -> None:
|
||||
if key.startswith("per_channel_statistics"):
|
||||
new_key = ".".join(["decoder", key])
|
||||
param = state_dict.pop(key)
|
||||
state_dict[new_key] = param
|
||||
|
||||
return
|
||||
|
||||
|
||||
LTX_2_0_TRANSFORMER_SPECIAL_KEYS_REMAP = {
|
||||
"video_embeddings_connector": remove_keys_inplace,
|
||||
"audio_embeddings_connector": remove_keys_inplace,
|
||||
@@ -107,6 +118,11 @@ LTX_2_0_VAE_SPECIAL_KEYS_REMAP = {
|
||||
"per_channel_statistics.mean-of-stds": remove_keys_inplace,
|
||||
}
|
||||
|
||||
LTX_2_0_AUDIO_VAE_SPECIAL_KEYS_REMAP = {
|
||||
"encoder": remove_keys_inplace,
|
||||
"per_channel_statistics": convert_ltx2_audio_vae_per_channel_statistics,
|
||||
}
|
||||
|
||||
LTX_2_0_VOCODER_SPECIAL_KEYS_REMAP = {}
|
||||
|
||||
|
||||
@@ -325,6 +341,60 @@ def convert_ltx2_video_vae(original_state_dict: Dict[str, Any], version: str) ->
|
||||
return vae
|
||||
|
||||
|
||||
def get_ltx2_audio_vae_config(version: str) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
|
||||
if version == "2.0":
|
||||
config = {
|
||||
"model_id": "diffusers-internal-dev/new-ltx-model",
|
||||
"diffusers_config": {
|
||||
"base_channels": 128,
|
||||
"output_channels": 2,
|
||||
"ch_mult": (1, 2, 4),
|
||||
"num_res_blocks": 2,
|
||||
"attn_resolutions": None,
|
||||
"in_channels": 2,
|
||||
"resolution": 256,
|
||||
"latent_channels": 8,
|
||||
"norm_type": "pixel",
|
||||
"causality_axis": "height",
|
||||
"dropout": 0.0,
|
||||
"mid_block_add_attention": False,
|
||||
"sample_rate": 16000,
|
||||
"mel_hop_length": 160,
|
||||
"is_causal": True,
|
||||
"mel_bins": 64,
|
||||
},
|
||||
}
|
||||
rename_dict = LTX_2_0_AUDIO_VAE_RENAME_DICT
|
||||
special_keys_remap = LTX_2_0_AUDIO_VAE_SPECIAL_KEYS_REMAP
|
||||
return config, rename_dict, special_keys_remap
|
||||
|
||||
|
||||
def convert_ltx2_audio_vae(original_state_dict: Dict[str, Any], version: str) -> Dict[str, Any]:
|
||||
config, rename_dict, special_keys_remap = get_ltx2_audio_vae_config(version)
|
||||
diffusers_config = config["diffusers_config"]
|
||||
|
||||
with init_empty_weights():
|
||||
vae = AutoencoderKLLTX2Audio.from_config(diffusers_config)
|
||||
|
||||
# Handle official code --> diffusers key remapping via the remap dict
|
||||
for key in list(original_state_dict.keys()):
|
||||
new_key = key[:]
|
||||
for replace_key, rename_key in rename_dict.items():
|
||||
new_key = new_key.replace(replace_key, rename_key)
|
||||
update_state_dict_inplace(original_state_dict, key, new_key)
|
||||
|
||||
# Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in
|
||||
# special_keys_remap
|
||||
for key in list(original_state_dict.keys()):
|
||||
for special_key, handler_fn_inplace in special_keys_remap.items():
|
||||
if special_key not in key:
|
||||
continue
|
||||
handler_fn_inplace(key, original_state_dict)
|
||||
|
||||
vae.load_state_dict(original_state_dict, strict=True, assign=True)
|
||||
return vae
|
||||
|
||||
|
||||
def get_ltx2_vocoder_config(version: str) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
|
||||
if version == "2.0":
|
||||
config = {
|
||||
@@ -513,7 +583,13 @@ def main(args):
|
||||
vae.to(vae_dtype).save_pretrained(os.path.join(args.output_path, "vae"))
|
||||
|
||||
if args.audio_vae or args.full_pipeline:
|
||||
pass
|
||||
if args.audio_vae_filename is not None:
|
||||
original_audio_vae_ckpt = load_hub_or_local_checkpoint(filename=args.audio_vae_filename)
|
||||
elif combined_ckpt is not None:
|
||||
original_audio_vae_ckpt = get_model_state_dict_from_combined_ckpt(combined_ckpt, args.audio_vae_prefix)
|
||||
audio_vae = convert_ltx2_audio_vae(original_audio_vae_ckpt, version=args.version)
|
||||
if not args.full_pipeline:
|
||||
audio_vae.to(audio_vae_dtype).save_pretrained(os.path.join(args.output_path, "audio_vae"))
|
||||
|
||||
if args.dit or args.full_pipeline:
|
||||
if args.dit_filename is not None:
|
||||
|
||||
Reference in New Issue
Block a user