1
0
mirror of https://github.com/kijai/ComfyUI-WanVideoWrapper.git synced 2026-01-26 23:41:35 +03:00
Files
ComfyUI-WanVideoWrapper/custom_linear.py
kijai 64191921d4 Squashed commit of the following:
commit fdb23dec7d
Author: kijai <40791699+kijai@users.noreply.github.com>
Date:   Mon Jan 5 22:11:04 2026 +0200

    Update model.py

commit 07d7d8ca8e
Author: kijai <40791699+kijai@users.noreply.github.com>
Date:   Mon Jan 5 22:10:02 2026 +0200

    remove prints

commit 01869d4bf5
Merge: 55c6720 bf1d77f
Author: kijai <40791699+kijai@users.noreply.github.com>
Date:   Mon Jan 5 18:47:48 2026 +0200

    Merge branch 'main' into longvie2

commit 55c672028b
Merge: b551ec9 be41f67
Author: kijai <40791699+kijai@users.noreply.github.com>
Date:   Mon Dec 29 15:39:43 2025 +0200

    Merge branch 'main' into longvie2

commit b551ec9e31
Merge: 9f019d7 19bcee6
Author: kijai <40791699+kijai@users.noreply.github.com>
Date:   Mon Dec 29 15:03:53 2025 +0200

    Merge branch 'main' into longvie2

commit 9f019d7dfb
Merge: fc5322f c5d3fb4
Author: kijai <40791699+kijai@users.noreply.github.com>
Date:   Tue Dec 23 23:40:25 2025 +0200

    Merge branch 'main' into longvie2

commit fc5322fae4
Merge: 222fc70 e75f814
Author: kijai <40791699+kijai@users.noreply.github.com>
Date:   Tue Dec 23 22:04:15 2025 +0200

    Merge branch 'main' into longvie2

commit 222fc70eb7
Author: kijai <40791699+kijai@users.noreply.github.com>
Date:   Tue Dec 23 17:18:55 2025 +0200

    Update nodes.py

commit 8509236da1
Author: kijai <40791699+kijai@users.noreply.github.com>
Date:   Tue Dec 23 14:20:18 2025 +0200

    init
2026-01-05 22:11:20 +02:00

282 lines
12 KiB
Python

import torch
import torch.nn as nn
from accelerate import init_empty_weights
from .gguf.gguf_utils import GGUFParameter, dequantize_gguf_tensor
@torch.library.custom_op("wanvideo::apply_lora", mutates_args=())
def apply_lora(weight: torch.Tensor, lora_diff_0: torch.Tensor, lora_diff_1: torch.Tensor, lora_diff_2: float, lora_strength: torch.Tensor) -> torch.Tensor:
patch_diff = torch.mm(
lora_diff_0.flatten(start_dim=1),
lora_diff_1.flatten(start_dim=1)
).reshape(weight.shape)
alpha = lora_diff_2 / lora_diff_1.shape[0] if lora_diff_2 != 0.0 else 1.0
scale = lora_strength * alpha
return weight + patch_diff * scale
@apply_lora.register_fake
def _(weight, lora_diff_0, lora_diff_1, lora_diff_2, lora_strength):
# Return weight with same metadata
return weight.clone()
@torch.library.custom_op("wanvideo::apply_single_lora", mutates_args=())
def apply_single_lora(weight: torch.Tensor, lora_diff: torch.Tensor, lora_strength: torch.Tensor) -> torch.Tensor:
return weight + lora_diff * lora_strength
@apply_single_lora.register_fake
def _(weight, lora_diff, lora_strength):
# Return weight with same metadata
return weight.clone()
@torch.library.custom_op("wanvideo::linear_forward", mutates_args=())
def linear_forward(input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor | None) -> torch.Tensor:
return torch.nn.functional.linear(input, weight, bias)
@linear_forward.register_fake
def _(input, weight, bias):
# Calculate output shape: (..., out_features)
out_features = weight.shape[0]
output_shape = list(input.shape[:-1]) + [out_features]
return input.new_empty(output_shape)
#based on https://github.com/huggingface/diffusers/blob/main/src/diffusers/quantizers/gguf/utils.py
def _replace_linear(model, compute_dtype, state_dict, prefix="", patches=None, scale_weights=None, compile_args=None, modules_to_not_convert=[]):
has_children = list(model.children())
if not has_children:
return
allow_compile = False
for name, module in model.named_children():
if compile_args is not None:
allow_compile = compile_args.get("allow_unmerged_lora_compile", False)
module_prefix = prefix + name + "."
module_prefix = module_prefix.replace("_orig_mod.", "")
_replace_linear(module, compute_dtype, state_dict, module_prefix, patches, scale_weights, compile_args, modules_to_not_convert)
if isinstance(module, nn.Linear) and "loras" not in module_prefix and "dual_controller" not in module_prefix and name not in modules_to_not_convert:
weight_key = module_prefix + "weight"
if weight_key not in state_dict:
continue
in_features = state_dict[weight_key].shape[1]
out_features = state_dict[weight_key].shape[0]
is_gguf = isinstance(state_dict[weight_key], GGUFParameter)
scale_weight = None
if not is_gguf and scale_weights is not None:
scale_key = f"{module_prefix}scale_weight"
scale_weight = scale_weights.get(scale_key)
with init_empty_weights():
model._modules[name] = CustomLinear(
in_features,
out_features,
module.bias is not None,
compute_dtype=compute_dtype,
scale_weight=scale_weight,
allow_compile=allow_compile,
is_gguf=is_gguf
)
model._modules[name].source_cls = type(module)
model._modules[name].requires_grad_(False)
return model
def set_lora_params(module, patches, module_prefix="", device=torch.device("cpu")):
remove_lora_from_module(module)
# Recursively set lora_diffs and lora_strengths for all CustomLinear layers
for name, child in module.named_children():
params = list(child.parameters())
if params:
device = params[0].device
else:
device = torch.device("cpu")
child_prefix = (f"{module_prefix}{name}.")
set_lora_params(child, patches, child_prefix, device)
if isinstance(module, CustomLinear):
key = f"diffusion_model.{module_prefix}weight"
patch = patches.get(key, [])
#print(f"Processing LoRA patches for {key}: {len(patch)} patches found")
if len(patch) == 0:
key = key.replace("_orig_mod.", "")
patch = patches.get(key, [])
#print(f"Processing LoRA patches for {key}: {len(patch)} patches found")
if len(patch) != 0:
lora_diffs = []
for p in patch:
lora_obj = p[1]
if "head" in key:
continue # For now skip LoRA for head layers
elif hasattr(lora_obj, "weights"):
lora_diffs.append(lora_obj.weights)
elif isinstance(lora_obj, tuple) and lora_obj[0] == "diff":
lora_diffs.append(lora_obj[1])
else:
continue
lora_strengths = [p[0] for p in patch]
module.set_lora_diffs(lora_diffs, device=device)
module.set_lora_strengths(lora_strengths, device=device)
module._step.fill_(0) # Initialize step for LoRA scheduling
class CustomLinear(nn.Linear):
def __init__(
self,
in_features,
out_features,
bias=False,
compute_dtype=None,
device=None,
scale_weight=None,
allow_compile=False,
is_gguf=False
) -> None:
super().__init__(in_features, out_features, bias, device)
self.compute_dtype = compute_dtype
self.lora_diffs = []
self.register_buffer("_step", torch.zeros((), dtype=torch.long))
self.scale_weight = scale_weight
self.lora_strengths = []
self.allow_compile = allow_compile
self.is_gguf = is_gguf
if not allow_compile:
self._apply_lora_impl = self._apply_lora_custom_op
self._apply_single_lora_impl = self._apply_single_lora_custom_op
self._linear_forward_impl = self._linear_forward_custom_op
else:
self._apply_lora_impl = self._apply_lora_direct
self._apply_single_lora_impl = self._apply_single_lora_direct
self._linear_forward_impl = self._linear_forward_direct
# Direct implementations (no custom ops)
def _apply_lora_direct(self, weight, lora_diff_0, lora_diff_1, lora_diff_2, lora_strength):
patch_diff = torch.mm(
lora_diff_0.flatten(start_dim=1),
lora_diff_1.flatten(start_dim=1)
).reshape(weight.shape) + 0
alpha = lora_diff_2 / lora_diff_1.shape[0] if lora_diff_2 != 0.0 else 1.0
scale = lora_strength * alpha
return weight + patch_diff * scale
def _apply_single_lora_direct(self, weight, lora_diff, lora_strength):
return weight + lora_diff * lora_strength
def _linear_forward_direct(self, input, weight, bias):
return torch.nn.functional.linear(input, weight, bias)
# Custom op implementations
def _apply_lora_custom_op(self, weight, lora_diff_0, lora_diff_1, lora_diff_2, lora_strength):
return torch.ops.wanvideo.apply_lora(weight, lora_diff_0, lora_diff_1,
float(lora_diff_2) if lora_diff_2 is not None else 0.0, lora_strength
)
def _apply_single_lora_custom_op(self, weight, lora_diff, lora_strength):
return torch.ops.wanvideo.apply_single_lora(weight, lora_diff, lora_strength)
def _linear_forward_custom_op(self, input, weight, bias):
return torch.ops.wanvideo.linear_forward(input, weight, bias)
def set_lora_diffs(self, lora_diffs, device=torch.device("cpu")):
self.lora_diffs = []
for i, diff in enumerate(lora_diffs):
if len(diff) > 1:
self.register_buffer(f"lora_diff_{i}_0", diff[0].to(device, self.compute_dtype))
self.register_buffer(f"lora_diff_{i}_1", diff[1].to(device, self.compute_dtype))
setattr(self, f"lora_diff_{i}_2", diff[2])
self.lora_diffs.append((f"lora_diff_{i}_0", f"lora_diff_{i}_1", f"lora_diff_{i}_2"))
else:
self.register_buffer(f"lora_diff_{i}_0", diff[0].to(device, self.compute_dtype))
self.lora_diffs.append(f"lora_diff_{i}_0")
def set_lora_strengths(self, lora_strengths, device=torch.device("cpu")):
self._lora_strength_tensors = []
self._lora_strength_is_scheduled = []
self._step = self._step.to(device)
for i, strength in enumerate(lora_strengths):
if isinstance(strength, list):
tensor = torch.tensor(strength, dtype=self.compute_dtype, device=device)
self.register_buffer(f"_lora_strength_{i}", tensor)
self._lora_strength_is_scheduled.append(True)
else:
tensor = torch.tensor([strength], dtype=self.compute_dtype, device=device)
self.register_buffer(f"_lora_strength_{i}", tensor)
self._lora_strength_is_scheduled.append(False)
def _get_lora_strength(self, idx):
strength_tensor = getattr(self, f"_lora_strength_{idx}")
if self._lora_strength_is_scheduled[idx]:
return strength_tensor.index_select(0, self._step).squeeze(0)
return strength_tensor[0]
def _get_weight_with_lora(self, weight):
"""Apply LoRA using custom ops to avoid graph breaks"""
if not hasattr(self, "lora_diff_0_0"):
return weight
for idx, lora_diff_names in enumerate(self.lora_diffs):
lora_strength = self._get_lora_strength(idx)
if isinstance(lora_diff_names, tuple):
lora_diff_0 = getattr(self, lora_diff_names[0])
lora_diff_1 = getattr(self, lora_diff_names[1])
lora_diff_2 = getattr(self, lora_diff_names[2])
weight = self._apply_lora_impl(
weight, lora_diff_0, lora_diff_1,
float(lora_diff_2) if lora_diff_2 is not None else 0.0, lora_strength
)
else:
lora_diff = getattr(self, lora_diff_names)
weight = self._apply_single_lora_impl(weight, lora_diff, lora_strength)
return weight
def _prepare_weight(self, input):
"""Prepare weight tensor - handles both regular and GGUF weights"""
if self.is_gguf:
weight = dequantize_gguf_tensor(self.weight).to(self.compute_dtype)
else:
weight = self.weight.to(input)
return weight
def forward(self, input):
weight = self._prepare_weight(input)
if self.bias is not None:
bias = self.bias.to(input if not self.is_gguf else self.compute_dtype)
else:
bias = None
# Only apply scale_weight for non-GGUF models
if not self.is_gguf and self.scale_weight is not None:
if weight.numel() < input.numel():
weight = weight * self.scale_weight
else:
input = input * self.scale_weight
weight = self._get_weight_with_lora(weight)
out = self._linear_forward_impl(input, weight, bias)
del weight, input, bias
return out
def update_lora_step(module, step):
for name, submodule in module.named_modules():
if isinstance(submodule, CustomLinear) and hasattr(submodule, "_step"):
submodule._step.fill_(step)
def remove_lora_from_module(module):
for name, submodule in module.named_modules():
if hasattr(submodule, "lora_diffs"):
for i in range(len(submodule.lora_diffs)):
if hasattr(submodule, f"lora_diff_{i}_0"):
delattr(submodule, f"lora_diff_{i}_0")
if hasattr(submodule, f"lora_diff_{i}_1"):
delattr(submodule, f"lora_diff_{i}_1")
if hasattr(submodule, f"lora_diff_{i}_2"):
delattr(submodule, f"lora_diff_{i}_2")