1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00
This commit is contained in:
sayakpaul
2025-12-22 13:46:58 +05:30
parent 907896d533
commit 4904fd6fa5

View File

@@ -1,14 +1,19 @@
#!/usr/bin/env python
"""
Quick check that an LTX2 audio decoder checkpoint converts cleanly to the diffusers
`AutoencoderKLLTX2Audio` layout and produces matching outputs on dummy data.
"""
import argparse
import sys
from pathlib import Path
import safetensors.torch
import torch
from huggingface_hub import hf_hub_download
def download_checkpoint(
repo_id="diffusers-internal-dev/new-ltx-model",
filename="ltx-av-step-1932500-interleaved-new-vae.safetensors",
device="cuda",
):
ckpt_path = hf_hub_download(repo_id=repo_id, filename=filename)
ckpt = safetensors.torch.load_file(ckpt_path, device=device)["audio_vae"]
return ckpt
def convert_state_dict(state_dict: dict) -> dict:
@@ -23,71 +28,57 @@ def convert_state_dict(state_dict: dict) -> dict:
return converted
def load_original_decoder(original_repo: Path, device: torch.device, dtype: torch.dtype, checkpoint_path: Path | None):
ltx_core_src = original_repo / "ltx-core" / "src"
if not ltx_core_src.exists():
raise FileNotFoundError(f"ltx-core sources not found under {ltx_core_src}")
sys.path.insert(0, str(ltx_core_src))
def load_original_decoder(device: torch.device, dtype: torch.dtype):
from ltx_core.model.audio_vae.model_configurator import VAEDecoderConfigurator
decoder = VAEDecoderConfigurator.from_config({}).to(device=device, dtype=dtype)
with torch.device("meta"):
decoder = VAEDecoderConfigurator.from_config({}).to(device=device, dtype=dtype)
original_state_dict = download_checkpoint(device)
if checkpoint_path is not None:
raw_state = torch.load(checkpoint_path, map_location=device)
state_dict = raw_state.get("state_dict", raw_state)
decoder_state: dict[str, torch.Tensor] = {}
for key, value in state_dict.items():
if not isinstance(value, torch.Tensor):
continue
trimmed = key
if trimmed.startswith("audio_vae.decoder."):
trimmed = trimmed[len("audio_vae.decoder.") :]
elif trimmed.startswith("decoder."):
trimmed = trimmed[len("decoder.") :]
decoder_state[trimmed] = value
decoder.load_state_dict(decoder_state, strict=False)
decoder_state_dict = {}
for key, value in original_state_dict.items():
if not isinstance(value, torch.Tensor):
continue
trimmed = key
if trimmed.startswith("audio_vae.decoder."):
trimmed = trimmed[len("audio_vae.decoder.") :]
elif trimmed.startswith("decoder."):
trimmed = trimmed[len("decoder.") :]
decoder_state_dict[trimmed] = value
decoder.load_state_dict(decoder_state_dict, strict=True, assign=True)
decoder.eval()
return decoder
def build_diffusers_decoder(device: torch.device, dtype: torch.dtype):
from diffusers.models.autoencoders.autoencoder_kl_ltx2_audio import AutoencoderKLLTX2Audio
from diffusers.models.autoencoders import AutoencoderKLLTX2Audio
with torch.device("meta"):
model = AutoencoderKLLTX2Audio().to(device=device, dtype=dtype)
model = AutoencoderKLLTX2Audio().to(device=device, dtype=dtype)
model.eval()
return model
@torch.no_grad()
def main() -> None:
parser = argparse.ArgumentParser(description="Validate LTX2 audio decoder conversion.")
parser.add_argument(
"--original-repo",
type=Path,
default=Path("/Users/sayakpaul/Downloads/ltx-2"),
help="Path to the original ltx-2 repository (needed to import ltx-core).",
)
parser.add_argument(
"--checkpoint",
type=Path,
default=None,
help="Optional path to an original checkpoint containing decoder weights.",
)
parser.add_argument("--device", type=str, default="cpu")
parser.add_argument("--dtype", type=str, default="float32", choices=["float32", "bfloat16", "float16"])
parser.add_argument("--dtype", type=str, default="bfloat16", choices=["float32", "bfloat16", "float16"])
parser.add_argument("--batch", type=int, default=2)
parser.add_argument("--output-path", type=Path, required=True)
args = parser.parse_args()
device = torch.device(args.device)
dtype_map = {"float32": torch.float32, "bfloat16": torch.bfloat16, "float16": torch.float16}
dtype = dtype_map[args.dtype]
original_decoder = load_original_decoder(args.original_repo, device, dtype, args.checkpoint)
original_decoder = load_original_decoder(device, dtype)
diffusers_model = build_diffusers_decoder(device, dtype)
converted_state = convert_state_dict(original_decoder.state_dict())
diffusers_model.load_state_dict(converted_state, strict=False)
diffusers_model.load_state_dict(converted_state, assign=True, strict=True)
levels = len(diffusers_model.decoder.channel_multipliers)
latent_size = diffusers_model.decoder.resolution // (2 ** (levels - 1))
@@ -95,9 +86,8 @@ def main() -> None:
args.batch, diffusers_model.decoder.latent_channels, latent_size, latent_size, device=device, dtype=dtype
)
with torch.no_grad():
original_out = original_decoder(dummy)
diffusers_out = diffusers_model.decode(dummy).sample
original_out = original_decoder(dummy)
diffusers_out = diffusers_model.decode(dummy).sample
torch.testing.assert_close(diffusers_out, original_out, rtol=1e-4, atol=1e-4)
max_diff = (diffusers_out - original_out).abs().max().item()