1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

propose change

This commit is contained in:
Patrick von Platen
2022-11-23 08:43:17 +00:00
parent 95e37119e9
commit 22c6b32672
3 changed files with 31 additions and 10 deletions

View File

@@ -32,6 +32,7 @@ def get_down_block(
resnet_groups=None,
cross_attention_dim=None,
downsample_padding=None,
dual_cross_attention=False,
):
down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
if down_block_type == "DownBlock2D":
@@ -74,6 +75,7 @@ def get_down_block(
downsample_padding=downsample_padding,
cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attn_num_head_channels,
dual_cross_attention=dual_cross_attention,
)
elif down_block_type == "SkipDownBlock2D":
return SkipDownBlock2D(
@@ -137,6 +139,7 @@ def get_up_block(
attn_num_head_channels,
resnet_groups=None,
cross_attention_dim=None,
dual_cross_attention=False,
):
up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
if up_block_type == "UpBlock2D":
@@ -322,6 +325,7 @@ class UNetMidBlock2DCrossAttn(nn.Module):
attention_type="default",
output_scale_factor=1.0,
cross_attention_dim=1280,
dual_cross_attention=False,
**kwargs,
):
super().__init__()
@@ -505,6 +509,7 @@ class CrossAttnDownBlock2D(nn.Module):
output_scale_factor=1.0,
downsample_padding=1,
add_downsample=True,
dual_cross_attention=False,
):
super().__init__()
resnets = []
@@ -529,16 +534,28 @@ class CrossAttnDownBlock2D(nn.Module):
pre_norm=resnet_pre_norm,
)
)
attentions.append(
Transformer2DModel(
attn_num_head_channels,
out_channels // attn_num_head_channels,
in_channels=out_channels,
num_layers=1,
cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
if dual_cross_attention is False:
attentions.append(
Transformer2DModel(
attn_num_head_channels,
out_channels // attn_num_head_channels,
in_channels=out_channels,
num_layers=1,
cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
)
)
else:
attentions.append(
DualTransformer2DModel(
attn_num_head_channels,
out_channels // attn_num_head_channels,
in_channels=out_channels,
num_layers=1,
cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
)
)
)
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)

View File

@@ -106,6 +106,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
norm_eps: float = 1e-5,
cross_attention_dim: int = 1280,
attention_head_dim: int = 8,
dual_cross_attention: bool = False,
):
super().__init__()
@@ -145,6 +146,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attention_head_dim,
downsample_padding=downsample_padding,
dual_cross_attention=dual_cross_attention,
)
self.down_blocks.append(down_block)
@@ -159,6 +161,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attention_head_dim,
resnet_groups=norm_num_groups,
dual_cross_attention=dual_cross_attention,
)
# count how many layers upsample the images
@@ -194,6 +197,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
resnet_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attention_head_dim,
dual_cross_attention=dual_cross_attention,
)
self.up_blocks.append(up_block)
prev_output_channel = output_channel

View File

@@ -40,7 +40,7 @@ class VersatileDiffusionImageToTextPipelineIntegrationTests(unittest.TestCase):
pipe.set_progress_bar_config(disable=None)
image_prompt = load_image(
"https://raw.githubusercontent.com/SHI-Labs/Versatile-Diffusion/master/assets/benz.jpg"
"https://raw.githubusercontent.com/SHI-Labs/Versatile-Diffusion/master/assets/boy_and_girl.jpg"
)
generator = torch.Generator(device=torch_device).manual_seed(0)
text = pipe(