1
0
mirror of https://github.com/kijai/ComfyUI-WanVideoWrapper.git synced 2026-01-26 23:41:35 +03:00

Handle torchscript issue better

Some other custom nodes globally set torch._C._jit_set_profiling_executor(False) which breaks the NLF model
This commit is contained in:
kijai
2025-12-14 19:41:43 +02:00
parent e3cfa64bd3
commit ea1677bd4a

View File

@@ -7,7 +7,7 @@ import comfy.model_management as mm
from comfy.utils import load_torch_file
import folder_paths
script_directory = os.path.dirname(os.path.abspath(__file__))
script_directory = os.path.dirname(os.path.abspath(__file__))
device = mm.get_torch_device()
offload_device = mm.unet_offload_device()
@@ -16,9 +16,7 @@ local_model_path = os.path.join(folder_paths.models_dir, "nlf", "nlf_l_multi_0.3
from .motion4d import SMPL_VQVAE, VectorQuantizer, Encoder, Decoder
def check_jit_script_function():
if torch.jit.script.__name__ == "patched_jit_script":
raise RuntimeError("The NLF model needs torch.jit.script. Currently ComfyUI-RMBG disables this which causes the NLF model to not work. Please disable ComfyUI-RMBG nodes to use the NLF model.")
elif torch.jit.script.__name__ != "script":
if torch.jit.script.__name__ != "script":
# Get more details about what modified it
module = torch.jit.script.__module__
qualname = getattr(torch.jit.script, '__qualname__', 'unknown')
@@ -79,8 +77,13 @@ class DownloadAndLoadNLFModel:
if warmup:
log.info("Warming up NLF model...")
dummy_input = torch.zeros(1, 3, 256, 256, device=device)
for _ in range(2): # Run warmup 2 times
_ = model.detect_smpl_batched(dummy_input)
jit_profiling_prev_state = torch._C._jit_set_profiling_executor(True)
try:
for _ in range(2):
_ = model.detect_smpl_batched(dummy_input)
finally:
torch._C._jit_set_profiling_executor(jit_profiling_prev_state)
log.info("NLF model warmed up")
model = model.to(offload_device)
@@ -111,8 +114,12 @@ class LoadNLFModel:
if warmup:
log.info("Warming up NLF model...")
dummy_input = torch.zeros(1, 3, 256, 256, device=device)
for _ in range(2): # Run warmup 2 times
_ = model.detect_smpl_batched(dummy_input)
jit_profiling_prev_state = torch._C._jit_set_profiling_executor(True)
try:
for _ in range(2):
_ = model.detect_smpl_batched(dummy_input)
finally:
torch._C._jit_set_profiling_executor(jit_profiling_prev_state)
log.info("NLF model warmed up")
model = model.to(offload_device)
@@ -221,7 +228,13 @@ class NLFPredict:
check_jit_script_function()
model = model.to(device)
pred = model.detect_smpl_batched(images.permute(0, 3, 1, 2).to(device))
jit_profiling_prev_state = torch._C._jit_set_profiling_executor(True)
try:
pred = model.detect_smpl_batched(images.permute(0, 3, 1, 2).to(device))
finally:
torch._C._jit_set_profiling_executor(jit_profiling_prev_state)
model = model.to(offload_device)
pred = dict_to_device(pred, offload_device)