1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Remove nn sequential (#1086)

* Remove nn sequential

* up
This commit is contained in:
Patrick von Platen
2022-10-31 19:01:42 +01:00
committed by GitHub
parent 17c2c0600b
commit 888468dd90

View File

@@ -244,7 +244,9 @@ class CrossAttention(nn.Module):
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
self.to_out = nn.ModuleList([])
self.to_out.append(nn.Linear(inner_dim, query_dim))
self.to_out.append(nn.Dropout(dropout))
def reshape_heads_to_batch_dim(self, tensor):
batch_size, seq_len, dim = tensor.shape
@@ -283,7 +285,11 @@ class CrossAttention(nn.Module):
else:
hidden_states = self._sliced_attention(query, key, value, sequence_length, dim)
return self.to_out(hidden_states)
# linear proj
hidden_states = self.to_out[0](hidden_states)
# dropout
hidden_states = self.to_out[1](hidden_states)
return hidden_states
def _attention(self, query, key, value):
# TODO: use baddbmm for better performance
@@ -354,12 +360,19 @@ class FeedForward(nn.Module):
super().__init__()
inner_dim = int(dim * mult)
dim_out = dim_out if dim_out is not None else dim
project_in = GEGLU(dim, inner_dim)
self.net = nn.ModuleList([])
self.net = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out))
# project in
self.net.append(GEGLU(dim, inner_dim))
# project dropout
self.net.append(nn.Dropout(dropout))
# project out
self.net.append(nn.Linear(inner_dim, dim_out))
def forward(self, hidden_states):
return self.net(hidden_states)
for module in self.net:
hidden_states = module(hidden_states)
return hidden_states
# feedforward