diff --git a/custom_linear.py b/custom_linear.py index ff3e7fc..7234c4c 100644 --- a/custom_linear.py +++ b/custom_linear.py @@ -101,7 +101,7 @@ class CustomLinear(nn.Linear): return torch.nn.functional.linear(input, weight, bias) - @torch.compiler.disable() + #@torch.compiler.disable() def apply_lora(self, weight): for lora_diff, lora_strength in zip(self.lora[0], self.lora[1]): if isinstance(lora_strength, list): @@ -113,7 +113,7 @@ class CustomLinear(nn.Linear): patch_diff = torch.mm( lora_diff[0].flatten(start_dim=1).to(weight.device), lora_diff[1].flatten(start_dim=1).to(weight.device) - ).reshape(weight.shape) + ).reshape(weight.shape) + 0 alpha = lora_diff[2] / lora_diff[1].shape[0] if lora_diff[2] is not None else 1.0 scale = lora_strength * alpha weight = weight.add(patch_diff, alpha=scale)