From f653ded7eda647a3f48b2d5eddff140f0ef2af4e Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 26 Jan 2023 22:26:11 +0200 Subject: [PATCH] [LoRA] Make sure LoRA can be disabled after it's run (#2128) --- src/diffusers/models/cross_attention.py | 14 ++++ tests/models/test_models_unet_2d_condition.py | 82 +++++++++++++------ 2 files changed, 71 insertions(+), 25 deletions(-) diff --git a/src/diffusers/models/cross_attention.py b/src/diffusers/models/cross_attention.py index bfa6fc3611..d834432ec4 100644 --- a/src/diffusers/models/cross_attention.py +++ b/src/diffusers/models/cross_attention.py @@ -17,9 +17,13 @@ import torch import torch.nn.functional as F from torch import nn +from ..utils import logging from ..utils.import_utils import is_xformers_available +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + if is_xformers_available(): import xformers import xformers.ops @@ -151,6 +155,16 @@ class CrossAttention(nn.Module): self.set_processor(processor) def set_processor(self, processor: "AttnProcessor"): + # if current processor is in `self._modules` and if passed `processor` is not, we need to + # pop `processor` from `self._modules` + if ( + hasattr(self, "processor") + and isinstance(self.processor, torch.nn.Module) + and not isinstance(processor, torch.nn.Module) + ): + logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}") + self._modules.pop("processor") + self.processor = processor def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, **cross_attention_kwargs): diff --git a/tests/models/test_models_unet_2d_condition.py b/tests/models/test_models_unet_2d_condition.py index f8f29e9f23..4fe3ab518c 100644 --- a/tests/models/test_models_unet_2d_condition.py +++ b/tests/models/test_models_unet_2d_condition.py @@ -20,7 +20,7 @@ import unittest import torch from diffusers import UNet2DConditionModel -from diffusers.models.cross_attention import LoRACrossAttnProcessor +from diffusers.models.cross_attention import CrossAttnProcessor, LoRACrossAttnProcessor from diffusers.utils import ( floats_tensor, load_hf_numpy, @@ -40,6 +40,34 @@ logger = logging.get_logger(__name__) torch.backends.cuda.matmul.allow_tf32 = False +def create_lora_layers(model): + lora_attn_procs = {} + for name in model.attn_processors.keys(): + cross_attention_dim = None if name.endswith("attn1.processor") else model.config.cross_attention_dim + if name.startswith("mid_block"): + hidden_size = model.config.block_out_channels[-1] + elif name.startswith("up_blocks"): + block_id = int(name[len("up_blocks.")]) + hidden_size = list(reversed(model.config.block_out_channels))[block_id] + elif name.startswith("down_blocks"): + block_id = int(name[len("down_blocks.")]) + hidden_size = model.config.block_out_channels[block_id] + + lora_attn_procs[name] = LoRACrossAttnProcessor( + hidden_size=hidden_size, cross_attention_dim=cross_attention_dim + ) + lora_attn_procs[name] = lora_attn_procs[name].to(model.device) + + # add 1 to weights to mock trained weights + with torch.no_grad(): + lora_attn_procs[name].to_q_lora.up.weight += 1 + lora_attn_procs[name].to_k_lora.up.weight += 1 + lora_attn_procs[name].to_v_lora.up.weight += 1 + lora_attn_procs[name].to_out_lora.up.weight += 1 + + return lora_attn_procs + + class UNet2DConditionModelTests(ModelTesterMixin, unittest.TestCase): model_class = UNet2DConditionModel @@ -336,30 +364,7 @@ class UNet2DConditionModelTests(ModelTesterMixin, unittest.TestCase): with torch.no_grad(): old_sample = model(**inputs_dict).sample - lora_attn_procs = {} - for name in model.attn_processors.keys(): - cross_attention_dim = None if name.endswith("attn1.processor") else model.config.cross_attention_dim - if name.startswith("mid_block"): - hidden_size = model.config.block_out_channels[-1] - elif name.startswith("up_blocks"): - block_id = int(name[len("up_blocks.")]) - hidden_size = list(reversed(model.config.block_out_channels))[block_id] - elif name.startswith("down_blocks"): - block_id = int(name[len("down_blocks.")]) - hidden_size = model.config.block_out_channels[block_id] - - lora_attn_procs[name] = LoRACrossAttnProcessor( - hidden_size=hidden_size, cross_attention_dim=cross_attention_dim - ) - lora_attn_procs[name] = lora_attn_procs[name].to(model.device) - - # add 1 to weights to mock trained weights - with torch.no_grad(): - lora_attn_procs[name].to_q_lora.up.weight += 1 - lora_attn_procs[name].to_k_lora.up.weight += 1 - lora_attn_procs[name].to_v_lora.up.weight += 1 - lora_attn_procs[name].to_out_lora.up.weight += 1 - + lora_attn_procs = create_lora_layers(model) model.set_attn_processor(lora_attn_procs) with torch.no_grad(): @@ -380,6 +385,33 @@ class UNet2DConditionModelTests(ModelTesterMixin, unittest.TestCase): # LoRA and no LoRA should NOT be the same assert (sample - old_sample).abs().max() > 1e-4 + def test_lora_on_off(self): + # enable deterministic behavior for gradient checkpointing + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + init_dict["attention_head_dim"] = (8, 16) + + torch.manual_seed(0) + model = self.model_class(**init_dict) + model.to(torch_device) + + with torch.no_grad(): + old_sample = model(**inputs_dict).sample + + lora_attn_procs = create_lora_layers(model) + model.set_attn_processor(lora_attn_procs) + + with torch.no_grad(): + sample = model(**inputs_dict, cross_attention_kwargs={"scale": 0.0}).sample + + model.set_attn_processor(CrossAttnProcessor()) + + with torch.no_grad(): + new_sample = model(**inputs_dict).sample + + assert (sample - new_sample).abs().max() < 1e-4 + assert (sample - old_sample).abs().max() < 1e-4 + @slow class UNet2DConditionModelIntegrationTests(unittest.TestCase):