From f07a16e09bb5b1cf4fa2306bfa4ea791f24fa968 Mon Sep 17 00:00:00 2001 From: Suraj Patil Date: Wed, 23 Nov 2022 20:46:30 +0100 Subject: [PATCH] update unet2d (#1376) * boom boom * remove duplicate arg * add use_linear_proj arg * fix copies * style * add fast tests * use_linear_proj -> use_linear_projection --- src/diffusers/models/attention.py | 35 +++++++++++++---- src/diffusers/models/unet_2d_blocks.py | 10 +++++ src/diffusers/models/unet_2d_condition.py | 21 ++++++---- .../versatile_diffusion/modeling_text_unet.py | 27 +++++++++---- tests/models/test_models_unet_2d.py | 38 +++++++++++++++++++ 5 files changed, 110 insertions(+), 21 deletions(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 6b2bd5205b..92d84acbbe 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -99,8 +99,10 @@ class Transformer2DModel(ModelMixin, ConfigMixin): num_vector_embeds: Optional[int] = None, activation_fn: str = "geglu", num_embeds_ada_norm: Optional[int] = None, + use_linear_projection: bool = False, ): super().__init__() + self.use_linear_projection = use_linear_projection self.num_attention_heads = num_attention_heads self.attention_head_dim = attention_head_dim inner_dim = num_attention_heads * attention_head_dim @@ -126,7 +128,10 @@ class Transformer2DModel(ModelMixin, ConfigMixin): self.in_channels = in_channels self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) - self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) + if use_linear_projection: + self.proj_in = nn.Linear(in_channels, inner_dim) + else: + self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) elif self.is_input_vectorized: assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size" assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed" @@ -159,7 +164,10 @@ class Transformer2DModel(ModelMixin, ConfigMixin): # 4. Define output layers if self.is_input_continuous: - self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) + if use_linear_projection: + self.proj_out = nn.Linear(in_channels, inner_dim) + else: + self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) elif self.is_input_vectorized: self.norm_out = nn.LayerNorm(inner_dim) self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1) @@ -191,10 +199,18 @@ class Transformer2DModel(ModelMixin, ConfigMixin): if self.is_input_continuous: batch, channel, height, weight = hidden_states.shape residual = hidden_states + hidden_states = self.norm(hidden_states) - hidden_states = self.proj_in(hidden_states) - inner_dim = hidden_states.shape[1] - hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) + + if not self.use_linear_projection: + hidden_states = self.proj_in(hidden_states) + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) + else: + hidden_states = self.norm(hidden_states) + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) + hidden_states = self.proj_in(hidden_states) elif self.is_input_vectorized: hidden_states = self.latent_image_embedding(hidden_states) @@ -204,8 +220,13 @@ class Transformer2DModel(ModelMixin, ConfigMixin): # 3. Output if self.is_input_continuous: - hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2) - hidden_states = self.proj_out(hidden_states) + if not self.use_linear_projection: + hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2) + hidden_states = self.proj_out(hidden_states) + else: + hidden_states = self.proj_out(hidden_states) + hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2) + output = hidden_states + residual elif self.is_input_vectorized: hidden_states = self.norm_out(hidden_states) diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index 4dd15845e0..5a8a97187f 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -33,6 +33,7 @@ def get_down_block( cross_attention_dim=None, downsample_padding=None, dual_cross_attention=False, + use_linear_projection=False, ): down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type if down_block_type == "DownBlock2D": @@ -76,6 +77,7 @@ def get_down_block( cross_attention_dim=cross_attention_dim, attn_num_head_channels=attn_num_head_channels, dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, ) elif down_block_type == "SkipDownBlock2D": return SkipDownBlock2D( @@ -140,6 +142,7 @@ def get_up_block( resnet_groups=None, cross_attention_dim=None, dual_cross_attention=False, + use_linear_projection=False, ): up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type if up_block_type == "UpBlock2D": @@ -170,6 +173,7 @@ def get_up_block( cross_attention_dim=cross_attention_dim, attn_num_head_channels=attn_num_head_channels, dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, ) elif up_block_type == "AttnUpBlock2D": return AttnUpBlock2D( @@ -327,6 +331,7 @@ class UNetMidBlock2DCrossAttn(nn.Module): output_scale_factor=1.0, cross_attention_dim=1280, dual_cross_attention=False, + use_linear_projection=False, **kwargs, ): super().__init__() @@ -362,6 +367,7 @@ class UNetMidBlock2DCrossAttn(nn.Module): num_layers=1, cross_attention_dim=cross_attention_dim, norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, ) ) else: @@ -523,6 +529,7 @@ class CrossAttnDownBlock2D(nn.Module): downsample_padding=1, add_downsample=True, dual_cross_attention=False, + use_linear_projection=False, ): super().__init__() resnets = [] @@ -556,6 +563,7 @@ class CrossAttnDownBlock2D(nn.Module): num_layers=1, cross_attention_dim=cross_attention_dim, norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, ) ) else: @@ -1120,6 +1128,7 @@ class CrossAttnUpBlock2D(nn.Module): output_scale_factor=1.0, add_upsample=True, dual_cross_attention=False, + use_linear_projection=False, ): super().__init__() resnets = [] @@ -1155,6 +1164,7 @@ class CrossAttnUpBlock2D(nn.Module): num_layers=1, cross_attention_dim=cross_attention_dim, norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, ) ) else: diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index 4eaed803ce..2060971493 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -61,7 +61,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin): in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample. out_channels (`int`, *optional*, defaults to 4): The number of channels in the output. center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample. - flip_sin_to_cos (`bool`, *optional*, defaults to `True`): + flip_sin_to_cos (`bool`, *optional*, defaults to `False`): Whether to flip the sin to cos in the time embedding. freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding. down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): @@ -106,8 +106,9 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin): norm_num_groups: int = 32, norm_eps: float = 1e-5, cross_attention_dim: int = 1280, - attention_head_dim: int = 8, + attention_head_dim: Union[int, Tuple[int]] = 8, dual_cross_attention: bool = False, + use_linear_projection: bool = False, ): super().__init__() @@ -127,6 +128,9 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin): self.mid_block = None self.up_blocks = nn.ModuleList([]) + if isinstance(attention_head_dim, int): + attention_head_dim = (attention_head_dim,) * len(down_block_types) + # down output_channel = block_out_channels[0] for i, down_block_type in enumerate(down_block_types): @@ -145,9 +149,10 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin): resnet_act_fn=act_fn, resnet_groups=norm_num_groups, cross_attention_dim=cross_attention_dim, - attn_num_head_channels=attention_head_dim, + attn_num_head_channels=attention_head_dim[i], downsample_padding=downsample_padding, dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, ) self.down_blocks.append(down_block) @@ -160,9 +165,10 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin): output_scale_factor=mid_block_scale_factor, resnet_time_scale_shift="default", cross_attention_dim=cross_attention_dim, - attn_num_head_channels=attention_head_dim, + attn_num_head_channels=attention_head_dim[-1], resnet_groups=norm_num_groups, dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, ) # count how many layers upsample the images @@ -170,6 +176,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin): # up reversed_block_out_channels = list(reversed(block_out_channels)) + reversed_attention_head_dim = list(reversed(attention_head_dim)) output_channel = reversed_block_out_channels[0] for i, up_block_type in enumerate(up_block_types): is_final_block = i == len(block_out_channels) - 1 @@ -197,8 +204,9 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin): resnet_act_fn=act_fn, resnet_groups=norm_num_groups, cross_attention_dim=cross_attention_dim, - attn_num_head_channels=attention_head_dim, + attn_num_head_channels=reversed_attention_head_dim[i], dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, ) self.up_blocks.append(up_block) prev_output_channel = output_channel @@ -256,8 +264,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin): Args: sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps - encoder_hidden_states (`torch.FloatTensor`): - (batch_size, sequence_length, hidden_size) encoder hidden states + encoder_hidden_states (`torch.FloatTensor`): (batch, channel, height, width) encoder hidden states return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. diff --git a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py index c89080a59e..6d521228e3 100644 --- a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py @@ -124,7 +124,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample. out_channels (`int`, *optional*, defaults to 4): The number of channels in the output. center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample. - flip_sin_to_cos (`bool`, *optional*, defaults to `True`): + flip_sin_to_cos (`bool`, *optional*, defaults to `False`): Whether to flip the sin to cos in the time embedding. freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding. down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlockFlat", "CrossAttnDownBlockFlat", "CrossAttnDownBlockFlat", "DownBlockFlat")`): @@ -174,8 +174,9 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): norm_num_groups: int = 32, norm_eps: float = 1e-5, cross_attention_dim: int = 1280, - attention_head_dim: int = 8, + attention_head_dim: Union[int, Tuple[int]] = 8, dual_cross_attention: bool = False, + use_linear_projection: bool = False, ): super().__init__() @@ -195,6 +196,9 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): self.mid_block = None self.up_blocks = nn.ModuleList([]) + if isinstance(attention_head_dim, int): + attention_head_dim = (attention_head_dim,) * len(down_block_types) + # down output_channel = block_out_channels[0] for i, down_block_type in enumerate(down_block_types): @@ -213,9 +217,10 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): resnet_act_fn=act_fn, resnet_groups=norm_num_groups, cross_attention_dim=cross_attention_dim, - attn_num_head_channels=attention_head_dim, + attn_num_head_channels=attention_head_dim[i], downsample_padding=downsample_padding, dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, ) self.down_blocks.append(down_block) @@ -228,9 +233,10 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): output_scale_factor=mid_block_scale_factor, resnet_time_scale_shift="default", cross_attention_dim=cross_attention_dim, - attn_num_head_channels=attention_head_dim, + attn_num_head_channels=attention_head_dim[-1], resnet_groups=norm_num_groups, dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, ) # count how many layers upsample the images @@ -238,6 +244,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): # up reversed_block_out_channels = list(reversed(block_out_channels)) + reversed_attention_head_dim = list(reversed(attention_head_dim)) output_channel = reversed_block_out_channels[0] for i, up_block_type in enumerate(up_block_types): is_final_block = i == len(block_out_channels) - 1 @@ -265,8 +272,9 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): resnet_act_fn=act_fn, resnet_groups=norm_num_groups, cross_attention_dim=cross_attention_dim, - attn_num_head_channels=attention_head_dim, + attn_num_head_channels=reversed_attention_head_dim[i], dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, ) self.up_blocks.append(up_block) prev_output_channel = output_channel @@ -324,8 +332,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): Args: sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps - encoder_hidden_states (`torch.FloatTensor`): - (batch_size, sequence_length, hidden_size) encoder hidden states + encoder_hidden_states (`torch.FloatTensor`): (batch, channel, height, width) encoder hidden states return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. @@ -640,6 +647,7 @@ class CrossAttnDownBlockFlat(nn.Module): downsample_padding=1, add_downsample=True, dual_cross_attention=False, + use_linear_projection=False, ): super().__init__() resnets = [] @@ -673,6 +681,7 @@ class CrossAttnDownBlockFlat(nn.Module): num_layers=1, cross_attention_dim=cross_attention_dim, norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, ) ) else: @@ -851,6 +860,7 @@ class CrossAttnUpBlockFlat(nn.Module): output_scale_factor=1.0, add_upsample=True, dual_cross_attention=False, + use_linear_projection=False, ): super().__init__() resnets = [] @@ -886,6 +896,7 @@ class CrossAttnUpBlockFlat(nn.Module): num_layers=1, cross_attention_dim=cross_attention_dim, norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, ) ) else: @@ -988,6 +999,7 @@ class UNetMidBlockFlatCrossAttn(nn.Module): output_scale_factor=1.0, cross_attention_dim=1280, dual_cross_attention=False, + use_linear_projection=False, **kwargs, ): super().__init__() @@ -1023,6 +1035,7 @@ class UNetMidBlockFlatCrossAttn(nn.Module): num_layers=1, cross_attention_dim=cross_attention_dim, norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, ) ) else: diff --git a/tests/models/test_models_unet_2d.py b/tests/models/test_models_unet_2d.py index 81437311c6..02c6d314bf 100644 --- a/tests/models/test_models_unet_2d.py +++ b/tests/models/test_models_unet_2d.py @@ -296,6 +296,44 @@ class UNet2DConditionModelTests(ModelTesterMixin, unittest.TestCase): for name, param in named_params.items(): self.assertTrue(torch_all_close(param.grad.data, named_params_2[name].grad.data, atol=5e-5)) + def test_model_with_attention_head_dim_tuple(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + init_dict["attention_head_dim"] = (8, 16) + + model = self.model_class(**init_dict) + model.to(torch_device) + model.eval() + + with torch.no_grad(): + output = model(**inputs_dict) + + if isinstance(output, dict): + output = output.sample + + self.assertIsNotNone(output) + expected_shape = inputs_dict["sample"].shape + self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") + + def test_model_with_use_linear_projection(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + init_dict["use_linear_projection"] = True + + model = self.model_class(**init_dict) + model.to(torch_device) + model.eval() + + with torch.no_grad(): + output = model(**inputs_dict) + + if isinstance(output, dict): + output = output.sample + + self.assertIsNotNone(output) + expected_shape = inputs_dict["sample"].shape + self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") + class NCSNppModelTests(ModelTesterMixin, unittest.TestCase): model_class = UNet2DModel