From 888468dd90b4e353225d7d505fa28d2963e65678 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 31 Oct 2022 19:01:42 +0100 Subject: [PATCH] Remove nn sequential (#1086) * Remove nn sequential * up --- src/diffusers/models/attention.py | 23 ++++++++++++++++++----- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index bf04c3e6a3..af441ef861 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -244,7 +244,9 @@ class CrossAttention(nn.Module): self.to_k = nn.Linear(context_dim, inner_dim, bias=False) self.to_v = nn.Linear(context_dim, inner_dim, bias=False) - self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)) + self.to_out = nn.ModuleList([]) + self.to_out.append(nn.Linear(inner_dim, query_dim)) + self.to_out.append(nn.Dropout(dropout)) def reshape_heads_to_batch_dim(self, tensor): batch_size, seq_len, dim = tensor.shape @@ -283,7 +285,11 @@ class CrossAttention(nn.Module): else: hidden_states = self._sliced_attention(query, key, value, sequence_length, dim) - return self.to_out(hidden_states) + # linear proj + hidden_states = self.to_out[0](hidden_states) + # dropout + hidden_states = self.to_out[1](hidden_states) + return hidden_states def _attention(self, query, key, value): # TODO: use baddbmm for better performance @@ -354,12 +360,19 @@ class FeedForward(nn.Module): super().__init__() inner_dim = int(dim * mult) dim_out = dim_out if dim_out is not None else dim - project_in = GEGLU(dim, inner_dim) + self.net = nn.ModuleList([]) - self.net = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)) + # project in + self.net.append(GEGLU(dim, inner_dim)) + # project dropout + self.net.append(nn.Dropout(dropout)) + # project out + self.net.append(nn.Linear(inner_dim, dim_out)) def forward(self, hidden_states): - return self.net(hidden_states) + for module in self.net: + hidden_states = module(hidden_states) + return hidden_states # feedforward