diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 8ce5989b5f..be40ae586d 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -26,6 +26,7 @@ from huggingface_hub import hf_hub_download from torch import nn from .models.attention_processor import ( + LORA_ATTENTION_PROCESSORS, AttnAddedKVProcessor, AttnAddedKVProcessor2_0, AttnProcessor, @@ -1293,22 +1294,21 @@ class LoraLoaderMixin: >>> ... ``` """ - is_unet_lora = all( - isinstance(processor, (LoRAAttnProcessor2_0, LoRAAttnProcessor, LoRAAttnAddedKVProcessor)) - for _, processor in self.unet.attn_processors.items() - ) - # Handle attention processors that are a mix of regular attention and AddedKV - # attention. - if is_unet_lora: - is_attn_procs_mixed = all( - isinstance(processor, (LoRAAttnProcessor2_0, LoRAAttnProcessor)) - for _, processor in self.unet.attn_processors.items() - ) - if not is_attn_procs_mixed: - unet_attn_proc_cls = AttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else AttnProcessor - self.unet.set_attn_processor(unet_attn_proc_cls()) - else: + unet_attention_classes = {type(processor) for _, processor in self.unet.attn_processors.items()} + + if unet_attention_classes.issubset(LORA_ATTENTION_PROCESSORS): + # Handle attention processors that are a mix of regular attention and AddedKV + # attention. + if len(unet_attention_classes) > 1 or LoRAAttnAddedKVProcessor in unet_attention_classes: self.unet.set_default_attn_processor() + else: + regular_attention_classes = { + LoRAAttnProcessor: AttnProcessor, + LoRAAttnProcessor2_0: AttnProcessor2_0, + LoRAXFormersAttnProcessor: XFormersAttnProcessor, + } + [attention_proc_class] = unet_attention_classes + self.unet.set_attn_processor(regular_attention_classes[attention_proc_class]()) # Safe to call the following regardless of LoRA. self._remove_text_encoder_monkey_patch() diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 5206ae7a4b..69889337ed 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -167,7 +167,7 @@ class Attention(nn.Module): ): is_lora = hasattr(self, "processor") and isinstance( self.processor, - (LoRAAttnProcessor, LoRAAttnProcessor2_0, LoRAXFormersAttnProcessor, LoRAAttnAddedKVProcessor), + LORA_ATTENTION_PROCESSORS, ) is_custom_diffusion = hasattr(self, "processor") and isinstance( self.processor, (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor) @@ -1623,6 +1623,13 @@ AttentionProcessor = Union[ CustomDiffusionXFormersAttnProcessor, ] +LORA_ATTENTION_PROCESSORS = ( + LoRAAttnProcessor, + LoRAAttnProcessor2_0, + LoRAXFormersAttnProcessor, + LoRAAttnAddedKVProcessor, +) + class SpatialNorm(nn.Module): """ diff --git a/tests/models/test_lora_layers.py b/tests/models/test_lora_layers.py index 1396561367..58cc5620c6 100644 --- a/tests/models/test_lora_layers.py +++ b/tests/models/test_lora_layers.py @@ -464,6 +464,14 @@ class LoraLoaderMixinTests(unittest.TestCase): if isinstance(module, Attention): self.assertIsInstance(module.processor, LoRAXFormersAttnProcessor) + # unload lora weights + sd_pipe.unload_lora_weights() + + # check if attention processors are reverted back to xFormers + for _, module in sd_pipe.unet.named_modules(): + if isinstance(module, Attention): + self.assertIsInstance(module.processor, XFormersAttnProcessor) + @unittest.skipIf(torch_device != "cuda", "This test is supposed to run on GPU") def test_lora_save_load_with_xformers(self): pipeline_components, lora_components = self.get_dummy_components()