You've already forked ComfyUI-WanVideoWrapper
mirror of
https://github.com/kijai/ComfyUI-WanVideoWrapper.git
synced 2026-01-26 23:41:35 +03:00
Handle torchscript issue better
Some other custom nodes globally set torch._C._jit_set_profiling_executor(False) which breaks the NLF model
This commit is contained in:
31
MTV/nodes.py
31
MTV/nodes.py
@@ -7,7 +7,7 @@ import comfy.model_management as mm
|
||||
from comfy.utils import load_torch_file
|
||||
import folder_paths
|
||||
|
||||
script_directory = os.path.dirname(os.path.abspath(__file__))
|
||||
script_directory = os.path.dirname(os.path.abspath(__file__))
|
||||
device = mm.get_torch_device()
|
||||
offload_device = mm.unet_offload_device()
|
||||
|
||||
@@ -16,9 +16,7 @@ local_model_path = os.path.join(folder_paths.models_dir, "nlf", "nlf_l_multi_0.3
|
||||
from .motion4d import SMPL_VQVAE, VectorQuantizer, Encoder, Decoder
|
||||
|
||||
def check_jit_script_function():
|
||||
if torch.jit.script.__name__ == "patched_jit_script":
|
||||
raise RuntimeError("The NLF model needs torch.jit.script. Currently ComfyUI-RMBG disables this which causes the NLF model to not work. Please disable ComfyUI-RMBG nodes to use the NLF model.")
|
||||
elif torch.jit.script.__name__ != "script":
|
||||
if torch.jit.script.__name__ != "script":
|
||||
# Get more details about what modified it
|
||||
module = torch.jit.script.__module__
|
||||
qualname = getattr(torch.jit.script, '__qualname__', 'unknown')
|
||||
@@ -79,8 +77,13 @@ class DownloadAndLoadNLFModel:
|
||||
if warmup:
|
||||
log.info("Warming up NLF model...")
|
||||
dummy_input = torch.zeros(1, 3, 256, 256, device=device)
|
||||
for _ in range(2): # Run warmup 2 times
|
||||
_ = model.detect_smpl_batched(dummy_input)
|
||||
jit_profiling_prev_state = torch._C._jit_set_profiling_executor(True)
|
||||
try:
|
||||
for _ in range(2):
|
||||
_ = model.detect_smpl_batched(dummy_input)
|
||||
finally:
|
||||
torch._C._jit_set_profiling_executor(jit_profiling_prev_state)
|
||||
|
||||
log.info("NLF model warmed up")
|
||||
|
||||
model = model.to(offload_device)
|
||||
@@ -111,8 +114,12 @@ class LoadNLFModel:
|
||||
if warmup:
|
||||
log.info("Warming up NLF model...")
|
||||
dummy_input = torch.zeros(1, 3, 256, 256, device=device)
|
||||
for _ in range(2): # Run warmup 2 times
|
||||
_ = model.detect_smpl_batched(dummy_input)
|
||||
jit_profiling_prev_state = torch._C._jit_set_profiling_executor(True)
|
||||
try:
|
||||
for _ in range(2):
|
||||
_ = model.detect_smpl_batched(dummy_input)
|
||||
finally:
|
||||
torch._C._jit_set_profiling_executor(jit_profiling_prev_state)
|
||||
log.info("NLF model warmed up")
|
||||
|
||||
model = model.to(offload_device)
|
||||
@@ -221,7 +228,13 @@ class NLFPredict:
|
||||
|
||||
check_jit_script_function()
|
||||
model = model.to(device)
|
||||
pred = model.detect_smpl_batched(images.permute(0, 3, 1, 2).to(device))
|
||||
|
||||
jit_profiling_prev_state = torch._C._jit_set_profiling_executor(True)
|
||||
try:
|
||||
pred = model.detect_smpl_batched(images.permute(0, 3, 1, 2).to(device))
|
||||
finally:
|
||||
torch._C._jit_set_profiling_executor(jit_profiling_prev_state)
|
||||
|
||||
model = model.to(offload_device)
|
||||
|
||||
pred = dict_to_device(pred, offload_device)
|
||||
|
||||
Reference in New Issue
Block a user