mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
committed by
GitHub
parent
17c2c0600b
commit
888468dd90
@@ -244,7 +244,9 @@ class CrossAttention(nn.Module):
|
||||
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
|
||||
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
|
||||
|
||||
self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
|
||||
self.to_out = nn.ModuleList([])
|
||||
self.to_out.append(nn.Linear(inner_dim, query_dim))
|
||||
self.to_out.append(nn.Dropout(dropout))
|
||||
|
||||
def reshape_heads_to_batch_dim(self, tensor):
|
||||
batch_size, seq_len, dim = tensor.shape
|
||||
@@ -283,7 +285,11 @@ class CrossAttention(nn.Module):
|
||||
else:
|
||||
hidden_states = self._sliced_attention(query, key, value, sequence_length, dim)
|
||||
|
||||
return self.to_out(hidden_states)
|
||||
# linear proj
|
||||
hidden_states = self.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = self.to_out[1](hidden_states)
|
||||
return hidden_states
|
||||
|
||||
def _attention(self, query, key, value):
|
||||
# TODO: use baddbmm for better performance
|
||||
@@ -354,12 +360,19 @@ class FeedForward(nn.Module):
|
||||
super().__init__()
|
||||
inner_dim = int(dim * mult)
|
||||
dim_out = dim_out if dim_out is not None else dim
|
||||
project_in = GEGLU(dim, inner_dim)
|
||||
self.net = nn.ModuleList([])
|
||||
|
||||
self.net = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out))
|
||||
# project in
|
||||
self.net.append(GEGLU(dim, inner_dim))
|
||||
# project dropout
|
||||
self.net.append(nn.Dropout(dropout))
|
||||
# project out
|
||||
self.net.append(nn.Linear(inner_dim, dim_out))
|
||||
|
||||
def forward(self, hidden_states):
|
||||
return self.net(hidden_states)
|
||||
for module in self.net:
|
||||
hidden_states = module(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
# feedforward
|
||||
|
||||
Reference in New Issue
Block a user