mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Make LoRACompatibleConv padding_mode work. (#6031)
* Make LoRACompatibleConv padding_mode work. * Format code style. * add fast test * Update src/diffusers/models/lora.py Simplify the code by patrickvonplaten. Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * code refactor * apply patrickvonplaten suggestion to simplify the code. * rm test_lora_layers_old_backend.py and add test case in test_lora_layers_peft.py * update test case. --------- Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: YiYi Xu <yixu310@gmail.com>
This commit is contained in:
@@ -361,16 +361,19 @@ class LoRACompatibleConv(nn.Conv2d):
|
||||
self.w_down = None
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> torch.Tensor:
|
||||
if self.lora_layer is None:
|
||||
# make sure to the functional Conv2D function as otherwise torch.compile's graph will break
|
||||
# see: https://github.com/huggingface/diffusers/pull/4315
|
||||
return F.conv2d(
|
||||
hidden_states, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups
|
||||
)
|
||||
if self.padding_mode != "zeros":
|
||||
hidden_states = F.pad(hidden_states, self._reversed_padding_repeated_twice, mode=self.padding_mode)
|
||||
padding = (0, 0)
|
||||
else:
|
||||
padding = self.padding
|
||||
|
||||
original_outputs = F.conv2d(
|
||||
hidden_states, self.weight, self.bias, self.stride, padding, self.dilation, self.groups
|
||||
)
|
||||
|
||||
if self.lora_layer is None:
|
||||
return original_outputs
|
||||
else:
|
||||
original_outputs = F.conv2d(
|
||||
hidden_states, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups
|
||||
)
|
||||
return original_outputs + (scale * self.lora_layer(hidden_states))
|
||||
|
||||
|
||||
|
||||
@@ -1177,6 +1177,24 @@ class PeftLoraLoaderMixinTests:
|
||||
# Just makes sure it works..
|
||||
_ = pipe(**inputs, generator=torch.manual_seed(0)).images
|
||||
|
||||
def test_modify_padding_mode(self):
|
||||
def set_pad_mode(network, mode="circular"):
|
||||
for _, module in network.named_modules():
|
||||
if isinstance(module, torch.nn.Conv2d):
|
||||
module.padding_mode = mode
|
||||
|
||||
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
|
||||
components, _, _ = self.get_dummy_components(scheduler_cls)
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(self.torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
_pad_mode = "circular"
|
||||
set_pad_mode(pipe.vae, _pad_mode)
|
||||
set_pad_mode(pipe.unet, _pad_mode)
|
||||
|
||||
_, _, inputs = self.get_dummy_inputs()
|
||||
_ = pipe(**inputs).images
|
||||
|
||||
|
||||
class StableDiffusionLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
|
||||
pipeline_class = StableDiffusionPipeline
|
||||
|
||||
Reference in New Issue
Block a user