1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Attn added kv processor torch 2.0 block (#3023)

add AttnAddedKVProcessor2_0 block
This commit is contained in:
Will Berman
2023-04-11 16:54:22 -07:00
committed by GitHub
parent 98c5e5da31
commit ea39cd7e64
4 changed files with 109 additions and 12 deletions

View File

@@ -255,11 +255,15 @@ class Attention(nn.Module):
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
return tensor
def head_to_batch_dim(self, tensor):
def head_to_batch_dim(self, tensor, out_dim=3):
head_size = self.heads
batch_size, seq_len, dim = tensor.shape
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
tensor = tensor.permute(0, 2, 1, 3)
if out_dim == 3:
tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size)
return tensor
def get_attention_scores(self, query, key, attention_mask=None):
@@ -293,7 +297,7 @@ class Attention(nn.Module):
return attention_probs
def prepare_attention_mask(self, attention_mask, target_length, batch_size=None):
def prepare_attention_mask(self, attention_mask, target_length, batch_size=None, out_dim=3):
if batch_size is None:
deprecate(
"batch_size=None",
@@ -320,8 +324,13 @@ class Attention(nn.Module):
else:
attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
if attention_mask.shape[0] < batch_size * head_size:
attention_mask = attention_mask.repeat_interleave(head_size, dim=0)
if out_dim == 3:
if attention_mask.shape[0] < batch_size * head_size:
attention_mask = attention_mask.repeat_interleave(head_size, dim=0)
elif out_dim == 4:
attention_mask = attention_mask.unsqueeze(1)
attention_mask = attention_mask.repeat_interleave(head_size, dim=1)
return attention_mask
def norm_encoder_hidden_states(self, encoder_hidden_states):
@@ -499,6 +508,64 @@ class AttnAddedKVProcessor:
return hidden_states
class AttnAddedKVProcessor2_0:
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
"AttnAddedKVProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
)
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
residual = hidden_states
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
batch_size, sequence_length, _ = hidden_states.shape
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size, out_dim=4)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
query = attn.head_to_batch_dim(query, out_dim=4)
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj, out_dim=4)
encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj, out_dim=4)
if not attn.only_cross_attention:
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
key = attn.head_to_batch_dim(key, out_dim=4)
value = attn.head_to_batch_dim(value, out_dim=4)
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
else:
key = encoder_hidden_states_key_proj
value = encoder_hidden_states_value_proj
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, residual.shape[1])
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
hidden_states = hidden_states + residual
return hidden_states
class XFormersAttnProcessor:
def __init__(self, attention_op: Optional[Callable] = None):
self.attention_op = attention_op
@@ -764,6 +831,7 @@ AttentionProcessor = Union[
SlicedAttnProcessor,
AttnAddedKVProcessor,
SlicedAttnAddedKVProcessor,
AttnAddedKVProcessor2_0,
LoRAAttnProcessor,
LoRAXFormersAttnProcessor,
]

View File

@@ -15,10 +15,11 @@ from typing import Optional
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from .attention import AdaGroupNorm, AttentionBlock
from .attention_processor import Attention, AttnAddedKVProcessor
from .attention_processor import Attention, AttnAddedKVProcessor, AttnAddedKVProcessor2_0
from .dual_transformer_2d import DualTransformer2DModel
from .resnet import Downsample2D, FirDownsample2D, FirUpsample2D, KDownsample2D, KUpsample2D, ResnetBlock2D, Upsample2D
from .transformer_2d import Transformer2DModel
@@ -612,6 +613,10 @@ class UNetMidBlock2DSimpleCrossAttn(nn.Module):
attentions = []
for _ in range(num_layers):
processor = (
AttnAddedKVProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnAddedKVProcessor()
)
attentions.append(
Attention(
query_dim=in_channels,
@@ -624,7 +629,7 @@ class UNetMidBlock2DSimpleCrossAttn(nn.Module):
upcast_softmax=True,
only_cross_attention=only_cross_attention,
cross_attention_norm=cross_attention_norm,
processor=AttnAddedKVProcessor(),
processor=processor,
)
)
resnets.append(
@@ -1396,6 +1401,11 @@ class SimpleCrossAttnDownBlock2D(nn.Module):
skip_time_act=skip_time_act,
)
)
processor = (
AttnAddedKVProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnAddedKVProcessor()
)
attentions.append(
Attention(
query_dim=out_channels,
@@ -1408,7 +1418,7 @@ class SimpleCrossAttnDownBlock2D(nn.Module):
upcast_softmax=True,
only_cross_attention=only_cross_attention,
cross_attention_norm=cross_attention_norm,
processor=AttnAddedKVProcessor(),
processor=processor,
)
)
self.attentions = nn.ModuleList(attentions)
@@ -2399,6 +2409,11 @@ class SimpleCrossAttnUpBlock2D(nn.Module):
skip_time_act=skip_time_act,
)
)
processor = (
AttnAddedKVProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnAddedKVProcessor()
)
attentions.append(
Attention(
query_dim=out_channels,
@@ -2411,7 +2426,7 @@ class SimpleCrossAttnUpBlock2D(nn.Module):
upcast_softmax=True,
only_cross_attention=only_cross_attention,
cross_attention_norm=cross_attention_norm,
processor=AttnAddedKVProcessor(),
processor=processor,
)
)
self.attentions = nn.ModuleList(attentions)

View File

@@ -8,7 +8,12 @@ import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...models import ModelMixin
from ...models.attention import Attention
from ...models.attention_processor import AttentionProcessor, AttnAddedKVProcessor, AttnProcessor
from ...models.attention_processor import (
AttentionProcessor,
AttnAddedKVProcessor,
AttnAddedKVProcessor2_0,
AttnProcessor,
)
from ...models.dual_transformer_2d import DualTransformer2DModel
from ...models.embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
from ...models.transformer_2d import Transformer2DModel
@@ -1545,6 +1550,10 @@ class UNetMidBlockFlatSimpleCrossAttn(nn.Module):
attentions = []
for _ in range(num_layers):
processor = (
AttnAddedKVProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnAddedKVProcessor()
)
attentions.append(
Attention(
query_dim=in_channels,
@@ -1557,7 +1566,7 @@ class UNetMidBlockFlatSimpleCrossAttn(nn.Module):
upcast_softmax=True,
only_cross_attention=only_cross_attention,
cross_attention_norm=cross_attention_norm,
processor=AttnAddedKVProcessor(),
processor=processor,
)
)
resnets.append(

View File

@@ -421,7 +421,12 @@ class UnCLIPImageVariationPipelineFastTests(PipelineTesterMixin, unittest.TestCa
def test_attention_slicing_forward_pass(self):
test_max_difference = torch_device == "cpu"
self._test_attention_slicing_forward_pass(test_max_difference=test_max_difference)
# Check is relaxed because there is not a torch 2.0 sliced attention added kv processor
expected_max_diff = 1e-2
self._test_attention_slicing_forward_pass(
test_max_difference=test_max_difference, expected_max_diff=expected_max_diff
)
# Overriding PipelineTesterMixin::test_inference_batch_single_identical
# because UnCLIP undeterminism requires a looser check.