mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Context has its own sequence length
This commit is contained in:
@@ -512,6 +512,7 @@ class CrossAttention(nn.Module):
|
||||
|
||||
query = self.to_q(hidden_states)
|
||||
context = context if context is not None else hidden_states
|
||||
context_sequence_length = context.shape[1]
|
||||
key = self.to_k(context)
|
||||
value = self.to_v(context)
|
||||
|
||||
@@ -525,19 +526,19 @@ class CrossAttention(nn.Module):
|
||||
value = (
|
||||
self.reshape_heads_to_batch_dim(value)
|
||||
.permute(0, 2, 1, 3)
|
||||
.reshape(batch_size * self.heads, sequence_length, dim // self.heads)
|
||||
.reshape(batch_size * self.heads, context_sequence_length, dim // self.heads)
|
||||
)
|
||||
|
||||
# TODO(PVP) - mask is currently never used. Remember to re-implement when used
|
||||
|
||||
# attention, what we cannot get enough of
|
||||
if self._use_memory_efficient_attention_xformers:
|
||||
key = self.reshape_heads_to_batch_dim(key).permute(0, 2, 1, 3).reshape(batch_size * self.heads, sequence_length, dim // self.heads)
|
||||
key = self.reshape_heads_to_batch_dim(key).permute(0, 2, 1, 3).reshape(batch_size * self.heads, context_sequence_length, dim // self.heads)
|
||||
hidden_states = self._memory_efficient_attention_xformers(query, key, value)
|
||||
# Some versions of xformers return output in fp32, cast it back to the dtype of the input
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
else:
|
||||
key = self.reshape_heads_to_batch_dim(key).permute(0, 2, 3, 1).reshape(batch_size * self.heads, dim // self.heads, sequence_length)
|
||||
key = self.reshape_heads_to_batch_dim(key).permute(0, 2, 3, 1).reshape(batch_size * self.heads, dim // self.heads, context_sequence_length)
|
||||
if self._slice_size is None or query.shape[0] // self._slice_size == 1:
|
||||
hidden_states = self._attention(query, key, value)
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user