From 48fa904ad89f70b605853eaee181cfd3dfd6df5b Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Sat, 9 Aug 2025 10:17:11 +0300 Subject: [PATCH] fp8 matmul for scaled models Fp8 matmul (fp8_fast) doesn't seem feasible with unmerged LoRAs as you'd need to first upcast, then apply LoRA, then downcast back to fp8 and that is too slow. Direct adding in fp8 is also not possible since that's just not something fp8 dtypes support. --- fp8_optimization.py | 62 +++++++++++++++++++++++------------------- nodes.py | 19 +++++++++---- nodes_model_loading.py | 37 +++++++++++-------------- skyreels/nodes.py | 2 +- 4 files changed, 65 insertions(+), 55 deletions(-) diff --git a/fp8_optimization.py b/fp8_optimization.py index b63e91d..527d622 100644 --- a/fp8_optimization.py +++ b/fp8_optimization.py @@ -1,31 +1,31 @@ -#based on ComfyUI's and MinusZoneAI's fp8_linear optimization - import torch import torch.nn as nn from .utils import log -def fp8_linear_forward(cls, original_dtype, input): +#based on ComfyUI's and MinusZoneAI's fp8_linear optimization +def fp8_linear_forward(cls, base_dtype, input): weight_dtype = cls.weight.dtype if weight_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]: if len(input.shape) == 3: - #target_dtype = torch.float8_e5m2 if weight_dtype == torch.float8_e4m3fn else torch.float8_e4m3fn - inn = input.reshape(-1, input.shape[2]).to(weight_dtype) - w = cls.weight.t() - - scale = torch.ones((1), device=input.device, dtype=torch.float32) - bias = cls.bias.to(original_dtype) if cls.bias is not None else None - - if bias is not None: - o = torch._scaled_mm(inn, w, out_dtype=original_dtype, bias=bias, scale_a=scale, scale_b=scale) + input_shape = input.shape + + scale_weight = getattr(cls, 'scale_weight', None) + if scale_weight is None: + scale_weight = torch.ones((), device=input.device, dtype=torch.float32) else: - o = torch._scaled_mm(inn, w, out_dtype=original_dtype, scale_a=scale, scale_b=scale) + scale_weight = scale_weight.to(input.device) + + scale_input = torch.ones((), device=input.device, dtype=torch.float32) + input = torch.clamp(input, min=-448, max=448, out=input) + inn = input.reshape(-1, input_shape[2]).to(torch.float8_e4m3fn).contiguous() #always e4m3fn because e5m2 * e5m2 is not supported - if isinstance(o, tuple): - o = o[0] + bias = cls.bias.to(base_dtype) if cls.bias is not None else None - return o.reshape((-1, input.shape[1], cls.weight.shape[0])) + o = torch._scaled_mm(inn, cls.weight.t(), out_dtype=base_dtype, bias=bias, scale_a=scale_input, scale_b=scale_weight) + + return o.reshape((-1, input_shape[1], cls.weight.shape[0])) else: - return cls.original_forward(input.to(original_dtype)) + return cls.original_forward(input.to(base_dtype)) else: return cls.original_forward(input) @@ -67,20 +67,23 @@ def linear_with_lora_and_scale_forward(cls, input): weight = apply_lora(weight, lora, cls.step).to(input.dtype) return torch.nn.functional.linear(input, weight, bias) - -def convert_fp8_linear(module, original_dtype, params_to_keep={}): - setattr(module, "fp8_matmul_enabled", True) - +def convert_fp8_linear(module, base_dtype, params_to_keep={}, scale_weight_keys=None): + log.info("FP8 matmul enabled") for name, submodule in module.named_modules(): if not any(keyword in name for keyword in params_to_keep): if isinstance(submodule, nn.Linear): + if scale_weight_keys is not None: + scale_key = f"{name}.scale_weight" + if scale_key in scale_weight_keys: + print("Setting scale_weight for", name) + setattr(submodule, "scale_weight", scale_weight_keys[scale_key]) original_forward = submodule.forward setattr(submodule, "original_forward", original_forward) - setattr(submodule, "forward", lambda input, m=submodule: fp8_linear_forward(m, original_dtype, input)) - - + setattr(submodule, "forward", lambda input, m=submodule: fp8_linear_forward(m, base_dtype, input)) + def convert_linear_with_lora_and_scale(module, scale_weight_keys=None, patches=None, params_to_keep={}): + log.info("Patching Linear layers...") for name, submodule in module.named_modules(): if not any(keyword in name for keyword in params_to_keep): # Set scale_weight if present @@ -90,6 +93,9 @@ def convert_linear_with_lora_and_scale(module, scale_weight_keys=None, patches=N setattr(submodule, "scale_weight", scale_weight_keys[scale_key]) # Set LoRA if present + if hasattr(submodule, "lora"): + print(f"removing old LoRA in {name}" ) + delattr(submodule, "lora") if patches is not None: patch_key = f"diffusion_model.{name}.weight" patch = patches.get(patch_key, []) @@ -108,17 +114,17 @@ def convert_linear_with_lora_and_scale(module, scale_weight_keys=None, patches=N lora_strengths = [p[0] for p in patch] lora = (lora_diffs, lora_strengths) setattr(submodule, "lora", lora) + print(f"Added LoRA to {name} with {len(lora_diffs)} diffs and strengths {lora_strengths}") # Set forward if Linear and has either scale or lora if isinstance(submodule, nn.Linear): has_scale = hasattr(submodule, "scale_weight") has_lora = hasattr(submodule, "lora") + if not hasattr(submodule, "original_forward"): + setattr(submodule, "original_forward", submodule.forward) if has_scale or has_lora: - original_forward_ = submodule.forward - setattr(submodule, "original_forward_", original_forward_) setattr(submodule, "forward", lambda input, m=submodule: linear_with_lora_and_scale_forward(m, input)) - setattr(submodule, "step", 0) # Initialize step for LoRA if needed - + setattr(submodule, "step", 0) # Initialize step for LoRA scheduling def remove_lora_from_module(module): unloaded = False diff --git a/nodes.py b/nodes.py index d749afd..b66f59d 100644 --- a/nodes.py +++ b/nodes.py @@ -1494,19 +1494,28 @@ class WanVideoSampler: transformer = model.diffusion_model dtype = model["dtype"] + fp8_matmul = model["fp8_matmul"] gguf = model["gguf"] + scale_weights = model["scale_weights"] control_lora = model["control_lora"] + transformer_options = patcher.model_options.get("transformer_options", None) + merge_loras = transformer_options["merge_loras"] is_5b = transformer.out_dim == 48 vae_upscale_factor = 16 if is_5b else 8 - if len(patcher.patches) != 0 and transformer_options.get("linear_with_lora", False) is True: + patch_linear = transformer_options.get("patch_linear", False) + + if gguf: + set_lora_params(transformer, patcher.patches) + elif len(patcher.patches) != 0 and patch_linear: log.info(f"Using {len(patcher.patches)} LoRA weight patches for WanVideo model") - if not gguf: - convert_linear_with_lora_and_scale(transformer, patches=patcher.patches) - else: - set_lora_params(transformer, patcher.patches) + if not merge_loras and fp8_matmul: + raise NotImplementedError("FP8 matmul with unmerged LoRAs is not supported") + convert_linear_with_lora_and_scale(transformer, patches=patcher.patches, scale_weight_keys=scale_weights) + elif patch_linear: + convert_linear_with_lora_and_scale(transformer, scale_weight_keys=scale_weights) else: remove_lora_from_module(transformer) diff --git a/nodes_model_loading.py b/nodes_model_loading.py index a4a42e9..fbbd946 100644 --- a/nodes_model_loading.py +++ b/nodes_model_loading.py @@ -698,7 +698,7 @@ class WanVideoSetLoRAs: if 'transformer_options' not in patcher.model_options: patcher.model_options['transformer_options'] = {} - patcher.model_options['transformer_options']["linear_with_lora"] = True + patcher.model_options['transformer_options']["patch_linear"] = True return (patcher,) @@ -711,7 +711,7 @@ class WanVideoModelLoader: "model": (folder_paths.get_filename_list("unet_gguf") + folder_paths.get_filename_list("diffusion_models"), {"tooltip": "These models are loaded from the 'ComfyUI/models/diffusion_models' -folder",}), "base_precision": (["fp32", "bf16", "fp16", "fp16_fast"], {"default": "bf16"}), - "quantization": (["disabled", "fp8_e4m3fn", "fp8_e4m3fn_fast", "fp8_e5m2", "fp8_e4m3fn_fast_no_ffn", "fp8_e4m3fn_scaled", "fp8_e5m2_scaled"], {"default": "disabled", "tooltip": "optional quantization method"}), + "quantization": (["disabled", "fp8_e4m3fn", "fp8_e4m3fn_fast", "fp8_e4m3fn_scaled", "fp8_e4m3fn_scaled_fast", "fp8_e5m2", "fp8_e5m2_fast", "fp8_e5m2_scaled", "fp8_e5m2_scaled_fast"], {"default": "disabled", "tooltip": "optional quantization method"}), "load_device": (["main_device", "offload_device"], {"default": "main_device", "tooltip": "Initial device to load the model to, NOT recommended with the larger models unless you have 48GB+ VRAM"}), }, "optional": { @@ -1029,10 +1029,8 @@ class WanVideoModelLoader: ) if not gguf: - scale_weights = {} - if "scaled" in quantization: - scale_weights = {} + if "fp8" in quantization: for k, v in sd.items(): if k.endswith(".scale_weight"): scale_weights[k] = v @@ -1140,6 +1138,8 @@ class WanVideoModelLoader: patcher, device, transformer_load_device, params_to_keep=params_to_keep, dtype=dtype, base_dtype=base_dtype, state_dict=sd, low_mem_load=lora_low_mem_load, control_lora=control_lora, scale_weights=scale_weights) + scale_weights.clear() + patcher.patches.clear() if gguf: #from diffusers.quantizers.gguf.utils import _replace_with_gguf_linear, GGUFParameter @@ -1175,22 +1175,14 @@ class WanVideoModelLoader: patcher.model.is_patched = True - - if "fast" in quantization: - if not merge_loras: - raise ValueError("FP8 fast quantization requires LoRAs to be merged into the model, please set merge_loras=True in the LoRA input") - from .fp8_optimization import convert_fp8_linear - if quantization == "fp8_e4m3fn_fast_no_ffn": - params_to_keep.update({"ffn"}) - print(params_to_keep) - convert_fp8_linear(patcher.model.diffusion_model, base_dtype, params_to_keep=params_to_keep) + patch_linear = (True if "scaled" in quantization or not merge_loras else False) - if "scaled" in quantization and not merge_loras: - log.info("Using FP8 scaled linear quantization") - convert_linear_with_lora_and_scale(patcher.model.diffusion_model, scale_weights, patches=patcher.patches) - elif lora is not None and not merge_loras and not gguf: - log.info("LoRAs will be applied at runtime") - convert_linear_with_lora_and_scale(patcher.model.diffusion_model, patches=patcher.patches) + if "fast" in quantization: + if lora is not None and not merge_loras: + raise NotImplementedError("fp8_fast is not supported with unmerged LoRAs") + from .fp8_optimization import convert_fp8_linear + convert_fp8_linear(transformer, base_dtype, params_to_keep, scale_weight_keys=scale_weights) + patch_linear = False del sd @@ -1255,11 +1247,14 @@ class WanVideoModelLoader: patcher.model["control_lora"] = control_lora patcher.model["compile_args"] = compile_args patcher.model["gguf"] = gguf + patcher.model["fp8_matmul"] = "fast" in quantization + patcher.model["scale_weights"] = scale_weights if 'transformer_options' not in patcher.model_options: patcher.model_options['transformer_options'] = {} patcher.model_options["transformer_options"]["block_swap_args"] = block_swap_args - patcher.model_options["transformer_options"]["linear_with_lora"] = True if not merge_loras else False + patcher.model_options["transformer_options"]["patch_linear"] = patch_linear + patcher.model_options["transformer_options"]["merge_loras"] = merge_loras for model in mm.current_loaded_models: if model._model() == patcher: diff --git a/skyreels/nodes.py b/skyreels/nodes.py index e0da49b..134c1dc 100644 --- a/skyreels/nodes.py +++ b/skyreels/nodes.py @@ -149,7 +149,7 @@ class WanVideoDiffusionForcingSampler: gguf = model["gguf"] transformer_options = patcher.model_options.get("transformer_options", None) - if len(patcher.patches) != 0 and transformer_options.get("linear_with_lora", False) is True: + if len(patcher.patches) != 0 and transformer_options.get("linear_patched", False) is True: log.info(f"Using {len(patcher.patches)} LoRA weight patches for WanVideo model") if not gguf: convert_linear_with_lora_and_scale(transformer, patches=patcher.patches)