1
0
mirror of https://github.com/kijai/ComfyUI-WanVideoWrapper.git synced 2026-01-28 12:20:55 +03:00

fp8 matmul for scaled models

Fp8 matmul (fp8_fast) doesn't seem feasible with unmerged LoRAs as you'd need to first upcast, then apply LoRA, then downcast back to fp8 and that is too slow. Direct adding in fp8 is also not possible since that's just not something fp8 dtypes support.
This commit is contained in:
kijai
2025-08-09 10:17:11 +03:00
parent 1757847e5f
commit 48fa904ad8
4 changed files with 65 additions and 55 deletions

View File

@@ -1,31 +1,31 @@
#based on ComfyUI's and MinusZoneAI's fp8_linear optimization
import torch
import torch.nn as nn
from .utils import log
def fp8_linear_forward(cls, original_dtype, input):
#based on ComfyUI's and MinusZoneAI's fp8_linear optimization
def fp8_linear_forward(cls, base_dtype, input):
weight_dtype = cls.weight.dtype
if weight_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
if len(input.shape) == 3:
#target_dtype = torch.float8_e5m2 if weight_dtype == torch.float8_e4m3fn else torch.float8_e4m3fn
inn = input.reshape(-1, input.shape[2]).to(weight_dtype)
w = cls.weight.t()
scale = torch.ones((1), device=input.device, dtype=torch.float32)
bias = cls.bias.to(original_dtype) if cls.bias is not None else None
if bias is not None:
o = torch._scaled_mm(inn, w, out_dtype=original_dtype, bias=bias, scale_a=scale, scale_b=scale)
input_shape = input.shape
scale_weight = getattr(cls, 'scale_weight', None)
if scale_weight is None:
scale_weight = torch.ones((), device=input.device, dtype=torch.float32)
else:
o = torch._scaled_mm(inn, w, out_dtype=original_dtype, scale_a=scale, scale_b=scale)
scale_weight = scale_weight.to(input.device)
scale_input = torch.ones((), device=input.device, dtype=torch.float32)
input = torch.clamp(input, min=-448, max=448, out=input)
inn = input.reshape(-1, input_shape[2]).to(torch.float8_e4m3fn).contiguous() #always e4m3fn because e5m2 * e5m2 is not supported
if isinstance(o, tuple):
o = o[0]
bias = cls.bias.to(base_dtype) if cls.bias is not None else None
return o.reshape((-1, input.shape[1], cls.weight.shape[0]))
o = torch._scaled_mm(inn, cls.weight.t(), out_dtype=base_dtype, bias=bias, scale_a=scale_input, scale_b=scale_weight)
return o.reshape((-1, input_shape[1], cls.weight.shape[0]))
else:
return cls.original_forward(input.to(original_dtype))
return cls.original_forward(input.to(base_dtype))
else:
return cls.original_forward(input)
@@ -67,20 +67,23 @@ def linear_with_lora_and_scale_forward(cls, input):
weight = apply_lora(weight, lora, cls.step).to(input.dtype)
return torch.nn.functional.linear(input, weight, bias)
def convert_fp8_linear(module, original_dtype, params_to_keep={}):
setattr(module, "fp8_matmul_enabled", True)
def convert_fp8_linear(module, base_dtype, params_to_keep={}, scale_weight_keys=None):
log.info("FP8 matmul enabled")
for name, submodule in module.named_modules():
if not any(keyword in name for keyword in params_to_keep):
if isinstance(submodule, nn.Linear):
if scale_weight_keys is not None:
scale_key = f"{name}.scale_weight"
if scale_key in scale_weight_keys:
print("Setting scale_weight for", name)
setattr(submodule, "scale_weight", scale_weight_keys[scale_key])
original_forward = submodule.forward
setattr(submodule, "original_forward", original_forward)
setattr(submodule, "forward", lambda input, m=submodule: fp8_linear_forward(m, original_dtype, input))
setattr(submodule, "forward", lambda input, m=submodule: fp8_linear_forward(m, base_dtype, input))
def convert_linear_with_lora_and_scale(module, scale_weight_keys=None, patches=None, params_to_keep={}):
log.info("Patching Linear layers...")
for name, submodule in module.named_modules():
if not any(keyword in name for keyword in params_to_keep):
# Set scale_weight if present
@@ -90,6 +93,9 @@ def convert_linear_with_lora_and_scale(module, scale_weight_keys=None, patches=N
setattr(submodule, "scale_weight", scale_weight_keys[scale_key])
# Set LoRA if present
if hasattr(submodule, "lora"):
print(f"removing old LoRA in {name}" )
delattr(submodule, "lora")
if patches is not None:
patch_key = f"diffusion_model.{name}.weight"
patch = patches.get(patch_key, [])
@@ -108,17 +114,17 @@ def convert_linear_with_lora_and_scale(module, scale_weight_keys=None, patches=N
lora_strengths = [p[0] for p in patch]
lora = (lora_diffs, lora_strengths)
setattr(submodule, "lora", lora)
print(f"Added LoRA to {name} with {len(lora_diffs)} diffs and strengths {lora_strengths}")
# Set forward if Linear and has either scale or lora
if isinstance(submodule, nn.Linear):
has_scale = hasattr(submodule, "scale_weight")
has_lora = hasattr(submodule, "lora")
if not hasattr(submodule, "original_forward"):
setattr(submodule, "original_forward", submodule.forward)
if has_scale or has_lora:
original_forward_ = submodule.forward
setattr(submodule, "original_forward_", original_forward_)
setattr(submodule, "forward", lambda input, m=submodule: linear_with_lora_and_scale_forward(m, input))
setattr(submodule, "step", 0) # Initialize step for LoRA if needed
setattr(submodule, "step", 0) # Initialize step for LoRA scheduling
def remove_lora_from_module(module):
unloaded = False

View File

@@ -1494,19 +1494,28 @@ class WanVideoSampler:
transformer = model.diffusion_model
dtype = model["dtype"]
fp8_matmul = model["fp8_matmul"]
gguf = model["gguf"]
scale_weights = model["scale_weights"]
control_lora = model["control_lora"]
transformer_options = patcher.model_options.get("transformer_options", None)
merge_loras = transformer_options["merge_loras"]
is_5b = transformer.out_dim == 48
vae_upscale_factor = 16 if is_5b else 8
if len(patcher.patches) != 0 and transformer_options.get("linear_with_lora", False) is True:
patch_linear = transformer_options.get("patch_linear", False)
if gguf:
set_lora_params(transformer, patcher.patches)
elif len(patcher.patches) != 0 and patch_linear:
log.info(f"Using {len(patcher.patches)} LoRA weight patches for WanVideo model")
if not gguf:
convert_linear_with_lora_and_scale(transformer, patches=patcher.patches)
else:
set_lora_params(transformer, patcher.patches)
if not merge_loras and fp8_matmul:
raise NotImplementedError("FP8 matmul with unmerged LoRAs is not supported")
convert_linear_with_lora_and_scale(transformer, patches=patcher.patches, scale_weight_keys=scale_weights)
elif patch_linear:
convert_linear_with_lora_and_scale(transformer, scale_weight_keys=scale_weights)
else:
remove_lora_from_module(transformer)

View File

@@ -698,7 +698,7 @@ class WanVideoSetLoRAs:
if 'transformer_options' not in patcher.model_options:
patcher.model_options['transformer_options'] = {}
patcher.model_options['transformer_options']["linear_with_lora"] = True
patcher.model_options['transformer_options']["patch_linear"] = True
return (patcher,)
@@ -711,7 +711,7 @@ class WanVideoModelLoader:
"model": (folder_paths.get_filename_list("unet_gguf") + folder_paths.get_filename_list("diffusion_models"), {"tooltip": "These models are loaded from the 'ComfyUI/models/diffusion_models' -folder",}),
"base_precision": (["fp32", "bf16", "fp16", "fp16_fast"], {"default": "bf16"}),
"quantization": (["disabled", "fp8_e4m3fn", "fp8_e4m3fn_fast", "fp8_e5m2", "fp8_e4m3fn_fast_no_ffn", "fp8_e4m3fn_scaled", "fp8_e5m2_scaled"], {"default": "disabled", "tooltip": "optional quantization method"}),
"quantization": (["disabled", "fp8_e4m3fn", "fp8_e4m3fn_fast", "fp8_e4m3fn_scaled", "fp8_e4m3fn_scaled_fast", "fp8_e5m2", "fp8_e5m2_fast", "fp8_e5m2_scaled", "fp8_e5m2_scaled_fast"], {"default": "disabled", "tooltip": "optional quantization method"}),
"load_device": (["main_device", "offload_device"], {"default": "main_device", "tooltip": "Initial device to load the model to, NOT recommended with the larger models unless you have 48GB+ VRAM"}),
},
"optional": {
@@ -1029,10 +1029,8 @@ class WanVideoModelLoader:
)
if not gguf:
scale_weights = {}
if "scaled" in quantization:
scale_weights = {}
if "fp8" in quantization:
for k, v in sd.items():
if k.endswith(".scale_weight"):
scale_weights[k] = v
@@ -1140,6 +1138,8 @@ class WanVideoModelLoader:
patcher, device, transformer_load_device,
params_to_keep=params_to_keep, dtype=dtype, base_dtype=base_dtype, state_dict=sd,
low_mem_load=lora_low_mem_load, control_lora=control_lora, scale_weights=scale_weights)
scale_weights.clear()
patcher.patches.clear()
if gguf:
#from diffusers.quantizers.gguf.utils import _replace_with_gguf_linear, GGUFParameter
@@ -1175,22 +1175,14 @@ class WanVideoModelLoader:
patcher.model.is_patched = True
if "fast" in quantization:
if not merge_loras:
raise ValueError("FP8 fast quantization requires LoRAs to be merged into the model, please set merge_loras=True in the LoRA input")
from .fp8_optimization import convert_fp8_linear
if quantization == "fp8_e4m3fn_fast_no_ffn":
params_to_keep.update({"ffn"})
print(params_to_keep)
convert_fp8_linear(patcher.model.diffusion_model, base_dtype, params_to_keep=params_to_keep)
patch_linear = (True if "scaled" in quantization or not merge_loras else False)
if "scaled" in quantization and not merge_loras:
log.info("Using FP8 scaled linear quantization")
convert_linear_with_lora_and_scale(patcher.model.diffusion_model, scale_weights, patches=patcher.patches)
elif lora is not None and not merge_loras and not gguf:
log.info("LoRAs will be applied at runtime")
convert_linear_with_lora_and_scale(patcher.model.diffusion_model, patches=patcher.patches)
if "fast" in quantization:
if lora is not None and not merge_loras:
raise NotImplementedError("fp8_fast is not supported with unmerged LoRAs")
from .fp8_optimization import convert_fp8_linear
convert_fp8_linear(transformer, base_dtype, params_to_keep, scale_weight_keys=scale_weights)
patch_linear = False
del sd
@@ -1255,11 +1247,14 @@ class WanVideoModelLoader:
patcher.model["control_lora"] = control_lora
patcher.model["compile_args"] = compile_args
patcher.model["gguf"] = gguf
patcher.model["fp8_matmul"] = "fast" in quantization
patcher.model["scale_weights"] = scale_weights
if 'transformer_options' not in patcher.model_options:
patcher.model_options['transformer_options'] = {}
patcher.model_options["transformer_options"]["block_swap_args"] = block_swap_args
patcher.model_options["transformer_options"]["linear_with_lora"] = True if not merge_loras else False
patcher.model_options["transformer_options"]["patch_linear"] = patch_linear
patcher.model_options["transformer_options"]["merge_loras"] = merge_loras
for model in mm.current_loaded_models:
if model._model() == patcher:

View File

@@ -149,7 +149,7 @@ class WanVideoDiffusionForcingSampler:
gguf = model["gguf"]
transformer_options = patcher.model_options.get("transformer_options", None)
if len(patcher.patches) != 0 and transformer_options.get("linear_with_lora", False) is True:
if len(patcher.patches) != 0 and transformer_options.get("linear_patched", False) is True:
log.info(f"Using {len(patcher.patches)} LoRA weight patches for WanVideo model")
if not gguf:
convert_linear_with_lora_and_scale(transformer, patches=patcher.patches)