You've already forked ComfyUI-WanVideoWrapper
mirror of
https://github.com/kijai/ComfyUI-WanVideoWrapper.git
synced 2026-01-26 23:41:35 +03:00
46 lines
2.1 KiB
Python
46 lines
2.1 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
from .utils import log
|
|
|
|
#based on ComfyUI's and MinusZoneAI's fp8_linear optimization
|
|
def fp8_linear_forward(cls, base_dtype, input):
|
|
weight_dtype = cls.weight.dtype
|
|
if weight_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
|
|
if len(input.shape) == 3:
|
|
input_shape = input.shape
|
|
|
|
scale_weight = getattr(cls, 'scale_weight', None)
|
|
if scale_weight is None:
|
|
scale_weight = torch.ones((), device=input.device, dtype=torch.float32)
|
|
else:
|
|
scale_weight = scale_weight.to(input.device).squeeze()
|
|
|
|
scale_input = torch.ones((), device=input.device, dtype=torch.float32)
|
|
|
|
input = torch.clamp(input, min=-448, max=448, out=input)
|
|
inn = input.reshape(-1, input_shape[2]).to(torch.float8_e4m3fn).contiguous() #always e4m3fn because e5m2 * e5m2 is not supported
|
|
|
|
bias = cls.bias.to(base_dtype) if cls.bias is not None else None
|
|
|
|
o = torch._scaled_mm(inn, cls.weight.t(), out_dtype=base_dtype, bias=bias, scale_a=scale_input, scale_b=scale_weight)
|
|
|
|
return o.reshape((-1, input_shape[1], cls.weight.shape[0]))
|
|
else:
|
|
return cls.original_forward(input.to(base_dtype))
|
|
else:
|
|
return cls.original_forward(input)
|
|
|
|
|
|
def convert_fp8_linear(module, base_dtype, params_to_keep={}, scale_weight_keys=None):
|
|
log.info("FP8 matmul enabled")
|
|
for name, submodule in module.named_modules():
|
|
if not any(keyword in name for keyword in params_to_keep):
|
|
if isinstance(submodule, nn.Linear):
|
|
if scale_weight_keys is not None:
|
|
scale_key = f"{name}.scale_weight"
|
|
if scale_key in scale_weight_keys:
|
|
setattr(submodule, "scale_weight", scale_weight_keys[scale_key].float())
|
|
original_forward = submodule.forward
|
|
setattr(submodule, "original_forward", original_forward)
|
|
setattr(submodule, "forward", lambda input, m=submodule: fp8_linear_forward(m, base_dtype, input))
|