mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[UNet2DConditionModel] add an option to upcast attention to fp32 (#1590)
upcast attention
This commit is contained in:
@@ -101,6 +101,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
|
||||
num_embeds_ada_norm: Optional[int] = None,
|
||||
use_linear_projection: bool = False,
|
||||
only_cross_attention: bool = False,
|
||||
upcast_attention: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.use_linear_projection = use_linear_projection
|
||||
@@ -159,6 +160,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
|
||||
num_embeds_ada_norm=num_embeds_ada_norm,
|
||||
attention_bias=attention_bias,
|
||||
only_cross_attention=only_cross_attention,
|
||||
upcast_attention=upcast_attention,
|
||||
)
|
||||
for d in range(num_layers)
|
||||
]
|
||||
@@ -403,6 +405,7 @@ class BasicTransformerBlock(nn.Module):
|
||||
num_embeds_ada_norm: Optional[int] = None,
|
||||
attention_bias: bool = False,
|
||||
only_cross_attention: bool = False,
|
||||
upcast_attention: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.only_cross_attention = only_cross_attention
|
||||
@@ -416,6 +419,7 @@ class BasicTransformerBlock(nn.Module):
|
||||
dropout=dropout,
|
||||
bias=attention_bias,
|
||||
cross_attention_dim=cross_attention_dim if only_cross_attention else None,
|
||||
upcast_attention=upcast_attention,
|
||||
) # is a self-attention
|
||||
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
|
||||
|
||||
@@ -428,6 +432,7 @@ class BasicTransformerBlock(nn.Module):
|
||||
dim_head=attention_head_dim,
|
||||
dropout=dropout,
|
||||
bias=attention_bias,
|
||||
upcast_attention=upcast_attention,
|
||||
) # is self-attn if context is none
|
||||
else:
|
||||
self.attn2 = None
|
||||
@@ -525,10 +530,12 @@ class CrossAttention(nn.Module):
|
||||
dim_head: int = 64,
|
||||
dropout: float = 0.0,
|
||||
bias=False,
|
||||
upcast_attention: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
|
||||
self.upcast_attention = upcast_attention
|
||||
|
||||
self.scale = dim_head**-0.5
|
||||
self.heads = heads
|
||||
@@ -601,6 +608,10 @@ class CrossAttention(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
def _attention(self, query, key, value):
|
||||
if self.upcast_attention:
|
||||
query = query.float()
|
||||
key = key.float()
|
||||
|
||||
attention_scores = torch.baddbmm(
|
||||
torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
|
||||
query,
|
||||
@@ -609,8 +620,11 @@ class CrossAttention(nn.Module):
|
||||
alpha=self.scale,
|
||||
)
|
||||
attention_probs = attention_scores.softmax(dim=-1)
|
||||
# compute attention output
|
||||
|
||||
# cast back to the original dtype
|
||||
attention_probs = attention_probs.to(value.dtype)
|
||||
|
||||
# compute attention output
|
||||
hidden_states = torch.bmm(attention_probs, value)
|
||||
|
||||
# reshape hidden_states
|
||||
@@ -626,6 +640,14 @@ class CrossAttention(nn.Module):
|
||||
for i in range(hidden_states.shape[0] // slice_size):
|
||||
start_idx = i * slice_size
|
||||
end_idx = (i + 1) * slice_size
|
||||
|
||||
query_slice = query[start_idx:end_idx]
|
||||
key_slice = key[start_idx:end_idx]
|
||||
|
||||
if self.upcast_attention:
|
||||
query_slice = query_slice.float()
|
||||
key_slice = key_slice.float()
|
||||
|
||||
attn_slice = torch.baddbmm(
|
||||
torch.empty(slice_size, query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
|
||||
query[start_idx:end_idx],
|
||||
@@ -634,6 +656,9 @@ class CrossAttention(nn.Module):
|
||||
alpha=self.scale,
|
||||
)
|
||||
attn_slice = attn_slice.softmax(dim=-1)
|
||||
|
||||
# cast back to the original dtype
|
||||
attn_slice = attn_slice.to(value.dtype)
|
||||
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
|
||||
|
||||
hidden_states[start_idx:end_idx] = attn_slice
|
||||
|
||||
@@ -35,6 +35,7 @@ def get_down_block(
|
||||
dual_cross_attention=False,
|
||||
use_linear_projection=False,
|
||||
only_cross_attention=False,
|
||||
upcast_attention=False,
|
||||
):
|
||||
down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
|
||||
if down_block_type == "DownBlock2D":
|
||||
@@ -80,6 +81,7 @@ def get_down_block(
|
||||
dual_cross_attention=dual_cross_attention,
|
||||
use_linear_projection=use_linear_projection,
|
||||
only_cross_attention=only_cross_attention,
|
||||
upcast_attention=upcast_attention,
|
||||
)
|
||||
elif down_block_type == "SkipDownBlock2D":
|
||||
return SkipDownBlock2D(
|
||||
@@ -146,6 +148,7 @@ def get_up_block(
|
||||
dual_cross_attention=False,
|
||||
use_linear_projection=False,
|
||||
only_cross_attention=False,
|
||||
upcast_attention=False,
|
||||
):
|
||||
up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
|
||||
if up_block_type == "UpBlock2D":
|
||||
@@ -178,6 +181,7 @@ def get_up_block(
|
||||
dual_cross_attention=dual_cross_attention,
|
||||
use_linear_projection=use_linear_projection,
|
||||
only_cross_attention=only_cross_attention,
|
||||
upcast_attention=upcast_attention,
|
||||
)
|
||||
elif up_block_type == "AttnUpBlock2D":
|
||||
return AttnUpBlock2D(
|
||||
@@ -335,6 +339,7 @@ class UNetMidBlock2DCrossAttn(nn.Module):
|
||||
cross_attention_dim=1280,
|
||||
dual_cross_attention=False,
|
||||
use_linear_projection=False,
|
||||
upcast_attention=False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -370,6 +375,7 @@ class UNetMidBlock2DCrossAttn(nn.Module):
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
norm_num_groups=resnet_groups,
|
||||
use_linear_projection=use_linear_projection,
|
||||
upcast_attention=upcast_attention,
|
||||
)
|
||||
)
|
||||
else:
|
||||
@@ -514,6 +520,7 @@ class CrossAttnDownBlock2D(nn.Module):
|
||||
dual_cross_attention=False,
|
||||
use_linear_projection=False,
|
||||
only_cross_attention=False,
|
||||
upcast_attention=False,
|
||||
):
|
||||
super().__init__()
|
||||
resnets = []
|
||||
@@ -549,6 +556,7 @@ class CrossAttnDownBlock2D(nn.Module):
|
||||
norm_num_groups=resnet_groups,
|
||||
use_linear_projection=use_linear_projection,
|
||||
only_cross_attention=only_cross_attention,
|
||||
upcast_attention=upcast_attention,
|
||||
)
|
||||
)
|
||||
else:
|
||||
@@ -1096,6 +1104,7 @@ class CrossAttnUpBlock2D(nn.Module):
|
||||
dual_cross_attention=False,
|
||||
use_linear_projection=False,
|
||||
only_cross_attention=False,
|
||||
upcast_attention=False,
|
||||
):
|
||||
super().__init__()
|
||||
resnets = []
|
||||
@@ -1133,6 +1142,7 @@ class CrossAttnUpBlock2D(nn.Module):
|
||||
norm_num_groups=resnet_groups,
|
||||
use_linear_projection=use_linear_projection,
|
||||
only_cross_attention=only_cross_attention,
|
||||
upcast_attention=upcast_attention,
|
||||
)
|
||||
)
|
||||
else:
|
||||
|
||||
@@ -111,6 +111,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
|
||||
dual_cross_attention: bool = False,
|
||||
use_linear_projection: bool = False,
|
||||
num_class_embeds: Optional[int] = None,
|
||||
upcast_attention: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -163,6 +164,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
|
||||
dual_cross_attention=dual_cross_attention,
|
||||
use_linear_projection=use_linear_projection,
|
||||
only_cross_attention=only_cross_attention[i],
|
||||
upcast_attention=upcast_attention,
|
||||
)
|
||||
self.down_blocks.append(down_block)
|
||||
|
||||
@@ -179,6 +181,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
|
||||
resnet_groups=norm_num_groups,
|
||||
dual_cross_attention=dual_cross_attention,
|
||||
use_linear_projection=use_linear_projection,
|
||||
upcast_attention=upcast_attention,
|
||||
)
|
||||
|
||||
# count how many layers upsample the images
|
||||
@@ -219,6 +222,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
|
||||
dual_cross_attention=dual_cross_attention,
|
||||
use_linear_projection=use_linear_projection,
|
||||
only_cross_attention=only_cross_attention[i],
|
||||
upcast_attention=upcast_attention,
|
||||
)
|
||||
self.up_blocks.append(up_block)
|
||||
prev_output_channel = output_channel
|
||||
|
||||
@@ -189,6 +189,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
dual_cross_attention: bool = False,
|
||||
use_linear_projection: bool = False,
|
||||
num_class_embeds: Optional[int] = None,
|
||||
upcast_attention: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -241,6 +242,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
dual_cross_attention=dual_cross_attention,
|
||||
use_linear_projection=use_linear_projection,
|
||||
only_cross_attention=only_cross_attention[i],
|
||||
upcast_attention=upcast_attention,
|
||||
)
|
||||
self.down_blocks.append(down_block)
|
||||
|
||||
@@ -257,6 +259,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
resnet_groups=norm_num_groups,
|
||||
dual_cross_attention=dual_cross_attention,
|
||||
use_linear_projection=use_linear_projection,
|
||||
upcast_attention=upcast_attention,
|
||||
)
|
||||
|
||||
# count how many layers upsample the images
|
||||
@@ -297,6 +300,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
dual_cross_attention=dual_cross_attention,
|
||||
use_linear_projection=use_linear_projection,
|
||||
only_cross_attention=only_cross_attention[i],
|
||||
upcast_attention=upcast_attention,
|
||||
)
|
||||
self.up_blocks.append(up_block)
|
||||
prev_output_channel = output_channel
|
||||
@@ -716,6 +720,7 @@ class CrossAttnDownBlockFlat(nn.Module):
|
||||
dual_cross_attention=False,
|
||||
use_linear_projection=False,
|
||||
only_cross_attention=False,
|
||||
upcast_attention=False,
|
||||
):
|
||||
super().__init__()
|
||||
resnets = []
|
||||
@@ -751,6 +756,7 @@ class CrossAttnDownBlockFlat(nn.Module):
|
||||
norm_num_groups=resnet_groups,
|
||||
use_linear_projection=use_linear_projection,
|
||||
only_cross_attention=only_cross_attention,
|
||||
upcast_attention=upcast_attention,
|
||||
)
|
||||
)
|
||||
else:
|
||||
@@ -912,6 +918,7 @@ class CrossAttnUpBlockFlat(nn.Module):
|
||||
dual_cross_attention=False,
|
||||
use_linear_projection=False,
|
||||
only_cross_attention=False,
|
||||
upcast_attention=False,
|
||||
):
|
||||
super().__init__()
|
||||
resnets = []
|
||||
@@ -949,6 +956,7 @@ class CrossAttnUpBlockFlat(nn.Module):
|
||||
norm_num_groups=resnet_groups,
|
||||
use_linear_projection=use_linear_projection,
|
||||
only_cross_attention=only_cross_attention,
|
||||
upcast_attention=upcast_attention,
|
||||
)
|
||||
)
|
||||
else:
|
||||
@@ -1031,6 +1039,7 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
|
||||
cross_attention_dim=1280,
|
||||
dual_cross_attention=False,
|
||||
use_linear_projection=False,
|
||||
upcast_attention=False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -1066,6 +1075,7 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
norm_num_groups=resnet_groups,
|
||||
use_linear_projection=use_linear_projection,
|
||||
upcast_attention=upcast_attention,
|
||||
)
|
||||
)
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user