mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
resolve conflicts.
This commit is contained in:
542
scripts/convert_ltx2_to_diffusers.py
Normal file
542
scripts/convert_ltx2_to_diffusers.py
Normal file
@@ -0,0 +1,542 @@
|
||||
import argparse
|
||||
import os
|
||||
from contextlib import nullcontext
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
import safetensors.torch
|
||||
import torch
|
||||
from accelerate import init_empty_weights
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
from diffusers import AutoencoderKLLTX2Video, LTX2VideoTransformer3DModel
|
||||
from diffusers.utils.import_utils import is_accelerate_available
|
||||
from diffusers.pipelines.ltx2.vocoder import LTX2Vocoder
|
||||
|
||||
|
||||
CTX = init_empty_weights if is_accelerate_available() else nullcontext
|
||||
|
||||
|
||||
LTX_2_0_TRANSFORMER_KEYS_RENAME_DICT = {
|
||||
# Input Patchify Projections
|
||||
"patchify_proj": "proj_in",
|
||||
"audio_patchify_proj": "audio_proj_in",
|
||||
# Modulation Parameters
|
||||
# Handle adaln_single --> time_embed, audioln_single --> audio_time_embed separately as the original keys are
|
||||
# substrings of the other modulation parameters below
|
||||
"av_ca_video_scale_shift_adaln_single": "av_cross_attn_video_scale_shift",
|
||||
"av_ca_a2v_gate_adaln_single": "av_cross_attn_video_a2v_gate",
|
||||
"av_ca_audio_scale_shift_adaln_single": "av_cross_attn_audio_scale_shift",
|
||||
"av_ca_v2a_gate_adaln_single": "av_cross_attn_audio_v2a_gate",
|
||||
# Transformer Blocks
|
||||
# Per-Block Cross Attention Modulatin Parameters
|
||||
"scale_shift_table_a2v_ca_video": "video_a2v_cross_attn_scale_shift_table",
|
||||
"scale_shift_table_a2v_ca_audio": "audio_a2v_cross_attn_scale_shift_table",
|
||||
# Attention QK Norms
|
||||
"q_norm": "norm_q",
|
||||
"k_norm": "norm_k",
|
||||
}
|
||||
|
||||
LTX_2_0_VIDEO_VAE_RENAME_DICT = {
|
||||
# Encoder
|
||||
"down_blocks.0": "down_blocks.0",
|
||||
"down_blocks.1": "down_blocks.0.downsamplers.0",
|
||||
"down_blocks.2": "down_blocks.1",
|
||||
"down_blocks.3": "down_blocks.1.downsamplers.0",
|
||||
"down_blocks.4": "down_blocks.2",
|
||||
"down_blocks.5": "down_blocks.2.downsamplers.0",
|
||||
"down_blocks.6": "down_blocks.3",
|
||||
"down_blocks.7": "down_blocks.3.downsamplers.0",
|
||||
"down_blocks.8": "mid_block",
|
||||
# Decoder
|
||||
"up_blocks.0": "mid_block",
|
||||
"up_blocks.1": "up_blocks.0.upsamplers.0",
|
||||
"up_blocks.2": "up_blocks.0",
|
||||
"up_blocks.3": "up_blocks.1.upsamplers.0",
|
||||
"up_blocks.4": "up_blocks.1",
|
||||
"up_blocks.5": "up_blocks.2.upsamplers.0",
|
||||
"up_blocks.6": "up_blocks.2",
|
||||
# Common
|
||||
# For all 3D ResNets
|
||||
"res_blocks": "resnets",
|
||||
"per_channel_statistics.mean-of-means": "latents_mean",
|
||||
"per_channel_statistics.std-of-means": "latents_std",
|
||||
}
|
||||
|
||||
LTX_2_0_VOCODER_RENAME_DICT = {
|
||||
"ups": "upsamplers",
|
||||
"resblocks": "resnets",
|
||||
"conv_pre": "conv_in",
|
||||
"conv_post": "conv_out",
|
||||
}
|
||||
|
||||
|
||||
def update_state_dict_inplace(state_dict: Dict[str, Any], old_key: str, new_key: str) -> None:
|
||||
state_dict[new_key] = state_dict.pop(old_key)
|
||||
|
||||
|
||||
def remove_keys_inplace(key: str, state_dict: Dict[str, Any]) -> None:
|
||||
state_dict.pop(key)
|
||||
|
||||
|
||||
def convert_ltx2_transformer_adaln_single(key: str, state_dict: Dict[str, Any]) -> None:
|
||||
# Skip if not a weight, bias
|
||||
if ".weight" not in key and ".bias" not in key:
|
||||
return
|
||||
|
||||
if key.startswith("adaln_single."):
|
||||
new_key = key.replace("adaln_single.", "time_embed.")
|
||||
param = state_dict.pop(key)
|
||||
state_dict[new_key] = param
|
||||
|
||||
if key.startswith("audio_adaln_single."):
|
||||
new_key = key.replace("audio_adaln_single.", "audio_time_embed.")
|
||||
param = state_dict.pop(key)
|
||||
state_dict[new_key] = param
|
||||
|
||||
return
|
||||
|
||||
|
||||
LTX_2_0_TRANSFORMER_SPECIAL_KEYS_REMAP = {
|
||||
"video_embeddings_connector": remove_keys_inplace,
|
||||
"audio_embeddings_connector": remove_keys_inplace,
|
||||
"adaln_single": convert_ltx2_transformer_adaln_single,
|
||||
}
|
||||
|
||||
LTX_2_0_VAE_SPECIAL_KEYS_REMAP = {
|
||||
"per_channel_statistics.channel": remove_keys_inplace,
|
||||
"per_channel_statistics.mean-of-stds": remove_keys_inplace,
|
||||
}
|
||||
|
||||
LTX_2_0_VOCODER_SPECIAL_KEYS_REMAP = {}
|
||||
|
||||
|
||||
def get_ltx2_transformer_config(version: str) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
|
||||
if version == "test":
|
||||
# Produces a transformer of the same size as used in test_models_transformer_ltx2.py
|
||||
config = {
|
||||
"model_id": "diffusers-internal-dev/dummy-ltx2",
|
||||
"diffusers_config": {
|
||||
"in_channels": 4,
|
||||
"out_channels": 4,
|
||||
"patch_size": 1,
|
||||
"patch_size_t": 1,
|
||||
"num_attention_heads": 2,
|
||||
"attention_head_dim": 8,
|
||||
"cross_attention_dim": 16,
|
||||
"vae_scale_factors": (8, 32 ,32),
|
||||
"pos_embed_max_pos": 20,
|
||||
"base_height": 2048,
|
||||
"base_width": 2048,
|
||||
"audio_in_channels": 4,
|
||||
"audio_out_channels": 4,
|
||||
"audio_patch_size": 1,
|
||||
"audio_patch_size_t": 1,
|
||||
"audio_num_attention_heads": 2,
|
||||
"audio_attention_head_dim": 4,
|
||||
"audio_cross_attention_dim": 8,
|
||||
"audio_scale_factor": 4,
|
||||
"audio_pos_embed_max_pos": 20,
|
||||
"audio_sampling_rate": 16000,
|
||||
"audio_hop_length": 160,
|
||||
"num_layers": 2,
|
||||
"activation_fn": "gelu-approximate",
|
||||
"qk_norm": "rms_norm_across_heads",
|
||||
"norm_elementwise_affine": False,
|
||||
"norm_eps": 1e-6,
|
||||
"caption_channels": 16,
|
||||
"attention_bias": True,
|
||||
"attention_out_bias": True,
|
||||
"rope_theta": 10000.0,
|
||||
"causal_offset": 1,
|
||||
},
|
||||
}
|
||||
rename_dict = LTX_2_0_TRANSFORMER_KEYS_RENAME_DICT
|
||||
special_keys_remap = LTX_2_0_TRANSFORMER_SPECIAL_KEYS_REMAP
|
||||
elif version == "2.0":
|
||||
config = {
|
||||
"model_id": "diffusers-internal-dev/new-ltx-model",
|
||||
"diffusers_config": {
|
||||
"in_channels": 128,
|
||||
"out_channels": 128,
|
||||
"patch_size": 1,
|
||||
"patch_size_t": 1,
|
||||
"num_attention_heads": 32,
|
||||
"attention_head_dim": 128,
|
||||
"cross_attention_dim": 4096,
|
||||
"vae_scale_factors": (8, 32 ,32),
|
||||
"pos_embed_max_pos": 20,
|
||||
"base_height": 2048,
|
||||
"base_width": 2048,
|
||||
"audio_in_channels": 128,
|
||||
"audio_out_channels": 128,
|
||||
"audio_patch_size": 1,
|
||||
"audio_patch_size_t": 1,
|
||||
"audio_num_attention_heads": 32,
|
||||
"audio_attention_head_dim": 64,
|
||||
"audio_cross_attention_dim": 2048,
|
||||
"audio_scale_factor": 4,
|
||||
"audio_pos_embed_max_pos": 20,
|
||||
"audio_sampling_rate": 16000,
|
||||
"audio_hop_length": 160,
|
||||
"num_layers": 48,
|
||||
"activation_fn": "gelu-approximate",
|
||||
"qk_norm": "rms_norm_across_heads",
|
||||
"norm_elementwise_affine": False,
|
||||
"norm_eps": 1e-6,
|
||||
"caption_channels": 3840,
|
||||
"attention_bias": True,
|
||||
"attention_out_bias": True,
|
||||
"rope_theta": 10000.0,
|
||||
"causal_offset": 1,
|
||||
},
|
||||
}
|
||||
rename_dict = LTX_2_0_TRANSFORMER_KEYS_RENAME_DICT
|
||||
special_keys_remap = LTX_2_0_TRANSFORMER_SPECIAL_KEYS_REMAP
|
||||
return config, rename_dict, special_keys_remap
|
||||
|
||||
|
||||
def convert_ltx2_transformer(original_state_dict: Dict[str, Any], version: str) -> Dict[str, Any]:
|
||||
config, rename_dict, special_keys_remap = get_ltx2_transformer_config(version)
|
||||
diffusers_config = config["diffusers_config"]
|
||||
|
||||
with init_empty_weights():
|
||||
transformer = LTX2VideoTransformer3DModel.from_config(diffusers_config)
|
||||
|
||||
# Handle official code --> diffusers key remapping via the remap dict
|
||||
for key in list(original_state_dict.keys()):
|
||||
new_key = key[:]
|
||||
for replace_key, rename_key in rename_dict.items():
|
||||
new_key = new_key.replace(replace_key, rename_key)
|
||||
update_state_dict_inplace(original_state_dict, key, new_key)
|
||||
|
||||
# Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in
|
||||
# special_keys_remap
|
||||
for key in list(original_state_dict.keys()):
|
||||
for special_key, handler_fn_inplace in special_keys_remap.items():
|
||||
if special_key not in key:
|
||||
continue
|
||||
handler_fn_inplace(key, original_state_dict)
|
||||
|
||||
transformer.load_state_dict(original_state_dict, strict=True, assign=True)
|
||||
return transformer
|
||||
|
||||
|
||||
def get_ltx2_video_vae_config(version: str) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
|
||||
if version == "test":
|
||||
config = {
|
||||
"model_id": "diffusers-internal-dev/dummy-ltx2",
|
||||
"diffusers_config": {
|
||||
"in_channels": 3,
|
||||
"out_channels": 3,
|
||||
"latent_channels": 128,
|
||||
"block_out_channels": (256, 512, 1024, 2048),
|
||||
"down_block_types": (
|
||||
"LTX2VideoDownBlock3D",
|
||||
"LTX2VideoDownBlock3D",
|
||||
"LTX2VideoDownBlock3D",
|
||||
"LTX2VideoDownBlock3D",
|
||||
),
|
||||
"decoder_block_out_channels": (256, 512, 1024),
|
||||
"layers_per_block": (4, 6, 6, 2, 2),
|
||||
"decoder_layers_per_block": (5, 5, 5, 5),
|
||||
"spatio_temporal_scaling": (True, True, True, True),
|
||||
"decoder_spatio_temporal_scaling": (True, True, True),
|
||||
"decoder_inject_noise": (False, False, False, False),
|
||||
"downsample_type": ("spatial", "temporal", "spatiotemporal", "spatiotemporal"),
|
||||
"upsample_residual": (True, True, True),
|
||||
"upsample_factor": (2, 2, 2),
|
||||
"timestep_conditioning": False,
|
||||
"patch_size": 4,
|
||||
"patch_size_t": 1,
|
||||
"resnet_norm_eps": 1e-6,
|
||||
"encoder_causal": True,
|
||||
"decoder_causal": False,
|
||||
"encoder_spatial_padding_mode": "zeros",
|
||||
"decoder_spatial_padding_mode": "reflect",
|
||||
"spatial_compression_ratio": 32,
|
||||
"temporal_compression_ratio": 8,
|
||||
},
|
||||
}
|
||||
rename_dict = LTX_2_0_VIDEO_VAE_RENAME_DICT
|
||||
special_keys_remap = LTX_2_0_VAE_SPECIAL_KEYS_REMAP
|
||||
elif version == "2.0":
|
||||
config = {
|
||||
"model_id": "diffusers-internal-dev/dummy-ltx2",
|
||||
"diffusers_config": {
|
||||
"in_channels": 3,
|
||||
"out_channels": 3,
|
||||
"latent_channels": 128,
|
||||
"block_out_channels": (256, 512, 1024, 2048),
|
||||
"down_block_types": (
|
||||
"LTX2VideoDownBlock3D",
|
||||
"LTX2VideoDownBlock3D",
|
||||
"LTX2VideoDownBlock3D",
|
||||
"LTX2VideoDownBlock3D",
|
||||
),
|
||||
"decoder_block_out_channels": (256, 512, 1024),
|
||||
"layers_per_block": (4, 6, 6, 2, 2),
|
||||
"decoder_layers_per_block": (5, 5, 5, 5),
|
||||
"spatio_temporal_scaling": (True, True, True, True),
|
||||
"decoder_spatio_temporal_scaling": (True, True, True),
|
||||
"decoder_inject_noise": (False, False, False, False),
|
||||
"downsample_type": ("spatial", "temporal", "spatiotemporal", "spatiotemporal"),
|
||||
"upsample_residual": (True, True, True),
|
||||
"upsample_factor": (2, 2, 2),
|
||||
"timestep_conditioning": False,
|
||||
"patch_size": 4,
|
||||
"patch_size_t": 1,
|
||||
"resnet_norm_eps": 1e-6,
|
||||
"encoder_causal": True,
|
||||
"decoder_causal": False,
|
||||
"encoder_spatial_padding_mode": "zeros",
|
||||
"decoder_spatial_padding_mode": "reflect",
|
||||
"spatial_compression_ratio": 32,
|
||||
"temporal_compression_ratio": 8,
|
||||
},
|
||||
}
|
||||
rename_dict = LTX_2_0_VIDEO_VAE_RENAME_DICT
|
||||
special_keys_remap = LTX_2_0_VAE_SPECIAL_KEYS_REMAP
|
||||
return config, rename_dict, special_keys_remap
|
||||
|
||||
|
||||
def convert_ltx2_video_vae(original_state_dict: Dict[str, Any], version: str) -> Dict[str, Any]:
|
||||
config, rename_dict, special_keys_remap = get_ltx2_video_vae_config(version)
|
||||
diffusers_config = config["diffusers_config"]
|
||||
|
||||
with init_empty_weights():
|
||||
vae = AutoencoderKLLTX2Video.from_config(diffusers_config)
|
||||
|
||||
# Handle official code --> diffusers key remapping via the remap dict
|
||||
for key in list(original_state_dict.keys()):
|
||||
new_key = key[:]
|
||||
for replace_key, rename_key in rename_dict.items():
|
||||
new_key = new_key.replace(replace_key, rename_key)
|
||||
update_state_dict_inplace(original_state_dict, key, new_key)
|
||||
|
||||
# Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in
|
||||
# special_keys_remap
|
||||
for key in list(original_state_dict.keys()):
|
||||
for special_key, handler_fn_inplace in special_keys_remap.items():
|
||||
if special_key not in key:
|
||||
continue
|
||||
handler_fn_inplace(key, original_state_dict)
|
||||
|
||||
vae.load_state_dict(original_state_dict, strict=True, assign=True)
|
||||
return vae
|
||||
|
||||
|
||||
def get_ltx2_vocoder_config(version: str) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
|
||||
if version == "2.0":
|
||||
config = {
|
||||
"model_id": "diffusers-internal-dev/new-ltx-model",
|
||||
"diffusers_config": {
|
||||
"in_channels": 128,
|
||||
"hidden_channels": 1024,
|
||||
"out_channels": 2,
|
||||
"upsample_kernel_sizes": [16, 15, 8, 4, 4],
|
||||
"upsample_factors": [6, 5, 2, 2, 2],
|
||||
"resnet_kernel_sizes": [3, 7, 11],
|
||||
"resnet_dilations": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
||||
"leaky_relu_negative_slope": 0.1,
|
||||
"output_sampling_rate": 24000,
|
||||
}
|
||||
}
|
||||
rename_dict = LTX_2_0_VOCODER_RENAME_DICT
|
||||
special_keys_remap = LTX_2_0_VOCODER_SPECIAL_KEYS_REMAP
|
||||
return config, rename_dict, special_keys_remap
|
||||
|
||||
|
||||
def convert_ltx2_vocoder(original_state_dict: Dict[str, Any], version: str) -> Dict[str, Any]:
|
||||
config, rename_dict, special_keys_remap = get_ltx2_vocoder_config(version)
|
||||
diffusers_config = config["diffusers_config"]
|
||||
|
||||
with init_empty_weights():
|
||||
vocoder = LTX2Vocoder.from_config(diffusers_config)
|
||||
|
||||
# Handle official code --> diffusers key remapping via the remap dict
|
||||
for key in list(original_state_dict.keys()):
|
||||
new_key = key[:]
|
||||
for replace_key, rename_key in rename_dict.items():
|
||||
new_key = new_key.replace(replace_key, rename_key)
|
||||
update_state_dict_inplace(original_state_dict, key, new_key)
|
||||
|
||||
# Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in
|
||||
# special_keys_remap
|
||||
for key in list(original_state_dict.keys()):
|
||||
for special_key, handler_fn_inplace in special_keys_remap.items():
|
||||
if special_key not in key:
|
||||
continue
|
||||
handler_fn_inplace(key, original_state_dict)
|
||||
|
||||
vocoder.load_state_dict(original_state_dict, strict=True, assign=True)
|
||||
return vocoder
|
||||
|
||||
|
||||
def load_original_checkpoint(args, filename: Optional[str]) -> Dict[str, Any]:
|
||||
if args.original_state_dict_repo_id is not None:
|
||||
ckpt_path = hf_hub_download(repo_id=args.original_state_dict_repo_id, filename=filename)
|
||||
elif args.checkpoint_path is not None:
|
||||
ckpt_path = args.checkpoint_path
|
||||
else:
|
||||
raise ValueError("Please provide either `original_state_dict_repo_id` or a local `checkpoint_path`")
|
||||
|
||||
original_state_dict = safetensors.torch.load_file(ckpt_path)
|
||||
return original_state_dict
|
||||
|
||||
|
||||
def load_hub_or_local_checkpoint(repo_id: Optional[str] = None, filename: Optional[str] = None) -> Dict[str, Any]:
|
||||
if repo_id is None and filename is None:
|
||||
raise ValueError("Please supply at least one of `repo_id` or `filename`")
|
||||
|
||||
if repo_id is not None:
|
||||
if filename is None:
|
||||
raise ValueError("If repo_id is specified, filename must also be specified.")
|
||||
ckpt_path = hf_hub_download(repo_id=repo_id, filename=filename)
|
||||
else:
|
||||
ckpt_path = filename
|
||||
|
||||
_, ext = os.path.splitext(ckpt_path)
|
||||
if ext in [".safetensors", ".sft"]:
|
||||
state_dict = safetensors.torch.load_file(ckpt_path)
|
||||
else:
|
||||
state_dict = torch.load(ckpt_path, map_location="cpu")
|
||||
|
||||
return state_dict
|
||||
|
||||
|
||||
def get_model_state_dict_from_combined_ckpt(combined_ckpt: Dict[str, Any], prefix: str) -> Dict[str, Any]:
|
||||
# Ensure that the key prefix ends with a dot (.)
|
||||
if not prefix.endswith("."):
|
||||
prefix = prefix + "."
|
||||
|
||||
model_state_dict = {}
|
||||
for param_name, param in combined_ckpt.items():
|
||||
if param_name.startswith(prefix):
|
||||
model_state_dict[param_name.replace(prefix, "")] = param
|
||||
return model_state_dict
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--original_state_dict_repo_id",
|
||||
default="diffusers-internal-dev/new-ltx-model",
|
||||
type=str,
|
||||
help="HF Hub repo id with LTX 2.0 checkpoint",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--checkpoint_path",
|
||||
default=None,
|
||||
type=str,
|
||||
help="Local checkpoint path for LTX 2.0. Will be used if `original_state_dict_repo_id` is not specified.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--version",
|
||||
type=str,
|
||||
default="2.0",
|
||||
choices=["test", "2.0"],
|
||||
help="Version of the LTX 2.0 model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--combined_filename",
|
||||
default="ltx-av-step-1932500-interleaved-new-vae.safetensors",
|
||||
type=str,
|
||||
help="Filename for combined checkpoint with all LTX 2.0 models (VAE, DiT, etc.)",
|
||||
)
|
||||
parser.add_argument("--vae_prefix", default="vae.", type=str)
|
||||
parser.add_argument("--audio_vae_prefix", default="audio_vae.", type=str)
|
||||
parser.add_argument("--dit_prefix", default="model.diffusion_model.", type=str)
|
||||
parser.add_argument("--vocoder_prefix", default="vocoder.", type=str)
|
||||
|
||||
parser.add_argument("--vae_filename", default=None, type=str, help="VAE filename; overrides combined ckpt if set")
|
||||
parser.add_argument(
|
||||
"--audio_vae_filename", default=None, type=str, help="Audio VAE filename; overrides combined ckpt if set"
|
||||
)
|
||||
parser.add_argument("--dit_filename", default=None, type=str, help="DiT filename; overrides combined ckpt if set")
|
||||
parser.add_argument(
|
||||
"--vocoder_filename", default=None, type=str, help="Vocoder filename; overrides combined ckpt if set"
|
||||
)
|
||||
|
||||
parser.add_argument("--vae", action="store_true", help="Whether to convert the video VAE model")
|
||||
parser.add_argument("--audio_vae", action="store_true", help="Whether to convert the audio VAE model")
|
||||
parser.add_argument("--dit", action="store_true", help="Whether to convert the DiT model")
|
||||
parser.add_argument("--vocoder", action="store_true", help="Whether to convert the vocoder model")
|
||||
parser.add_argument(
|
||||
"--full_pipeline",
|
||||
action="store_true",
|
||||
help="Whether to save the pipeline. This will attempt to convert all models (e.g. vae, dit, etc.)",
|
||||
)
|
||||
|
||||
parser.add_argument("--vae_dtype", type=str, default="bf16", choices=["fp32", "fp16", "bf16"])
|
||||
parser.add_argument("--audio_vae_dtype", type=str, default="bf16", choices=["fp32", "fp16", "bf16"])
|
||||
parser.add_argument("--dit_dtype", type=str, default="bf16", choices=["fp32", "fp16", "bf16"])
|
||||
parser.add_argument("--vocoder_dtype", type=str, default="bf16", choices=["fp32", "fp16", "bf16"])
|
||||
|
||||
parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
DTYPE_MAPPING = {
|
||||
"fp32": torch.float32,
|
||||
"fp16": torch.float16,
|
||||
"bf16": torch.bfloat16,
|
||||
}
|
||||
|
||||
VARIANT_MAPPING = {
|
||||
"fp32": None,
|
||||
"fp16": "fp16",
|
||||
"bf16": "bf16",
|
||||
}
|
||||
|
||||
|
||||
def main(args):
|
||||
vae_dtype = DTYPE_MAPPING[args.vae_dtype]
|
||||
audio_vae_dtype = DTYPE_MAPPING[args.audio_vae_dtype]
|
||||
dit_dtype = DTYPE_MAPPING[args.dit_dtype]
|
||||
vocoder_dtype = DTYPE_MAPPING[args.vocoder_dtype]
|
||||
|
||||
combined_ckpt = None
|
||||
load_combined_models = any([args.vae, args.audio_vae, args.dit, args.vocoder, args.full_pipeline])
|
||||
if args.combined_filename is not None and load_combined_models:
|
||||
combined_ckpt = load_original_checkpoint(args, filename=args.combined_filename)
|
||||
|
||||
if args.vae or args.full_pipeline:
|
||||
if args.vae_filename is not None:
|
||||
original_vae_ckpt = load_hub_or_local_checkpoint(filename=args.vae_filename)
|
||||
elif combined_ckpt is not None:
|
||||
original_vae_ckpt = get_model_state_dict_from_combined_ckpt(combined_ckpt, args.vae_prefix)
|
||||
vae = convert_ltx2_video_vae(original_vae_ckpt, version=args.version)
|
||||
if not args.full_pipeline:
|
||||
vae.to(vae_dtype).save_pretrained(os.path.join(args.output_path, "vae"))
|
||||
|
||||
if args.audio_vae or args.full_pipeline:
|
||||
pass
|
||||
|
||||
if args.dit or args.full_pipeline:
|
||||
if args.dit_filename is not None:
|
||||
original_dit_ckpt = load_hub_or_local_checkpoint(filename=args.dit_filename)
|
||||
elif combined_ckpt is not None:
|
||||
original_dit_ckpt = get_model_state_dict_from_combined_ckpt(combined_ckpt, args.dit_prefix)
|
||||
transformer = convert_ltx2_transformer(original_dit_ckpt, version=args.version)
|
||||
if not args.full_pipeline:
|
||||
transformer.to(dit_dtype).save_pretrained(os.path.join(args.output_path, "transformer"))
|
||||
|
||||
if args.vocoder or args.full_pipeline:
|
||||
if args.vocoder_filename is not None:
|
||||
original_vocoder_ckpt = load_hub_or_local_checkpoint(filename=args.vocoder_filename)
|
||||
elif combined_ckpt is not None:
|
||||
original_vocoder_ckpt = get_model_state_dict_from_combined_ckpt(combined_ckpt, args.vocoder_prefix)
|
||||
vocoder = convert_ltx2_vocoder(original_vocoder_ckpt, version=args.version)
|
||||
if not args.full_pipeline:
|
||||
vocoder.to(vocoder_dtype).save_pretrained(os.path.join(args.output_path, "vocoder"))
|
||||
|
||||
if args.full_pipeline:
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = get_args()
|
||||
main(args)
|
||||
@@ -194,6 +194,7 @@ else:
|
||||
"AutoencoderKLHunyuanVideo",
|
||||
"AutoencoderKLHunyuanVideo15",
|
||||
"AutoencoderKLLTXVideo",
|
||||
"AutoencoderKLLTX2Video",
|
||||
"AutoencoderKLMagvit",
|
||||
"AutoencoderKLMochi",
|
||||
"AutoencoderKLQwenImage",
|
||||
@@ -236,6 +237,7 @@ else:
|
||||
"Kandinsky5Transformer3DModel",
|
||||
"LatteTransformer3DModel",
|
||||
"LTXVideoTransformer3DModel",
|
||||
"LTX2VideoTransformer3DModel",
|
||||
"Lumina2Transformer2DModel",
|
||||
"LuminaNextDiT2DModel",
|
||||
"MochiTransformer3DModel",
|
||||
@@ -927,6 +929,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
AutoencoderKLHunyuanVideo,
|
||||
AutoencoderKLHunyuanVideo15,
|
||||
AutoencoderKLLTXVideo,
|
||||
AutoencoderKLLTX2Video,
|
||||
AutoencoderKLMagvit,
|
||||
AutoencoderKLMochi,
|
||||
AutoencoderKLQwenImage,
|
||||
@@ -969,6 +972,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
Kandinsky5Transformer3DModel,
|
||||
LatteTransformer3DModel,
|
||||
LTXVideoTransformer3DModel,
|
||||
LTX2VideoTransformer3DModel,
|
||||
Lumina2Transformer2DModel,
|
||||
LuminaNextDiT2DModel,
|
||||
MochiTransformer3DModel,
|
||||
|
||||
@@ -41,6 +41,7 @@ if is_torch_available():
|
||||
_import_structure["autoencoders.autoencoder_kl_hunyuanimage_refiner"] = ["AutoencoderKLHunyuanImageRefiner"]
|
||||
_import_structure["autoencoders.autoencoder_kl_hunyuanvideo15"] = ["AutoencoderKLHunyuanVideo15"]
|
||||
_import_structure["autoencoders.autoencoder_kl_ltx"] = ["AutoencoderKLLTXVideo"]
|
||||
_import_structure["autoencoders.autoencoder_kl_ltx2"] = ["AutoencoderKLLTX2Video"]
|
||||
_import_structure["autoencoders.autoencoder_kl_magvit"] = ["AutoencoderKLMagvit"]
|
||||
_import_structure["autoencoders.autoencoder_kl_mochi"] = ["AutoencoderKLMochi"]
|
||||
_import_structure["autoencoders.autoencoder_kl_qwenimage"] = ["AutoencoderKLQwenImage"]
|
||||
@@ -102,6 +103,7 @@ if is_torch_available():
|
||||
_import_structure["transformers.transformer_hunyuanimage"] = ["HunyuanImageTransformer2DModel"]
|
||||
_import_structure["transformers.transformer_kandinsky"] = ["Kandinsky5Transformer3DModel"]
|
||||
_import_structure["transformers.transformer_ltx"] = ["LTXVideoTransformer3DModel"]
|
||||
_import_structure["transformers.transformer_ltx2"] = ["LTX2VideoTransformer3DModel"]
|
||||
_import_structure["transformers.transformer_lumina2"] = ["Lumina2Transformer2DModel"]
|
||||
_import_structure["transformers.transformer_mochi"] = ["MochiTransformer3DModel"]
|
||||
_import_structure["transformers.transformer_omnigen"] = ["OmniGenTransformer2DModel"]
|
||||
@@ -152,6 +154,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
AutoencoderKLHunyuanVideo,
|
||||
AutoencoderKLHunyuanVideo15,
|
||||
AutoencoderKLLTXVideo,
|
||||
AutoencoderKLLTX2Video,
|
||||
AutoencoderKLMagvit,
|
||||
AutoencoderKLMochi,
|
||||
AutoencoderKLQwenImage,
|
||||
@@ -209,6 +212,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
Kandinsky5Transformer3DModel,
|
||||
LatteTransformer3DModel,
|
||||
LTXVideoTransformer3DModel,
|
||||
LTX2VideoTransformer3DModel,
|
||||
Lumina2Transformer2DModel,
|
||||
LuminaNextDiT2DModel,
|
||||
MochiTransformer3DModel,
|
||||
|
||||
@@ -11,6 +11,7 @@ from .autoencoder_kl_hunyuanimage_refiner import AutoencoderKLHunyuanImageRefine
|
||||
from .autoencoder_kl_hunyuanvideo15 import AutoencoderKLHunyuanVideo15
|
||||
from .autoencoder_kl_ltx import AutoencoderKLLTXVideo
|
||||
from .autoencoder_kl_ltx2_audio import AutoencoderKLLTX2Audio
|
||||
from .autoencoder_kl_ltx2 import AutoencoderKLLTX2Video
|
||||
from .autoencoder_kl_magvit import AutoencoderKLMagvit
|
||||
from .autoencoder_kl_mochi import AutoencoderKLMochi
|
||||
from .autoencoder_kl_qwenimage import AutoencoderKLQwenImage
|
||||
|
||||
1520
src/diffusers/models/autoencoders/autoencoder_kl_ltx2.py
Normal file
1520
src/diffusers/models/autoencoders/autoencoder_kl_ltx2.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -34,6 +34,7 @@ if is_torch_available():
|
||||
from .transformer_hunyuanimage import HunyuanImageTransformer2DModel
|
||||
from .transformer_kandinsky import Kandinsky5Transformer3DModel
|
||||
from .transformer_ltx import LTXVideoTransformer3DModel
|
||||
from .transformer_ltx2 import LTX2VideoTransformer3DModel
|
||||
from .transformer_lumina2 import Lumina2Transformer2DModel
|
||||
from .transformer_mochi import MochiTransformer3DModel
|
||||
from .transformer_omnigen import OmniGenTransformer2DModel
|
||||
|
||||
1232
src/diffusers/models/transformers/transformer_ltx2.py
Normal file
1232
src/diffusers/models/transformers/transformer_ltx2.py
Normal file
File diff suppressed because it is too large
Load Diff
173
src/diffusers/pipelines/ltx2/vocoder.py
Normal file
173
src/diffusers/pipelines/ltx2/vocoder.py
Normal file
@@ -0,0 +1,173 @@
|
||||
import math
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...models.modeling_utils import ModelMixin
|
||||
|
||||
|
||||
class ResBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
channels: int,
|
||||
kernel_size: int = 3,
|
||||
stride: int = 1,
|
||||
dilations: Tuple[int, ...] = (1, 3, 5),
|
||||
leaky_relu_negative_slope: float = 0.1,
|
||||
padding_mode: str = "same",
|
||||
):
|
||||
super().__init__()
|
||||
self.dilations = dilations
|
||||
self.negative_slope = leaky_relu_negative_slope
|
||||
|
||||
self.convs1 = nn.ModuleList(
|
||||
[
|
||||
nn.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
stride=stride,
|
||||
dilation=dilation,
|
||||
padding=padding_mode
|
||||
)
|
||||
for dilation in dilations
|
||||
]
|
||||
)
|
||||
|
||||
self.convs2 = nn.ModuleList(
|
||||
[
|
||||
nn.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
stride=stride,
|
||||
dilation=1,
|
||||
padding=padding_mode
|
||||
)
|
||||
for _ in range(len(dilations))
|
||||
]
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
for conv1, conv2 in zip(self.convs1, self.convs2):
|
||||
xt = F.leaky_relu(x, negative_slope=self.negative_slope)
|
||||
xt = conv1(xt)
|
||||
xt = F.leaky_relu(xt, negative_slope=self.negative_slope)
|
||||
xt = conv2(xt)
|
||||
x = x + xt
|
||||
return x
|
||||
|
||||
|
||||
class LTX2Vocoder(ModelMixin, ConfigMixin):
|
||||
r"""
|
||||
LTX 2.0 vocoder for converting generated mel spectrograms back to audio waveforms.
|
||||
"""
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 128,
|
||||
hidden_channels: int = 1024,
|
||||
out_channels: int = 2,
|
||||
upsample_kernel_sizes: List[int] = [16, 15, 8, 4, 4],
|
||||
upsample_factors: List[int] = [6, 5, 2, 2, 2],
|
||||
resnet_kernel_sizes: List[int] = [3, 7, 11],
|
||||
resnet_dilations: List[List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
||||
leaky_relu_negative_slope: float = 0.1,
|
||||
output_sampling_rate: int = 24000,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_upsample_layers = len(upsample_kernel_sizes)
|
||||
self.resnets_per_upsample = len(resnet_kernel_sizes)
|
||||
self.out_channels = out_channels
|
||||
self.total_upsample_factor = math.prod(upsample_factors)
|
||||
self.negative_slope = leaky_relu_negative_slope
|
||||
|
||||
if self.num_upsample_layers != len(upsample_factors):
|
||||
raise ValueError(
|
||||
f"`upsample_kernel_sizes` and `upsample_factors` should be lists of the same length but are length"
|
||||
f" {self.num_upsample_layers} and {len(upsample_factors)}, respectively."
|
||||
)
|
||||
|
||||
if self.resnets_per_upsample != len(resnet_dilations):
|
||||
raise ValueError(
|
||||
f"`resnet_kernel_sizes` and `resnet_dilations` should be lists of the same length but are length"
|
||||
f" {len(self.resnets_per_upsample)} and {len(resnet_dilations)}, respectively."
|
||||
)
|
||||
|
||||
self.conv_in = nn.Conv1d(in_channels, hidden_channels, kernel_size=7, stride=1, padding=3)
|
||||
|
||||
self.upsamplers = nn.ModuleList()
|
||||
self.resnets = nn.ModuleList()
|
||||
input_channels = hidden_channels
|
||||
for i, (stride, kernel_size) in enumerate(zip(upsample_factors, upsample_kernel_sizes)):
|
||||
output_channels = input_channels // 2
|
||||
self.upsamplers.append(
|
||||
nn.ConvTranspose1d(
|
||||
input_channels, # hidden_channels // (2 ** i)
|
||||
output_channels, # hidden_channels // (2 ** (i + 1))
|
||||
kernel_size,
|
||||
stride=stride,
|
||||
padding=(kernel_size - stride) // 2,
|
||||
)
|
||||
)
|
||||
|
||||
for kernel_size, dilations in zip(resnet_kernel_sizes, resnet_dilations):
|
||||
self.resnets.append(
|
||||
ResBlock(
|
||||
output_channels,
|
||||
kernel_size,
|
||||
dilations=dilations,
|
||||
leaky_relu_negative_slope=leaky_relu_negative_slope,
|
||||
)
|
||||
)
|
||||
input_channels = output_channels
|
||||
|
||||
self.conv_out = nn.Conv1d(output_channels, out_channels, 7, stride=1, padding=3)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, time_last: bool = False) -> torch.Tensor:
|
||||
r"""
|
||||
Forward pass of the vocoder.
|
||||
|
||||
Args:
|
||||
hidden_states (`torch.Tensor`):
|
||||
Input Mel spectrogram tensor of shape `(batch_size, num_channels, time, num_mel_bins)` if `time_last`
|
||||
is `False` (the default) or shape `(batch_size, num_channels, num_mel_bins, time)` if `time_last` is
|
||||
`True`.
|
||||
time_last (`bool`, *optional*, defaults to `False`):
|
||||
Whether the last dimension of the input is the time/frame dimension or the Mel bins dimension.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`:
|
||||
Audio waveform tensor of shape (batch_size, out_channels, audio_length)
|
||||
"""
|
||||
|
||||
# Ensure that the time/frame dimension is last
|
||||
if not time_last:
|
||||
hidden_states = hidden_states.transpose(2, 3)
|
||||
# Combine channels and frequency (mel bins) dimensions
|
||||
hidden_states = hidden_states.flatten(1, 2)
|
||||
|
||||
hidden_states = self.conv_in(hidden_states)
|
||||
|
||||
for i in range(self.num_upsample_layers):
|
||||
hidden_states = F.leaky_relu(hidden_states, negative_slope=self.negative_slope)
|
||||
hidden_states = self.upsamplers[i](hidden_states)
|
||||
|
||||
# Run all resnets in parallel on hidden_states
|
||||
start = i * self.resnets_per_upsample
|
||||
end = (i + 1) * self.resnets_per_upsample
|
||||
resnet_outputs = torch.stack([self.resnets[j](hidden_states) for j in range(start, end)], dim=0)
|
||||
|
||||
hidden_states = torch.mean(resnet_outputs, dim=0)
|
||||
|
||||
# NOTE: unlike the first leaky ReLU, this leaky ReLU is set to use the default F.leaky_relu negative slope of
|
||||
# 0.01 (whereas the others usually use a slope of 0.1). Not sure if this is intended
|
||||
hidden_states = F.leaky_relu(hidden_states, negative_slope=0.01)
|
||||
hidden_states = self.conv_out(hidden_states)
|
||||
hidden_states = torch.tanh(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
105
tests/models/autoencoders/test_models_autoencoder_ltx2_video.py
Normal file
105
tests/models/autoencoders/test_models_autoencoder_ltx2_video.py
Normal file
@@ -0,0 +1,105 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import AutoencoderKLLTX2Video
|
||||
|
||||
from ...testing_utils import (
|
||||
enable_full_determinism,
|
||||
floats_tensor,
|
||||
torch_device,
|
||||
)
|
||||
from ..test_modeling_common import ModelTesterMixin
|
||||
from .testing_utils import AutoencoderTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class AutoencoderKLLTX2VideoTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
|
||||
model_class = AutoencoderKLLTX2Video
|
||||
main_input_name = "sample"
|
||||
base_precision = 1e-2
|
||||
|
||||
def get_autoencoder_kl_ltx_video_config(self):
|
||||
return {
|
||||
"in_channels": 3,
|
||||
"out_channels": 3,
|
||||
"latent_channels": 8,
|
||||
"block_out_channels": (8, 8, 8, 8),
|
||||
"decoder_block_out_channels": (16, 32, 64),
|
||||
"layers_per_block": (1, 1, 1, 1, 1),
|
||||
"decoder_layers_per_block": (1, 1, 1, 1),
|
||||
"spatio_temporal_scaling": (True, True, True, True),
|
||||
"decoder_spatio_temporal_scaling": (True, True, True),
|
||||
"decoder_inject_noise": (False, False, False, False),
|
||||
"downsample_type": ("spatial", "temporal", "spatiotemporal", "spatiotemporal"),
|
||||
"upsample_residual": (True, True, True),
|
||||
"upsample_factor": (2, 2, 2),
|
||||
"timestep_conditioning": False,
|
||||
"patch_size": 1,
|
||||
"patch_size_t": 1,
|
||||
"encoder_causal": True,
|
||||
"decoder_causal": False,
|
||||
"encoder_spatial_padding_mode": "zeros",
|
||||
# Full model uses `reflect` but this does not have deterministic backward implementation, so use `zeros`
|
||||
"decoder_spatial_padding_mode": "zeros",
|
||||
}
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
batch_size = 2
|
||||
num_frames = 9
|
||||
num_channels = 3
|
||||
sizes = (16, 16)
|
||||
|
||||
image = floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device)
|
||||
|
||||
input_dict = {"sample": image}
|
||||
return input_dict
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (3, 9, 16, 16)
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (3, 9, 16, 16)
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = self.get_autoencoder_kl_ltx_video_config()
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {
|
||||
"LTX2VideoEncoder3d",
|
||||
"LTX2VideoDecoder3d",
|
||||
"LTX2VideoDownBlock3D",
|
||||
"LTX2VideoMidBlock3d",
|
||||
"LTX2VideoUpBlock3d",
|
||||
}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
@unittest.skip("Unsupported test.")
|
||||
def test_outputs_equivalence(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("AutoencoderKLLTXVideo does not support `norm_num_groups` because it does not use GroupNorm.")
|
||||
def test_forward_with_norm_groups(self):
|
||||
pass
|
||||
221
tests/models/transformers/test_models_transformer_ltx2.py
Normal file
221
tests/models/transformers/test_models_transformer_ltx2.py
Normal file
@@ -0,0 +1,221 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import LTX2VideoTransformer3DModel, attention_backend
|
||||
|
||||
from ...testing_utils import enable_full_determinism, torch_device
|
||||
from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class LTX2TransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
model_class = LTX2VideoTransformer3DModel
|
||||
main_input_name = "hidden_states"
|
||||
uses_custom_attn_processor = True
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
# Common
|
||||
batch_size = 2
|
||||
|
||||
# Video
|
||||
num_frames = 2
|
||||
num_channels = 4
|
||||
height = 16
|
||||
width = 16
|
||||
|
||||
# Audio
|
||||
audio_num_frames = 9
|
||||
audio_num_channels = 2
|
||||
num_mel_bins = 2
|
||||
|
||||
# Text
|
||||
embedding_dim = 16
|
||||
sequence_length = 16
|
||||
|
||||
hidden_states = torch.randn((batch_size, num_frames * height * width, num_channels)).to(torch_device)
|
||||
audio_hidden_states = torch.randn(
|
||||
(batch_size, audio_num_frames, audio_num_channels * num_mel_bins)
|
||||
).to(torch_device)
|
||||
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
|
||||
audio_encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
|
||||
encoder_attention_mask = torch.ones((batch_size, sequence_length)).bool().to(torch_device)
|
||||
timestep = torch.rand((batch_size,)).to(torch_device)
|
||||
|
||||
return {
|
||||
"hidden_states": hidden_states,
|
||||
"audio_hidden_states": audio_hidden_states,
|
||||
"encoder_hidden_states": encoder_hidden_states,
|
||||
"audio_encoder_hidden_states": audio_encoder_hidden_states,
|
||||
"timestep": timestep,
|
||||
"encoder_attention_mask": encoder_attention_mask,
|
||||
"num_frames": num_frames,
|
||||
"height": height,
|
||||
"width": width,
|
||||
"audio_num_frames": audio_num_frames,
|
||||
"fps": 25.0,
|
||||
}
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (512, 4)
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (512, 4)
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = {
|
||||
"in_channels": 4,
|
||||
"out_channels": 4,
|
||||
"patch_size": 1,
|
||||
"patch_size_t": 1,
|
||||
"num_attention_heads": 2,
|
||||
"attention_head_dim": 8,
|
||||
"cross_attention_dim": 16,
|
||||
"audio_in_channels": 4,
|
||||
"audio_out_channels": 4,
|
||||
"audio_num_attention_heads": 2,
|
||||
"audio_attention_head_dim": 4,
|
||||
"audio_cross_attention_dim": 8,
|
||||
"num_layers": 2,
|
||||
"qk_norm": "rms_norm_across_heads",
|
||||
"caption_channels": 16,
|
||||
}
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {"LTX2VideoTransformer3DModel"}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
def test_ltx2_consistency(self, seed=0, dtype=torch.float32):
|
||||
torch.manual_seed(seed)
|
||||
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
# Calculate dummy inputs in a custom manner to ensure compatibility with original code
|
||||
batch_size = 2
|
||||
num_frames = 9
|
||||
latent_frames = 2
|
||||
text_embedding_dim = 16
|
||||
text_seq_len = 16
|
||||
fps = 25.0
|
||||
sampling_rate = 16000.0
|
||||
hop_length = 160.0
|
||||
|
||||
sigma = torch.rand((1,), generator=torch.manual_seed(seed), dtype=dtype, device="cpu")
|
||||
timestep = (sigma * torch.ones((batch_size,), dtype=dtype, device="cpu")).to(device=torch_device)
|
||||
|
||||
num_channels = 4
|
||||
latent_height = 4
|
||||
latent_width = 4
|
||||
hidden_states = torch.randn(
|
||||
(batch_size, num_channels, latent_frames, latent_height, latent_width),
|
||||
generator=torch.manual_seed(seed),
|
||||
dtype=dtype,
|
||||
device="cpu",
|
||||
)
|
||||
# Patchify video latents (with patch_size (1, 1, 1))
|
||||
hidden_states = hidden_states.reshape(batch_size, -1, latent_frames, 1, latent_height, 1, latent_width, 1)
|
||||
hidden_states = hidden_states.permute(0, 2, 4, 6, 1, 3, 5, 7).flatten(4, 7).flatten(1, 3)
|
||||
encoder_hidden_states = torch.randn(
|
||||
(batch_size, text_seq_len, text_embedding_dim),
|
||||
generator=torch.manual_seed(seed),
|
||||
dtype=dtype,
|
||||
device="cpu",
|
||||
)
|
||||
|
||||
audio_num_channels = 2
|
||||
num_mel_bins = 2
|
||||
latent_length = int((sampling_rate / hop_length / 4) * (num_frames / fps))
|
||||
audio_hidden_states = torch.randn(
|
||||
(batch_size, audio_num_channels, latent_length, num_mel_bins),
|
||||
generator=torch.manual_seed(seed),
|
||||
dtype=dtype,
|
||||
device="cpu",
|
||||
)
|
||||
# Patchify audio latents
|
||||
audio_hidden_states = audio_hidden_states.transpose(1, 2).flatten(2, 3)
|
||||
audio_encoder_hidden_states = torch.randn(
|
||||
(batch_size, text_seq_len, text_embedding_dim),
|
||||
generator=torch.manual_seed(seed),
|
||||
dtype=dtype,
|
||||
device="cpu",
|
||||
)
|
||||
|
||||
inputs_dict = {
|
||||
"hidden_states": hidden_states.to(device=torch_device),
|
||||
"audio_hidden_states": audio_hidden_states.to(device=torch_device),
|
||||
"encoder_hidden_states": encoder_hidden_states.to(device=torch_device),
|
||||
"audio_encoder_hidden_states": audio_encoder_hidden_states.to(device=torch_device),
|
||||
"timestep": timestep,
|
||||
"num_frames": latent_frames,
|
||||
"height": latent_height,
|
||||
"width": latent_width,
|
||||
"audio_num_frames": num_frames,
|
||||
"fps": 25.0,
|
||||
}
|
||||
|
||||
model = self.model_class.from_pretrained(
|
||||
"diffusers-internal-dev/dummy-ltx2",
|
||||
subfolder="transformer",
|
||||
device_map="cpu",
|
||||
)
|
||||
# torch.manual_seed(seed)
|
||||
# model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
with attention_backend("native"):
|
||||
with torch.no_grad():
|
||||
output = model(**inputs_dict)
|
||||
|
||||
video_output, audio_output = output.to_tuple()
|
||||
|
||||
self.assertIsNotNone(video_output)
|
||||
self.assertIsNotNone(audio_output)
|
||||
|
||||
# input & output have to have the same shape
|
||||
video_expected_shape = (batch_size, latent_frames * latent_height * latent_width, num_channels)
|
||||
self.assertEqual(video_output.shape, video_expected_shape, "Video input and output shapes do not match")
|
||||
audio_expected_shape = (batch_size, latent_length, audio_num_channels * num_mel_bins)
|
||||
self.assertEqual(audio_output.shape, audio_expected_shape, "Audio input and output shapes do not match")
|
||||
|
||||
# Check against expected slice
|
||||
# fmt: off
|
||||
video_expected_slice = torch.tensor([0.4783, 1.6954, -1.2092, 0.1762, 0.7801, 1.2025, -1.4525, -0.2721, 0.3354, 1.9144, -1.5546, 0.0831, 0.4391, 1.7012, -1.7373, -0.2676])
|
||||
audio_expected_slice = torch.tensor([-0.4236, 0.4750, 0.3901, -0.4339, -0.2782, 0.4357, 0.4526, -0.3927, -0.0980, 0.4870, 0.3964, -0.3169, -0.3974, 0.4408, 0.3809, -0.4692])
|
||||
# fmt: on
|
||||
|
||||
video_output_flat = video_output.cpu().flatten().float()
|
||||
video_generated_slice = torch.cat([video_output_flat[:8], video_output_flat[-8:]])
|
||||
self.assertTrue(torch.allclose(video_generated_slice, video_expected_slice, atol=1e-4))
|
||||
|
||||
audio_output_flat = audio_output.cpu().flatten().float()
|
||||
audio_generated_slice = torch.cat([audio_output_flat[:8], audio_output_flat[-8:]])
|
||||
self.assertTrue(torch.allclose(audio_generated_slice, audio_expected_slice, atol=1e-4))
|
||||
|
||||
|
||||
class LTX2TransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
|
||||
model_class = LTX2VideoTransformer3DModel
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
return LTX2TransformerTests().prepare_init_args_and_inputs_for_common()
|
||||
Reference in New Issue
Block a user