diff --git a/fp8_optimization.py b/fp8_optimization.py index b63e91d..527d622 100644 --- a/fp8_optimization.py +++ b/fp8_optimization.py @@ -1,31 +1,31 @@ -#based on ComfyUI's and MinusZoneAI's fp8_linear optimization - import torch import torch.nn as nn from .utils import log -def fp8_linear_forward(cls, original_dtype, input): +#based on ComfyUI's and MinusZoneAI's fp8_linear optimization +def fp8_linear_forward(cls, base_dtype, input): weight_dtype = cls.weight.dtype if weight_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]: if len(input.shape) == 3: - #target_dtype = torch.float8_e5m2 if weight_dtype == torch.float8_e4m3fn else torch.float8_e4m3fn - inn = input.reshape(-1, input.shape[2]).to(weight_dtype) - w = cls.weight.t() - - scale = torch.ones((1), device=input.device, dtype=torch.float32) - bias = cls.bias.to(original_dtype) if cls.bias is not None else None - - if bias is not None: - o = torch._scaled_mm(inn, w, out_dtype=original_dtype, bias=bias, scale_a=scale, scale_b=scale) + input_shape = input.shape + + scale_weight = getattr(cls, 'scale_weight', None) + if scale_weight is None: + scale_weight = torch.ones((), device=input.device, dtype=torch.float32) else: - o = torch._scaled_mm(inn, w, out_dtype=original_dtype, scale_a=scale, scale_b=scale) + scale_weight = scale_weight.to(input.device) + + scale_input = torch.ones((), device=input.device, dtype=torch.float32) + input = torch.clamp(input, min=-448, max=448, out=input) + inn = input.reshape(-1, input_shape[2]).to(torch.float8_e4m3fn).contiguous() #always e4m3fn because e5m2 * e5m2 is not supported - if isinstance(o, tuple): - o = o[0] + bias = cls.bias.to(base_dtype) if cls.bias is not None else None - return o.reshape((-1, input.shape[1], cls.weight.shape[0])) + o = torch._scaled_mm(inn, cls.weight.t(), out_dtype=base_dtype, bias=bias, scale_a=scale_input, scale_b=scale_weight) + + return o.reshape((-1, input_shape[1], cls.weight.shape[0])) else: - return cls.original_forward(input.to(original_dtype)) + return cls.original_forward(input.to(base_dtype)) else: return cls.original_forward(input) @@ -67,20 +67,23 @@ def linear_with_lora_and_scale_forward(cls, input): weight = apply_lora(weight, lora, cls.step).to(input.dtype) return torch.nn.functional.linear(input, weight, bias) - -def convert_fp8_linear(module, original_dtype, params_to_keep={}): - setattr(module, "fp8_matmul_enabled", True) - +def convert_fp8_linear(module, base_dtype, params_to_keep={}, scale_weight_keys=None): + log.info("FP8 matmul enabled") for name, submodule in module.named_modules(): if not any(keyword in name for keyword in params_to_keep): if isinstance(submodule, nn.Linear): + if scale_weight_keys is not None: + scale_key = f"{name}.scale_weight" + if scale_key in scale_weight_keys: + print("Setting scale_weight for", name) + setattr(submodule, "scale_weight", scale_weight_keys[scale_key]) original_forward = submodule.forward setattr(submodule, "original_forward", original_forward) - setattr(submodule, "forward", lambda input, m=submodule: fp8_linear_forward(m, original_dtype, input)) - - + setattr(submodule, "forward", lambda input, m=submodule: fp8_linear_forward(m, base_dtype, input)) + def convert_linear_with_lora_and_scale(module, scale_weight_keys=None, patches=None, params_to_keep={}): + log.info("Patching Linear layers...") for name, submodule in module.named_modules(): if not any(keyword in name for keyword in params_to_keep): # Set scale_weight if present @@ -90,6 +93,9 @@ def convert_linear_with_lora_and_scale(module, scale_weight_keys=None, patches=N setattr(submodule, "scale_weight", scale_weight_keys[scale_key]) # Set LoRA if present + if hasattr(submodule, "lora"): + print(f"removing old LoRA in {name}" ) + delattr(submodule, "lora") if patches is not None: patch_key = f"diffusion_model.{name}.weight" patch = patches.get(patch_key, []) @@ -108,17 +114,17 @@ def convert_linear_with_lora_and_scale(module, scale_weight_keys=None, patches=N lora_strengths = [p[0] for p in patch] lora = (lora_diffs, lora_strengths) setattr(submodule, "lora", lora) + print(f"Added LoRA to {name} with {len(lora_diffs)} diffs and strengths {lora_strengths}") # Set forward if Linear and has either scale or lora if isinstance(submodule, nn.Linear): has_scale = hasattr(submodule, "scale_weight") has_lora = hasattr(submodule, "lora") + if not hasattr(submodule, "original_forward"): + setattr(submodule, "original_forward", submodule.forward) if has_scale or has_lora: - original_forward_ = submodule.forward - setattr(submodule, "original_forward_", original_forward_) setattr(submodule, "forward", lambda input, m=submodule: linear_with_lora_and_scale_forward(m, input)) - setattr(submodule, "step", 0) # Initialize step for LoRA if needed - + setattr(submodule, "step", 0) # Initialize step for LoRA scheduling def remove_lora_from_module(module): unloaded = False diff --git a/nodes.py b/nodes.py index d749afd..b66f59d 100644 --- a/nodes.py +++ b/nodes.py @@ -1494,19 +1494,28 @@ class WanVideoSampler: transformer = model.diffusion_model dtype = model["dtype"] + fp8_matmul = model["fp8_matmul"] gguf = model["gguf"] + scale_weights = model["scale_weights"] control_lora = model["control_lora"] + transformer_options = patcher.model_options.get("transformer_options", None) + merge_loras = transformer_options["merge_loras"] is_5b = transformer.out_dim == 48 vae_upscale_factor = 16 if is_5b else 8 - if len(patcher.patches) != 0 and transformer_options.get("linear_with_lora", False) is True: + patch_linear = transformer_options.get("patch_linear", False) + + if gguf: + set_lora_params(transformer, patcher.patches) + elif len(patcher.patches) != 0 and patch_linear: log.info(f"Using {len(patcher.patches)} LoRA weight patches for WanVideo model") - if not gguf: - convert_linear_with_lora_and_scale(transformer, patches=patcher.patches) - else: - set_lora_params(transformer, patcher.patches) + if not merge_loras and fp8_matmul: + raise NotImplementedError("FP8 matmul with unmerged LoRAs is not supported") + convert_linear_with_lora_and_scale(transformer, patches=patcher.patches, scale_weight_keys=scale_weights) + elif patch_linear: + convert_linear_with_lora_and_scale(transformer, scale_weight_keys=scale_weights) else: remove_lora_from_module(transformer) diff --git a/nodes_model_loading.py b/nodes_model_loading.py index a4a42e9..fbbd946 100644 --- a/nodes_model_loading.py +++ b/nodes_model_loading.py @@ -698,7 +698,7 @@ class WanVideoSetLoRAs: if 'transformer_options' not in patcher.model_options: patcher.model_options['transformer_options'] = {} - patcher.model_options['transformer_options']["linear_with_lora"] = True + patcher.model_options['transformer_options']["patch_linear"] = True return (patcher,) @@ -711,7 +711,7 @@ class WanVideoModelLoader: "model": (folder_paths.get_filename_list("unet_gguf") + folder_paths.get_filename_list("diffusion_models"), {"tooltip": "These models are loaded from the 'ComfyUI/models/diffusion_models' -folder",}), "base_precision": (["fp32", "bf16", "fp16", "fp16_fast"], {"default": "bf16"}), - "quantization": (["disabled", "fp8_e4m3fn", "fp8_e4m3fn_fast", "fp8_e5m2", "fp8_e4m3fn_fast_no_ffn", "fp8_e4m3fn_scaled", "fp8_e5m2_scaled"], {"default": "disabled", "tooltip": "optional quantization method"}), + "quantization": (["disabled", "fp8_e4m3fn", "fp8_e4m3fn_fast", "fp8_e4m3fn_scaled", "fp8_e4m3fn_scaled_fast", "fp8_e5m2", "fp8_e5m2_fast", "fp8_e5m2_scaled", "fp8_e5m2_scaled_fast"], {"default": "disabled", "tooltip": "optional quantization method"}), "load_device": (["main_device", "offload_device"], {"default": "main_device", "tooltip": "Initial device to load the model to, NOT recommended with the larger models unless you have 48GB+ VRAM"}), }, "optional": { @@ -1029,10 +1029,8 @@ class WanVideoModelLoader: ) if not gguf: - scale_weights = {} - if "scaled" in quantization: - scale_weights = {} + if "fp8" in quantization: for k, v in sd.items(): if k.endswith(".scale_weight"): scale_weights[k] = v @@ -1140,6 +1138,8 @@ class WanVideoModelLoader: patcher, device, transformer_load_device, params_to_keep=params_to_keep, dtype=dtype, base_dtype=base_dtype, state_dict=sd, low_mem_load=lora_low_mem_load, control_lora=control_lora, scale_weights=scale_weights) + scale_weights.clear() + patcher.patches.clear() if gguf: #from diffusers.quantizers.gguf.utils import _replace_with_gguf_linear, GGUFParameter @@ -1175,22 +1175,14 @@ class WanVideoModelLoader: patcher.model.is_patched = True - - if "fast" in quantization: - if not merge_loras: - raise ValueError("FP8 fast quantization requires LoRAs to be merged into the model, please set merge_loras=True in the LoRA input") - from .fp8_optimization import convert_fp8_linear - if quantization == "fp8_e4m3fn_fast_no_ffn": - params_to_keep.update({"ffn"}) - print(params_to_keep) - convert_fp8_linear(patcher.model.diffusion_model, base_dtype, params_to_keep=params_to_keep) + patch_linear = (True if "scaled" in quantization or not merge_loras else False) - if "scaled" in quantization and not merge_loras: - log.info("Using FP8 scaled linear quantization") - convert_linear_with_lora_and_scale(patcher.model.diffusion_model, scale_weights, patches=patcher.patches) - elif lora is not None and not merge_loras and not gguf: - log.info("LoRAs will be applied at runtime") - convert_linear_with_lora_and_scale(patcher.model.diffusion_model, patches=patcher.patches) + if "fast" in quantization: + if lora is not None and not merge_loras: + raise NotImplementedError("fp8_fast is not supported with unmerged LoRAs") + from .fp8_optimization import convert_fp8_linear + convert_fp8_linear(transformer, base_dtype, params_to_keep, scale_weight_keys=scale_weights) + patch_linear = False del sd @@ -1255,11 +1247,14 @@ class WanVideoModelLoader: patcher.model["control_lora"] = control_lora patcher.model["compile_args"] = compile_args patcher.model["gguf"] = gguf + patcher.model["fp8_matmul"] = "fast" in quantization + patcher.model["scale_weights"] = scale_weights if 'transformer_options' not in patcher.model_options: patcher.model_options['transformer_options'] = {} patcher.model_options["transformer_options"]["block_swap_args"] = block_swap_args - patcher.model_options["transformer_options"]["linear_with_lora"] = True if not merge_loras else False + patcher.model_options["transformer_options"]["patch_linear"] = patch_linear + patcher.model_options["transformer_options"]["merge_loras"] = merge_loras for model in mm.current_loaded_models: if model._model() == patcher: diff --git a/skyreels/nodes.py b/skyreels/nodes.py index e0da49b..134c1dc 100644 --- a/skyreels/nodes.py +++ b/skyreels/nodes.py @@ -149,7 +149,7 @@ class WanVideoDiffusionForcingSampler: gguf = model["gguf"] transformer_options = patcher.model_options.get("transformer_options", None) - if len(patcher.patches) != 0 and transformer_options.get("linear_with_lora", False) is True: + if len(patcher.patches) != 0 and transformer_options.get("linear_patched", False) is True: log.info(f"Using {len(patcher.patches)} LoRA weight patches for WanVideo model") if not gguf: convert_linear_with_lora_and_scale(transformer, patches=patcher.patches)