mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[Bug fix] Fix batch size attention head size mismatch (#3214)
This commit is contained in:
committed by
Daniel Gu
parent
de05ea0f50
commit
895320910c
@@ -86,8 +86,10 @@ class AttentionBlock(nn.Module):
|
||||
head_size = self.num_heads
|
||||
|
||||
if unmerge_head_and_batch:
|
||||
batch_size, seq_len, dim = tensor.shape
|
||||
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
|
||||
batch_head_size, seq_len, dim = tensor.shape
|
||||
batch_size = batch_head_size // head_size
|
||||
|
||||
tensor = tensor.reshape(batch_size, head_size, seq_len, dim)
|
||||
else:
|
||||
batch_size, _, seq_len, dim = tensor.shape
|
||||
|
||||
|
||||
Reference in New Issue
Block a user