1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Fix unloading of LoRAs when xformers attention procs are in use (#4179)

This commit is contained in:
Batuhan Taskaya
2023-07-21 11:59:20 +03:00
committed by GitHub
parent 7a47df22a5
commit ad787082e2
3 changed files with 31 additions and 16 deletions

View File

@@ -26,6 +26,7 @@ from huggingface_hub import hf_hub_download
from torch import nn
from .models.attention_processor import (
LORA_ATTENTION_PROCESSORS,
AttnAddedKVProcessor,
AttnAddedKVProcessor2_0,
AttnProcessor,
@@ -1293,22 +1294,21 @@ class LoraLoaderMixin:
>>> ...
```
"""
is_unet_lora = all(
isinstance(processor, (LoRAAttnProcessor2_0, LoRAAttnProcessor, LoRAAttnAddedKVProcessor))
for _, processor in self.unet.attn_processors.items()
)
# Handle attention processors that are a mix of regular attention and AddedKV
# attention.
if is_unet_lora:
is_attn_procs_mixed = all(
isinstance(processor, (LoRAAttnProcessor2_0, LoRAAttnProcessor))
for _, processor in self.unet.attn_processors.items()
)
if not is_attn_procs_mixed:
unet_attn_proc_cls = AttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else AttnProcessor
self.unet.set_attn_processor(unet_attn_proc_cls())
else:
unet_attention_classes = {type(processor) for _, processor in self.unet.attn_processors.items()}
if unet_attention_classes.issubset(LORA_ATTENTION_PROCESSORS):
# Handle attention processors that are a mix of regular attention and AddedKV
# attention.
if len(unet_attention_classes) > 1 or LoRAAttnAddedKVProcessor in unet_attention_classes:
self.unet.set_default_attn_processor()
else:
regular_attention_classes = {
LoRAAttnProcessor: AttnProcessor,
LoRAAttnProcessor2_0: AttnProcessor2_0,
LoRAXFormersAttnProcessor: XFormersAttnProcessor,
}
[attention_proc_class] = unet_attention_classes
self.unet.set_attn_processor(regular_attention_classes[attention_proc_class]())
# Safe to call the following regardless of LoRA.
self._remove_text_encoder_monkey_patch()

View File

@@ -167,7 +167,7 @@ class Attention(nn.Module):
):
is_lora = hasattr(self, "processor") and isinstance(
self.processor,
(LoRAAttnProcessor, LoRAAttnProcessor2_0, LoRAXFormersAttnProcessor, LoRAAttnAddedKVProcessor),
LORA_ATTENTION_PROCESSORS,
)
is_custom_diffusion = hasattr(self, "processor") and isinstance(
self.processor, (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor)
@@ -1623,6 +1623,13 @@ AttentionProcessor = Union[
CustomDiffusionXFormersAttnProcessor,
]
LORA_ATTENTION_PROCESSORS = (
LoRAAttnProcessor,
LoRAAttnProcessor2_0,
LoRAXFormersAttnProcessor,
LoRAAttnAddedKVProcessor,
)
class SpatialNorm(nn.Module):
"""

View File

@@ -464,6 +464,14 @@ class LoraLoaderMixinTests(unittest.TestCase):
if isinstance(module, Attention):
self.assertIsInstance(module.processor, LoRAXFormersAttnProcessor)
# unload lora weights
sd_pipe.unload_lora_weights()
# check if attention processors are reverted back to xFormers
for _, module in sd_pipe.unet.named_modules():
if isinstance(module, Attention):
self.assertIsInstance(module.processor, XFormersAttnProcessor)
@unittest.skipIf(torch_device != "cuda", "This test is supposed to run on GPU")
def test_lora_save_load_with_xformers(self):
pipeline_components, lora_components = self.get_dummy_components()