From e00a9cf2d99ff619b50102e7c5364c485cb828d6 Mon Sep 17 00:00:00 2001 From: anton-l Date: Tue, 15 Nov 2022 17:53:26 +0100 Subject: [PATCH] revert dual attn --- src/diffusers/models/attention.py | 167 ---------- src/diffusers/models/unet_2d_blocks.py | 378 +--------------------- src/diffusers/models/unet_2d_condition.py | 4 +- 3 files changed, 2 insertions(+), 547 deletions(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index c436d47269..e8ea37970e 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -225,173 +225,6 @@ class Transformer2DModel(ModelMixin, ConfigMixin): block._set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers) -class DualTransformer2DModel(ModelMixin, ConfigMixin): - """ - Transformer model for image-like data. Takes either discrete (classes of vector embeddings) or continuous (actual - embeddings) inputs. - - When input is continuous: First, project the input (aka embedding) and reshape to b, t, d. Then apply standard - transformer action. Finally, reshape to image. - - When input is discrete: First, input (classes of latent pixels) is converted to embeddings and has positional - embeddings applied, see `ImagePositionalEmbeddings`. Then apply standard transformer action. Finally, predict - classes of unnoised image. - - Note that it is assumed one of the input classes is the masked latent pixel. The predicted classes of the unnoised - image do not contain a prediction for the masked pixel as the unnoised image cannot be masked. - - Parameters: - num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. - attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. - in_channels (`int`, *optional*): - Pass if the input is continuous. The number of channels in the input and output. - num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. - dropout (`float`, *optional*, defaults to 0.1): The dropout probability to use. - cross_attention_dim (`int`, *optional*): The number of context dimensions to use. - sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images. - Note that this is fixed at training time as it is used for learning a number of position embeddings. See - `ImagePositionalEmbeddings`. - num_vector_embeds (`int`, *optional*): - Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels. - Includes the class for the masked latent pixel. - activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. - num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`. - The number of diffusion steps used during training. Note that this is fixed at training time as it is used - to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for - up to but not more than steps than `num_embeds_ada_norm`. - attention_bias (`bool`, *optional*): - Configure if the TransformerBlocks' attention should contain a bias parameter. - """ - - @register_to_config - def __init__( - self, - num_attention_heads: int = 16, - attention_head_dim: int = 88, - in_channels: Optional[int] = None, - num_layers: int = 1, - dropout: float = 0.0, - norm_num_groups: int = 32, - cross_attention_dim: Optional[int] = None, - attention_bias: bool = False, - activation_fn: str = "geglu", - num_embeds_ada_norm: Optional[int] = None, - ): - super().__init__() - self.num_attention_heads = num_attention_heads - self.attention_head_dim = attention_head_dim - inner_dim = num_attention_heads * attention_head_dim - - # 1. DualTransformer2DModel can process both standard continous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)` - # Define whether input is continuous or discrete depending on configuration - self.is_input_continuous = in_channels is not None - - # 2. Define input layers - self.in_channels = in_channels - - self.norm_0 = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) - self.proj_in_0 = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) - self.norm_1 = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) - self.proj_in_1 = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) - - # 3. Define transformers blocks - self.transformer_blocks_0 = nn.ModuleList( - [ - BasicTransformerBlock( - inner_dim, - num_attention_heads, - attention_head_dim, - dropout=dropout, - cross_attention_dim=cross_attention_dim, - activation_fn=activation_fn, - num_embeds_ada_norm=num_embeds_ada_norm, - attention_bias=attention_bias, - ) - for d in range(num_layers) - ] - ) - self.transformer_blocks_1 = nn.ModuleList( - [ - BasicTransformerBlock( - inner_dim, - num_attention_heads, - attention_head_dim, - dropout=dropout, - cross_attention_dim=cross_attention_dim, - activation_fn=activation_fn, - num_embeds_ada_norm=num_embeds_ada_norm, - attention_bias=attention_bias, - ) - for d in range(num_layers) - ] - ) - - # 4. Define output layers - self.proj_out_0 = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) - self.proj_out_1 = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) - - def _set_attention_slice(self, slice_size): - for block in self.transformer_blocks: - block._set_attention_slice(slice_size) - - def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True): - """ - Args: - hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`. - When continous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input - hidden_states - encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, context dim)`, *optional*): - Conditional embeddings for cross attention layer. If not given, cross-attention defaults to - self-attention. - timestep ( `torch.long`, *optional*): - Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. - - Returns: - [`~models.attention.Transformer2DModelOutput`] or `tuple`: [`~models.attention.Transformer2DModelOutput`] - if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample - tensor. - """ - # 1. Input - if self.is_input_continuous: - batch, channel, height, weight = hidden_states.shape - residual = hidden_states - hidden_states = self.norm(hidden_states) - hidden_states = self.proj_in(hidden_states) - inner_dim = hidden_states.shape[1] - hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) - elif self.is_input_vectorized: - hidden_states = self.latent_image_embedding(hidden_states) - - # 2. Blocks - for block in self.transformer_blocks: - hidden_states = block(hidden_states, context=encoder_hidden_states, timestep=timestep) - - # 3. Output - if self.is_input_continuous: - hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2) - hidden_states = self.proj_out(hidden_states) - output = hidden_states + residual - elif self.is_input_vectorized: - hidden_states = self.norm_out(hidden_states) - logits = self.out(hidden_states) - # (batch, self.num_vector_embeds - 1, self.num_latent_pixels) - logits = logits.permute(0, 2, 1) - - # log(p(x_0)) - output = F.log_softmax(logits.double(), dim=1).float() - - if not return_dict: - return (output,) - - return Transformer2DModelOutput(sample=output) - - def _set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool): - for block in self.transformer_blocks: - block._set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers) - - class AttentionBlock(nn.Module): """ An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index eef099eff2..770043f053 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -15,7 +15,7 @@ import numpy as np import torch from torch import nn -from .attention import AttentionBlock, Transformer2DModel, DualTransformer2DModel +from .attention import AttentionBlock, Transformer2DModel from .resnet import Downsample2D, FirDownsample2D, FirUpsample2D, ResnetBlock2D, Upsample2D @@ -75,22 +75,6 @@ def get_down_block( cross_attention_dim=cross_attention_dim, attn_num_head_channels=attn_num_head_channels, ) - elif down_block_type == "DualCrossAttnDownBlock2D": - if cross_attention_dim is None: - raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock2D") - return DualCrossAttnDownBlock2D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - add_downsample=add_downsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - resnet_groups=resnet_groups, - downsample_padding=downsample_padding, - cross_attention_dim=cross_attention_dim, - attn_num_head_channels=attn_num_head_channels, - ) elif down_block_type == "SkipDownBlock2D": return SkipDownBlock2D( num_layers=num_layers, @@ -183,22 +167,6 @@ def get_up_block( cross_attention_dim=cross_attention_dim, attn_num_head_channels=attn_num_head_channels, ) - elif up_block_type == "DualCrossAttnUpBlock2D": - if cross_attention_dim is None: - raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock2D") - return DualCrossAttnUpBlock2D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - prev_output_channel=prev_output_channel, - temb_channels=temb_channels, - add_upsample=add_upsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - resnet_groups=resnet_groups, - cross_attention_dim=cross_attention_dim, - attn_num_head_channels=attn_num_head_channels, - ) elif up_block_type == "AttnUpBlock2D": return AttnUpBlock2D( num_layers=num_layers, @@ -436,103 +404,6 @@ class UNetMidBlock2DCrossAttn(nn.Module): return hidden_states -class UNetMidBlock2DDualCrossAttn(nn.Module): - def __init__( - self, - in_channels: int, - temb_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - resnet_pre_norm: bool = True, - attn_num_head_channels=1, - attention_type="default", - output_scale_factor=1.0, - cross_attention_dim=1280, - **kwargs, - ): - super().__init__() - - self.attention_type = attention_type - self.attn_num_head_channels = attn_num_head_channels - resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) - - # there is always at least one resnet - resnets = [ - ResnetBlock2D( - in_channels=in_channels, - out_channels=in_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ] - attentions = [] - - for _ in range(num_layers): - attentions.append( - DualTransformer2DModel( - attn_num_head_channels, - in_channels // attn_num_head_channels, - in_channels=in_channels, - num_layers=1, - cross_attention_dim=cross_attention_dim, - norm_num_groups=resnet_groups, - ) - ) - resnets.append( - ResnetBlock2D( - in_channels=in_channels, - out_channels=in_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - - self.attentions = nn.ModuleList(attentions) - self.resnets = nn.ModuleList(resnets) - - def set_attention_slice(self, slice_size): - if slice_size is not None and self.attn_num_head_channels % slice_size != 0: - raise ValueError( - f"Make sure slice_size {slice_size} is a divisor of " - f"the number of heads used in cross_attention {self.attn_num_head_channels}" - ) - if slice_size is not None and slice_size > self.attn_num_head_channels: - raise ValueError( - f"Chunk_size {slice_size} has to be smaller or equal to " - f"the number of heads used in cross_attention {self.attn_num_head_channels}" - ) - - for attn in self.attentions: - attn._set_attention_slice(slice_size) - - def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool): - for attn in self.attentions: - attn._set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers) - - def forward(self, hidden_states, temb=None, encoder_hidden_states=None): - hidden_states = self.resnets[0](hidden_states, temb) - for attn, resnet in zip(self.attentions, self.resnets[1:]): - hidden_states = attn(hidden_states, encoder_hidden_states).sample - hidden_states = resnet(hidden_states, temb) - - return hidden_states - class AttnDownBlock2D(nn.Module): def __init__( self, @@ -736,127 +607,6 @@ class CrossAttnDownBlock2D(nn.Module): return hidden_states, output_states -class DualCrossAttnDownBlock2D(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - temb_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - resnet_pre_norm: bool = True, - attn_num_head_channels=1, - cross_attention_dim=1280, - attention_type="default", - output_scale_factor=1.0, - downsample_padding=1, - add_downsample=True, - ): - super().__init__() - resnets = [] - attentions = [] - - self.attention_type = attention_type - self.attn_num_head_channels = attn_num_head_channels - - for i in range(num_layers): - in_channels = in_channels if i == 0 else out_channels - resnets.append( - ResnetBlock2D( - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - attentions.append( - DualTransformer2DModel( - attn_num_head_channels, - out_channels // attn_num_head_channels, - in_channels=out_channels, - num_layers=1, - cross_attention_dim=cross_attention_dim, - norm_num_groups=resnet_groups, - ) - ) - self.attentions = nn.ModuleList(attentions) - self.resnets = nn.ModuleList(resnets) - - if add_downsample: - self.downsamplers = nn.ModuleList( - [ - Downsample2D( - out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" - ) - ] - ) - else: - self.downsamplers = None - - self.gradient_checkpointing = False - - def set_attention_slice(self, slice_size): - if slice_size is not None and self.attn_num_head_channels % slice_size != 0: - raise ValueError( - f"Make sure slice_size {slice_size} is a divisor of " - f"the number of heads used in cross_attention {self.attn_num_head_channels}" - ) - if slice_size is not None and slice_size > self.attn_num_head_channels: - raise ValueError( - f"Chunk_size {slice_size} has to be smaller or equal to " - f"the number of heads used in cross_attention {self.attn_num_head_channels}" - ) - - for attn in self.attentions: - attn._set_attention_slice(slice_size) - - def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool): - for attn in self.attentions: - attn._set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers) - - def forward(self, hidden_states, temb=None, encoder_hidden_states=None): - output_states = () - - for resnet, attn in zip(self.resnets, self.attentions): - if self.training and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states - )[0] - else: - hidden_states = resnet(hidden_states, temb) - hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample - - output_states += (hidden_states,) - - if self.downsamplers is not None: - for downsampler in self.downsamplers: - hidden_states = downsampler(hidden_states) - - output_states += (hidden_states,) - - return hidden_states, output_states - class DownBlock2D(nn.Module): def __init__( @@ -1446,132 +1196,6 @@ class CrossAttnUpBlock2D(nn.Module): return hidden_states -class DualCrossAttnUpBlock2D(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - prev_output_channel: int, - temb_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - resnet_pre_norm: bool = True, - attn_num_head_channels=1, - cross_attention_dim=1280, - attention_type="default", - output_scale_factor=1.0, - add_upsample=True, - ): - super().__init__() - resnets = [] - attentions = [] - - self.attention_type = attention_type - self.attn_num_head_channels = attn_num_head_channels - - for i in range(num_layers): - res_skip_channels = in_channels if (i == num_layers - 1) else out_channels - resnet_in_channels = prev_output_channel if i == 0 else out_channels - - resnets.append( - ResnetBlock2D( - in_channels=resnet_in_channels + res_skip_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - attentions.append( - DualTransformer2DModel( - attn_num_head_channels, - out_channels // attn_num_head_channels, - in_channels=out_channels, - num_layers=1, - cross_attention_dim=cross_attention_dim, - norm_num_groups=resnet_groups, - ) - ) - self.attentions = nn.ModuleList(attentions) - self.resnets = nn.ModuleList(resnets) - - if add_upsample: - self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) - else: - self.upsamplers = None - - self.gradient_checkpointing = False - - def set_attention_slice(self, slice_size): - if slice_size is not None and self.attn_num_head_channels % slice_size != 0: - raise ValueError( - f"Make sure slice_size {slice_size} is a divisor of " - f"the number of heads used in cross_attention {self.attn_num_head_channels}" - ) - if slice_size is not None and slice_size > self.attn_num_head_channels: - raise ValueError( - f"Chunk_size {slice_size} has to be smaller or equal to " - f"the number of heads used in cross_attention {self.attn_num_head_channels}" - ) - - for attn in self.attentions: - attn._set_attention_slice(slice_size) - - self.gradient_checkpointing = False - - def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool): - for attn in self.attentions: - attn._set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers) - - def forward( - self, - hidden_states, - res_hidden_states_tuple, - temb=None, - encoder_hidden_states=None, - upsample_size=None, - ): - for resnet, attn in zip(self.resnets, self.attentions): - # pop res hidden states - res_hidden_states = res_hidden_states_tuple[-1] - res_hidden_states_tuple = res_hidden_states_tuple[:-1] - hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) - - if self.training and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states - )[0] - else: - hidden_states = resnet(hidden_states, temb) - hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample - - if self.upsamplers is not None: - for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states, upsample_size) - - return hidden_states - - class UpBlock2D(nn.Module): def __init__( self, diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index 5dcc705466..7f7f3ecd44 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -27,7 +27,6 @@ from .unet_2d_blocks import ( CrossAttnUpBlock2D, DownBlock2D, UNetMidBlock2DCrossAttn, - UNetMidBlock2DDualCrossAttn, UpBlock2D, get_down_block, get_up_block, @@ -150,8 +149,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin): self.down_blocks.append(down_block) # mid - # TODO: temporary, need to add get_mid_block() - self.mid_block = UNetMidBlock2DDualCrossAttn( + self.mid_block = UNetMidBlock2DCrossAttn( in_channels=block_out_channels[-1], temb_channels=time_embed_dim, resnet_eps=norm_eps,