You've already forked ComfyUI-WanVideoWrapper
mirror of
https://github.com/kijai/ComfyUI-WanVideoWrapper.git
synced 2026-01-28 12:20:55 +03:00
fp8 matmul for scaled models
Fp8 matmul (fp8_fast) doesn't seem feasible with unmerged LoRAs as you'd need to first upcast, then apply LoRA, then downcast back to fp8 and that is too slow. Direct adding in fp8 is also not possible since that's just not something fp8 dtypes support.
This commit is contained in:
@@ -1,31 +1,31 @@
|
||||
#based on ComfyUI's and MinusZoneAI's fp8_linear optimization
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from .utils import log
|
||||
|
||||
def fp8_linear_forward(cls, original_dtype, input):
|
||||
#based on ComfyUI's and MinusZoneAI's fp8_linear optimization
|
||||
def fp8_linear_forward(cls, base_dtype, input):
|
||||
weight_dtype = cls.weight.dtype
|
||||
if weight_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
|
||||
if len(input.shape) == 3:
|
||||
#target_dtype = torch.float8_e5m2 if weight_dtype == torch.float8_e4m3fn else torch.float8_e4m3fn
|
||||
inn = input.reshape(-1, input.shape[2]).to(weight_dtype)
|
||||
w = cls.weight.t()
|
||||
|
||||
scale = torch.ones((1), device=input.device, dtype=torch.float32)
|
||||
bias = cls.bias.to(original_dtype) if cls.bias is not None else None
|
||||
|
||||
if bias is not None:
|
||||
o = torch._scaled_mm(inn, w, out_dtype=original_dtype, bias=bias, scale_a=scale, scale_b=scale)
|
||||
input_shape = input.shape
|
||||
|
||||
scale_weight = getattr(cls, 'scale_weight', None)
|
||||
if scale_weight is None:
|
||||
scale_weight = torch.ones((), device=input.device, dtype=torch.float32)
|
||||
else:
|
||||
o = torch._scaled_mm(inn, w, out_dtype=original_dtype, scale_a=scale, scale_b=scale)
|
||||
scale_weight = scale_weight.to(input.device)
|
||||
|
||||
scale_input = torch.ones((), device=input.device, dtype=torch.float32)
|
||||
input = torch.clamp(input, min=-448, max=448, out=input)
|
||||
inn = input.reshape(-1, input_shape[2]).to(torch.float8_e4m3fn).contiguous() #always e4m3fn because e5m2 * e5m2 is not supported
|
||||
|
||||
if isinstance(o, tuple):
|
||||
o = o[0]
|
||||
bias = cls.bias.to(base_dtype) if cls.bias is not None else None
|
||||
|
||||
return o.reshape((-1, input.shape[1], cls.weight.shape[0]))
|
||||
o = torch._scaled_mm(inn, cls.weight.t(), out_dtype=base_dtype, bias=bias, scale_a=scale_input, scale_b=scale_weight)
|
||||
|
||||
return o.reshape((-1, input_shape[1], cls.weight.shape[0]))
|
||||
else:
|
||||
return cls.original_forward(input.to(original_dtype))
|
||||
return cls.original_forward(input.to(base_dtype))
|
||||
else:
|
||||
return cls.original_forward(input)
|
||||
|
||||
@@ -67,20 +67,23 @@ def linear_with_lora_and_scale_forward(cls, input):
|
||||
weight = apply_lora(weight, lora, cls.step).to(input.dtype)
|
||||
|
||||
return torch.nn.functional.linear(input, weight, bias)
|
||||
|
||||
|
||||
def convert_fp8_linear(module, original_dtype, params_to_keep={}):
|
||||
setattr(module, "fp8_matmul_enabled", True)
|
||||
|
||||
def convert_fp8_linear(module, base_dtype, params_to_keep={}, scale_weight_keys=None):
|
||||
log.info("FP8 matmul enabled")
|
||||
for name, submodule in module.named_modules():
|
||||
if not any(keyword in name for keyword in params_to_keep):
|
||||
if isinstance(submodule, nn.Linear):
|
||||
if scale_weight_keys is not None:
|
||||
scale_key = f"{name}.scale_weight"
|
||||
if scale_key in scale_weight_keys:
|
||||
print("Setting scale_weight for", name)
|
||||
setattr(submodule, "scale_weight", scale_weight_keys[scale_key])
|
||||
original_forward = submodule.forward
|
||||
setattr(submodule, "original_forward", original_forward)
|
||||
setattr(submodule, "forward", lambda input, m=submodule: fp8_linear_forward(m, original_dtype, input))
|
||||
|
||||
|
||||
setattr(submodule, "forward", lambda input, m=submodule: fp8_linear_forward(m, base_dtype, input))
|
||||
|
||||
def convert_linear_with_lora_and_scale(module, scale_weight_keys=None, patches=None, params_to_keep={}):
|
||||
log.info("Patching Linear layers...")
|
||||
for name, submodule in module.named_modules():
|
||||
if not any(keyword in name for keyword in params_to_keep):
|
||||
# Set scale_weight if present
|
||||
@@ -90,6 +93,9 @@ def convert_linear_with_lora_and_scale(module, scale_weight_keys=None, patches=N
|
||||
setattr(submodule, "scale_weight", scale_weight_keys[scale_key])
|
||||
|
||||
# Set LoRA if present
|
||||
if hasattr(submodule, "lora"):
|
||||
print(f"removing old LoRA in {name}" )
|
||||
delattr(submodule, "lora")
|
||||
if patches is not None:
|
||||
patch_key = f"diffusion_model.{name}.weight"
|
||||
patch = patches.get(patch_key, [])
|
||||
@@ -108,17 +114,17 @@ def convert_linear_with_lora_and_scale(module, scale_weight_keys=None, patches=N
|
||||
lora_strengths = [p[0] for p in patch]
|
||||
lora = (lora_diffs, lora_strengths)
|
||||
setattr(submodule, "lora", lora)
|
||||
print(f"Added LoRA to {name} with {len(lora_diffs)} diffs and strengths {lora_strengths}")
|
||||
|
||||
# Set forward if Linear and has either scale or lora
|
||||
if isinstance(submodule, nn.Linear):
|
||||
has_scale = hasattr(submodule, "scale_weight")
|
||||
has_lora = hasattr(submodule, "lora")
|
||||
if not hasattr(submodule, "original_forward"):
|
||||
setattr(submodule, "original_forward", submodule.forward)
|
||||
if has_scale or has_lora:
|
||||
original_forward_ = submodule.forward
|
||||
setattr(submodule, "original_forward_", original_forward_)
|
||||
setattr(submodule, "forward", lambda input, m=submodule: linear_with_lora_and_scale_forward(m, input))
|
||||
setattr(submodule, "step", 0) # Initialize step for LoRA if needed
|
||||
|
||||
setattr(submodule, "step", 0) # Initialize step for LoRA scheduling
|
||||
|
||||
def remove_lora_from_module(module):
|
||||
unloaded = False
|
||||
|
||||
19
nodes.py
19
nodes.py
@@ -1494,19 +1494,28 @@ class WanVideoSampler:
|
||||
transformer = model.diffusion_model
|
||||
|
||||
dtype = model["dtype"]
|
||||
fp8_matmul = model["fp8_matmul"]
|
||||
gguf = model["gguf"]
|
||||
scale_weights = model["scale_weights"]
|
||||
control_lora = model["control_lora"]
|
||||
|
||||
transformer_options = patcher.model_options.get("transformer_options", None)
|
||||
merge_loras = transformer_options["merge_loras"]
|
||||
|
||||
is_5b = transformer.out_dim == 48
|
||||
vae_upscale_factor = 16 if is_5b else 8
|
||||
|
||||
if len(patcher.patches) != 0 and transformer_options.get("linear_with_lora", False) is True:
|
||||
patch_linear = transformer_options.get("patch_linear", False)
|
||||
|
||||
if gguf:
|
||||
set_lora_params(transformer, patcher.patches)
|
||||
elif len(patcher.patches) != 0 and patch_linear:
|
||||
log.info(f"Using {len(patcher.patches)} LoRA weight patches for WanVideo model")
|
||||
if not gguf:
|
||||
convert_linear_with_lora_and_scale(transformer, patches=patcher.patches)
|
||||
else:
|
||||
set_lora_params(transformer, patcher.patches)
|
||||
if not merge_loras and fp8_matmul:
|
||||
raise NotImplementedError("FP8 matmul with unmerged LoRAs is not supported")
|
||||
convert_linear_with_lora_and_scale(transformer, patches=patcher.patches, scale_weight_keys=scale_weights)
|
||||
elif patch_linear:
|
||||
convert_linear_with_lora_and_scale(transformer, scale_weight_keys=scale_weights)
|
||||
else:
|
||||
remove_lora_from_module(transformer)
|
||||
|
||||
|
||||
@@ -698,7 +698,7 @@ class WanVideoSetLoRAs:
|
||||
if 'transformer_options' not in patcher.model_options:
|
||||
patcher.model_options['transformer_options'] = {}
|
||||
|
||||
patcher.model_options['transformer_options']["linear_with_lora"] = True
|
||||
patcher.model_options['transformer_options']["patch_linear"] = True
|
||||
|
||||
return (patcher,)
|
||||
|
||||
@@ -711,7 +711,7 @@ class WanVideoModelLoader:
|
||||
"model": (folder_paths.get_filename_list("unet_gguf") + folder_paths.get_filename_list("diffusion_models"), {"tooltip": "These models are loaded from the 'ComfyUI/models/diffusion_models' -folder",}),
|
||||
|
||||
"base_precision": (["fp32", "bf16", "fp16", "fp16_fast"], {"default": "bf16"}),
|
||||
"quantization": (["disabled", "fp8_e4m3fn", "fp8_e4m3fn_fast", "fp8_e5m2", "fp8_e4m3fn_fast_no_ffn", "fp8_e4m3fn_scaled", "fp8_e5m2_scaled"], {"default": "disabled", "tooltip": "optional quantization method"}),
|
||||
"quantization": (["disabled", "fp8_e4m3fn", "fp8_e4m3fn_fast", "fp8_e4m3fn_scaled", "fp8_e4m3fn_scaled_fast", "fp8_e5m2", "fp8_e5m2_fast", "fp8_e5m2_scaled", "fp8_e5m2_scaled_fast"], {"default": "disabled", "tooltip": "optional quantization method"}),
|
||||
"load_device": (["main_device", "offload_device"], {"default": "main_device", "tooltip": "Initial device to load the model to, NOT recommended with the larger models unless you have 48GB+ VRAM"}),
|
||||
},
|
||||
"optional": {
|
||||
@@ -1029,10 +1029,8 @@ class WanVideoModelLoader:
|
||||
)
|
||||
|
||||
if not gguf:
|
||||
|
||||
scale_weights = {}
|
||||
if "scaled" in quantization:
|
||||
scale_weights = {}
|
||||
if "fp8" in quantization:
|
||||
for k, v in sd.items():
|
||||
if k.endswith(".scale_weight"):
|
||||
scale_weights[k] = v
|
||||
@@ -1140,6 +1138,8 @@ class WanVideoModelLoader:
|
||||
patcher, device, transformer_load_device,
|
||||
params_to_keep=params_to_keep, dtype=dtype, base_dtype=base_dtype, state_dict=sd,
|
||||
low_mem_load=lora_low_mem_load, control_lora=control_lora, scale_weights=scale_weights)
|
||||
scale_weights.clear()
|
||||
patcher.patches.clear()
|
||||
|
||||
if gguf:
|
||||
#from diffusers.quantizers.gguf.utils import _replace_with_gguf_linear, GGUFParameter
|
||||
@@ -1175,22 +1175,14 @@ class WanVideoModelLoader:
|
||||
|
||||
patcher.model.is_patched = True
|
||||
|
||||
|
||||
if "fast" in quantization:
|
||||
if not merge_loras:
|
||||
raise ValueError("FP8 fast quantization requires LoRAs to be merged into the model, please set merge_loras=True in the LoRA input")
|
||||
from .fp8_optimization import convert_fp8_linear
|
||||
if quantization == "fp8_e4m3fn_fast_no_ffn":
|
||||
params_to_keep.update({"ffn"})
|
||||
print(params_to_keep)
|
||||
convert_fp8_linear(patcher.model.diffusion_model, base_dtype, params_to_keep=params_to_keep)
|
||||
patch_linear = (True if "scaled" in quantization or not merge_loras else False)
|
||||
|
||||
if "scaled" in quantization and not merge_loras:
|
||||
log.info("Using FP8 scaled linear quantization")
|
||||
convert_linear_with_lora_and_scale(patcher.model.diffusion_model, scale_weights, patches=patcher.patches)
|
||||
elif lora is not None and not merge_loras and not gguf:
|
||||
log.info("LoRAs will be applied at runtime")
|
||||
convert_linear_with_lora_and_scale(patcher.model.diffusion_model, patches=patcher.patches)
|
||||
if "fast" in quantization:
|
||||
if lora is not None and not merge_loras:
|
||||
raise NotImplementedError("fp8_fast is not supported with unmerged LoRAs")
|
||||
from .fp8_optimization import convert_fp8_linear
|
||||
convert_fp8_linear(transformer, base_dtype, params_to_keep, scale_weight_keys=scale_weights)
|
||||
patch_linear = False
|
||||
|
||||
del sd
|
||||
|
||||
@@ -1255,11 +1247,14 @@ class WanVideoModelLoader:
|
||||
patcher.model["control_lora"] = control_lora
|
||||
patcher.model["compile_args"] = compile_args
|
||||
patcher.model["gguf"] = gguf
|
||||
patcher.model["fp8_matmul"] = "fast" in quantization
|
||||
patcher.model["scale_weights"] = scale_weights
|
||||
|
||||
if 'transformer_options' not in patcher.model_options:
|
||||
patcher.model_options['transformer_options'] = {}
|
||||
patcher.model_options["transformer_options"]["block_swap_args"] = block_swap_args
|
||||
patcher.model_options["transformer_options"]["linear_with_lora"] = True if not merge_loras else False
|
||||
patcher.model_options["transformer_options"]["patch_linear"] = patch_linear
|
||||
patcher.model_options["transformer_options"]["merge_loras"] = merge_loras
|
||||
|
||||
for model in mm.current_loaded_models:
|
||||
if model._model() == patcher:
|
||||
|
||||
@@ -149,7 +149,7 @@ class WanVideoDiffusionForcingSampler:
|
||||
gguf = model["gguf"]
|
||||
transformer_options = patcher.model_options.get("transformer_options", None)
|
||||
|
||||
if len(patcher.patches) != 0 and transformer_options.get("linear_with_lora", False) is True:
|
||||
if len(patcher.patches) != 0 and transformer_options.get("linear_patched", False) is True:
|
||||
log.info(f"Using {len(patcher.patches)} LoRA weight patches for WanVideo model")
|
||||
if not gguf:
|
||||
convert_linear_with_lora_and_scale(transformer, patches=patcher.patches)
|
||||
|
||||
Reference in New Issue
Block a user