diff --git a/src/diffusers/models/attention_flax.py b/src/diffusers/models/attention_flax.py index 25ae5d0a5d..246f3afaf5 100644 --- a/src/diffusers/models/attention_flax.py +++ b/src/diffusers/models/attention_flax.py @@ -216,8 +216,8 @@ class FlaxAttention(nn.Module): hidden_states = jax_memory_efficient_attention( query_states, key_states, value_states, query_chunk_size=query_chunk_size, key_chunk_size=4096 * 4 ) - hidden_states = hidden_states.transpose(1, 0, 2) + hidden_states = self.reshape_batch_dim_to_heads(hidden_states) else: # compute attentions if self.split_head_dim: