mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
black
This commit is contained in:
@@ -288,9 +288,22 @@ class AttentionBlock(nn.Module):
|
||||
|
||||
# get scores
|
||||
if self.num_heads > 1:
|
||||
query_states = self.transpose_for_scores(query_proj).contiguous().view(batch * self.num_heads, height * width, self.num_head_size)
|
||||
key_states = self.transpose_for_scores(key_proj).transpose(3,2).contiguous().view(batch * self.num_heads, self.num_head_size, height * width)
|
||||
value_states = self.transpose_for_scores(value_proj).contiguous().view(batch * self.num_heads, height * width, self.num_head_size)
|
||||
query_states = (
|
||||
self.transpose_for_scores(query_proj)
|
||||
.contiguous()
|
||||
.view(batch * self.num_heads, height * width, self.num_head_size)
|
||||
)
|
||||
key_states = (
|
||||
self.transpose_for_scores(key_proj)
|
||||
.transpose(3, 2)
|
||||
.contiguous()
|
||||
.view(batch * self.num_heads, self.num_head_size, height * width)
|
||||
)
|
||||
value_states = (
|
||||
self.transpose_for_scores(value_proj)
|
||||
.contiguous()
|
||||
.view(batch * self.num_heads, height * width, self.num_head_size)
|
||||
)
|
||||
else:
|
||||
query_states, key_states, value_states = query_proj, key_proj.transpose(-1, -2), value_proj
|
||||
|
||||
@@ -313,7 +326,11 @@ class AttentionBlock(nn.Module):
|
||||
# compute attention output
|
||||
hidden_states = torch.bmm(attention_probs, value_states)
|
||||
if self.num_heads > 1:
|
||||
hidden_states = hidden_states.view(batch, self.num_heads, height * width, self.num_head_size).permute(0, 2, 1, 3).contiguous()
|
||||
hidden_states = (
|
||||
hidden_states.view(batch, self.num_heads, height * width, self.num_head_size)
|
||||
.permute(0, 2, 1, 3)
|
||||
.contiguous()
|
||||
)
|
||||
new_hidden_states_shape = hidden_states.size()[:-2] + (self.channels,)
|
||||
hidden_states = hidden_states.view(new_hidden_states_shape)
|
||||
|
||||
@@ -322,7 +339,7 @@ class AttentionBlock(nn.Module):
|
||||
hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width)
|
||||
|
||||
# res connect and rescale
|
||||
hidden_states = (hidden_states + residual)
|
||||
hidden_states = hidden_states + residual
|
||||
if self.rescale_output_factor != 1.0:
|
||||
hidden_states = hidden_states / self.rescale_output_factor
|
||||
return hidden_states
|
||||
|
||||
@@ -476,7 +476,7 @@ class ResnetBlock2D(nn.Module):
|
||||
if self.conv_shortcut is not None:
|
||||
input_tensor = self.conv_shortcut(input_tensor)
|
||||
|
||||
output_tensor = (input_tensor + hidden_states)
|
||||
output_tensor = input_tensor + hidden_states
|
||||
if self.output_scale_factor != 1.0:
|
||||
output_tensor = output_tensor / self.output_scale_factor
|
||||
|
||||
|
||||
Reference in New Issue
Block a user