1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Allow disabling torch 2_0 attention (#3273)

* Allow disabling torch 2_0 attention

* make style

* Update src/diffusers/models/attention.py
This commit is contained in:
Patrick von Platen
2023-04-28 13:31:11 +02:00
committed by GitHub
parent a7b0671c07
commit 4d35d7fea3

View File

@@ -71,6 +71,7 @@ class AttentionBlock(nn.Module):
self.proj_attn = nn.Linear(channels, channels, bias=True)
self._use_memory_efficient_attention_xformers = False
self._use_2_0_attn = True
self._attention_op = None
def reshape_heads_to_batch_dim(self, tensor, merge_head_and_batch=True):
@@ -142,9 +143,8 @@ class AttentionBlock(nn.Module):
scale = 1 / math.sqrt(self.channels / self.num_heads)
use_torch_2_0_attn = (
hasattr(F, "scaled_dot_product_attention") and not self._use_memory_efficient_attention_xformers
)
_use_2_0_attn = self._use_2_0_attn and not self._use_memory_efficient_attention_xformers
use_torch_2_0_attn = hasattr(F, "scaled_dot_product_attention") and _use_2_0_attn
query_proj = self.reshape_heads_to_batch_dim(query_proj, merge_head_and_batch=not use_torch_2_0_attn)
key_proj = self.reshape_heads_to_batch_dim(key_proj, merge_head_and_batch=not use_torch_2_0_attn)