From 0c49f4cf30b5cd98b28845152c82ea2c83d624e0 Mon Sep 17 00:00:00 2001 From: thomasw21 <24695242+thomasw21@users.noreply.github.com> Date: Tue, 22 Nov 2022 12:19:15 +0100 Subject: [PATCH] Woops --- src/diffusers/models/attention.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index c71e2a8336..7d8b962905 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -520,12 +520,12 @@ class CrossAttention(nn.Module): query = ( self.reshape_heads_to_batch_dim(query) .permute(0, 2, 1, 3) - .reshape(batch_size * head_size, seq_len, dim // head_size) + .reshape(batch_size * self.heads, seq_len, dim // self.heads) ) value = ( self.reshape_heads_to_batch_dim(value) .permute(0, 2, 1, 3) - .reshape(batch_size * head_size, seq_len, dim // head_size) + .reshape(batch_size * self.heads, seq_len, dim // self.heads) ) # TODO(PVP) - mask is currently never used. Remember to re-implement when used