diff --git a/src/diffusers/models/lora.py b/src/diffusers/models/lora.py index bb83897457..7bc573bf72 100644 --- a/src/diffusers/models/lora.py +++ b/src/diffusers/models/lora.py @@ -14,6 +14,7 @@ from typing import Optional +import torch.nn.functional as F from torch import nn @@ -91,7 +92,9 @@ class LoRACompatibleConv(nn.Conv2d): def forward(self, x): if self.lora_layer is None: - return super().forward(x) + # make sure to the functional Conv2D function as otherwise torch.compile's graph will break + # see: https://github.com/huggingface/diffusers/pull/4315 + return F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) else: return super().forward(x) + self.lora_layer(x)