From 57a8b9c3300201cc9609b882c3229bce3eb5cfeb Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Tue, 16 Dec 2025 10:38:03 +0100 Subject: [PATCH] Allow LTX 2 transformer to be loaded from local path for conversion --- scripts/convert_ltx2_to_diffusers.py | 22 +++++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/scripts/convert_ltx2_to_diffusers.py b/scripts/convert_ltx2_to_diffusers.py index 286e2aed42..312559dbee 100644 --- a/scripts/convert_ltx2_to_diffusers.py +++ b/scripts/convert_ltx2_to_diffusers.py @@ -192,6 +192,26 @@ def load_original_checkpoint(args, filename: Optional[str]) -> Dict[str, Any]: return original_state_dict +def load_hub_or_local_checkpoint(repo_id: Optional[str] = None, filename: Optional[str] = None) -> Dict[str, Any]: + if repo_id is None and filename is None: + raise ValueError("Please supply at least one of `repo_id` or `filename`") + + if repo_id is not None: + if filename is None: + raise ValueError("If repo_id is specified, filename must also be specified.") + ckpt_path = hf_hub_download(repo_id=repo_id, filename=filename) + else: + ckpt_path = filename + + _, ext = os.path.splitext(ckpt_path) + if ext in [".safetensors", ".sft"]: + state_dict = safetensors.torch.load_file(ckpt_path) + else: + state_dict = torch.load(ckpt_path, map_location="cpu") + + return state_dict + + def get_model_state_dict_from_combined_ckpt(combined_ckpt: Dict[str, Any], prefix: str) -> Dict[str, Any]: # Ensure that the key prefix ends with a dot (.) if not prefix.endswith("."): @@ -299,7 +319,7 @@ def main(args): if args.dit or args.full_pipeline: if args.dit_filename is not None: - original_dit_ckpt = load_original_checkpoint(args, filename=args.dit_filename) + original_dit_ckpt = load_hub_or_local_checkpoint(filename=args.dit_filename) elif combined_ckpt is not None: original_dit_ckpt = get_model_state_dict_from_combined_ckpt(combined_ckpt, args.dit_prefix) transformer = convert_ltx2_transformer(original_dit_ckpt, version=args.version)