diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 7ac7a263ce..f09172be41 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -101,6 +101,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin): num_embeds_ada_norm: Optional[int] = None, use_linear_projection: bool = False, only_cross_attention: bool = False, + upcast_attention: bool = False, ): super().__init__() self.use_linear_projection = use_linear_projection @@ -159,6 +160,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin): num_embeds_ada_norm=num_embeds_ada_norm, attention_bias=attention_bias, only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, ) for d in range(num_layers) ] @@ -403,6 +405,7 @@ class BasicTransformerBlock(nn.Module): num_embeds_ada_norm: Optional[int] = None, attention_bias: bool = False, only_cross_attention: bool = False, + upcast_attention: bool = False, ): super().__init__() self.only_cross_attention = only_cross_attention @@ -416,6 +419,7 @@ class BasicTransformerBlock(nn.Module): dropout=dropout, bias=attention_bias, cross_attention_dim=cross_attention_dim if only_cross_attention else None, + upcast_attention=upcast_attention, ) # is a self-attention self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn) @@ -428,6 +432,7 @@ class BasicTransformerBlock(nn.Module): dim_head=attention_head_dim, dropout=dropout, bias=attention_bias, + upcast_attention=upcast_attention, ) # is self-attn if context is none else: self.attn2 = None @@ -525,10 +530,12 @@ class CrossAttention(nn.Module): dim_head: int = 64, dropout: float = 0.0, bias=False, + upcast_attention: bool = False, ): super().__init__() inner_dim = dim_head * heads cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim + self.upcast_attention = upcast_attention self.scale = dim_head**-0.5 self.heads = heads @@ -601,6 +608,10 @@ class CrossAttention(nn.Module): return hidden_states def _attention(self, query, key, value): + if self.upcast_attention: + query = query.float() + key = key.float() + attention_scores = torch.baddbmm( torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device), query, @@ -609,8 +620,11 @@ class CrossAttention(nn.Module): alpha=self.scale, ) attention_probs = attention_scores.softmax(dim=-1) - # compute attention output + # cast back to the original dtype + attention_probs = attention_probs.to(value.dtype) + + # compute attention output hidden_states = torch.bmm(attention_probs, value) # reshape hidden_states @@ -626,6 +640,14 @@ class CrossAttention(nn.Module): for i in range(hidden_states.shape[0] // slice_size): start_idx = i * slice_size end_idx = (i + 1) * slice_size + + query_slice = query[start_idx:end_idx] + key_slice = key[start_idx:end_idx] + + if self.upcast_attention: + query_slice = query_slice.float() + key_slice = key_slice.float() + attn_slice = torch.baddbmm( torch.empty(slice_size, query.shape[1], key.shape[1], dtype=query.dtype, device=query.device), query[start_idx:end_idx], @@ -634,6 +656,9 @@ class CrossAttention(nn.Module): alpha=self.scale, ) attn_slice = attn_slice.softmax(dim=-1) + + # cast back to the original dtype + attn_slice = attn_slice.to(value.dtype) attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx]) hidden_states[start_idx:end_idx] = attn_slice diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index 63e2b809d7..726f050e65 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -35,6 +35,7 @@ def get_down_block( dual_cross_attention=False, use_linear_projection=False, only_cross_attention=False, + upcast_attention=False, ): down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type if down_block_type == "DownBlock2D": @@ -80,6 +81,7 @@ def get_down_block( dual_cross_attention=dual_cross_attention, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, ) elif down_block_type == "SkipDownBlock2D": return SkipDownBlock2D( @@ -146,6 +148,7 @@ def get_up_block( dual_cross_attention=False, use_linear_projection=False, only_cross_attention=False, + upcast_attention=False, ): up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type if up_block_type == "UpBlock2D": @@ -178,6 +181,7 @@ def get_up_block( dual_cross_attention=dual_cross_attention, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, ) elif up_block_type == "AttnUpBlock2D": return AttnUpBlock2D( @@ -335,6 +339,7 @@ class UNetMidBlock2DCrossAttn(nn.Module): cross_attention_dim=1280, dual_cross_attention=False, use_linear_projection=False, + upcast_attention=False, ): super().__init__() @@ -370,6 +375,7 @@ class UNetMidBlock2DCrossAttn(nn.Module): cross_attention_dim=cross_attention_dim, norm_num_groups=resnet_groups, use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, ) ) else: @@ -514,6 +520,7 @@ class CrossAttnDownBlock2D(nn.Module): dual_cross_attention=False, use_linear_projection=False, only_cross_attention=False, + upcast_attention=False, ): super().__init__() resnets = [] @@ -549,6 +556,7 @@ class CrossAttnDownBlock2D(nn.Module): norm_num_groups=resnet_groups, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, ) ) else: @@ -1096,6 +1104,7 @@ class CrossAttnUpBlock2D(nn.Module): dual_cross_attention=False, use_linear_projection=False, only_cross_attention=False, + upcast_attention=False, ): super().__init__() resnets = [] @@ -1133,6 +1142,7 @@ class CrossAttnUpBlock2D(nn.Module): norm_num_groups=resnet_groups, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, ) ) else: diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index a712169e86..7d6db4aba6 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -111,6 +111,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin): dual_cross_attention: bool = False, use_linear_projection: bool = False, num_class_embeds: Optional[int] = None, + upcast_attention: bool = False, ): super().__init__() @@ -163,6 +164,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin): dual_cross_attention=dual_cross_attention, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, ) self.down_blocks.append(down_block) @@ -179,6 +181,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin): resnet_groups=norm_num_groups, dual_cross_attention=dual_cross_attention, use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, ) # count how many layers upsample the images @@ -219,6 +222,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin): dual_cross_attention=dual_cross_attention, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, ) self.up_blocks.append(up_block) prev_output_channel = output_channel diff --git a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py index 3f833d17a6..f1cf46aaf6 100644 --- a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py @@ -189,6 +189,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): dual_cross_attention: bool = False, use_linear_projection: bool = False, num_class_embeds: Optional[int] = None, + upcast_attention: bool = False, ): super().__init__() @@ -241,6 +242,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): dual_cross_attention=dual_cross_attention, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, ) self.down_blocks.append(down_block) @@ -257,6 +259,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): resnet_groups=norm_num_groups, dual_cross_attention=dual_cross_attention, use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, ) # count how many layers upsample the images @@ -297,6 +300,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): dual_cross_attention=dual_cross_attention, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, ) self.up_blocks.append(up_block) prev_output_channel = output_channel @@ -716,6 +720,7 @@ class CrossAttnDownBlockFlat(nn.Module): dual_cross_attention=False, use_linear_projection=False, only_cross_attention=False, + upcast_attention=False, ): super().__init__() resnets = [] @@ -751,6 +756,7 @@ class CrossAttnDownBlockFlat(nn.Module): norm_num_groups=resnet_groups, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, ) ) else: @@ -912,6 +918,7 @@ class CrossAttnUpBlockFlat(nn.Module): dual_cross_attention=False, use_linear_projection=False, only_cross_attention=False, + upcast_attention=False, ): super().__init__() resnets = [] @@ -949,6 +956,7 @@ class CrossAttnUpBlockFlat(nn.Module): norm_num_groups=resnet_groups, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, ) ) else: @@ -1031,6 +1039,7 @@ class UNetMidBlockFlatCrossAttn(nn.Module): cross_attention_dim=1280, dual_cross_attention=False, use_linear_projection=False, + upcast_attention=False, ): super().__init__() @@ -1066,6 +1075,7 @@ class UNetMidBlockFlatCrossAttn(nn.Module): cross_attention_dim=cross_attention_dim, norm_num_groups=resnet_groups, use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, ) ) else: