From 58257eb0e0f1a8ac07ff4854009f35c1b2bad444 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 22 Dec 2025 15:45:56 +0530 Subject: [PATCH] up --- scripts/test_ltx2_audio_conversion.py | 31 ++++++++++++------- .../autoencoders/autoencoder_kl_ltx2_audio.py | 7 ----- 2 files changed, 19 insertions(+), 19 deletions(-) diff --git a/scripts/test_ltx2_audio_conversion.py b/scripts/test_ltx2_audio_conversion.py index 6a124f74df..8d07a6f9b1 100644 --- a/scripts/test_ltx2_audio_conversion.py +++ b/scripts/test_ltx2_audio_conversion.py @@ -25,13 +25,13 @@ def convert_state_dict(state_dict: dict) -> dict: return converted -def load_original_decoder(device: torch.device): +def load_original_decoder(device: torch.device, dtype: torch.dtype): from ltx_core.loader.single_gpu_model_builder import SingleGPUModelBuilder as Builder - from ltx_core.model.audio_vae.model_configurator import VAEDecoderConfigurator as AudioDecoderConfigurator from ltx_core.model.audio_vae.model_configurator import AUDIO_VAE_DECODER_COMFY_KEYS_FILTER - + from ltx_core.model.audio_vae.model_configurator import VAEDecoderConfigurator as AudioDecoderConfigurator + checkpoint_path = download_checkpoint() - + # The code below comes from `ltx-pipelines/src/ltx_pipelines/txt2vid.py` decoder = Builder( model_path=checkpoint_path, @@ -39,10 +39,6 @@ def load_original_decoder(device: torch.device): model_sd_key_ops=AUDIO_VAE_DECODER_COMFY_KEYS_FILTER, ).build(device=device) - state_dict = decoder.state_dict() - for k, v in state_dict.items(): - if "mid" in k: - print(f"{k=}") decoder.eval() return decoder @@ -70,16 +66,27 @@ def main() -> None: dtype_map = {"float32": torch.float32, "bfloat16": torch.bfloat16, "float16": torch.float16} dtype = dtype_map[args.dtype] - original_decoder = load_original_decoder(device) + original_decoder = load_original_decoder(device, dtype) diffusers_model = build_diffusers_decoder() converted_state_dict = convert_state_dict(original_decoder.state_dict()) - diffusers_model.load_state_dict(converted_state_dict, assign=True, strict=True) + diffusers_model.load_state_dict(converted_state_dict, assign=True, strict=False) + + per_channel_len = original_decoder.per_channel_statistics.get_buffer("std-of-means").numel() + latent_channels = diffusers_model.decoder.latent_channels + mel_bins_for_match = per_channel_len // latent_channels if per_channel_len % latent_channels == 0 else None levels = len(diffusers_model.decoder.channel_multipliers) - latent_size = diffusers_model.decoder.resolution // (2 ** (levels - 1)) + latent_height = diffusers_model.decoder.resolution // (2 ** (levels - 1)) + latent_width = mel_bins_for_match or latent_height + dummy = torch.randn( - args.batch, diffusers_model.decoder.latent_channels, latent_size, latent_size, device=device + args.batch, + diffusers_model.decoder.latent_channels, + latent_height, + latent_width, + device=device, + dtype=dtype, ) original_out = original_decoder(dummy) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_ltx2_audio.py b/src/diffusers/models/autoencoders/autoencoder_kl_ltx2_audio.py index 457cbf5bce..e7960c3e14 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_ltx2_audio.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_ltx2_audio.py @@ -580,13 +580,6 @@ class AutoencoderKLLTX2Audio(ModelMixin, AutoencoderMixin, ConfigMixin): _supports_gradient_checkpointing = False - # { - # 'double_z': True, 'mel_bins': 64, 'z_channels': 8, 'resolution': 256, 'downsample_time': False, - # 'in_channels': 2, 'out_ch': 2, 'ch': 128, 'ch_mult': [1, 2, 4], 'num_res_blocks': 2, - # 'attn_resolutions': [], 'dropout': 0.0, 'mid_block_add_attention': False, - # 'norm_type': 'pixel', 'causality_axis': 'height' - # } - # sample_rate=16000, mel_hop_length=160, is_causal=True, mel_bins=64 @register_to_config def __init__( self,