1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Make LoRACompatibleConv padding_mode work. (#6031)

* Make LoRACompatibleConv padding_mode work.

* Format code style.

* add fast test

* Update src/diffusers/models/lora.py

Simplify the code by patrickvonplaten.

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* code refactor

* apply patrickvonplaten suggestion to simplify the code.

* rm test_lora_layers_old_backend.py and add test case in test_lora_layers_peft.py

* update test case.

---------

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: YiYi Xu <yixu310@gmail.com>
This commit is contained in:
jinghuan-Chen
2024-02-27 08:05:13 +08:00
committed by GitHub
parent ad310af0d6
commit 88aa7f6ebf
2 changed files with 30 additions and 9 deletions

View File

@@ -361,16 +361,19 @@ class LoRACompatibleConv(nn.Conv2d):
self.w_down = None
def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> torch.Tensor:
if self.lora_layer is None:
# make sure to the functional Conv2D function as otherwise torch.compile's graph will break
# see: https://github.com/huggingface/diffusers/pull/4315
return F.conv2d(
hidden_states, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups
)
if self.padding_mode != "zeros":
hidden_states = F.pad(hidden_states, self._reversed_padding_repeated_twice, mode=self.padding_mode)
padding = (0, 0)
else:
padding = self.padding
original_outputs = F.conv2d(
hidden_states, self.weight, self.bias, self.stride, padding, self.dilation, self.groups
)
if self.lora_layer is None:
return original_outputs
else:
original_outputs = F.conv2d(
hidden_states, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups
)
return original_outputs + (scale * self.lora_layer(hidden_states))

View File

@@ -1177,6 +1177,24 @@ class PeftLoraLoaderMixinTests:
# Just makes sure it works..
_ = pipe(**inputs, generator=torch.manual_seed(0)).images
def test_modify_padding_mode(self):
def set_pad_mode(network, mode="circular"):
for _, module in network.named_modules():
if isinstance(module, torch.nn.Conv2d):
module.padding_mode = mode
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
components, _, _ = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(self.torch_device)
pipe.set_progress_bar_config(disable=None)
_pad_mode = "circular"
set_pad_mode(pipe.vae, _pad_mode)
set_pad_mode(pipe.unet, _pad_mode)
_, _, inputs = self.get_dummy_inputs()
_ = pipe(**inputs).images
class StableDiffusionLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
pipeline_class = StableDiffusionPipeline