mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Fix running LoRA with xformers (#2286)
* Fix running LoRA with xformers * support disabling xformers * reformat * Add test
This commit is contained in:
@@ -105,6 +105,10 @@ class CrossAttention(nn.Module):
|
||||
def set_use_memory_efficient_attention_xformers(
|
||||
self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
|
||||
):
|
||||
is_lora = hasattr(self, "processor") and isinstance(
|
||||
self.processor, (LoRACrossAttnProcessor, LoRAXFormersCrossAttnProcessor)
|
||||
)
|
||||
|
||||
if use_memory_efficient_attention_xformers:
|
||||
if self.added_kv_proj_dim is not None:
|
||||
# TODO(Anton, Patrick, Suraj, William) - currently xformers doesn't work for UnCLIP
|
||||
@@ -138,9 +142,28 @@ class CrossAttention(nn.Module):
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
processor = XFormersCrossAttnProcessor(attention_op=attention_op)
|
||||
if is_lora:
|
||||
processor = LoRAXFormersCrossAttnProcessor(
|
||||
hidden_size=self.processor.hidden_size,
|
||||
cross_attention_dim=self.processor.cross_attention_dim,
|
||||
rank=self.processor.rank,
|
||||
attention_op=attention_op,
|
||||
)
|
||||
processor.load_state_dict(self.processor.state_dict())
|
||||
processor.to(self.processor.to_q_lora.up.weight.device)
|
||||
else:
|
||||
processor = XFormersCrossAttnProcessor(attention_op=attention_op)
|
||||
else:
|
||||
processor = CrossAttnProcessor()
|
||||
if is_lora:
|
||||
processor = LoRACrossAttnProcessor(
|
||||
hidden_size=self.processor.hidden_size,
|
||||
cross_attention_dim=self.processor.cross_attention_dim,
|
||||
rank=self.processor.rank,
|
||||
)
|
||||
processor.load_state_dict(self.processor.state_dict())
|
||||
processor.to(self.processor.to_q_lora.up.weight.device)
|
||||
else:
|
||||
processor = CrossAttnProcessor()
|
||||
|
||||
self.set_processor(processor)
|
||||
|
||||
@@ -324,6 +347,10 @@ class LoRACrossAttnProcessor(nn.Module):
|
||||
def __init__(self, hidden_size, cross_attention_dim=None, rank=4):
|
||||
super().__init__()
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.cross_attention_dim = cross_attention_dim
|
||||
self.rank = rank
|
||||
|
||||
self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank)
|
||||
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank)
|
||||
self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank)
|
||||
@@ -437,9 +464,14 @@ class XFormersCrossAttnProcessor:
|
||||
|
||||
|
||||
class LoRAXFormersCrossAttnProcessor(nn.Module):
|
||||
def __init__(self, hidden_size, cross_attention_dim, rank=4):
|
||||
def __init__(self, hidden_size, cross_attention_dim, rank=4, attention_op: Optional[Callable] = None):
|
||||
super().__init__()
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.cross_attention_dim = cross_attention_dim
|
||||
self.rank = rank
|
||||
self.attention_op = attention_op
|
||||
|
||||
self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank)
|
||||
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank)
|
||||
self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank)
|
||||
@@ -462,7 +494,9 @@ class LoRAXFormersCrossAttnProcessor(nn.Module):
|
||||
key = attn.head_to_batch_dim(key).contiguous()
|
||||
value = attn.head_to_batch_dim(value).contiguous()
|
||||
|
||||
hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
|
||||
hidden_states = xformers.ops.memory_efficient_attention(
|
||||
query, key, value, attn_bias=attention_mask, op=self.attention_op
|
||||
)
|
||||
hidden_states = attn.batch_to_head_dim(hidden_states)
|
||||
|
||||
# linear proj
|
||||
@@ -595,4 +629,6 @@ AttnProcessor = Union[
|
||||
SlicedAttnProcessor,
|
||||
CrossAttnAddedKVProcessor,
|
||||
SlicedAttnAddedKVProcessor,
|
||||
LoRACrossAttnProcessor,
|
||||
LoRAXFormersCrossAttnProcessor,
|
||||
]
|
||||
|
||||
@@ -412,6 +412,35 @@ class UNet2DConditionModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
assert (sample - new_sample).abs().max() < 1e-4
|
||||
assert (sample - old_sample).abs().max() < 1e-4
|
||||
|
||||
@unittest.skipIf(
|
||||
torch_device != "cuda" or not is_xformers_available(),
|
||||
reason="XFormers attention is only available with CUDA and `xformers` installed",
|
||||
)
|
||||
def test_lora_xformers_on_off(self):
|
||||
# enable deterministic behavior for gradient checkpointing
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
init_dict["attention_head_dim"] = (8, 16)
|
||||
|
||||
torch.manual_seed(0)
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
lora_attn_procs = create_lora_layers(model)
|
||||
model.set_attn_processor(lora_attn_procs)
|
||||
|
||||
# default
|
||||
with torch.no_grad():
|
||||
sample = model(**inputs_dict).sample
|
||||
|
||||
model.enable_xformers_memory_efficient_attention()
|
||||
on_sample = model(**inputs_dict).sample
|
||||
|
||||
model.disable_xformers_memory_efficient_attention()
|
||||
off_sample = model(**inputs_dict).sample
|
||||
|
||||
assert (sample - on_sample).abs().max() < 1e-4
|
||||
assert (sample - off_sample).abs().max() < 1e-4
|
||||
|
||||
|
||||
@slow
|
||||
class UNet2DConditionModelIntegrationTests(unittest.TestCase):
|
||||
|
||||
Reference in New Issue
Block a user