1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

LTX 2 Single File Support (#12983)

* LTX 2 transformer single file support

* LTX 2 video VAE single file support

* LTX 2 audio VAE single file support

* Make it easier to distinguish LTX 1 and 2 models
This commit is contained in:
dg845
2026-01-15 22:46:42 -08:00
committed by GitHub
parent 74654df203
commit 8af8e86bc7
2 changed files with 184 additions and 1 deletions

View File

@@ -40,6 +40,9 @@ from .single_file_utils import (
convert_hunyuan_video_transformer_to_diffusers,
convert_ldm_unet_checkpoint,
convert_ldm_vae_checkpoint,
convert_ltx2_audio_vae_to_diffusers,
convert_ltx2_transformer_to_diffusers,
convert_ltx2_vae_to_diffusers,
convert_ltx_transformer_checkpoint_to_diffusers,
convert_ltx_vae_checkpoint_to_diffusers,
convert_lumina2_to_diffusers,
@@ -176,6 +179,18 @@ SINGLE_FILE_LOADABLE_CLASSES = {
"ZImageControlNetModel": {
"checkpoint_mapping_fn": convert_z_image_controlnet_checkpoint_to_diffusers,
},
"LTX2VideoTransformer3DModel": {
"checkpoint_mapping_fn": convert_ltx2_transformer_to_diffusers,
"default_subfolder": "transformer",
},
"AutoencoderKLLTX2Video": {
"checkpoint_mapping_fn": convert_ltx2_vae_to_diffusers,
"default_subfolder": "vae",
},
"AutoencoderKLLTX2Audio": {
"checkpoint_mapping_fn": convert_ltx2_audio_vae_to_diffusers,
"default_subfolder": "audio_vae",
},
}

View File

@@ -112,7 +112,8 @@ CHECKPOINT_KEY_NAMES = {
"model.diffusion_model.transformer_blocks.27.scale_shift_table",
"patchify_proj.weight",
"transformer_blocks.27.scale_shift_table",
"vae.per_channel_statistics.mean-of-means",
"vae.decoder.last_scale_shift_table", # 0.9.1, 0.9.5, 0.9.7, 0.9.8
"vae.decoder.up_blocks.9.res_blocks.0.conv1.conv.weight", # 0.9.0
],
"autoencoder-dc": "decoder.stages.1.op_list.0.main.conv.conv.bias",
"autoencoder-dc-sana": "encoder.project_in.conv.bias",
@@ -147,6 +148,11 @@ CHECKPOINT_KEY_NAMES = {
"net.pos_embedder.dim_spatial_range",
],
"flux2": ["model.diffusion_model.single_stream_modulation.lin.weight", "single_stream_modulation.lin.weight"],
"ltx2": [
"model.diffusion_model.av_ca_a2v_gate_adaln_single.emb.timestep_embedder.linear_1.weight",
"vae.per_channel_statistics.mean-of-means",
"audio_vae.per_channel_statistics.mean-of-means",
],
}
DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
@@ -228,6 +234,7 @@ DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
"z-image-turbo-controlnet": {"pretrained_model_name_or_path": "hlky/Z-Image-Turbo-Fun-Controlnet-Union"},
"z-image-turbo-controlnet-2.0": {"pretrained_model_name_or_path": "hlky/Z-Image-Turbo-Fun-Controlnet-Union-2.0"},
"z-image-turbo-controlnet-2.1": {"pretrained_model_name_or_path": "hlky/Z-Image-Turbo-Fun-Controlnet-Union-2.1"},
"ltx2-dev": {"pretrained_model_name_or_path": "Lightricks/LTX-2"},
}
# Use to configure model sample size when original config is provided
@@ -796,6 +803,9 @@ def infer_diffusers_model_type(checkpoint):
elif CHECKPOINT_KEY_NAMES["z-image-turbo-controlnet"] in checkpoint:
model_type = "z-image-turbo-controlnet"
elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["ltx2"]):
model_type = "ltx2-dev"
else:
model_type = "v1"
@@ -3920,3 +3930,161 @@ def convert_z_image_controlnet_checkpoint_to_diffusers(checkpoint, config, **kwa
return converted_state_dict
else:
raise ValueError("Unknown Z-Image Turbo ControlNet type.")
def convert_ltx2_transformer_to_diffusers(checkpoint, **kwargs):
LTX_2_0_TRANSFORMER_KEYS_RENAME_DICT = {
# Transformer prefix
"model.diffusion_model.": "",
# Input Patchify Projections
"patchify_proj": "proj_in",
"audio_patchify_proj": "audio_proj_in",
# Modulation Parameters
# Handle adaln_single --> time_embed, audioln_single --> audio_time_embed separately as the original keys are
# substrings of the other modulation parameters below
"av_ca_video_scale_shift_adaln_single": "av_cross_attn_video_scale_shift",
"av_ca_a2v_gate_adaln_single": "av_cross_attn_video_a2v_gate",
"av_ca_audio_scale_shift_adaln_single": "av_cross_attn_audio_scale_shift",
"av_ca_v2a_gate_adaln_single": "av_cross_attn_audio_v2a_gate",
# Transformer Blocks
# Per-Block Cross Attention Modulation Parameters
"scale_shift_table_a2v_ca_video": "video_a2v_cross_attn_scale_shift_table",
"scale_shift_table_a2v_ca_audio": "audio_a2v_cross_attn_scale_shift_table",
# Attention QK Norms
"q_norm": "norm_q",
"k_norm": "norm_k",
}
def update_state_dict_inplace(state_dict, old_key: str, new_key: str) -> None:
state_dict[new_key] = state_dict.pop(old_key)
def remove_keys_inplace(key: str, state_dict) -> None:
state_dict.pop(key)
def convert_ltx2_transformer_adaln_single(key: str, state_dict) -> None:
# Skip if not a weight, bias
if ".weight" not in key and ".bias" not in key:
return
if key.startswith("adaln_single."):
new_key = key.replace("adaln_single.", "time_embed.")
param = state_dict.pop(key)
state_dict[new_key] = param
if key.startswith("audio_adaln_single."):
new_key = key.replace("audio_adaln_single.", "audio_time_embed.")
param = state_dict.pop(key)
state_dict[new_key] = param
return
LTX_2_0_TRANSFORMER_SPECIAL_KEYS_REMAP = {
"video_embeddings_connector": remove_keys_inplace,
"audio_embeddings_connector": remove_keys_inplace,
"adaln_single": convert_ltx2_transformer_adaln_single,
}
converted_state_dict = {key: checkpoint.pop(key) for key in list(checkpoint.keys())}
# Handle official code --> diffusers key remapping via the remap dict
for key in list(converted_state_dict.keys()):
new_key = key[:]
for replace_key, rename_key in LTX_2_0_TRANSFORMER_KEYS_RENAME_DICT.items():
new_key = new_key.replace(replace_key, rename_key)
update_state_dict_inplace(converted_state_dict, key, new_key)
# Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in
# special_keys_remap
for key in list(converted_state_dict.keys()):
for special_key, handler_fn_inplace in LTX_2_0_TRANSFORMER_SPECIAL_KEYS_REMAP.items():
if special_key not in key:
continue
handler_fn_inplace(key, converted_state_dict)
return converted_state_dict
def convert_ltx2_vae_to_diffusers(checkpoint, **kwargs):
LTX_2_0_VIDEO_VAE_RENAME_DICT = {
# Video VAE prefix
"vae.": "",
# Encoder
"down_blocks.0": "down_blocks.0",
"down_blocks.1": "down_blocks.0.downsamplers.0",
"down_blocks.2": "down_blocks.1",
"down_blocks.3": "down_blocks.1.downsamplers.0",
"down_blocks.4": "down_blocks.2",
"down_blocks.5": "down_blocks.2.downsamplers.0",
"down_blocks.6": "down_blocks.3",
"down_blocks.7": "down_blocks.3.downsamplers.0",
"down_blocks.8": "mid_block",
# Decoder
"up_blocks.0": "mid_block",
"up_blocks.1": "up_blocks.0.upsamplers.0",
"up_blocks.2": "up_blocks.0",
"up_blocks.3": "up_blocks.1.upsamplers.0",
"up_blocks.4": "up_blocks.1",
"up_blocks.5": "up_blocks.2.upsamplers.0",
"up_blocks.6": "up_blocks.2",
# Common
# For all 3D ResNets
"res_blocks": "resnets",
"per_channel_statistics.mean-of-means": "latents_mean",
"per_channel_statistics.std-of-means": "latents_std",
}
def update_state_dict_inplace(state_dict, old_key: str, new_key: str) -> None:
state_dict[new_key] = state_dict.pop(old_key)
def remove_keys_inplace(key: str, state_dict) -> None:
state_dict.pop(key)
LTX_2_0_VAE_SPECIAL_KEYS_REMAP = {
"per_channel_statistics.channel": remove_keys_inplace,
"per_channel_statistics.mean-of-stds": remove_keys_inplace,
}
converted_state_dict = {key: checkpoint.pop(key) for key in list(checkpoint.keys())}
# Handle official code --> diffusers key remapping via the remap dict
for key in list(converted_state_dict.keys()):
new_key = key[:]
for replace_key, rename_key in LTX_2_0_VIDEO_VAE_RENAME_DICT.items():
new_key = new_key.replace(replace_key, rename_key)
update_state_dict_inplace(converted_state_dict, key, new_key)
# Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in
# special_keys_remap
for key in list(converted_state_dict.keys()):
for special_key, handler_fn_inplace in LTX_2_0_VAE_SPECIAL_KEYS_REMAP.items():
if special_key not in key:
continue
handler_fn_inplace(key, converted_state_dict)
return converted_state_dict
def convert_ltx2_audio_vae_to_diffusers(checkpoint, **kwargs):
LTX_2_0_AUDIO_VAE_RENAME_DICT = {
# Audio VAE prefix
"audio_vae.": "",
"per_channel_statistics.mean-of-means": "latents_mean",
"per_channel_statistics.std-of-means": "latents_std",
}
def update_state_dict_inplace(state_dict, old_key: str, new_key: str) -> None:
state_dict[new_key] = state_dict.pop(old_key)
converted_state_dict = {key: checkpoint.pop(key) for key in list(checkpoint.keys())}
# Handle official code --> diffusers key remapping via the remap dict
for key in list(converted_state_dict.keys()):
new_key = key[:]
for replace_key, rename_key in LTX_2_0_AUDIO_VAE_RENAME_DICT.items():
new_key = new_key.replace(replace_key, rename_key)
update_state_dict_inplace(converted_state_dict, key, new_key)
return converted_state_dict