mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
fuse attention mask (#2111)
* fuse attention mask * lint * use 0 beta when no attention mask re: @Birch-san
This commit is contained in:
@@ -185,17 +185,23 @@ class CrossAttention(nn.Module):
|
||||
query = query.float()
|
||||
key = key.float()
|
||||
|
||||
if attention_mask is None:
|
||||
baddbmm_input = torch.empty(
|
||||
query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device
|
||||
)
|
||||
beta = 0
|
||||
else:
|
||||
baddbmm_input = attention_mask
|
||||
beta = 1
|
||||
|
||||
attention_scores = torch.baddbmm(
|
||||
torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
|
||||
baddbmm_input,
|
||||
query,
|
||||
key.transpose(-1, -2),
|
||||
beta=0,
|
||||
beta=beta,
|
||||
alpha=self.scale,
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_scores = attention_scores + attention_mask
|
||||
|
||||
if self.upcast_softmax:
|
||||
attention_scores = attention_scores.float()
|
||||
|
||||
@@ -228,11 +234,12 @@ class CrossAttnProcessor:
|
||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
|
||||
|
||||
query = attn.to_q(hidden_states)
|
||||
query = attn.head_to_batch_dim(query)
|
||||
|
||||
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
|
||||
key = attn.to_k(encoder_hidden_states)
|
||||
value = attn.to_v(encoder_hidden_states)
|
||||
|
||||
query = attn.head_to_batch_dim(query)
|
||||
key = attn.head_to_batch_dim(key)
|
||||
value = attn.head_to_batch_dim(value)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user