mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Attn added kv processor torch 2.0 block (#3023)
add AttnAddedKVProcessor2_0 block
This commit is contained in:
@@ -255,11 +255,15 @@ class Attention(nn.Module):
|
||||
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
|
||||
return tensor
|
||||
|
||||
def head_to_batch_dim(self, tensor):
|
||||
def head_to_batch_dim(self, tensor, out_dim=3):
|
||||
head_size = self.heads
|
||||
batch_size, seq_len, dim = tensor.shape
|
||||
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
|
||||
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
|
||||
tensor = tensor.permute(0, 2, 1, 3)
|
||||
|
||||
if out_dim == 3:
|
||||
tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size)
|
||||
|
||||
return tensor
|
||||
|
||||
def get_attention_scores(self, query, key, attention_mask=None):
|
||||
@@ -293,7 +297,7 @@ class Attention(nn.Module):
|
||||
|
||||
return attention_probs
|
||||
|
||||
def prepare_attention_mask(self, attention_mask, target_length, batch_size=None):
|
||||
def prepare_attention_mask(self, attention_mask, target_length, batch_size=None, out_dim=3):
|
||||
if batch_size is None:
|
||||
deprecate(
|
||||
"batch_size=None",
|
||||
@@ -320,8 +324,13 @@ class Attention(nn.Module):
|
||||
else:
|
||||
attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
|
||||
|
||||
if attention_mask.shape[0] < batch_size * head_size:
|
||||
attention_mask = attention_mask.repeat_interleave(head_size, dim=0)
|
||||
if out_dim == 3:
|
||||
if attention_mask.shape[0] < batch_size * head_size:
|
||||
attention_mask = attention_mask.repeat_interleave(head_size, dim=0)
|
||||
elif out_dim == 4:
|
||||
attention_mask = attention_mask.unsqueeze(1)
|
||||
attention_mask = attention_mask.repeat_interleave(head_size, dim=1)
|
||||
|
||||
return attention_mask
|
||||
|
||||
def norm_encoder_hidden_states(self, encoder_hidden_states):
|
||||
@@ -499,6 +508,64 @@ class AttnAddedKVProcessor:
|
||||
return hidden_states
|
||||
|
||||
|
||||
class AttnAddedKVProcessor2_0:
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError(
|
||||
"AttnAddedKVProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
|
||||
)
|
||||
|
||||
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
||||
residual = hidden_states
|
||||
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
|
||||
batch_size, sequence_length, _ = hidden_states.shape
|
||||
|
||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size, out_dim=4)
|
||||
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
elif attn.norm_cross:
|
||||
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
||||
|
||||
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
||||
|
||||
query = attn.to_q(hidden_states)
|
||||
query = attn.head_to_batch_dim(query, out_dim=4)
|
||||
|
||||
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
|
||||
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
|
||||
encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj, out_dim=4)
|
||||
encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj, out_dim=4)
|
||||
|
||||
if not attn.only_cross_attention:
|
||||
key = attn.to_k(hidden_states)
|
||||
value = attn.to_v(hidden_states)
|
||||
key = attn.head_to_batch_dim(key, out_dim=4)
|
||||
value = attn.head_to_batch_dim(value, out_dim=4)
|
||||
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
|
||||
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
|
||||
else:
|
||||
key = encoder_hidden_states_key_proj
|
||||
value = encoder_hidden_states_value_proj
|
||||
|
||||
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
||||
# TODO: add support for attn.scale when we move to Torch 2.1
|
||||
hidden_states = F.scaled_dot_product_attention(
|
||||
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, residual.shape[1])
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class XFormersAttnProcessor:
|
||||
def __init__(self, attention_op: Optional[Callable] = None):
|
||||
self.attention_op = attention_op
|
||||
@@ -764,6 +831,7 @@ AttentionProcessor = Union[
|
||||
SlicedAttnProcessor,
|
||||
AttnAddedKVProcessor,
|
||||
SlicedAttnAddedKVProcessor,
|
||||
AttnAddedKVProcessor2_0,
|
||||
LoRAAttnProcessor,
|
||||
LoRAXFormersAttnProcessor,
|
||||
]
|
||||
|
||||
@@ -15,10 +15,11 @@ from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from .attention import AdaGroupNorm, AttentionBlock
|
||||
from .attention_processor import Attention, AttnAddedKVProcessor
|
||||
from .attention_processor import Attention, AttnAddedKVProcessor, AttnAddedKVProcessor2_0
|
||||
from .dual_transformer_2d import DualTransformer2DModel
|
||||
from .resnet import Downsample2D, FirDownsample2D, FirUpsample2D, KDownsample2D, KUpsample2D, ResnetBlock2D, Upsample2D
|
||||
from .transformer_2d import Transformer2DModel
|
||||
@@ -612,6 +613,10 @@ class UNetMidBlock2DSimpleCrossAttn(nn.Module):
|
||||
attentions = []
|
||||
|
||||
for _ in range(num_layers):
|
||||
processor = (
|
||||
AttnAddedKVProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnAddedKVProcessor()
|
||||
)
|
||||
|
||||
attentions.append(
|
||||
Attention(
|
||||
query_dim=in_channels,
|
||||
@@ -624,7 +629,7 @@ class UNetMidBlock2DSimpleCrossAttn(nn.Module):
|
||||
upcast_softmax=True,
|
||||
only_cross_attention=only_cross_attention,
|
||||
cross_attention_norm=cross_attention_norm,
|
||||
processor=AttnAddedKVProcessor(),
|
||||
processor=processor,
|
||||
)
|
||||
)
|
||||
resnets.append(
|
||||
@@ -1396,6 +1401,11 @@ class SimpleCrossAttnDownBlock2D(nn.Module):
|
||||
skip_time_act=skip_time_act,
|
||||
)
|
||||
)
|
||||
|
||||
processor = (
|
||||
AttnAddedKVProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnAddedKVProcessor()
|
||||
)
|
||||
|
||||
attentions.append(
|
||||
Attention(
|
||||
query_dim=out_channels,
|
||||
@@ -1408,7 +1418,7 @@ class SimpleCrossAttnDownBlock2D(nn.Module):
|
||||
upcast_softmax=True,
|
||||
only_cross_attention=only_cross_attention,
|
||||
cross_attention_norm=cross_attention_norm,
|
||||
processor=AttnAddedKVProcessor(),
|
||||
processor=processor,
|
||||
)
|
||||
)
|
||||
self.attentions = nn.ModuleList(attentions)
|
||||
@@ -2399,6 +2409,11 @@ class SimpleCrossAttnUpBlock2D(nn.Module):
|
||||
skip_time_act=skip_time_act,
|
||||
)
|
||||
)
|
||||
|
||||
processor = (
|
||||
AttnAddedKVProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnAddedKVProcessor()
|
||||
)
|
||||
|
||||
attentions.append(
|
||||
Attention(
|
||||
query_dim=out_channels,
|
||||
@@ -2411,7 +2426,7 @@ class SimpleCrossAttnUpBlock2D(nn.Module):
|
||||
upcast_softmax=True,
|
||||
only_cross_attention=only_cross_attention,
|
||||
cross_attention_norm=cross_attention_norm,
|
||||
processor=AttnAddedKVProcessor(),
|
||||
processor=processor,
|
||||
)
|
||||
)
|
||||
self.attentions = nn.ModuleList(attentions)
|
||||
|
||||
@@ -8,7 +8,12 @@ import torch.nn.functional as F
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...models import ModelMixin
|
||||
from ...models.attention import Attention
|
||||
from ...models.attention_processor import AttentionProcessor, AttnAddedKVProcessor, AttnProcessor
|
||||
from ...models.attention_processor import (
|
||||
AttentionProcessor,
|
||||
AttnAddedKVProcessor,
|
||||
AttnAddedKVProcessor2_0,
|
||||
AttnProcessor,
|
||||
)
|
||||
from ...models.dual_transformer_2d import DualTransformer2DModel
|
||||
from ...models.embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
|
||||
from ...models.transformer_2d import Transformer2DModel
|
||||
@@ -1545,6 +1550,10 @@ class UNetMidBlockFlatSimpleCrossAttn(nn.Module):
|
||||
attentions = []
|
||||
|
||||
for _ in range(num_layers):
|
||||
processor = (
|
||||
AttnAddedKVProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnAddedKVProcessor()
|
||||
)
|
||||
|
||||
attentions.append(
|
||||
Attention(
|
||||
query_dim=in_channels,
|
||||
@@ -1557,7 +1566,7 @@ class UNetMidBlockFlatSimpleCrossAttn(nn.Module):
|
||||
upcast_softmax=True,
|
||||
only_cross_attention=only_cross_attention,
|
||||
cross_attention_norm=cross_attention_norm,
|
||||
processor=AttnAddedKVProcessor(),
|
||||
processor=processor,
|
||||
)
|
||||
)
|
||||
resnets.append(
|
||||
|
||||
@@ -421,7 +421,12 @@ class UnCLIPImageVariationPipelineFastTests(PipelineTesterMixin, unittest.TestCa
|
||||
def test_attention_slicing_forward_pass(self):
|
||||
test_max_difference = torch_device == "cpu"
|
||||
|
||||
self._test_attention_slicing_forward_pass(test_max_difference=test_max_difference)
|
||||
# Check is relaxed because there is not a torch 2.0 sliced attention added kv processor
|
||||
expected_max_diff = 1e-2
|
||||
|
||||
self._test_attention_slicing_forward_pass(
|
||||
test_max_difference=test_max_difference, expected_max_diff=expected_max_diff
|
||||
)
|
||||
|
||||
# Overriding PipelineTesterMixin::test_inference_batch_single_identical
|
||||
# because UnCLIP undeterminism requires a looser check.
|
||||
|
||||
Reference in New Issue
Block a user