mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
add reshape to fix use_memory_efficient_attention in flax (#7918)
Co-authored-by: Juan Acevedo <jfacevedo@google.com> Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> Co-authored-by: Aryan <aryan@huggingface.co>
This commit is contained in:
@@ -216,8 +216,8 @@ class FlaxAttention(nn.Module):
|
||||
hidden_states = jax_memory_efficient_attention(
|
||||
query_states, key_states, value_states, query_chunk_size=query_chunk_size, key_chunk_size=4096 * 4
|
||||
)
|
||||
|
||||
hidden_states = hidden_states.transpose(1, 0, 2)
|
||||
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
|
||||
else:
|
||||
# compute attentions
|
||||
if self.split_head_dim:
|
||||
|
||||
Reference in New Issue
Block a user