1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Move Video and Audio Text Encoder Connectors to Transformer (#12)

* Denormalize audio latents in I2V pipeline (analogous to T2V change)

* Initial refactor to put video and audio text encoder connectors in transformer

* Get LTX 2 transformer tests working after connector refactor

* precompute run_connectors,.

* fixes

* Address review comments

* Calculate RoPE double precisions freqs using torch instead of np

* Further simplify LTX 2 RoPE freq calc

* Make connectors a separate module (#18)

* remove text_encoder.py

* address yiyi's comments.

* up

* up

* up

* up

---------

Co-authored-by: sayakpaul <spsayakpaul@gmail.com>
This commit is contained in:
dg845
2026-01-05 06:41:13 -08:00
committed by GitHub
parent aae70b90db
commit caae16768a
8 changed files with 629 additions and 815 deletions

View File

@@ -8,18 +8,11 @@ import safetensors.torch
import torch
from accelerate import init_empty_weights
from huggingface_hub import hf_hub_download
from transformers import AutoModel, AutoTokenizer
from transformers import AutoModel, AutoTokenizer, Gemma3ForConditionalGeneration
from diffusers import (
AutoencoderKLLTX2Audio,
AutoencoderKLLTX2Video,
FlowMatchEulerDiscreteScheduler,
LTX2Pipeline,
LTX2VideoTransformer3DModel,
)
from diffusers import AutoencoderKLLTX2Audio, AutoencoderKLLTX2Video, FlowMatchEulerDiscreteScheduler, LTX2Pipeline, LTX2VideoTransformer3DModel
from diffusers.pipelines.ltx2 import LTX2TextConnectors, LTX2Vocoder
from diffusers.utils.import_utils import is_accelerate_available
from diffusers.pipelines.ltx2.text_encoder import LTX2AudioVisualTextEncoder
from diffusers.pipelines.ltx2.vocoder import LTX2Vocoder
CTX = init_empty_weights if is_accelerate_available() else nullcontext
@@ -134,6 +127,17 @@ LTX_2_0_TRANSFORMER_SPECIAL_KEYS_REMAP = {
"adaln_single": convert_ltx2_transformer_adaln_single,
}
LTX_2_0_CONNECTORS_KEYS_RENAME_DICT = {
"connectors.": "",
"video_embeddings_connector": "video_connector",
"audio_embeddings_connector": "audio_connector",
"transformer_1d_blocks": "transformer_blocks",
"text_embedding_projection.aggregate_embed": "text_proj_in",
# Attention QK Norms
"q_norm": "norm_q",
"k_norm": "norm_k",
}
LTX_2_0_VAE_SPECIAL_KEYS_REMAP = {
"per_channel_statistics.channel": remove_keys_inplace,
"per_channel_statistics.mean-of-stds": remove_keys_inplace,
@@ -146,7 +150,27 @@ LTX_2_0_AUDIO_VAE_SPECIAL_KEYS_REMAP = {
LTX_2_0_VOCODER_SPECIAL_KEYS_REMAP = {}
LTX_2_0_TEXT_ENCODER_SPECIAL_KEYS_REMAP = {}
def split_transformer_and_connector_state_dict(state_dict: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
connector_prefixes = (
"video_embeddings_connector",
"audio_embeddings_connector",
"transformer_1d_blocks",
"text_embedding_projection.aggregate_embed",
"connectors.",
"video_connector",
"audio_connector",
"text_proj_in",
)
transformer_state_dict, connector_state_dict = {}, {}
for key, value in state_dict.items():
if key.startswith(connector_prefixes):
connector_state_dict[key] = value
else:
transformer_state_dict[key] = value
return transformer_state_dict, connector_state_dict
def get_ltx2_transformer_config(version: str) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
@@ -240,32 +264,109 @@ def get_ltx2_transformer_config(version: str) -> Tuple[Dict[str, Any], Dict[str,
return config, rename_dict, special_keys_remap
def get_ltx2_connectors_config(version: str) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
if version == "test":
config = {
"model_id": "diffusers-internal-dev/dummy-ltx2",
"diffusers_config": {
"caption_channels": 16,
"text_proj_in_factor": 3,
"video_connector_num_attention_heads": 4,
"video_connector_attention_head_dim": 8,
"video_connector_num_layers": 1,
"video_connector_num_learnable_registers": None,
"audio_connector_num_attention_heads": 4,
"audio_connector_attention_head_dim": 8,
"audio_connector_num_layers": 1,
"audio_connector_num_learnable_registers": None,
"connector_rope_base_seq_len": 32,
"rope_theta": 10000.0,
"rope_double_precision": False,
"causal_temporal_positioning": False,
},
}
elif version == "2.0":
config = {
"model_id": "diffusers-internal-dev/new-ltx-model",
"diffusers_config": {
"caption_channels": 3840,
"text_proj_in_factor": 49,
"video_connector_num_attention_heads": 30,
"video_connector_attention_head_dim": 128,
"video_connector_num_layers": 2,
"video_connector_num_learnable_registers": 128,
"audio_connector_num_attention_heads": 30,
"audio_connector_attention_head_dim": 128,
"audio_connector_num_layers": 2,
"audio_connector_num_learnable_registers": 128,
"connector_rope_base_seq_len": 4096,
"rope_theta": 10000.0,
"rope_double_precision": True,
"causal_temporal_positioning": False,
},
}
rename_dict = LTX_2_0_CONNECTORS_KEYS_RENAME_DICT
special_keys_remap = {}
return config, rename_dict, special_keys_remap
def convert_ltx2_transformer(original_state_dict: Dict[str, Any], version: str) -> Dict[str, Any]:
config, rename_dict, special_keys_remap = get_ltx2_transformer_config(version)
diffusers_config = config["diffusers_config"]
transformer_state_dict, _ = split_transformer_and_connector_state_dict(original_state_dict)
with init_empty_weights():
transformer = LTX2VideoTransformer3DModel.from_config(diffusers_config)
# Handle official code --> diffusers key remapping via the remap dict
for key in list(original_state_dict.keys()):
for key in list(transformer_state_dict.keys()):
new_key = key[:]
for replace_key, rename_key in rename_dict.items():
new_key = new_key.replace(replace_key, rename_key)
update_state_dict_inplace(original_state_dict, key, new_key)
update_state_dict_inplace(transformer_state_dict, key, new_key)
# Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in
# special_keys_remap
for key in list(original_state_dict.keys()):
for key in list(transformer_state_dict.keys()):
for special_key, handler_fn_inplace in special_keys_remap.items():
if special_key not in key:
continue
handler_fn_inplace(key, original_state_dict)
handler_fn_inplace(key, transformer_state_dict)
transformer.load_state_dict(original_state_dict, strict=True, assign=True)
transformer.load_state_dict(transformer_state_dict, strict=True, assign=True)
return transformer
def convert_ltx2_connectors(original_state_dict: Dict[str, Any], version: str) -> LTX2TextConnectors:
config, rename_dict, special_keys_remap = get_ltx2_connectors_config(version)
diffusers_config = config["diffusers_config"]
_, connector_state_dict = split_transformer_and_connector_state_dict(original_state_dict)
if len(connector_state_dict) == 0:
raise ValueError("No connector weights found in the provided state dict.")
with init_empty_weights():
connectors = LTX2TextConnectors.from_config(diffusers_config)
for key in list(connector_state_dict.keys()):
new_key = key[:]
for replace_key, rename_key in rename_dict.items():
new_key = new_key.replace(replace_key, rename_key)
update_state_dict_inplace(connector_state_dict, key, new_key)
for key in list(connector_state_dict.keys()):
for special_key, handler_fn_inplace in special_keys_remap.items():
if special_key not in key:
continue
handler_fn_inplace(key, connector_state_dict)
connectors.load_state_dict(connector_state_dict, strict=True, assign=True)
return connectors
def get_ltx2_video_vae_config(version: str) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
if version == "test":
config = {
@@ -471,81 +572,6 @@ def convert_ltx2_vocoder(original_state_dict: Dict[str, Any], version: str) -> D
return vocoder
def get_ltx2_text_encoder_config(version: str) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
if version == "2.0":
config = {
"model_id": "diffusers-internal-dev/new-ltx-model",
"diffusers_config": {
"text_encoder_hidden_dim": 3840,
"text_proj_in_factor": 49,
"video_connector_num_attention_heads": 30,
"video_connector_attention_head_dim": 128,
"video_connector_num_layers": 2,
"video_connector_num_learnable_registers": 128,
"audio_connector_num_attention_heads": 30,
"audio_connector_attention_head_dim": 128,
"audio_connector_num_layers": 2,
"audio_connector_num_learnable_registers": 128,
"rope_base_seq_len": 4096,
"rope_theta": 10000.0,
"rope_double_precision": True,
"causal_temporal_positioning": False,
},
}
rename_dict = LTX_2_0_TEXT_ENCODER_RENAME_DICT
special_keys_remap = LTX_2_0_TEXT_ENCODER_SPECIAL_KEYS_REMAP
return config, rename_dict, special_keys_remap
def get_text_encoder_keys_from_combined_ckpt(combined_ckpt: Dict[str, Any], prefix: str = "model.diffusion_model."):
model_state_dict = {}
model_state_dict["text_proj_in.weight"] = combined_ckpt["text_embedding_projection.aggregate_embed.weight"]
text_encoder_submodules = ["video_embeddings_connector", "audio_embeddings_connector"]
for param_name, param in combined_ckpt.items():
if param_name.startswith(prefix):
new_param_name = param_name.replace(prefix, "")
for submodule_name in text_encoder_submodules:
if new_param_name.startswith(submodule_name):
model_state_dict[new_param_name] = param
break
return model_state_dict
def convert_ltx2_text_encoder(original_state_dict: Dict[str, Any], version: str, text_model_id: str) -> Dict[str, Any]:
config, rename_dict, special_keys_remap = get_ltx2_text_encoder_config(version)
diffusers_config = config["diffusers_config"]
diffusers_config["text_model_id"] = text_model_id
diffusers_config["config_only"] = True
with init_empty_weights():
text_encoder = LTX2AudioVisualTextEncoder.from_config(diffusers_config)
# Handle official code --> diffusers key remapping via the remap dict
for key in list(original_state_dict.keys()):
new_key = key[:]
for replace_key, rename_key in rename_dict.items():
new_key = new_key.replace(replace_key, rename_key)
update_state_dict_inplace(original_state_dict, key, new_key)
# Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in
# special_keys_remap
for key in list(original_state_dict.keys()):
for special_key, handler_fn_inplace in special_keys_remap.items():
if special_key not in key:
continue
handler_fn_inplace(key, original_state_dict)
base_text_model = AutoModel.from_pretrained(text_model_id)
base_text_model_state_dict= base_text_model.state_dict()
base_text_model_state_dict = {"base_text_encoder." + k: v for k, v in base_text_model_state_dict.items()}
combined_state_dict = {**original_state_dict, **base_text_model_state_dict}
text_encoder.load_state_dict(combined_state_dict, strict=True, assign=True)
return text_encoder
def load_original_checkpoint(args, filename: Optional[str]) -> Dict[str, Any]:
if args.original_state_dict_repo_id is not None:
@@ -588,6 +614,13 @@ def get_model_state_dict_from_combined_ckpt(combined_ckpt: Dict[str, Any], prefi
for param_name, param in combined_ckpt.items():
if param_name.startswith(prefix):
model_state_dict[param_name.replace(prefix, "")] = param
if prefix == "model.diffusion_model.":
# Some checkpoints store the text connector projection outside the diffusion model prefix.
connector_key = "text_embedding_projection.aggregate_embed.weight"
if connector_key in combined_ckpt and connector_key not in model_state_dict:
model_state_dict[connector_key] = combined_ckpt[connector_key]
return model_state_dict
@@ -649,6 +682,7 @@ def get_args():
parser.add_argument("--vae", action="store_true", help="Whether to convert the video VAE model")
parser.add_argument("--audio_vae", action="store_true", help="Whether to convert the audio VAE model")
parser.add_argument("--dit", action="store_true", help="Whether to convert the DiT model")
parser.add_argument("--connectors", action="store_true", help="Whether to convert the connector model")
parser.add_argument("--vocoder", action="store_true", help="Whether to convert the vocoder model")
parser.add_argument("--text_encoder", action="store_true", help="Whether to conver the text encoder")
parser.add_argument(
@@ -721,6 +755,15 @@ def main(args):
transformer = convert_ltx2_transformer(original_dit_ckpt, version=args.version)
if not args.full_pipeline:
transformer.to(dit_dtype).save_pretrained(os.path.join(args.output_path, "transformer"))
if args.connectors or args.full_pipeline:
if args.dit_filename is not None:
original_connectors_ckpt = load_hub_or_local_checkpoint(filename=args.dit_filename)
elif combined_ckpt is not None:
original_connectors_ckpt = get_model_state_dict_from_combined_ckpt(combined_ckpt, args.dit_prefix)
connectors = convert_ltx2_connectors(original_connectors_ckpt, version=args.version)
if not args.full_pipeline:
connectors.to(dit_dtype).save_pretrained(os.path.join(args.output_path, "connectors"))
if args.vocoder or args.full_pipeline:
if args.vocoder_filename is not None:
@@ -732,8 +775,8 @@ def main(args):
vocoder.to(vocoder_dtype).save_pretrained(os.path.join(args.output_path, "vocoder"))
if args.text_encoder or args.full_pipeline:
text_encoder_ckpt = get_text_encoder_keys_from_combined_ckpt(combined_ckpt)
text_encoder = convert_ltx2_text_encoder(text_encoder_ckpt, args.version, args.text_encoder_model_id)
# text_encoder = AutoModel.from_pretrained(args.text_encoder_model_id)
text_encoder = Gemma3ForConditionalGeneration.from_pretrained(args.text_encoder_model_id)
if not args.full_pipeline:
text_encoder.to(text_encoder_dtype).save_pretrained(os.path.join(args.output_path, "text_encoder"))
@@ -758,6 +801,7 @@ def main(args):
audio_vae=audio_vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
connectors=connectors,
transformer=transformer,
vocoder=vocoder,
)