diff --git a/scripts/ltx2_test_latent_upsampler.py b/scripts/ltx2_test_latent_upsampler.py index 5194e32d5c..745f2c8d1f 100644 --- a/scripts/ltx2_test_latent_upsampler.py +++ b/scripts/ltx2_test_latent_upsampler.py @@ -52,6 +52,7 @@ def parse_args(): parser.add_argument("--device", type=str, default="cuda:0") parser.add_argument("--dtype", type=str, default="bf16") parser.add_argument("--cpu_offload", action="store_true") + parser.add_argument("--vae_tiling", action="store_true") parser.add_argument( "--output_dir", @@ -120,6 +121,8 @@ def main(args): ) upsample_pipeline = LTX2LatentUpsamplePipeline(vae=vae, latent_upsampler=latent_upsampler) upsample_pipeline.to(device=args.device) + if args.vae_tiling: + upsample_pipeline.enable_vae_tiling() video = upsample_pipeline( video=video,