mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Fix to apply LoRAXFormersAttnProcessor instead of LoRAAttnProcessor when xFormers is enabled (#3556)
* fix to use LoRAXFormersAttnProcessor * add test * using new LoraLoaderMixin.save_lora_weights * add test_lora_save_load_with_xformers
This commit is contained in:
@@ -27,7 +27,9 @@ from .models.attention_processor import (
|
||||
CustomDiffusionXFormersAttnProcessor,
|
||||
LoRAAttnAddedKVProcessor,
|
||||
LoRAAttnProcessor,
|
||||
LoRAXFormersAttnProcessor,
|
||||
SlicedAttnAddedKVProcessor,
|
||||
XFormersAttnProcessor,
|
||||
)
|
||||
from .utils import (
|
||||
DIFFUSERS_CACHE,
|
||||
@@ -279,7 +281,10 @@ class UNet2DConditionLoadersMixin:
|
||||
attn_processor_class = LoRAAttnAddedKVProcessor
|
||||
else:
|
||||
cross_attention_dim = value_dict["to_k_lora.down.weight"].shape[1]
|
||||
attn_processor_class = LoRAAttnProcessor
|
||||
if isinstance(attn_processor, (XFormersAttnProcessor, LoRAXFormersAttnProcessor)):
|
||||
attn_processor_class = LoRAXFormersAttnProcessor
|
||||
else:
|
||||
attn_processor_class = LoRAAttnProcessor
|
||||
|
||||
attn_processors[key] = attn_processor_class(
|
||||
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=rank
|
||||
|
||||
@@ -22,7 +22,14 @@ from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel
|
||||
from diffusers.loaders import AttnProcsLayers, LoraLoaderMixin
|
||||
from diffusers.models.attention_processor import LoRAAttnProcessor
|
||||
from diffusers.models.attention_processor import (
|
||||
Attention,
|
||||
AttnProcessor,
|
||||
AttnProcessor2_0,
|
||||
LoRAAttnProcessor,
|
||||
LoRAXFormersAttnProcessor,
|
||||
XFormersAttnProcessor,
|
||||
)
|
||||
from diffusers.utils import TEXT_ENCODER_TARGET_MODULES, floats_tensor, torch_device
|
||||
|
||||
|
||||
@@ -212,3 +219,90 @@ class LoraLoaderMixinTests(unittest.TestCase):
|
||||
|
||||
# Outputs shouldn't match.
|
||||
self.assertFalse(torch.allclose(torch.from_numpy(orig_image_slice), torch.from_numpy(lora_image_slice)))
|
||||
|
||||
def create_lora_weight_file(self, tmpdirname):
|
||||
_, lora_components = self.get_dummy_components()
|
||||
LoraLoaderMixin.save_lora_weights(
|
||||
save_directory=tmpdirname,
|
||||
unet_lora_layers=lora_components["unet_lora_layers"],
|
||||
text_encoder_lora_layers=lora_components["text_encoder_lora_layers"],
|
||||
)
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
|
||||
|
||||
def test_lora_unet_attn_processors(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
self.create_lora_weight_file(tmpdirname)
|
||||
|
||||
pipeline_components, _ = self.get_dummy_components()
|
||||
sd_pipe = StableDiffusionPipeline(**pipeline_components)
|
||||
sd_pipe = sd_pipe.to(torch_device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
# check if vanilla attention processors are used
|
||||
for _, module in sd_pipe.unet.named_modules():
|
||||
if isinstance(module, Attention):
|
||||
self.assertIsInstance(module.processor, (AttnProcessor, AttnProcessor2_0))
|
||||
|
||||
# load LoRA weight file
|
||||
sd_pipe.load_lora_weights(tmpdirname)
|
||||
|
||||
# check if lora attention processors are used
|
||||
for _, module in sd_pipe.unet.named_modules():
|
||||
if isinstance(module, Attention):
|
||||
self.assertIsInstance(module.processor, LoRAAttnProcessor)
|
||||
|
||||
@unittest.skipIf(torch_device != "cuda", "This test is supposed to run on GPU")
|
||||
def test_lora_unet_attn_processors_with_xformers(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
self.create_lora_weight_file(tmpdirname)
|
||||
|
||||
pipeline_components, _ = self.get_dummy_components()
|
||||
sd_pipe = StableDiffusionPipeline(**pipeline_components)
|
||||
sd_pipe = sd_pipe.to(torch_device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
# enable XFormers
|
||||
sd_pipe.enable_xformers_memory_efficient_attention()
|
||||
|
||||
# check if xFormers attention processors are used
|
||||
for _, module in sd_pipe.unet.named_modules():
|
||||
if isinstance(module, Attention):
|
||||
self.assertIsInstance(module.processor, XFormersAttnProcessor)
|
||||
|
||||
# load LoRA weight file
|
||||
sd_pipe.load_lora_weights(tmpdirname)
|
||||
|
||||
# check if lora attention processors are used
|
||||
for _, module in sd_pipe.unet.named_modules():
|
||||
if isinstance(module, Attention):
|
||||
self.assertIsInstance(module.processor, LoRAXFormersAttnProcessor)
|
||||
|
||||
@unittest.skipIf(torch_device != "cuda", "This test is supposed to run on GPU")
|
||||
def test_lora_save_load_with_xformers(self):
|
||||
pipeline_components, lora_components = self.get_dummy_components()
|
||||
sd_pipe = StableDiffusionPipeline(**pipeline_components)
|
||||
sd_pipe = sd_pipe.to(torch_device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
noise, input_ids, pipeline_inputs = self.get_dummy_inputs()
|
||||
|
||||
# enable XFormers
|
||||
sd_pipe.enable_xformers_memory_efficient_attention()
|
||||
|
||||
original_images = sd_pipe(**pipeline_inputs).images
|
||||
orig_image_slice = original_images[0, -3:, -3:, -1]
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
LoraLoaderMixin.save_lora_weights(
|
||||
save_directory=tmpdirname,
|
||||
unet_lora_layers=lora_components["unet_lora_layers"],
|
||||
text_encoder_lora_layers=lora_components["text_encoder_lora_layers"],
|
||||
)
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
|
||||
sd_pipe.load_lora_weights(tmpdirname)
|
||||
|
||||
lora_images = sd_pipe(**pipeline_inputs).images
|
||||
lora_image_slice = lora_images[0, -3:, -3:, -1]
|
||||
|
||||
# Outputs shouldn't match.
|
||||
self.assertFalse(torch.allclose(torch.from_numpy(orig_image_slice), torch.from_numpy(lora_image_slice)))
|
||||
|
||||
Reference in New Issue
Block a user