1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[UNet2DConditionModel] add an option to upcast attention to fp32 (#1590)

upcast attention
This commit is contained in:
Suraj Patil
2022-12-07 14:36:22 +01:00
committed by GitHub
parent dc87f526d4
commit 170ebd288f
4 changed files with 50 additions and 1 deletions

View File

@@ -101,6 +101,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
num_embeds_ada_norm: Optional[int] = None,
use_linear_projection: bool = False,
only_cross_attention: bool = False,
upcast_attention: bool = False,
):
super().__init__()
self.use_linear_projection = use_linear_projection
@@ -159,6 +160,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
num_embeds_ada_norm=num_embeds_ada_norm,
attention_bias=attention_bias,
only_cross_attention=only_cross_attention,
upcast_attention=upcast_attention,
)
for d in range(num_layers)
]
@@ -403,6 +405,7 @@ class BasicTransformerBlock(nn.Module):
num_embeds_ada_norm: Optional[int] = None,
attention_bias: bool = False,
only_cross_attention: bool = False,
upcast_attention: bool = False,
):
super().__init__()
self.only_cross_attention = only_cross_attention
@@ -416,6 +419,7 @@ class BasicTransformerBlock(nn.Module):
dropout=dropout,
bias=attention_bias,
cross_attention_dim=cross_attention_dim if only_cross_attention else None,
upcast_attention=upcast_attention,
) # is a self-attention
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
@@ -428,6 +432,7 @@ class BasicTransformerBlock(nn.Module):
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
upcast_attention=upcast_attention,
) # is self-attn if context is none
else:
self.attn2 = None
@@ -525,10 +530,12 @@ class CrossAttention(nn.Module):
dim_head: int = 64,
dropout: float = 0.0,
bias=False,
upcast_attention: bool = False,
):
super().__init__()
inner_dim = dim_head * heads
cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
self.upcast_attention = upcast_attention
self.scale = dim_head**-0.5
self.heads = heads
@@ -601,6 +608,10 @@ class CrossAttention(nn.Module):
return hidden_states
def _attention(self, query, key, value):
if self.upcast_attention:
query = query.float()
key = key.float()
attention_scores = torch.baddbmm(
torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
query,
@@ -609,8 +620,11 @@ class CrossAttention(nn.Module):
alpha=self.scale,
)
attention_probs = attention_scores.softmax(dim=-1)
# compute attention output
# cast back to the original dtype
attention_probs = attention_probs.to(value.dtype)
# compute attention output
hidden_states = torch.bmm(attention_probs, value)
# reshape hidden_states
@@ -626,6 +640,14 @@ class CrossAttention(nn.Module):
for i in range(hidden_states.shape[0] // slice_size):
start_idx = i * slice_size
end_idx = (i + 1) * slice_size
query_slice = query[start_idx:end_idx]
key_slice = key[start_idx:end_idx]
if self.upcast_attention:
query_slice = query_slice.float()
key_slice = key_slice.float()
attn_slice = torch.baddbmm(
torch.empty(slice_size, query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
query[start_idx:end_idx],
@@ -634,6 +656,9 @@ class CrossAttention(nn.Module):
alpha=self.scale,
)
attn_slice = attn_slice.softmax(dim=-1)
# cast back to the original dtype
attn_slice = attn_slice.to(value.dtype)
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
hidden_states[start_idx:end_idx] = attn_slice

View File

@@ -35,6 +35,7 @@ def get_down_block(
dual_cross_attention=False,
use_linear_projection=False,
only_cross_attention=False,
upcast_attention=False,
):
down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
if down_block_type == "DownBlock2D":
@@ -80,6 +81,7 @@ def get_down_block(
dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention,
upcast_attention=upcast_attention,
)
elif down_block_type == "SkipDownBlock2D":
return SkipDownBlock2D(
@@ -146,6 +148,7 @@ def get_up_block(
dual_cross_attention=False,
use_linear_projection=False,
only_cross_attention=False,
upcast_attention=False,
):
up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
if up_block_type == "UpBlock2D":
@@ -178,6 +181,7 @@ def get_up_block(
dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention,
upcast_attention=upcast_attention,
)
elif up_block_type == "AttnUpBlock2D":
return AttnUpBlock2D(
@@ -335,6 +339,7 @@ class UNetMidBlock2DCrossAttn(nn.Module):
cross_attention_dim=1280,
dual_cross_attention=False,
use_linear_projection=False,
upcast_attention=False,
):
super().__init__()
@@ -370,6 +375,7 @@ class UNetMidBlock2DCrossAttn(nn.Module):
cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
use_linear_projection=use_linear_projection,
upcast_attention=upcast_attention,
)
)
else:
@@ -514,6 +520,7 @@ class CrossAttnDownBlock2D(nn.Module):
dual_cross_attention=False,
use_linear_projection=False,
only_cross_attention=False,
upcast_attention=False,
):
super().__init__()
resnets = []
@@ -549,6 +556,7 @@ class CrossAttnDownBlock2D(nn.Module):
norm_num_groups=resnet_groups,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention,
upcast_attention=upcast_attention,
)
)
else:
@@ -1096,6 +1104,7 @@ class CrossAttnUpBlock2D(nn.Module):
dual_cross_attention=False,
use_linear_projection=False,
only_cross_attention=False,
upcast_attention=False,
):
super().__init__()
resnets = []
@@ -1133,6 +1142,7 @@ class CrossAttnUpBlock2D(nn.Module):
norm_num_groups=resnet_groups,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention,
upcast_attention=upcast_attention,
)
)
else:

View File

@@ -111,6 +111,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
dual_cross_attention: bool = False,
use_linear_projection: bool = False,
num_class_embeds: Optional[int] = None,
upcast_attention: bool = False,
):
super().__init__()
@@ -163,6 +164,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention[i],
upcast_attention=upcast_attention,
)
self.down_blocks.append(down_block)
@@ -179,6 +181,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
resnet_groups=norm_num_groups,
dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
upcast_attention=upcast_attention,
)
# count how many layers upsample the images
@@ -219,6 +222,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention[i],
upcast_attention=upcast_attention,
)
self.up_blocks.append(up_block)
prev_output_channel = output_channel

View File

@@ -189,6 +189,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
dual_cross_attention: bool = False,
use_linear_projection: bool = False,
num_class_embeds: Optional[int] = None,
upcast_attention: bool = False,
):
super().__init__()
@@ -241,6 +242,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention[i],
upcast_attention=upcast_attention,
)
self.down_blocks.append(down_block)
@@ -257,6 +259,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
resnet_groups=norm_num_groups,
dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
upcast_attention=upcast_attention,
)
# count how many layers upsample the images
@@ -297,6 +300,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention[i],
upcast_attention=upcast_attention,
)
self.up_blocks.append(up_block)
prev_output_channel = output_channel
@@ -716,6 +720,7 @@ class CrossAttnDownBlockFlat(nn.Module):
dual_cross_attention=False,
use_linear_projection=False,
only_cross_attention=False,
upcast_attention=False,
):
super().__init__()
resnets = []
@@ -751,6 +756,7 @@ class CrossAttnDownBlockFlat(nn.Module):
norm_num_groups=resnet_groups,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention,
upcast_attention=upcast_attention,
)
)
else:
@@ -912,6 +918,7 @@ class CrossAttnUpBlockFlat(nn.Module):
dual_cross_attention=False,
use_linear_projection=False,
only_cross_attention=False,
upcast_attention=False,
):
super().__init__()
resnets = []
@@ -949,6 +956,7 @@ class CrossAttnUpBlockFlat(nn.Module):
norm_num_groups=resnet_groups,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention,
upcast_attention=upcast_attention,
)
)
else:
@@ -1031,6 +1039,7 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
cross_attention_dim=1280,
dual_cross_attention=False,
use_linear_projection=False,
upcast_attention=False,
):
super().__init__()
@@ -1066,6 +1075,7 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
use_linear_projection=use_linear_projection,
upcast_attention=upcast_attention,
)
)
else: