From 14976500edd28f20dcda9fae7bd8cdb0cc6b1426 Mon Sep 17 00:00:00 2001 From: Will Berman Date: Thu, 26 Jan 2023 08:36:07 -0800 Subject: [PATCH] fuse attention mask (#2111) * fuse attention mask * lint * use 0 beta when no attention mask re: @Birch-san --- src/diffusers/models/cross_attention.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/src/diffusers/models/cross_attention.py b/src/diffusers/models/cross_attention.py index 7dda30fbda..bfa6fc3611 100644 --- a/src/diffusers/models/cross_attention.py +++ b/src/diffusers/models/cross_attention.py @@ -185,17 +185,23 @@ class CrossAttention(nn.Module): query = query.float() key = key.float() + if attention_mask is None: + baddbmm_input = torch.empty( + query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device + ) + beta = 0 + else: + baddbmm_input = attention_mask + beta = 1 + attention_scores = torch.baddbmm( - torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device), + baddbmm_input, query, key.transpose(-1, -2), - beta=0, + beta=beta, alpha=self.scale, ) - if attention_mask is not None: - attention_scores = attention_scores + attention_mask - if self.upcast_softmax: attention_scores = attention_scores.float() @@ -228,11 +234,12 @@ class CrossAttnProcessor: attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length) query = attn.to_q(hidden_states) - query = attn.head_to_batch_dim(query) encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) + + query = attn.head_to_batch_dim(query) key = attn.head_to_batch_dim(key) value = attn.head_to_batch_dim(value)