1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Fix running LoRA with xformers (#2286)

* Fix running LoRA with xformers

* support disabling xformers

* reformat

* Add test
This commit is contained in:
bddppq
2023-02-13 02:58:18 -08:00
committed by GitHub
parent f2eae16849
commit 5d4f59ee96
2 changed files with 69 additions and 4 deletions

View File

@@ -105,6 +105,10 @@ class CrossAttention(nn.Module):
def set_use_memory_efficient_attention_xformers(
self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
):
is_lora = hasattr(self, "processor") and isinstance(
self.processor, (LoRACrossAttnProcessor, LoRAXFormersCrossAttnProcessor)
)
if use_memory_efficient_attention_xformers:
if self.added_kv_proj_dim is not None:
# TODO(Anton, Patrick, Suraj, William) - currently xformers doesn't work for UnCLIP
@@ -138,9 +142,28 @@ class CrossAttention(nn.Module):
except Exception as e:
raise e
processor = XFormersCrossAttnProcessor(attention_op=attention_op)
if is_lora:
processor = LoRAXFormersCrossAttnProcessor(
hidden_size=self.processor.hidden_size,
cross_attention_dim=self.processor.cross_attention_dim,
rank=self.processor.rank,
attention_op=attention_op,
)
processor.load_state_dict(self.processor.state_dict())
processor.to(self.processor.to_q_lora.up.weight.device)
else:
processor = XFormersCrossAttnProcessor(attention_op=attention_op)
else:
processor = CrossAttnProcessor()
if is_lora:
processor = LoRACrossAttnProcessor(
hidden_size=self.processor.hidden_size,
cross_attention_dim=self.processor.cross_attention_dim,
rank=self.processor.rank,
)
processor.load_state_dict(self.processor.state_dict())
processor.to(self.processor.to_q_lora.up.weight.device)
else:
processor = CrossAttnProcessor()
self.set_processor(processor)
@@ -324,6 +347,10 @@ class LoRACrossAttnProcessor(nn.Module):
def __init__(self, hidden_size, cross_attention_dim=None, rank=4):
super().__init__()
self.hidden_size = hidden_size
self.cross_attention_dim = cross_attention_dim
self.rank = rank
self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank)
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank)
self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank)
@@ -437,9 +464,14 @@ class XFormersCrossAttnProcessor:
class LoRAXFormersCrossAttnProcessor(nn.Module):
def __init__(self, hidden_size, cross_attention_dim, rank=4):
def __init__(self, hidden_size, cross_attention_dim, rank=4, attention_op: Optional[Callable] = None):
super().__init__()
self.hidden_size = hidden_size
self.cross_attention_dim = cross_attention_dim
self.rank = rank
self.attention_op = attention_op
self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank)
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank)
self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank)
@@ -462,7 +494,9 @@ class LoRAXFormersCrossAttnProcessor(nn.Module):
key = attn.head_to_batch_dim(key).contiguous()
value = attn.head_to_batch_dim(value).contiguous()
hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
hidden_states = xformers.ops.memory_efficient_attention(
query, key, value, attn_bias=attention_mask, op=self.attention_op
)
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
@@ -595,4 +629,6 @@ AttnProcessor = Union[
SlicedAttnProcessor,
CrossAttnAddedKVProcessor,
SlicedAttnAddedKVProcessor,
LoRACrossAttnProcessor,
LoRAXFormersCrossAttnProcessor,
]

View File

@@ -412,6 +412,35 @@ class UNet2DConditionModelTests(ModelTesterMixin, unittest.TestCase):
assert (sample - new_sample).abs().max() < 1e-4
assert (sample - old_sample).abs().max() < 1e-4
@unittest.skipIf(
torch_device != "cuda" or not is_xformers_available(),
reason="XFormers attention is only available with CUDA and `xformers` installed",
)
def test_lora_xformers_on_off(self):
# enable deterministic behavior for gradient checkpointing
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["attention_head_dim"] = (8, 16)
torch.manual_seed(0)
model = self.model_class(**init_dict)
model.to(torch_device)
lora_attn_procs = create_lora_layers(model)
model.set_attn_processor(lora_attn_procs)
# default
with torch.no_grad():
sample = model(**inputs_dict).sample
model.enable_xformers_memory_efficient_attention()
on_sample = model(**inputs_dict).sample
model.disable_xformers_memory_efficient_attention()
off_sample = model(**inputs_dict).sample
assert (sample - on_sample).abs().max() < 1e-4
assert (sample - off_sample).abs().max() < 1e-4
@slow
class UNet2DConditionModelIntegrationTests(unittest.TestCase):