From 99c0efdb45ffde12ef2b7c215910ef41b802c4e7 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Sun, 19 Oct 2025 18:01:03 +0300 Subject: [PATCH] Expose force_parameter_static_shapes in torch.compile options --- nodes_model_loading.py | 4 +++- uni3c/nodes.py | 11 ++++++----- utils.py | 12 +++++++----- 3 files changed, 16 insertions(+), 11 deletions(-) diff --git a/nodes_model_loading.py b/nodes_model_loading.py index 3d425bc..dfe0947 100644 --- a/nodes_model_loading.py +++ b/nodes_model_loading.py @@ -326,6 +326,7 @@ class WanVideoTorchCompileSettings: }, "optional": { "dynamo_recompile_limit": ("INT", {"default": 128, "min": 0, "max": 1024, "step": 1, "tooltip": "torch._dynamo.config.recompile_limit"}), + "force_parameter_static_shapes": ("BOOLEAN", {"default": True, "tooltip": "torch._dynamo.config.force_parameter_static_shapes"}), }, } RETURN_TYPES = ("WANCOMPILEARGS",) @@ -334,7 +335,7 @@ class WanVideoTorchCompileSettings: CATEGORY = "WanVideoWrapper" DESCRIPTION = "torch.compile settings, when connected to the model loader, torch.compile of the selected layers is attempted. Requires Triton and torch > 2.7.0 is recommended" - def set_args(self, backend, fullgraph, mode, dynamic, dynamo_cache_size_limit, compile_transformer_blocks_only, dynamo_recompile_limit=128): + def set_args(self, backend, fullgraph, mode, dynamic, dynamo_cache_size_limit, compile_transformer_blocks_only, dynamo_recompile_limit=128, force_parameter_static_shapes=True): compile_args = { "backend": backend, @@ -344,6 +345,7 @@ class WanVideoTorchCompileSettings: "dynamo_cache_size_limit": dynamo_cache_size_limit, "dynamo_recompile_limit": dynamo_recompile_limit, "compile_transformer_blocks_only": compile_transformer_blocks_only, + "force_parameter_static_shapes": force_parameter_static_shapes, } return (compile_args, ) diff --git a/uni3c/nodes.py b/uni3c/nodes.py index 064b6dc..2e5d114 100644 --- a/uni3c/nodes.py +++ b/uni3c/nodes.py @@ -112,12 +112,13 @@ class WanVideoUni3C_ControlnetLoader: del sd if compile_args is not None: - torch._dynamo.config.cache_size_limit = compile_args["dynamo_cache_size_limit"] - try: - if hasattr(torch, '_dynamo') and hasattr(torch._dynamo, 'config'): + if hasattr(torch, '_dynamo') and hasattr(torch._dynamo, 'config'): + torch._dynamo.config.cache_size_limit = compile_args["dynamo_cache_size_limit"] + torch._dynamo.config.force_parameter_static_shapes = compile_args["force_parameter_static_shapes"] + try: torch._dynamo.config.recompile_limit = compile_args["dynamo_recompile_limit"] - except Exception as e: - log.warning(f"Could not set recompile_limit: {e}") + except Exception as e: + log.warning(f"Could not set recompile_limit: {e}") if compile_args["compile_transformer_blocks_only"]: for i, block in enumerate(controlnet.controlnet_blocks): controlnet.controlnet_blocks[i] = torch.compile(block, fullgraph=compile_args["fullgraph"], dynamic=compile_args["dynamic"], backend=compile_args["backend"], mode=compile_args["mode"]) diff --git a/utils.py b/utils.py index 99055b1..cedf537 100644 --- a/utils.py +++ b/utils.py @@ -502,12 +502,14 @@ def dict_to_device(tensor_dict, device, dtype=None): def compile_model(transformer, compile_args=None): if compile_args is None: return transformer - torch._dynamo.config.cache_size_limit = compile_args["dynamo_cache_size_limit"] - try: - if hasattr(torch, '_dynamo') and hasattr(torch._dynamo, 'config'): + if hasattr(torch, '_dynamo') and hasattr(torch._dynamo, 'config'): + torch._dynamo.config.cache_size_limit = compile_args["dynamo_cache_size_limit"] + torch._dynamo.config.force_parameter_static_shapes = compile_args["force_parameter_static_shapes"] + try: torch._dynamo.config.recompile_limit = compile_args["dynamo_recompile_limit"] - except Exception as e: - log.warning(f"Could not set recompile_limit: {e}") + except Exception as e: + log.warning(f"Could not set recompile_limit: {e}") + if compile_args["compile_transformer_blocks_only"]: for i, block in enumerate(transformer.blocks): if hasattr(block, "_orig_mod"):