1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

fuse attention mask (#2111)

* fuse attention mask

* lint

* use 0 beta when no attention mask re: @Birch-san
This commit is contained in:
Will Berman
2023-01-26 08:36:07 -08:00
committed by GitHub
parent 96af5bf7d9
commit 14976500ed

View File

@@ -185,17 +185,23 @@ class CrossAttention(nn.Module):
query = query.float()
key = key.float()
if attention_mask is None:
baddbmm_input = torch.empty(
query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device
)
beta = 0
else:
baddbmm_input = attention_mask
beta = 1
attention_scores = torch.baddbmm(
torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
baddbmm_input,
query,
key.transpose(-1, -2),
beta=0,
beta=beta,
alpha=self.scale,
)
if attention_mask is not None:
attention_scores = attention_scores + attention_mask
if self.upcast_softmax:
attention_scores = attention_scores.float()
@@ -228,11 +234,12 @@ class CrossAttnProcessor:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
query = attn.to_q(hidden_states)
query = attn.head_to_batch_dim(query)
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
query = attn.head_to_batch_dim(query)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)