1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[Bug fix] Fix batch size attention head size mismatch (#3214)

This commit is contained in:
Patrick von Platen
2023-04-25 00:44:00 +02:00
committed by Daniel Gu
parent de05ea0f50
commit 895320910c

View File

@@ -86,8 +86,10 @@ class AttentionBlock(nn.Module):
head_size = self.num_heads
if unmerge_head_and_batch:
batch_size, seq_len, dim = tensor.shape
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
batch_head_size, seq_len, dim = tensor.shape
batch_size = batch_head_size // head_size
tensor = tensor.reshape(batch_size, head_size, seq_len, dim)
else:
batch_size, _, seq_len, dim = tensor.shape