1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

[LoRA] Make sure LoRA can be disabled after it's run (#2128)

This commit is contained in:
Patrick von Platen
2023-01-26 22:26:11 +02:00
committed by GitHub
parent e92d43feb0
commit f653ded7ed
2 changed files with 71 additions and 25 deletions

View File

@@ -20,7 +20,7 @@ import unittest
import torch
from diffusers import UNet2DConditionModel
from diffusers.models.cross_attention import LoRACrossAttnProcessor
from diffusers.models.cross_attention import CrossAttnProcessor, LoRACrossAttnProcessor
from diffusers.utils import (
floats_tensor,
load_hf_numpy,
@@ -40,6 +40,34 @@ logger = logging.get_logger(__name__)
torch.backends.cuda.matmul.allow_tf32 = False
def create_lora_layers(model):
lora_attn_procs = {}
for name in model.attn_processors.keys():
cross_attention_dim = None if name.endswith("attn1.processor") else model.config.cross_attention_dim
if name.startswith("mid_block"):
hidden_size = model.config.block_out_channels[-1]
elif name.startswith("up_blocks"):
block_id = int(name[len("up_blocks.")])
hidden_size = list(reversed(model.config.block_out_channels))[block_id]
elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")])
hidden_size = model.config.block_out_channels[block_id]
lora_attn_procs[name] = LoRACrossAttnProcessor(
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim
)
lora_attn_procs[name] = lora_attn_procs[name].to(model.device)
# add 1 to weights to mock trained weights
with torch.no_grad():
lora_attn_procs[name].to_q_lora.up.weight += 1
lora_attn_procs[name].to_k_lora.up.weight += 1
lora_attn_procs[name].to_v_lora.up.weight += 1
lora_attn_procs[name].to_out_lora.up.weight += 1
return lora_attn_procs
class UNet2DConditionModelTests(ModelTesterMixin, unittest.TestCase):
model_class = UNet2DConditionModel
@@ -336,30 +364,7 @@ class UNet2DConditionModelTests(ModelTesterMixin, unittest.TestCase):
with torch.no_grad():
old_sample = model(**inputs_dict).sample
lora_attn_procs = {}
for name in model.attn_processors.keys():
cross_attention_dim = None if name.endswith("attn1.processor") else model.config.cross_attention_dim
if name.startswith("mid_block"):
hidden_size = model.config.block_out_channels[-1]
elif name.startswith("up_blocks"):
block_id = int(name[len("up_blocks.")])
hidden_size = list(reversed(model.config.block_out_channels))[block_id]
elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")])
hidden_size = model.config.block_out_channels[block_id]
lora_attn_procs[name] = LoRACrossAttnProcessor(
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim
)
lora_attn_procs[name] = lora_attn_procs[name].to(model.device)
# add 1 to weights to mock trained weights
with torch.no_grad():
lora_attn_procs[name].to_q_lora.up.weight += 1
lora_attn_procs[name].to_k_lora.up.weight += 1
lora_attn_procs[name].to_v_lora.up.weight += 1
lora_attn_procs[name].to_out_lora.up.weight += 1
lora_attn_procs = create_lora_layers(model)
model.set_attn_processor(lora_attn_procs)
with torch.no_grad():
@@ -380,6 +385,33 @@ class UNet2DConditionModelTests(ModelTesterMixin, unittest.TestCase):
# LoRA and no LoRA should NOT be the same
assert (sample - old_sample).abs().max() > 1e-4
def test_lora_on_off(self):
# enable deterministic behavior for gradient checkpointing
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["attention_head_dim"] = (8, 16)
torch.manual_seed(0)
model = self.model_class(**init_dict)
model.to(torch_device)
with torch.no_grad():
old_sample = model(**inputs_dict).sample
lora_attn_procs = create_lora_layers(model)
model.set_attn_processor(lora_attn_procs)
with torch.no_grad():
sample = model(**inputs_dict, cross_attention_kwargs={"scale": 0.0}).sample
model.set_attn_processor(CrossAttnProcessor())
with torch.no_grad():
new_sample = model(**inputs_dict).sample
assert (sample - new_sample).abs().max() < 1e-4
assert (sample - old_sample).abs().max() < 1e-4
@slow
class UNet2DConditionModelIntegrationTests(unittest.TestCase):