mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Allow disabling torch 2_0 attention (#3273)
* Allow disabling torch 2_0 attention * make style * Update src/diffusers/models/attention.py
This commit is contained in:
committed by
GitHub
parent
a7b0671c07
commit
4d35d7fea3
@@ -71,6 +71,7 @@ class AttentionBlock(nn.Module):
|
||||
self.proj_attn = nn.Linear(channels, channels, bias=True)
|
||||
|
||||
self._use_memory_efficient_attention_xformers = False
|
||||
self._use_2_0_attn = True
|
||||
self._attention_op = None
|
||||
|
||||
def reshape_heads_to_batch_dim(self, tensor, merge_head_and_batch=True):
|
||||
@@ -142,9 +143,8 @@ class AttentionBlock(nn.Module):
|
||||
|
||||
scale = 1 / math.sqrt(self.channels / self.num_heads)
|
||||
|
||||
use_torch_2_0_attn = (
|
||||
hasattr(F, "scaled_dot_product_attention") and not self._use_memory_efficient_attention_xformers
|
||||
)
|
||||
_use_2_0_attn = self._use_2_0_attn and not self._use_memory_efficient_attention_xformers
|
||||
use_torch_2_0_attn = hasattr(F, "scaled_dot_product_attention") and _use_2_0_attn
|
||||
|
||||
query_proj = self.reshape_heads_to_batch_dim(query_proj, merge_head_and_batch=not use_torch_2_0_attn)
|
||||
key_proj = self.reshape_heads_to_batch_dim(key_proj, merge_head_and_batch=not use_torch_2_0_attn)
|
||||
|
||||
Reference in New Issue
Block a user