From f2ed5d8b44f54bb5cc27d68727a9482312ac9eb4 Mon Sep 17 00:00:00 2001 From: thomasw21 <24695242+thomasw21@users.noreply.github.com> Date: Tue, 22 Nov 2022 00:48:50 +0100 Subject: [PATCH] black --- src/diffusers/models/attention.py | 27 ++++++++++++++++++++++----- src/diffusers/models/resnet.py | 2 +- 2 files changed, 23 insertions(+), 6 deletions(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index fb0674051f..c67a28da8d 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -288,9 +288,22 @@ class AttentionBlock(nn.Module): # get scores if self.num_heads > 1: - query_states = self.transpose_for_scores(query_proj).contiguous().view(batch * self.num_heads, height * width, self.num_head_size) - key_states = self.transpose_for_scores(key_proj).transpose(3,2).contiguous().view(batch * self.num_heads, self.num_head_size, height * width) - value_states = self.transpose_for_scores(value_proj).contiguous().view(batch * self.num_heads, height * width, self.num_head_size) + query_states = ( + self.transpose_for_scores(query_proj) + .contiguous() + .view(batch * self.num_heads, height * width, self.num_head_size) + ) + key_states = ( + self.transpose_for_scores(key_proj) + .transpose(3, 2) + .contiguous() + .view(batch * self.num_heads, self.num_head_size, height * width) + ) + value_states = ( + self.transpose_for_scores(value_proj) + .contiguous() + .view(batch * self.num_heads, height * width, self.num_head_size) + ) else: query_states, key_states, value_states = query_proj, key_proj.transpose(-1, -2), value_proj @@ -313,7 +326,11 @@ class AttentionBlock(nn.Module): # compute attention output hidden_states = torch.bmm(attention_probs, value_states) if self.num_heads > 1: - hidden_states = hidden_states.view(batch, self.num_heads, height * width, self.num_head_size).permute(0, 2, 1, 3).contiguous() + hidden_states = ( + hidden_states.view(batch, self.num_heads, height * width, self.num_head_size) + .permute(0, 2, 1, 3) + .contiguous() + ) new_hidden_states_shape = hidden_states.size()[:-2] + (self.channels,) hidden_states = hidden_states.view(new_hidden_states_shape) @@ -322,7 +339,7 @@ class AttentionBlock(nn.Module): hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width) # res connect and rescale - hidden_states = (hidden_states + residual) + hidden_states = hidden_states + residual if self.rescale_output_factor != 1.0: hidden_states = hidden_states / self.rescale_output_factor return hidden_states diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index 2e8d0609cd..a6cb3828fd 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -476,7 +476,7 @@ class ResnetBlock2D(nn.Module): if self.conv_shortcut is not None: input_tensor = self.conv_shortcut(input_tensor) - output_tensor = (input_tensor + hidden_states) + output_tensor = input_tensor + hidden_states if self.output_scale_factor != 1.0: output_tensor = output_tensor / self.output_scale_factor