1
0
mirror of https://github.com/kijai/ComfyUI-WanVideoWrapper.git synced 2026-01-26 23:41:35 +03:00

Expose force_parameter_static_shapes in torch.compile options

This commit is contained in:
kijai
2025-10-19 18:01:03 +03:00
parent a721fe3d0f
commit 99c0efdb45
3 changed files with 16 additions and 11 deletions

View File

@@ -326,6 +326,7 @@ class WanVideoTorchCompileSettings:
},
"optional": {
"dynamo_recompile_limit": ("INT", {"default": 128, "min": 0, "max": 1024, "step": 1, "tooltip": "torch._dynamo.config.recompile_limit"}),
"force_parameter_static_shapes": ("BOOLEAN", {"default": True, "tooltip": "torch._dynamo.config.force_parameter_static_shapes"}),
},
}
RETURN_TYPES = ("WANCOMPILEARGS",)
@@ -334,7 +335,7 @@ class WanVideoTorchCompileSettings:
CATEGORY = "WanVideoWrapper"
DESCRIPTION = "torch.compile settings, when connected to the model loader, torch.compile of the selected layers is attempted. Requires Triton and torch > 2.7.0 is recommended"
def set_args(self, backend, fullgraph, mode, dynamic, dynamo_cache_size_limit, compile_transformer_blocks_only, dynamo_recompile_limit=128):
def set_args(self, backend, fullgraph, mode, dynamic, dynamo_cache_size_limit, compile_transformer_blocks_only, dynamo_recompile_limit=128, force_parameter_static_shapes=True):
compile_args = {
"backend": backend,
@@ -344,6 +345,7 @@ class WanVideoTorchCompileSettings:
"dynamo_cache_size_limit": dynamo_cache_size_limit,
"dynamo_recompile_limit": dynamo_recompile_limit,
"compile_transformer_blocks_only": compile_transformer_blocks_only,
"force_parameter_static_shapes": force_parameter_static_shapes,
}
return (compile_args, )

View File

@@ -112,12 +112,13 @@ class WanVideoUni3C_ControlnetLoader:
del sd
if compile_args is not None:
torch._dynamo.config.cache_size_limit = compile_args["dynamo_cache_size_limit"]
try:
if hasattr(torch, '_dynamo') and hasattr(torch._dynamo, 'config'):
if hasattr(torch, '_dynamo') and hasattr(torch._dynamo, 'config'):
torch._dynamo.config.cache_size_limit = compile_args["dynamo_cache_size_limit"]
torch._dynamo.config.force_parameter_static_shapes = compile_args["force_parameter_static_shapes"]
try:
torch._dynamo.config.recompile_limit = compile_args["dynamo_recompile_limit"]
except Exception as e:
log.warning(f"Could not set recompile_limit: {e}")
except Exception as e:
log.warning(f"Could not set recompile_limit: {e}")
if compile_args["compile_transformer_blocks_only"]:
for i, block in enumerate(controlnet.controlnet_blocks):
controlnet.controlnet_blocks[i] = torch.compile(block, fullgraph=compile_args["fullgraph"], dynamic=compile_args["dynamic"], backend=compile_args["backend"], mode=compile_args["mode"])

View File

@@ -502,12 +502,14 @@ def dict_to_device(tensor_dict, device, dtype=None):
def compile_model(transformer, compile_args=None):
if compile_args is None:
return transformer
torch._dynamo.config.cache_size_limit = compile_args["dynamo_cache_size_limit"]
try:
if hasattr(torch, '_dynamo') and hasattr(torch._dynamo, 'config'):
if hasattr(torch, '_dynamo') and hasattr(torch._dynamo, 'config'):
torch._dynamo.config.cache_size_limit = compile_args["dynamo_cache_size_limit"]
torch._dynamo.config.force_parameter_static_shapes = compile_args["force_parameter_static_shapes"]
try:
torch._dynamo.config.recompile_limit = compile_args["dynamo_recompile_limit"]
except Exception as e:
log.warning(f"Could not set recompile_limit: {e}")
except Exception as e:
log.warning(f"Could not set recompile_limit: {e}")
if compile_args["compile_transformer_blocks_only"]:
for i, block in enumerate(transformer.blocks):
if hasattr(block, "_orig_mod"):