From 895320910c7eb0739fcd744abd7136235dd474f3 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 25 Apr 2023 00:44:00 +0200 Subject: [PATCH] [Bug fix] Fix batch size attention head size mismatch (#3214) --- src/diffusers/models/attention.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 1085c452b0..8e537c6f36 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -86,8 +86,10 @@ class AttentionBlock(nn.Module): head_size = self.num_heads if unmerge_head_and_batch: - batch_size, seq_len, dim = tensor.shape - tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) + batch_head_size, seq_len, dim = tensor.shape + batch_size = batch_head_size // head_size + + tensor = tensor.reshape(batch_size, head_size, seq_len, dim) else: batch_size, _, seq_len, dim = tensor.shape