From 7eaae83f168e0d7ebca1ef238cbc58c5c46d39d9 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Fri, 6 Oct 2023 17:14:47 +0200 Subject: [PATCH] [LoRA] fix: torch.compile() for lora conv (#5298) fix: torch.compile() for lora conv --- src/diffusers/models/lora.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/lora.py b/src/diffusers/models/lora.py index cc8e3e231e..a777bb93e1 100644 --- a/src/diffusers/models/lora.py +++ b/src/diffusers/models/lora.py @@ -164,7 +164,10 @@ class LoRACompatibleConv(nn.Conv2d): hidden_states, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups ) else: - return super().forward(hidden_states) + (scale * self.lora_layer(hidden_states)) + original_outputs = F.conv2d( + hidden_states, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups + ) + return original_outputs + (scale * self.lora_layer(hidden_states)) class LoRACompatibleLinear(nn.Linear):