From a5f35ee4731b731d6bd8977525873b0bc480cb42 Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Sat, 14 Dec 2024 08:45:45 -0800 Subject: [PATCH] add reshape to fix use_memory_efficient_attention in flax (#7918) Co-authored-by: Juan Acevedo Co-authored-by: Sayak Paul Co-authored-by: Aryan --- src/diffusers/models/attention_flax.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/attention_flax.py b/src/diffusers/models/attention_flax.py index 25ae5d0a5d..246f3afaf5 100644 --- a/src/diffusers/models/attention_flax.py +++ b/src/diffusers/models/attention_flax.py @@ -216,8 +216,8 @@ class FlaxAttention(nn.Module): hidden_states = jax_memory_efficient_attention( query_states, key_states, value_states, query_chunk_size=query_chunk_size, key_chunk_size=4096 * 4 ) - hidden_states = hidden_states.transpose(1, 0, 2) + hidden_states = self.reshape_batch_dim_to_heads(hidden_states) else: # compute attentions if self.split_head_dim: