diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index 770043f053..c645f9f607 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -32,6 +32,7 @@ def get_down_block( resnet_groups=None, cross_attention_dim=None, downsample_padding=None, + dual_cross_attention=False, ): down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type if down_block_type == "DownBlock2D": @@ -74,6 +75,7 @@ def get_down_block( downsample_padding=downsample_padding, cross_attention_dim=cross_attention_dim, attn_num_head_channels=attn_num_head_channels, + dual_cross_attention=dual_cross_attention, ) elif down_block_type == "SkipDownBlock2D": return SkipDownBlock2D( @@ -137,6 +139,7 @@ def get_up_block( attn_num_head_channels, resnet_groups=None, cross_attention_dim=None, + dual_cross_attention=False, ): up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type if up_block_type == "UpBlock2D": @@ -322,6 +325,7 @@ class UNetMidBlock2DCrossAttn(nn.Module): attention_type="default", output_scale_factor=1.0, cross_attention_dim=1280, + dual_cross_attention=False, **kwargs, ): super().__init__() @@ -505,6 +509,7 @@ class CrossAttnDownBlock2D(nn.Module): output_scale_factor=1.0, downsample_padding=1, add_downsample=True, + dual_cross_attention=False, ): super().__init__() resnets = [] @@ -529,16 +534,28 @@ class CrossAttnDownBlock2D(nn.Module): pre_norm=resnet_pre_norm, ) ) - attentions.append( - Transformer2DModel( - attn_num_head_channels, - out_channels // attn_num_head_channels, - in_channels=out_channels, - num_layers=1, - cross_attention_dim=cross_attention_dim, - norm_num_groups=resnet_groups, + if dual_cross_attention is False: + attentions.append( + Transformer2DModel( + attn_num_head_channels, + out_channels // attn_num_head_channels, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + else: + attentions.append( + DualTransformer2DModel( + attn_num_head_channels, + out_channels // attn_num_head_channels, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) ) - ) self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index 5a02a3ba1e..49ccb66e4a 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -106,6 +106,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin): norm_eps: float = 1e-5, cross_attention_dim: int = 1280, attention_head_dim: int = 8, + dual_cross_attention: bool = False, ): super().__init__() @@ -145,6 +146,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin): cross_attention_dim=cross_attention_dim, attn_num_head_channels=attention_head_dim, downsample_padding=downsample_padding, + dual_cross_attention=dual_cross_attention, ) self.down_blocks.append(down_block) @@ -159,6 +161,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin): cross_attention_dim=cross_attention_dim, attn_num_head_channels=attention_head_dim, resnet_groups=norm_num_groups, + dual_cross_attention=dual_cross_attention, ) # count how many layers upsample the images @@ -194,6 +197,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin): resnet_groups=norm_num_groups, cross_attention_dim=cross_attention_dim, attn_num_head_channels=attention_head_dim, + dual_cross_attention=dual_cross_attention, ) self.up_blocks.append(up_block) prev_output_channel = output_channel diff --git a/tests/pipelines/versatile_diffusion/test_versatile_diffusion_image_to_text.py b/tests/pipelines/versatile_diffusion/test_versatile_diffusion_image_to_text.py index 648ef96758..dbaaeeb262 100644 --- a/tests/pipelines/versatile_diffusion/test_versatile_diffusion_image_to_text.py +++ b/tests/pipelines/versatile_diffusion/test_versatile_diffusion_image_to_text.py @@ -40,7 +40,7 @@ class VersatileDiffusionImageToTextPipelineIntegrationTests(unittest.TestCase): pipe.set_progress_bar_config(disable=None) image_prompt = load_image( - "https://raw.githubusercontent.com/SHI-Labs/Versatile-Diffusion/master/assets/benz.jpg" + "https://raw.githubusercontent.com/SHI-Labs/Versatile-Diffusion/master/assets/boy_and_girl.jpg" ) generator = torch.Generator(device=torch_device).manual_seed(0) text = pipe(