You've already forked ComfyUI-WanVideoWrapper
mirror of
https://github.com/kijai/ComfyUI-WanVideoWrapper.git
synced 2026-01-26 23:41:35 +03:00
Expose force_parameter_static_shapes in torch.compile options
This commit is contained in:
@@ -326,6 +326,7 @@ class WanVideoTorchCompileSettings:
|
||||
},
|
||||
"optional": {
|
||||
"dynamo_recompile_limit": ("INT", {"default": 128, "min": 0, "max": 1024, "step": 1, "tooltip": "torch._dynamo.config.recompile_limit"}),
|
||||
"force_parameter_static_shapes": ("BOOLEAN", {"default": True, "tooltip": "torch._dynamo.config.force_parameter_static_shapes"}),
|
||||
},
|
||||
}
|
||||
RETURN_TYPES = ("WANCOMPILEARGS",)
|
||||
@@ -334,7 +335,7 @@ class WanVideoTorchCompileSettings:
|
||||
CATEGORY = "WanVideoWrapper"
|
||||
DESCRIPTION = "torch.compile settings, when connected to the model loader, torch.compile of the selected layers is attempted. Requires Triton and torch > 2.7.0 is recommended"
|
||||
|
||||
def set_args(self, backend, fullgraph, mode, dynamic, dynamo_cache_size_limit, compile_transformer_blocks_only, dynamo_recompile_limit=128):
|
||||
def set_args(self, backend, fullgraph, mode, dynamic, dynamo_cache_size_limit, compile_transformer_blocks_only, dynamo_recompile_limit=128, force_parameter_static_shapes=True):
|
||||
|
||||
compile_args = {
|
||||
"backend": backend,
|
||||
@@ -344,6 +345,7 @@ class WanVideoTorchCompileSettings:
|
||||
"dynamo_cache_size_limit": dynamo_cache_size_limit,
|
||||
"dynamo_recompile_limit": dynamo_recompile_limit,
|
||||
"compile_transformer_blocks_only": compile_transformer_blocks_only,
|
||||
"force_parameter_static_shapes": force_parameter_static_shapes,
|
||||
}
|
||||
|
||||
return (compile_args, )
|
||||
|
||||
@@ -112,12 +112,13 @@ class WanVideoUni3C_ControlnetLoader:
|
||||
del sd
|
||||
|
||||
if compile_args is not None:
|
||||
torch._dynamo.config.cache_size_limit = compile_args["dynamo_cache_size_limit"]
|
||||
try:
|
||||
if hasattr(torch, '_dynamo') and hasattr(torch._dynamo, 'config'):
|
||||
if hasattr(torch, '_dynamo') and hasattr(torch._dynamo, 'config'):
|
||||
torch._dynamo.config.cache_size_limit = compile_args["dynamo_cache_size_limit"]
|
||||
torch._dynamo.config.force_parameter_static_shapes = compile_args["force_parameter_static_shapes"]
|
||||
try:
|
||||
torch._dynamo.config.recompile_limit = compile_args["dynamo_recompile_limit"]
|
||||
except Exception as e:
|
||||
log.warning(f"Could not set recompile_limit: {e}")
|
||||
except Exception as e:
|
||||
log.warning(f"Could not set recompile_limit: {e}")
|
||||
if compile_args["compile_transformer_blocks_only"]:
|
||||
for i, block in enumerate(controlnet.controlnet_blocks):
|
||||
controlnet.controlnet_blocks[i] = torch.compile(block, fullgraph=compile_args["fullgraph"], dynamic=compile_args["dynamic"], backend=compile_args["backend"], mode=compile_args["mode"])
|
||||
|
||||
12
utils.py
12
utils.py
@@ -502,12 +502,14 @@ def dict_to_device(tensor_dict, device, dtype=None):
|
||||
def compile_model(transformer, compile_args=None):
|
||||
if compile_args is None:
|
||||
return transformer
|
||||
torch._dynamo.config.cache_size_limit = compile_args["dynamo_cache_size_limit"]
|
||||
try:
|
||||
if hasattr(torch, '_dynamo') and hasattr(torch._dynamo, 'config'):
|
||||
if hasattr(torch, '_dynamo') and hasattr(torch._dynamo, 'config'):
|
||||
torch._dynamo.config.cache_size_limit = compile_args["dynamo_cache_size_limit"]
|
||||
torch._dynamo.config.force_parameter_static_shapes = compile_args["force_parameter_static_shapes"]
|
||||
try:
|
||||
torch._dynamo.config.recompile_limit = compile_args["dynamo_recompile_limit"]
|
||||
except Exception as e:
|
||||
log.warning(f"Could not set recompile_limit: {e}")
|
||||
except Exception as e:
|
||||
log.warning(f"Could not set recompile_limit: {e}")
|
||||
|
||||
if compile_args["compile_transformer_blocks_only"]:
|
||||
for i, block in enumerate(transformer.blocks):
|
||||
if hasattr(block, "_orig_mod"):
|
||||
|
||||
Reference in New Issue
Block a user