mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[Torch.compile] Fixes torch compile graph break (#4315)
* fix torch compile * Fix all * make style
This commit is contained in:
committed by
GitHub
parent
80c10d8245
commit
d8bc1a4e51
@@ -14,6 +14,7 @@
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
|
||||
@@ -91,7 +92,9 @@ class LoRACompatibleConv(nn.Conv2d):
|
||||
|
||||
def forward(self, x):
|
||||
if self.lora_layer is None:
|
||||
return super().forward(x)
|
||||
# make sure to the functional Conv2D function as otherwise torch.compile's graph will break
|
||||
# see: https://github.com/huggingface/diffusers/pull/4315
|
||||
return F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
|
||||
else:
|
||||
return super().forward(x) + self.lora_layer(x)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user