You've already forked ComfyUI-WanVideoWrapper
mirror of
https://github.com/kijai/ComfyUI-WanVideoWrapper.git
synced 2026-01-26 23:41:35 +03:00
51 lines
2.1 KiB
Python
51 lines
2.1 KiB
Python
import torch
|
|
from contextlib import contextmanager
|
|
|
|
@contextmanager
|
|
def init_weights_on_device(device = torch.device("meta"), include_buffers :bool = False):
|
|
|
|
old_register_parameter = torch.nn.Module.register_parameter
|
|
if include_buffers:
|
|
old_register_buffer = torch.nn.Module.register_buffer
|
|
|
|
def register_empty_parameter(module, name, param):
|
|
old_register_parameter(module, name, param)
|
|
if param is not None:
|
|
param_cls = type(module._parameters[name])
|
|
kwargs = module._parameters[name].__dict__
|
|
kwargs["requires_grad"] = param.requires_grad
|
|
module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs)
|
|
|
|
def register_empty_buffer(module, name, buffer, persistent=True):
|
|
old_register_buffer(module, name, buffer, persistent=persistent)
|
|
if buffer is not None:
|
|
module._buffers[name] = module._buffers[name].to(device)
|
|
|
|
def patch_tensor_constructor(fn):
|
|
def wrapper(*args, **kwargs):
|
|
kwargs["device"] = device
|
|
return fn(*args, **kwargs)
|
|
|
|
return wrapper
|
|
|
|
if include_buffers:
|
|
tensor_constructors_to_patch = {
|
|
torch_function_name: getattr(torch, torch_function_name)
|
|
for torch_function_name in ["empty", "zeros", "ones", "full"]
|
|
}
|
|
else:
|
|
tensor_constructors_to_patch = {}
|
|
|
|
try:
|
|
torch.nn.Module.register_parameter = register_empty_parameter
|
|
if include_buffers:
|
|
torch.nn.Module.register_buffer = register_empty_buffer
|
|
for torch_function_name in tensor_constructors_to_patch.keys():
|
|
setattr(torch, torch_function_name, patch_tensor_constructor(getattr(torch, torch_function_name)))
|
|
yield
|
|
finally:
|
|
torch.nn.Module.register_parameter = old_register_parameter
|
|
if include_buffers:
|
|
torch.nn.Module.register_buffer = old_register_buffer
|
|
for torch_function_name, old_torch_function in tensor_constructors_to_patch.items():
|
|
setattr(torch, torch_function_name, old_torch_function) |