mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Fix unloading of LoRAs when xformers attention procs are in use (#4179)
This commit is contained in:
@@ -26,6 +26,7 @@ from huggingface_hub import hf_hub_download
|
||||
from torch import nn
|
||||
|
||||
from .models.attention_processor import (
|
||||
LORA_ATTENTION_PROCESSORS,
|
||||
AttnAddedKVProcessor,
|
||||
AttnAddedKVProcessor2_0,
|
||||
AttnProcessor,
|
||||
@@ -1293,22 +1294,21 @@ class LoraLoaderMixin:
|
||||
>>> ...
|
||||
```
|
||||
"""
|
||||
is_unet_lora = all(
|
||||
isinstance(processor, (LoRAAttnProcessor2_0, LoRAAttnProcessor, LoRAAttnAddedKVProcessor))
|
||||
for _, processor in self.unet.attn_processors.items()
|
||||
)
|
||||
# Handle attention processors that are a mix of regular attention and AddedKV
|
||||
# attention.
|
||||
if is_unet_lora:
|
||||
is_attn_procs_mixed = all(
|
||||
isinstance(processor, (LoRAAttnProcessor2_0, LoRAAttnProcessor))
|
||||
for _, processor in self.unet.attn_processors.items()
|
||||
)
|
||||
if not is_attn_procs_mixed:
|
||||
unet_attn_proc_cls = AttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else AttnProcessor
|
||||
self.unet.set_attn_processor(unet_attn_proc_cls())
|
||||
else:
|
||||
unet_attention_classes = {type(processor) for _, processor in self.unet.attn_processors.items()}
|
||||
|
||||
if unet_attention_classes.issubset(LORA_ATTENTION_PROCESSORS):
|
||||
# Handle attention processors that are a mix of regular attention and AddedKV
|
||||
# attention.
|
||||
if len(unet_attention_classes) > 1 or LoRAAttnAddedKVProcessor in unet_attention_classes:
|
||||
self.unet.set_default_attn_processor()
|
||||
else:
|
||||
regular_attention_classes = {
|
||||
LoRAAttnProcessor: AttnProcessor,
|
||||
LoRAAttnProcessor2_0: AttnProcessor2_0,
|
||||
LoRAXFormersAttnProcessor: XFormersAttnProcessor,
|
||||
}
|
||||
[attention_proc_class] = unet_attention_classes
|
||||
self.unet.set_attn_processor(regular_attention_classes[attention_proc_class]())
|
||||
|
||||
# Safe to call the following regardless of LoRA.
|
||||
self._remove_text_encoder_monkey_patch()
|
||||
|
||||
@@ -167,7 +167,7 @@ class Attention(nn.Module):
|
||||
):
|
||||
is_lora = hasattr(self, "processor") and isinstance(
|
||||
self.processor,
|
||||
(LoRAAttnProcessor, LoRAAttnProcessor2_0, LoRAXFormersAttnProcessor, LoRAAttnAddedKVProcessor),
|
||||
LORA_ATTENTION_PROCESSORS,
|
||||
)
|
||||
is_custom_diffusion = hasattr(self, "processor") and isinstance(
|
||||
self.processor, (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor)
|
||||
@@ -1623,6 +1623,13 @@ AttentionProcessor = Union[
|
||||
CustomDiffusionXFormersAttnProcessor,
|
||||
]
|
||||
|
||||
LORA_ATTENTION_PROCESSORS = (
|
||||
LoRAAttnProcessor,
|
||||
LoRAAttnProcessor2_0,
|
||||
LoRAXFormersAttnProcessor,
|
||||
LoRAAttnAddedKVProcessor,
|
||||
)
|
||||
|
||||
|
||||
class SpatialNorm(nn.Module):
|
||||
"""
|
||||
|
||||
@@ -464,6 +464,14 @@ class LoraLoaderMixinTests(unittest.TestCase):
|
||||
if isinstance(module, Attention):
|
||||
self.assertIsInstance(module.processor, LoRAXFormersAttnProcessor)
|
||||
|
||||
# unload lora weights
|
||||
sd_pipe.unload_lora_weights()
|
||||
|
||||
# check if attention processors are reverted back to xFormers
|
||||
for _, module in sd_pipe.unet.named_modules():
|
||||
if isinstance(module, Attention):
|
||||
self.assertIsInstance(module.processor, XFormersAttnProcessor)
|
||||
|
||||
@unittest.skipIf(torch_device != "cuda", "This test is supposed to run on GPU")
|
||||
def test_lora_save_load_with_xformers(self):
|
||||
pipeline_components, lora_components = self.get_dummy_components()
|
||||
|
||||
Reference in New Issue
Block a user