1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[Torch.compile] Fixes torch compile graph break (#4315)

* fix torch compile

* Fix all

* make style
This commit is contained in:
Patrick von Platen
2023-07-27 13:53:36 +02:00
committed by GitHub
parent 80c10d8245
commit d8bc1a4e51

View File

@@ -14,6 +14,7 @@
from typing import Optional
import torch.nn.functional as F
from torch import nn
@@ -91,7 +92,9 @@ class LoRACompatibleConv(nn.Conv2d):
def forward(self, x):
if self.lora_layer is None:
return super().forward(x)
# make sure to the functional Conv2D function as otherwise torch.compile's graph will break
# see: https://github.com/huggingface/diffusers/pull/4315
return F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
else:
return super().forward(x) + self.lora_layer(x)