mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
up
This commit is contained in:
@@ -25,13 +25,13 @@ def convert_state_dict(state_dict: dict) -> dict:
|
||||
return converted
|
||||
|
||||
|
||||
def load_original_decoder(device: torch.device):
|
||||
def load_original_decoder(device: torch.device, dtype: torch.dtype):
|
||||
from ltx_core.loader.single_gpu_model_builder import SingleGPUModelBuilder as Builder
|
||||
from ltx_core.model.audio_vae.model_configurator import VAEDecoderConfigurator as AudioDecoderConfigurator
|
||||
from ltx_core.model.audio_vae.model_configurator import AUDIO_VAE_DECODER_COMFY_KEYS_FILTER
|
||||
|
||||
from ltx_core.model.audio_vae.model_configurator import VAEDecoderConfigurator as AudioDecoderConfigurator
|
||||
|
||||
checkpoint_path = download_checkpoint()
|
||||
|
||||
|
||||
# The code below comes from `ltx-pipelines/src/ltx_pipelines/txt2vid.py`
|
||||
decoder = Builder(
|
||||
model_path=checkpoint_path,
|
||||
@@ -39,10 +39,6 @@ def load_original_decoder(device: torch.device):
|
||||
model_sd_key_ops=AUDIO_VAE_DECODER_COMFY_KEYS_FILTER,
|
||||
).build(device=device)
|
||||
|
||||
state_dict = decoder.state_dict()
|
||||
for k, v in state_dict.items():
|
||||
if "mid" in k:
|
||||
print(f"{k=}")
|
||||
decoder.eval()
|
||||
return decoder
|
||||
|
||||
@@ -70,16 +66,27 @@ def main() -> None:
|
||||
dtype_map = {"float32": torch.float32, "bfloat16": torch.bfloat16, "float16": torch.float16}
|
||||
dtype = dtype_map[args.dtype]
|
||||
|
||||
original_decoder = load_original_decoder(device)
|
||||
original_decoder = load_original_decoder(device, dtype)
|
||||
diffusers_model = build_diffusers_decoder()
|
||||
|
||||
converted_state_dict = convert_state_dict(original_decoder.state_dict())
|
||||
diffusers_model.load_state_dict(converted_state_dict, assign=True, strict=True)
|
||||
diffusers_model.load_state_dict(converted_state_dict, assign=True, strict=False)
|
||||
|
||||
per_channel_len = original_decoder.per_channel_statistics.get_buffer("std-of-means").numel()
|
||||
latent_channels = diffusers_model.decoder.latent_channels
|
||||
mel_bins_for_match = per_channel_len // latent_channels if per_channel_len % latent_channels == 0 else None
|
||||
|
||||
levels = len(diffusers_model.decoder.channel_multipliers)
|
||||
latent_size = diffusers_model.decoder.resolution // (2 ** (levels - 1))
|
||||
latent_height = diffusers_model.decoder.resolution // (2 ** (levels - 1))
|
||||
latent_width = mel_bins_for_match or latent_height
|
||||
|
||||
dummy = torch.randn(
|
||||
args.batch, diffusers_model.decoder.latent_channels, latent_size, latent_size, device=device
|
||||
args.batch,
|
||||
diffusers_model.decoder.latent_channels,
|
||||
latent_height,
|
||||
latent_width,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
original_out = original_decoder(dummy)
|
||||
|
||||
@@ -580,13 +580,6 @@ class AutoencoderKLLTX2Audio(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
|
||||
_supports_gradient_checkpointing = False
|
||||
|
||||
# {
|
||||
# 'double_z': True, 'mel_bins': 64, 'z_channels': 8, 'resolution': 256, 'downsample_time': False,
|
||||
# 'in_channels': 2, 'out_ch': 2, 'ch': 128, 'ch_mult': [1, 2, 4], 'num_res_blocks': 2,
|
||||
# 'attn_resolutions': [], 'dropout': 0.0, 'mid_block_add_attention': False,
|
||||
# 'norm_type': 'pixel', 'causality_axis': 'height'
|
||||
# }
|
||||
# sample_rate=16000, mel_hop_length=160, is_causal=True, mel_bins=64
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user