1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

add reshape to fix use_memory_efficient_attention in flax (#7918)

Co-authored-by: Juan Acevedo <jfacevedo@google.com>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: Aryan <aryan@huggingface.co>
This commit is contained in:
Juan Acevedo
2024-12-14 08:45:45 -08:00
committed by GitHub
parent 63243406ba
commit a5f35ee473

View File

@@ -216,8 +216,8 @@ class FlaxAttention(nn.Module):
hidden_states = jax_memory_efficient_attention(
query_states, key_states, value_states, query_chunk_size=query_chunk_size, key_chunk_size=4096 * 4
)
hidden_states = hidden_states.transpose(1, 0, 2)
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
else:
# compute attentions
if self.split_head_dim: