mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
LTX 2 Single File Support (#12983)
* LTX 2 transformer single file support * LTX 2 video VAE single file support * LTX 2 audio VAE single file support * Make it easier to distinguish LTX 1 and 2 models
This commit is contained in:
@@ -40,6 +40,9 @@ from .single_file_utils import (
|
||||
convert_hunyuan_video_transformer_to_diffusers,
|
||||
convert_ldm_unet_checkpoint,
|
||||
convert_ldm_vae_checkpoint,
|
||||
convert_ltx2_audio_vae_to_diffusers,
|
||||
convert_ltx2_transformer_to_diffusers,
|
||||
convert_ltx2_vae_to_diffusers,
|
||||
convert_ltx_transformer_checkpoint_to_diffusers,
|
||||
convert_ltx_vae_checkpoint_to_diffusers,
|
||||
convert_lumina2_to_diffusers,
|
||||
@@ -176,6 +179,18 @@ SINGLE_FILE_LOADABLE_CLASSES = {
|
||||
"ZImageControlNetModel": {
|
||||
"checkpoint_mapping_fn": convert_z_image_controlnet_checkpoint_to_diffusers,
|
||||
},
|
||||
"LTX2VideoTransformer3DModel": {
|
||||
"checkpoint_mapping_fn": convert_ltx2_transformer_to_diffusers,
|
||||
"default_subfolder": "transformer",
|
||||
},
|
||||
"AutoencoderKLLTX2Video": {
|
||||
"checkpoint_mapping_fn": convert_ltx2_vae_to_diffusers,
|
||||
"default_subfolder": "vae",
|
||||
},
|
||||
"AutoencoderKLLTX2Audio": {
|
||||
"checkpoint_mapping_fn": convert_ltx2_audio_vae_to_diffusers,
|
||||
"default_subfolder": "audio_vae",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -112,7 +112,8 @@ CHECKPOINT_KEY_NAMES = {
|
||||
"model.diffusion_model.transformer_blocks.27.scale_shift_table",
|
||||
"patchify_proj.weight",
|
||||
"transformer_blocks.27.scale_shift_table",
|
||||
"vae.per_channel_statistics.mean-of-means",
|
||||
"vae.decoder.last_scale_shift_table", # 0.9.1, 0.9.5, 0.9.7, 0.9.8
|
||||
"vae.decoder.up_blocks.9.res_blocks.0.conv1.conv.weight", # 0.9.0
|
||||
],
|
||||
"autoencoder-dc": "decoder.stages.1.op_list.0.main.conv.conv.bias",
|
||||
"autoencoder-dc-sana": "encoder.project_in.conv.bias",
|
||||
@@ -147,6 +148,11 @@ CHECKPOINT_KEY_NAMES = {
|
||||
"net.pos_embedder.dim_spatial_range",
|
||||
],
|
||||
"flux2": ["model.diffusion_model.single_stream_modulation.lin.weight", "single_stream_modulation.lin.weight"],
|
||||
"ltx2": [
|
||||
"model.diffusion_model.av_ca_a2v_gate_adaln_single.emb.timestep_embedder.linear_1.weight",
|
||||
"vae.per_channel_statistics.mean-of-means",
|
||||
"audio_vae.per_channel_statistics.mean-of-means",
|
||||
],
|
||||
}
|
||||
|
||||
DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
|
||||
@@ -228,6 +234,7 @@ DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
|
||||
"z-image-turbo-controlnet": {"pretrained_model_name_or_path": "hlky/Z-Image-Turbo-Fun-Controlnet-Union"},
|
||||
"z-image-turbo-controlnet-2.0": {"pretrained_model_name_or_path": "hlky/Z-Image-Turbo-Fun-Controlnet-Union-2.0"},
|
||||
"z-image-turbo-controlnet-2.1": {"pretrained_model_name_or_path": "hlky/Z-Image-Turbo-Fun-Controlnet-Union-2.1"},
|
||||
"ltx2-dev": {"pretrained_model_name_or_path": "Lightricks/LTX-2"},
|
||||
}
|
||||
|
||||
# Use to configure model sample size when original config is provided
|
||||
@@ -796,6 +803,9 @@ def infer_diffusers_model_type(checkpoint):
|
||||
elif CHECKPOINT_KEY_NAMES["z-image-turbo-controlnet"] in checkpoint:
|
||||
model_type = "z-image-turbo-controlnet"
|
||||
|
||||
elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["ltx2"]):
|
||||
model_type = "ltx2-dev"
|
||||
|
||||
else:
|
||||
model_type = "v1"
|
||||
|
||||
@@ -3920,3 +3930,161 @@ def convert_z_image_controlnet_checkpoint_to_diffusers(checkpoint, config, **kwa
|
||||
return converted_state_dict
|
||||
else:
|
||||
raise ValueError("Unknown Z-Image Turbo ControlNet type.")
|
||||
|
||||
|
||||
def convert_ltx2_transformer_to_diffusers(checkpoint, **kwargs):
|
||||
LTX_2_0_TRANSFORMER_KEYS_RENAME_DICT = {
|
||||
# Transformer prefix
|
||||
"model.diffusion_model.": "",
|
||||
# Input Patchify Projections
|
||||
"patchify_proj": "proj_in",
|
||||
"audio_patchify_proj": "audio_proj_in",
|
||||
# Modulation Parameters
|
||||
# Handle adaln_single --> time_embed, audioln_single --> audio_time_embed separately as the original keys are
|
||||
# substrings of the other modulation parameters below
|
||||
"av_ca_video_scale_shift_adaln_single": "av_cross_attn_video_scale_shift",
|
||||
"av_ca_a2v_gate_adaln_single": "av_cross_attn_video_a2v_gate",
|
||||
"av_ca_audio_scale_shift_adaln_single": "av_cross_attn_audio_scale_shift",
|
||||
"av_ca_v2a_gate_adaln_single": "av_cross_attn_audio_v2a_gate",
|
||||
# Transformer Blocks
|
||||
# Per-Block Cross Attention Modulation Parameters
|
||||
"scale_shift_table_a2v_ca_video": "video_a2v_cross_attn_scale_shift_table",
|
||||
"scale_shift_table_a2v_ca_audio": "audio_a2v_cross_attn_scale_shift_table",
|
||||
# Attention QK Norms
|
||||
"q_norm": "norm_q",
|
||||
"k_norm": "norm_k",
|
||||
}
|
||||
|
||||
def update_state_dict_inplace(state_dict, old_key: str, new_key: str) -> None:
|
||||
state_dict[new_key] = state_dict.pop(old_key)
|
||||
|
||||
def remove_keys_inplace(key: str, state_dict) -> None:
|
||||
state_dict.pop(key)
|
||||
|
||||
def convert_ltx2_transformer_adaln_single(key: str, state_dict) -> None:
|
||||
# Skip if not a weight, bias
|
||||
if ".weight" not in key and ".bias" not in key:
|
||||
return
|
||||
|
||||
if key.startswith("adaln_single."):
|
||||
new_key = key.replace("adaln_single.", "time_embed.")
|
||||
param = state_dict.pop(key)
|
||||
state_dict[new_key] = param
|
||||
|
||||
if key.startswith("audio_adaln_single."):
|
||||
new_key = key.replace("audio_adaln_single.", "audio_time_embed.")
|
||||
param = state_dict.pop(key)
|
||||
state_dict[new_key] = param
|
||||
|
||||
return
|
||||
|
||||
LTX_2_0_TRANSFORMER_SPECIAL_KEYS_REMAP = {
|
||||
"video_embeddings_connector": remove_keys_inplace,
|
||||
"audio_embeddings_connector": remove_keys_inplace,
|
||||
"adaln_single": convert_ltx2_transformer_adaln_single,
|
||||
}
|
||||
|
||||
converted_state_dict = {key: checkpoint.pop(key) for key in list(checkpoint.keys())}
|
||||
|
||||
# Handle official code --> diffusers key remapping via the remap dict
|
||||
for key in list(converted_state_dict.keys()):
|
||||
new_key = key[:]
|
||||
for replace_key, rename_key in LTX_2_0_TRANSFORMER_KEYS_RENAME_DICT.items():
|
||||
new_key = new_key.replace(replace_key, rename_key)
|
||||
|
||||
update_state_dict_inplace(converted_state_dict, key, new_key)
|
||||
|
||||
# Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in
|
||||
# special_keys_remap
|
||||
for key in list(converted_state_dict.keys()):
|
||||
for special_key, handler_fn_inplace in LTX_2_0_TRANSFORMER_SPECIAL_KEYS_REMAP.items():
|
||||
if special_key not in key:
|
||||
continue
|
||||
handler_fn_inplace(key, converted_state_dict)
|
||||
|
||||
return converted_state_dict
|
||||
|
||||
|
||||
def convert_ltx2_vae_to_diffusers(checkpoint, **kwargs):
|
||||
LTX_2_0_VIDEO_VAE_RENAME_DICT = {
|
||||
# Video VAE prefix
|
||||
"vae.": "",
|
||||
# Encoder
|
||||
"down_blocks.0": "down_blocks.0",
|
||||
"down_blocks.1": "down_blocks.0.downsamplers.0",
|
||||
"down_blocks.2": "down_blocks.1",
|
||||
"down_blocks.3": "down_blocks.1.downsamplers.0",
|
||||
"down_blocks.4": "down_blocks.2",
|
||||
"down_blocks.5": "down_blocks.2.downsamplers.0",
|
||||
"down_blocks.6": "down_blocks.3",
|
||||
"down_blocks.7": "down_blocks.3.downsamplers.0",
|
||||
"down_blocks.8": "mid_block",
|
||||
# Decoder
|
||||
"up_blocks.0": "mid_block",
|
||||
"up_blocks.1": "up_blocks.0.upsamplers.0",
|
||||
"up_blocks.2": "up_blocks.0",
|
||||
"up_blocks.3": "up_blocks.1.upsamplers.0",
|
||||
"up_blocks.4": "up_blocks.1",
|
||||
"up_blocks.5": "up_blocks.2.upsamplers.0",
|
||||
"up_blocks.6": "up_blocks.2",
|
||||
# Common
|
||||
# For all 3D ResNets
|
||||
"res_blocks": "resnets",
|
||||
"per_channel_statistics.mean-of-means": "latents_mean",
|
||||
"per_channel_statistics.std-of-means": "latents_std",
|
||||
}
|
||||
|
||||
def update_state_dict_inplace(state_dict, old_key: str, new_key: str) -> None:
|
||||
state_dict[new_key] = state_dict.pop(old_key)
|
||||
|
||||
def remove_keys_inplace(key: str, state_dict) -> None:
|
||||
state_dict.pop(key)
|
||||
|
||||
LTX_2_0_VAE_SPECIAL_KEYS_REMAP = {
|
||||
"per_channel_statistics.channel": remove_keys_inplace,
|
||||
"per_channel_statistics.mean-of-stds": remove_keys_inplace,
|
||||
}
|
||||
|
||||
converted_state_dict = {key: checkpoint.pop(key) for key in list(checkpoint.keys())}
|
||||
|
||||
# Handle official code --> diffusers key remapping via the remap dict
|
||||
for key in list(converted_state_dict.keys()):
|
||||
new_key = key[:]
|
||||
for replace_key, rename_key in LTX_2_0_VIDEO_VAE_RENAME_DICT.items():
|
||||
new_key = new_key.replace(replace_key, rename_key)
|
||||
|
||||
update_state_dict_inplace(converted_state_dict, key, new_key)
|
||||
|
||||
# Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in
|
||||
# special_keys_remap
|
||||
for key in list(converted_state_dict.keys()):
|
||||
for special_key, handler_fn_inplace in LTX_2_0_VAE_SPECIAL_KEYS_REMAP.items():
|
||||
if special_key not in key:
|
||||
continue
|
||||
handler_fn_inplace(key, converted_state_dict)
|
||||
|
||||
return converted_state_dict
|
||||
|
||||
|
||||
def convert_ltx2_audio_vae_to_diffusers(checkpoint, **kwargs):
|
||||
LTX_2_0_AUDIO_VAE_RENAME_DICT = {
|
||||
# Audio VAE prefix
|
||||
"audio_vae.": "",
|
||||
"per_channel_statistics.mean-of-means": "latents_mean",
|
||||
"per_channel_statistics.std-of-means": "latents_std",
|
||||
}
|
||||
|
||||
def update_state_dict_inplace(state_dict, old_key: str, new_key: str) -> None:
|
||||
state_dict[new_key] = state_dict.pop(old_key)
|
||||
|
||||
converted_state_dict = {key: checkpoint.pop(key) for key in list(checkpoint.keys())}
|
||||
|
||||
# Handle official code --> diffusers key remapping via the remap dict
|
||||
for key in list(converted_state_dict.keys()):
|
||||
new_key = key[:]
|
||||
for replace_key, rename_key in LTX_2_0_AUDIO_VAE_RENAME_DICT.items():
|
||||
new_key = new_key.replace(replace_key, rename_key)
|
||||
|
||||
update_state_dict_inplace(converted_state_dict, key, new_key)
|
||||
|
||||
return converted_state_dict
|
||||
|
||||
Reference in New Issue
Block a user