1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Add new LTX 2.0 spatial latent upsampler logic

This commit is contained in:
Daniel Gu
2026-01-06 04:47:06 +01:00
parent 084490cd98
commit d97fd2dd35
4 changed files with 296 additions and 8 deletions

View File

@@ -11,9 +11,8 @@ from huggingface_hub import hf_hub_download
from transformers import AutoModel, AutoTokenizer, Gemma3ForConditionalGeneration
from diffusers import AutoencoderKLLTX2Audio, AutoencoderKLLTX2Video, FlowMatchEulerDiscreteScheduler, LTX2Pipeline, LTX2VideoTransformer3DModel
from diffusers.pipelines.ltx2 import LTX2TextConnectors, LTX2Vocoder
from diffusers.pipelines.ltx2 import LTX2LatentUpsamplerModel, LTX2TextConnectors, LTX2Vocoder
from diffusers.utils.import_utils import is_accelerate_available
from diffusers.pipelines.ltx.modeling_latent_upsampler import LTXLatentUpsamplerModel
CTX = init_empty_weights if is_accelerate_available() else nullcontext
@@ -580,11 +579,12 @@ def get_ltx2_spatial_latent_upsampler_config(version: str):
if version == "2.0":
config = {
"in_channels": 128,
"mid_channels": 512,
"mid_channels": 1024,
"num_blocks_per_stage": 4,
"dims": 3,
"spatial_upsample": True,
"temporal_upsample": False,
"rational_spatial_scale": 2.0,
}
else:
raise ValueError(f"Unsupported version: {version}")
@@ -595,7 +595,7 @@ def convert_ltx2_spatial_latent_upsampler(
original_state_dict: Dict[str, Any], config: Dict[str, Any], dtype: torch.dtype
):
with init_empty_weights():
latent_upsampler = LTXLatentUpsamplerModel(**config)
latent_upsampler = LTX2LatentUpsamplerModel(**config)
latent_upsampler.load_state_dict(original_state_dict, strict=True, assign=True)
latent_upsampler.to(dtype)
@@ -708,7 +708,10 @@ def get_args():
help="HF Hub id for the LTX 2.0 text tokenizer",
)
parser.add_argument(
"--latent_upsampler_filename", default=None, type=str, help="Latent upsampler filename"
"--latent_upsampler_filename",
default="rc1/ltx-2-spatial-upscaler-x2-1.0-rc1.safetensors",
type=str,
help="Latent upsampler filename",
)
parser.add_argument("--vae", action="store_true", help="Whether to convert the video VAE model")
@@ -818,7 +821,9 @@ def main(args):
tokenizer.save_pretrained(os.path.join(args.output_path, "tokenizer"))
if args.latent_upsampler:
original_latent_upsampler_ckpt = load_hub_or_local_checkpoint(filename=args.latent_upsampler_filename)
original_latent_upsampler_ckpt = load_hub_or_local_checkpoint(
repo_id=args.original_state_dict_repo_id, filename=args.latent_upsampler_filename
)
latent_upsampler_config = get_ltx2_spatial_latent_upsampler_config(args.version)
latent_upsampler = convert_ltx2_spatial_latent_upsampler(
original_latent_upsampler_ckpt, latent_upsampler_config, dtype=vae_dtype,