mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
543 lines
22 KiB
Python
543 lines
22 KiB
Python
import argparse
|
|
import os
|
|
from contextlib import nullcontext
|
|
from typing import Any, Dict, Optional, Tuple
|
|
|
|
import safetensors.torch
|
|
import torch
|
|
from accelerate import init_empty_weights
|
|
from huggingface_hub import hf_hub_download
|
|
|
|
from diffusers import AutoencoderKLLTX2Video, LTX2VideoTransformer3DModel
|
|
from diffusers.utils.import_utils import is_accelerate_available
|
|
from diffusers.pipelines.ltx2.vocoder import LTX2Vocoder
|
|
|
|
|
|
CTX = init_empty_weights if is_accelerate_available() else nullcontext
|
|
|
|
|
|
LTX_2_0_TRANSFORMER_KEYS_RENAME_DICT = {
|
|
# Input Patchify Projections
|
|
"patchify_proj": "proj_in",
|
|
"audio_patchify_proj": "audio_proj_in",
|
|
# Modulation Parameters
|
|
# Handle adaln_single --> time_embed, audioln_single --> audio_time_embed separately as the original keys are
|
|
# substrings of the other modulation parameters below
|
|
"av_ca_video_scale_shift_adaln_single": "av_cross_attn_video_scale_shift",
|
|
"av_ca_a2v_gate_adaln_single": "av_cross_attn_video_a2v_gate",
|
|
"av_ca_audio_scale_shift_adaln_single": "av_cross_attn_audio_scale_shift",
|
|
"av_ca_v2a_gate_adaln_single": "av_cross_attn_audio_v2a_gate",
|
|
# Transformer Blocks
|
|
# Per-Block Cross Attention Modulatin Parameters
|
|
"scale_shift_table_a2v_ca_video": "video_a2v_cross_attn_scale_shift_table",
|
|
"scale_shift_table_a2v_ca_audio": "audio_a2v_cross_attn_scale_shift_table",
|
|
# Attention QK Norms
|
|
"q_norm": "norm_q",
|
|
"k_norm": "norm_k",
|
|
}
|
|
|
|
LTX_2_0_VIDEO_VAE_RENAME_DICT = {
|
|
# Encoder
|
|
"down_blocks.0": "down_blocks.0",
|
|
"down_blocks.1": "down_blocks.0.downsamplers.0",
|
|
"down_blocks.2": "down_blocks.1",
|
|
"down_blocks.3": "down_blocks.1.downsamplers.0",
|
|
"down_blocks.4": "down_blocks.2",
|
|
"down_blocks.5": "down_blocks.2.downsamplers.0",
|
|
"down_blocks.6": "down_blocks.3",
|
|
"down_blocks.7": "down_blocks.3.downsamplers.0",
|
|
"down_blocks.8": "mid_block",
|
|
# Decoder
|
|
"up_blocks.0": "mid_block",
|
|
"up_blocks.1": "up_blocks.0.upsamplers.0",
|
|
"up_blocks.2": "up_blocks.0",
|
|
"up_blocks.3": "up_blocks.1.upsamplers.0",
|
|
"up_blocks.4": "up_blocks.1",
|
|
"up_blocks.5": "up_blocks.2.upsamplers.0",
|
|
"up_blocks.6": "up_blocks.2",
|
|
# Common
|
|
# For all 3D ResNets
|
|
"res_blocks": "resnets",
|
|
"per_channel_statistics.mean-of-means": "latents_mean",
|
|
"per_channel_statistics.std-of-means": "latents_std",
|
|
}
|
|
|
|
LTX_2_0_VOCODER_RENAME_DICT = {
|
|
"ups": "upsamplers",
|
|
"resblocks": "resnets",
|
|
"conv_pre": "conv_in",
|
|
"conv_post": "conv_out",
|
|
}
|
|
|
|
|
|
def update_state_dict_inplace(state_dict: Dict[str, Any], old_key: str, new_key: str) -> None:
|
|
state_dict[new_key] = state_dict.pop(old_key)
|
|
|
|
|
|
def remove_keys_inplace(key: str, state_dict: Dict[str, Any]) -> None:
|
|
state_dict.pop(key)
|
|
|
|
|
|
def convert_ltx2_transformer_adaln_single(key: str, state_dict: Dict[str, Any]) -> None:
|
|
# Skip if not a weight, bias
|
|
if ".weight" not in key and ".bias" not in key:
|
|
return
|
|
|
|
if key.startswith("adaln_single."):
|
|
new_key = key.replace("adaln_single.", "time_embed.")
|
|
param = state_dict.pop(key)
|
|
state_dict[new_key] = param
|
|
|
|
if key.startswith("audio_adaln_single."):
|
|
new_key = key.replace("audio_adaln_single.", "audio_time_embed.")
|
|
param = state_dict.pop(key)
|
|
state_dict[new_key] = param
|
|
|
|
return
|
|
|
|
|
|
LTX_2_0_TRANSFORMER_SPECIAL_KEYS_REMAP = {
|
|
"video_embeddings_connector": remove_keys_inplace,
|
|
"audio_embeddings_connector": remove_keys_inplace,
|
|
"adaln_single": convert_ltx2_transformer_adaln_single,
|
|
}
|
|
|
|
LTX_2_0_VAE_SPECIAL_KEYS_REMAP = {
|
|
"per_channel_statistics.channel": remove_keys_inplace,
|
|
"per_channel_statistics.mean-of-stds": remove_keys_inplace,
|
|
}
|
|
|
|
LTX_2_0_VOCODER_SPECIAL_KEYS_REMAP = {}
|
|
|
|
|
|
def get_ltx2_transformer_config(version: str) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
|
|
if version == "test":
|
|
# Produces a transformer of the same size as used in test_models_transformer_ltx2.py
|
|
config = {
|
|
"model_id": "diffusers-internal-dev/dummy-ltx2",
|
|
"diffusers_config": {
|
|
"in_channels": 4,
|
|
"out_channels": 4,
|
|
"patch_size": 1,
|
|
"patch_size_t": 1,
|
|
"num_attention_heads": 2,
|
|
"attention_head_dim": 8,
|
|
"cross_attention_dim": 16,
|
|
"vae_scale_factors": (8, 32 ,32),
|
|
"pos_embed_max_pos": 20,
|
|
"base_height": 2048,
|
|
"base_width": 2048,
|
|
"audio_in_channels": 4,
|
|
"audio_out_channels": 4,
|
|
"audio_patch_size": 1,
|
|
"audio_patch_size_t": 1,
|
|
"audio_num_attention_heads": 2,
|
|
"audio_attention_head_dim": 4,
|
|
"audio_cross_attention_dim": 8,
|
|
"audio_scale_factor": 4,
|
|
"audio_pos_embed_max_pos": 20,
|
|
"audio_sampling_rate": 16000,
|
|
"audio_hop_length": 160,
|
|
"num_layers": 2,
|
|
"activation_fn": "gelu-approximate",
|
|
"qk_norm": "rms_norm_across_heads",
|
|
"norm_elementwise_affine": False,
|
|
"norm_eps": 1e-6,
|
|
"caption_channels": 16,
|
|
"attention_bias": True,
|
|
"attention_out_bias": True,
|
|
"rope_theta": 10000.0,
|
|
"causal_offset": 1,
|
|
},
|
|
}
|
|
rename_dict = LTX_2_0_TRANSFORMER_KEYS_RENAME_DICT
|
|
special_keys_remap = LTX_2_0_TRANSFORMER_SPECIAL_KEYS_REMAP
|
|
elif version == "2.0":
|
|
config = {
|
|
"model_id": "diffusers-internal-dev/new-ltx-model",
|
|
"diffusers_config": {
|
|
"in_channels": 128,
|
|
"out_channels": 128,
|
|
"patch_size": 1,
|
|
"patch_size_t": 1,
|
|
"num_attention_heads": 32,
|
|
"attention_head_dim": 128,
|
|
"cross_attention_dim": 4096,
|
|
"vae_scale_factors": (8, 32 ,32),
|
|
"pos_embed_max_pos": 20,
|
|
"base_height": 2048,
|
|
"base_width": 2048,
|
|
"audio_in_channels": 128,
|
|
"audio_out_channels": 128,
|
|
"audio_patch_size": 1,
|
|
"audio_patch_size_t": 1,
|
|
"audio_num_attention_heads": 32,
|
|
"audio_attention_head_dim": 64,
|
|
"audio_cross_attention_dim": 2048,
|
|
"audio_scale_factor": 4,
|
|
"audio_pos_embed_max_pos": 20,
|
|
"audio_sampling_rate": 16000,
|
|
"audio_hop_length": 160,
|
|
"num_layers": 48,
|
|
"activation_fn": "gelu-approximate",
|
|
"qk_norm": "rms_norm_across_heads",
|
|
"norm_elementwise_affine": False,
|
|
"norm_eps": 1e-6,
|
|
"caption_channels": 3840,
|
|
"attention_bias": True,
|
|
"attention_out_bias": True,
|
|
"rope_theta": 10000.0,
|
|
"causal_offset": 1,
|
|
},
|
|
}
|
|
rename_dict = LTX_2_0_TRANSFORMER_KEYS_RENAME_DICT
|
|
special_keys_remap = LTX_2_0_TRANSFORMER_SPECIAL_KEYS_REMAP
|
|
return config, rename_dict, special_keys_remap
|
|
|
|
|
|
def convert_ltx2_transformer(original_state_dict: Dict[str, Any], version: str) -> Dict[str, Any]:
|
|
config, rename_dict, special_keys_remap = get_ltx2_transformer_config(version)
|
|
diffusers_config = config["diffusers_config"]
|
|
|
|
with init_empty_weights():
|
|
transformer = LTX2VideoTransformer3DModel.from_config(diffusers_config)
|
|
|
|
# Handle official code --> diffusers key remapping via the remap dict
|
|
for key in list(original_state_dict.keys()):
|
|
new_key = key[:]
|
|
for replace_key, rename_key in rename_dict.items():
|
|
new_key = new_key.replace(replace_key, rename_key)
|
|
update_state_dict_inplace(original_state_dict, key, new_key)
|
|
|
|
# Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in
|
|
# special_keys_remap
|
|
for key in list(original_state_dict.keys()):
|
|
for special_key, handler_fn_inplace in special_keys_remap.items():
|
|
if special_key not in key:
|
|
continue
|
|
handler_fn_inplace(key, original_state_dict)
|
|
|
|
transformer.load_state_dict(original_state_dict, strict=True, assign=True)
|
|
return transformer
|
|
|
|
|
|
def get_ltx2_video_vae_config(version: str) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
|
|
if version == "test":
|
|
config = {
|
|
"model_id": "diffusers-internal-dev/dummy-ltx2",
|
|
"diffusers_config": {
|
|
"in_channels": 3,
|
|
"out_channels": 3,
|
|
"latent_channels": 128,
|
|
"block_out_channels": (256, 512, 1024, 2048),
|
|
"down_block_types": (
|
|
"LTX2VideoDownBlock3D",
|
|
"LTX2VideoDownBlock3D",
|
|
"LTX2VideoDownBlock3D",
|
|
"LTX2VideoDownBlock3D",
|
|
),
|
|
"decoder_block_out_channels": (256, 512, 1024),
|
|
"layers_per_block": (4, 6, 6, 2, 2),
|
|
"decoder_layers_per_block": (5, 5, 5, 5),
|
|
"spatio_temporal_scaling": (True, True, True, True),
|
|
"decoder_spatio_temporal_scaling": (True, True, True),
|
|
"decoder_inject_noise": (False, False, False, False),
|
|
"downsample_type": ("spatial", "temporal", "spatiotemporal", "spatiotemporal"),
|
|
"upsample_residual": (True, True, True),
|
|
"upsample_factor": (2, 2, 2),
|
|
"timestep_conditioning": False,
|
|
"patch_size": 4,
|
|
"patch_size_t": 1,
|
|
"resnet_norm_eps": 1e-6,
|
|
"encoder_causal": True,
|
|
"decoder_causal": False,
|
|
"encoder_spatial_padding_mode": "zeros",
|
|
"decoder_spatial_padding_mode": "reflect",
|
|
"spatial_compression_ratio": 32,
|
|
"temporal_compression_ratio": 8,
|
|
},
|
|
}
|
|
rename_dict = LTX_2_0_VIDEO_VAE_RENAME_DICT
|
|
special_keys_remap = LTX_2_0_VAE_SPECIAL_KEYS_REMAP
|
|
elif version == "2.0":
|
|
config = {
|
|
"model_id": "diffusers-internal-dev/dummy-ltx2",
|
|
"diffusers_config": {
|
|
"in_channels": 3,
|
|
"out_channels": 3,
|
|
"latent_channels": 128,
|
|
"block_out_channels": (256, 512, 1024, 2048),
|
|
"down_block_types": (
|
|
"LTX2VideoDownBlock3D",
|
|
"LTX2VideoDownBlock3D",
|
|
"LTX2VideoDownBlock3D",
|
|
"LTX2VideoDownBlock3D",
|
|
),
|
|
"decoder_block_out_channels": (256, 512, 1024),
|
|
"layers_per_block": (4, 6, 6, 2, 2),
|
|
"decoder_layers_per_block": (5, 5, 5, 5),
|
|
"spatio_temporal_scaling": (True, True, True, True),
|
|
"decoder_spatio_temporal_scaling": (True, True, True),
|
|
"decoder_inject_noise": (False, False, False, False),
|
|
"downsample_type": ("spatial", "temporal", "spatiotemporal", "spatiotemporal"),
|
|
"upsample_residual": (True, True, True),
|
|
"upsample_factor": (2, 2, 2),
|
|
"timestep_conditioning": False,
|
|
"patch_size": 4,
|
|
"patch_size_t": 1,
|
|
"resnet_norm_eps": 1e-6,
|
|
"encoder_causal": True,
|
|
"decoder_causal": False,
|
|
"encoder_spatial_padding_mode": "zeros",
|
|
"decoder_spatial_padding_mode": "reflect",
|
|
"spatial_compression_ratio": 32,
|
|
"temporal_compression_ratio": 8,
|
|
},
|
|
}
|
|
rename_dict = LTX_2_0_VIDEO_VAE_RENAME_DICT
|
|
special_keys_remap = LTX_2_0_VAE_SPECIAL_KEYS_REMAP
|
|
return config, rename_dict, special_keys_remap
|
|
|
|
|
|
def convert_ltx2_video_vae(original_state_dict: Dict[str, Any], version: str) -> Dict[str, Any]:
|
|
config, rename_dict, special_keys_remap = get_ltx2_video_vae_config(version)
|
|
diffusers_config = config["diffusers_config"]
|
|
|
|
with init_empty_weights():
|
|
vae = AutoencoderKLLTX2Video.from_config(diffusers_config)
|
|
|
|
# Handle official code --> diffusers key remapping via the remap dict
|
|
for key in list(original_state_dict.keys()):
|
|
new_key = key[:]
|
|
for replace_key, rename_key in rename_dict.items():
|
|
new_key = new_key.replace(replace_key, rename_key)
|
|
update_state_dict_inplace(original_state_dict, key, new_key)
|
|
|
|
# Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in
|
|
# special_keys_remap
|
|
for key in list(original_state_dict.keys()):
|
|
for special_key, handler_fn_inplace in special_keys_remap.items():
|
|
if special_key not in key:
|
|
continue
|
|
handler_fn_inplace(key, original_state_dict)
|
|
|
|
vae.load_state_dict(original_state_dict, strict=True, assign=True)
|
|
return vae
|
|
|
|
|
|
def get_ltx2_vocoder_config(version: str) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
|
|
if version == "2.0":
|
|
config = {
|
|
"model_id": "diffusers-internal-dev/new-ltx-model",
|
|
"diffusers_config": {
|
|
"in_channels": 128,
|
|
"hidden_channels": 1024,
|
|
"out_channels": 2,
|
|
"upsample_kernel_sizes": [16, 15, 8, 4, 4],
|
|
"upsample_factors": [6, 5, 2, 2, 2],
|
|
"resnet_kernel_sizes": [3, 7, 11],
|
|
"resnet_dilations": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
|
"leaky_relu_negative_slope": 0.1,
|
|
"output_sampling_rate": 24000,
|
|
}
|
|
}
|
|
rename_dict = LTX_2_0_VOCODER_RENAME_DICT
|
|
special_keys_remap = LTX_2_0_VOCODER_SPECIAL_KEYS_REMAP
|
|
return config, rename_dict, special_keys_remap
|
|
|
|
|
|
def convert_ltx2_vocoder(original_state_dict: Dict[str, Any], version: str) -> Dict[str, Any]:
|
|
config, rename_dict, special_keys_remap = get_ltx2_vocoder_config(version)
|
|
diffusers_config = config["diffusers_config"]
|
|
|
|
with init_empty_weights():
|
|
vocoder = LTX2Vocoder.from_config(diffusers_config)
|
|
|
|
# Handle official code --> diffusers key remapping via the remap dict
|
|
for key in list(original_state_dict.keys()):
|
|
new_key = key[:]
|
|
for replace_key, rename_key in rename_dict.items():
|
|
new_key = new_key.replace(replace_key, rename_key)
|
|
update_state_dict_inplace(original_state_dict, key, new_key)
|
|
|
|
# Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in
|
|
# special_keys_remap
|
|
for key in list(original_state_dict.keys()):
|
|
for special_key, handler_fn_inplace in special_keys_remap.items():
|
|
if special_key not in key:
|
|
continue
|
|
handler_fn_inplace(key, original_state_dict)
|
|
|
|
vocoder.load_state_dict(original_state_dict, strict=True, assign=True)
|
|
return vocoder
|
|
|
|
|
|
def load_original_checkpoint(args, filename: Optional[str]) -> Dict[str, Any]:
|
|
if args.original_state_dict_repo_id is not None:
|
|
ckpt_path = hf_hub_download(repo_id=args.original_state_dict_repo_id, filename=filename)
|
|
elif args.checkpoint_path is not None:
|
|
ckpt_path = args.checkpoint_path
|
|
else:
|
|
raise ValueError("Please provide either `original_state_dict_repo_id` or a local `checkpoint_path`")
|
|
|
|
original_state_dict = safetensors.torch.load_file(ckpt_path)
|
|
return original_state_dict
|
|
|
|
|
|
def load_hub_or_local_checkpoint(repo_id: Optional[str] = None, filename: Optional[str] = None) -> Dict[str, Any]:
|
|
if repo_id is None and filename is None:
|
|
raise ValueError("Please supply at least one of `repo_id` or `filename`")
|
|
|
|
if repo_id is not None:
|
|
if filename is None:
|
|
raise ValueError("If repo_id is specified, filename must also be specified.")
|
|
ckpt_path = hf_hub_download(repo_id=repo_id, filename=filename)
|
|
else:
|
|
ckpt_path = filename
|
|
|
|
_, ext = os.path.splitext(ckpt_path)
|
|
if ext in [".safetensors", ".sft"]:
|
|
state_dict = safetensors.torch.load_file(ckpt_path)
|
|
else:
|
|
state_dict = torch.load(ckpt_path, map_location="cpu")
|
|
|
|
return state_dict
|
|
|
|
|
|
def get_model_state_dict_from_combined_ckpt(combined_ckpt: Dict[str, Any], prefix: str) -> Dict[str, Any]:
|
|
# Ensure that the key prefix ends with a dot (.)
|
|
if not prefix.endswith("."):
|
|
prefix = prefix + "."
|
|
|
|
model_state_dict = {}
|
|
for param_name, param in combined_ckpt.items():
|
|
if param_name.startswith(prefix):
|
|
model_state_dict[param_name.replace(prefix, "")] = param
|
|
return model_state_dict
|
|
|
|
|
|
def get_args():
|
|
parser = argparse.ArgumentParser()
|
|
|
|
parser.add_argument(
|
|
"--original_state_dict_repo_id",
|
|
default="diffusers-internal-dev/new-ltx-model",
|
|
type=str,
|
|
help="HF Hub repo id with LTX 2.0 checkpoint",
|
|
)
|
|
parser.add_argument(
|
|
"--checkpoint_path",
|
|
default=None,
|
|
type=str,
|
|
help="Local checkpoint path for LTX 2.0. Will be used if `original_state_dict_repo_id` is not specified.",
|
|
)
|
|
parser.add_argument(
|
|
"--version",
|
|
type=str,
|
|
default="2.0",
|
|
choices=["test", "2.0"],
|
|
help="Version of the LTX 2.0 model",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--combined_filename",
|
|
default="ltx-av-step-1932500-interleaved-new-vae.safetensors",
|
|
type=str,
|
|
help="Filename for combined checkpoint with all LTX 2.0 models (VAE, DiT, etc.)",
|
|
)
|
|
parser.add_argument("--vae_prefix", default="vae.", type=str)
|
|
parser.add_argument("--audio_vae_prefix", default="audio_vae.", type=str)
|
|
parser.add_argument("--dit_prefix", default="model.diffusion_model.", type=str)
|
|
parser.add_argument("--vocoder_prefix", default="vocoder.", type=str)
|
|
|
|
parser.add_argument("--vae_filename", default=None, type=str, help="VAE filename; overrides combined ckpt if set")
|
|
parser.add_argument(
|
|
"--audio_vae_filename", default=None, type=str, help="Audio VAE filename; overrides combined ckpt if set"
|
|
)
|
|
parser.add_argument("--dit_filename", default=None, type=str, help="DiT filename; overrides combined ckpt if set")
|
|
parser.add_argument(
|
|
"--vocoder_filename", default=None, type=str, help="Vocoder filename; overrides combined ckpt if set"
|
|
)
|
|
|
|
parser.add_argument("--vae", action="store_true", help="Whether to convert the video VAE model")
|
|
parser.add_argument("--audio_vae", action="store_true", help="Whether to convert the audio VAE model")
|
|
parser.add_argument("--dit", action="store_true", help="Whether to convert the DiT model")
|
|
parser.add_argument("--vocoder", action="store_true", help="Whether to convert the vocoder model")
|
|
parser.add_argument(
|
|
"--full_pipeline",
|
|
action="store_true",
|
|
help="Whether to save the pipeline. This will attempt to convert all models (e.g. vae, dit, etc.)",
|
|
)
|
|
|
|
parser.add_argument("--vae_dtype", type=str, default="bf16", choices=["fp32", "fp16", "bf16"])
|
|
parser.add_argument("--audio_vae_dtype", type=str, default="bf16", choices=["fp32", "fp16", "bf16"])
|
|
parser.add_argument("--dit_dtype", type=str, default="bf16", choices=["fp32", "fp16", "bf16"])
|
|
parser.add_argument("--vocoder_dtype", type=str, default="bf16", choices=["fp32", "fp16", "bf16"])
|
|
|
|
parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
|
|
|
|
return parser.parse_args()
|
|
|
|
|
|
DTYPE_MAPPING = {
|
|
"fp32": torch.float32,
|
|
"fp16": torch.float16,
|
|
"bf16": torch.bfloat16,
|
|
}
|
|
|
|
VARIANT_MAPPING = {
|
|
"fp32": None,
|
|
"fp16": "fp16",
|
|
"bf16": "bf16",
|
|
}
|
|
|
|
|
|
def main(args):
|
|
vae_dtype = DTYPE_MAPPING[args.vae_dtype]
|
|
audio_vae_dtype = DTYPE_MAPPING[args.audio_vae_dtype]
|
|
dit_dtype = DTYPE_MAPPING[args.dit_dtype]
|
|
vocoder_dtype = DTYPE_MAPPING[args.vocoder_dtype]
|
|
|
|
combined_ckpt = None
|
|
load_combined_models = any([args.vae, args.audio_vae, args.dit, args.vocoder, args.full_pipeline])
|
|
if args.combined_filename is not None and load_combined_models:
|
|
combined_ckpt = load_original_checkpoint(args, filename=args.combined_filename)
|
|
|
|
if args.vae or args.full_pipeline:
|
|
if args.vae_filename is not None:
|
|
original_vae_ckpt = load_hub_or_local_checkpoint(filename=args.vae_filename)
|
|
elif combined_ckpt is not None:
|
|
original_vae_ckpt = get_model_state_dict_from_combined_ckpt(combined_ckpt, args.vae_prefix)
|
|
vae = convert_ltx2_video_vae(original_vae_ckpt, version=args.version)
|
|
if not args.full_pipeline:
|
|
vae.to(vae_dtype).save_pretrained(os.path.join(args.output_path, "vae"))
|
|
|
|
if args.audio_vae or args.full_pipeline:
|
|
pass
|
|
|
|
if args.dit or args.full_pipeline:
|
|
if args.dit_filename is not None:
|
|
original_dit_ckpt = load_hub_or_local_checkpoint(filename=args.dit_filename)
|
|
elif combined_ckpt is not None:
|
|
original_dit_ckpt = get_model_state_dict_from_combined_ckpt(combined_ckpt, args.dit_prefix)
|
|
transformer = convert_ltx2_transformer(original_dit_ckpt, version=args.version)
|
|
if not args.full_pipeline:
|
|
transformer.to(dit_dtype).save_pretrained(os.path.join(args.output_path, "transformer"))
|
|
|
|
if args.vocoder or args.full_pipeline:
|
|
if args.vocoder_filename is not None:
|
|
original_vocoder_ckpt = load_hub_or_local_checkpoint(filename=args.vocoder_filename)
|
|
elif combined_ckpt is not None:
|
|
original_vocoder_ckpt = get_model_state_dict_from_combined_ckpt(combined_ckpt, args.vocoder_prefix)
|
|
vocoder = convert_ltx2_vocoder(original_vocoder_ckpt, version=args.version)
|
|
if not args.full_pipeline:
|
|
vocoder.to(vocoder_dtype).save_pretrained(os.path.join(args.output_path, "vocoder"))
|
|
|
|
if args.full_pipeline:
|
|
pass
|
|
|
|
|
|
if __name__ == '__main__':
|
|
args = get_args()
|
|
main(args)
|