1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00
This commit is contained in:
sayakpaul
2025-12-22 15:45:56 +05:30
parent 5f0f2a03f7
commit 58257eb0e0
2 changed files with 19 additions and 19 deletions

View File

@@ -25,13 +25,13 @@ def convert_state_dict(state_dict: dict) -> dict:
return converted
def load_original_decoder(device: torch.device):
def load_original_decoder(device: torch.device, dtype: torch.dtype):
from ltx_core.loader.single_gpu_model_builder import SingleGPUModelBuilder as Builder
from ltx_core.model.audio_vae.model_configurator import VAEDecoderConfigurator as AudioDecoderConfigurator
from ltx_core.model.audio_vae.model_configurator import AUDIO_VAE_DECODER_COMFY_KEYS_FILTER
from ltx_core.model.audio_vae.model_configurator import VAEDecoderConfigurator as AudioDecoderConfigurator
checkpoint_path = download_checkpoint()
# The code below comes from `ltx-pipelines/src/ltx_pipelines/txt2vid.py`
decoder = Builder(
model_path=checkpoint_path,
@@ -39,10 +39,6 @@ def load_original_decoder(device: torch.device):
model_sd_key_ops=AUDIO_VAE_DECODER_COMFY_KEYS_FILTER,
).build(device=device)
state_dict = decoder.state_dict()
for k, v in state_dict.items():
if "mid" in k:
print(f"{k=}")
decoder.eval()
return decoder
@@ -70,16 +66,27 @@ def main() -> None:
dtype_map = {"float32": torch.float32, "bfloat16": torch.bfloat16, "float16": torch.float16}
dtype = dtype_map[args.dtype]
original_decoder = load_original_decoder(device)
original_decoder = load_original_decoder(device, dtype)
diffusers_model = build_diffusers_decoder()
converted_state_dict = convert_state_dict(original_decoder.state_dict())
diffusers_model.load_state_dict(converted_state_dict, assign=True, strict=True)
diffusers_model.load_state_dict(converted_state_dict, assign=True, strict=False)
per_channel_len = original_decoder.per_channel_statistics.get_buffer("std-of-means").numel()
latent_channels = diffusers_model.decoder.latent_channels
mel_bins_for_match = per_channel_len // latent_channels if per_channel_len % latent_channels == 0 else None
levels = len(diffusers_model.decoder.channel_multipliers)
latent_size = diffusers_model.decoder.resolution // (2 ** (levels - 1))
latent_height = diffusers_model.decoder.resolution // (2 ** (levels - 1))
latent_width = mel_bins_for_match or latent_height
dummy = torch.randn(
args.batch, diffusers_model.decoder.latent_channels, latent_size, latent_size, device=device
args.batch,
diffusers_model.decoder.latent_channels,
latent_height,
latent_width,
device=device,
dtype=dtype,
)
original_out = original_decoder(dummy)

View File

@@ -580,13 +580,6 @@ class AutoencoderKLLTX2Audio(ModelMixin, AutoencoderMixin, ConfigMixin):
_supports_gradient_checkpointing = False
# {
# 'double_z': True, 'mel_bins': 64, 'z_channels': 8, 'resolution': 256, 'downsample_time': False,
# 'in_channels': 2, 'out_ch': 2, 'ch': 128, 'ch_mult': [1, 2, 4], 'num_res_blocks': 2,
# 'attn_resolutions': [], 'dropout': 0.0, 'mid_block_add_attention': False,
# 'norm_type': 'pixel', 'causality_axis': 'height'
# }
# sample_rate=16000, mel_hop_length=160, is_causal=True, mel_bins=64
@register_to_config
def __init__(
self,