From 8187865aef3185e9beb9a9dabddf0cff4effaea2 Mon Sep 17 00:00:00 2001 From: ydshieh Date: Mon, 19 Sep 2022 14:08:29 +0200 Subject: [PATCH] Fix CrossAttention._sliced_attention --- src/diffusers/models/attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index e4cedbff8c..e99a0745ab 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -267,7 +267,7 @@ class CrossAttention(nn.Module): if self._slice_size is None or query.shape[0] // self._slice_size == 1: hidden_states = self._attention(query, key, value) else: - hidden_states = self._sliced_attention(query, key, value, sequence_length, dim) + hidden_states = self._sliced_attention(query, key, value, sequence_length, dim=query.shape[-1] * self.heads) return self.to_out(hidden_states)