mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
propose change
This commit is contained in:
@@ -32,6 +32,7 @@ def get_down_block(
|
||||
resnet_groups=None,
|
||||
cross_attention_dim=None,
|
||||
downsample_padding=None,
|
||||
dual_cross_attention=False,
|
||||
):
|
||||
down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
|
||||
if down_block_type == "DownBlock2D":
|
||||
@@ -74,6 +75,7 @@ def get_down_block(
|
||||
downsample_padding=downsample_padding,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
attn_num_head_channels=attn_num_head_channels,
|
||||
dual_cross_attention=dual_cross_attention,
|
||||
)
|
||||
elif down_block_type == "SkipDownBlock2D":
|
||||
return SkipDownBlock2D(
|
||||
@@ -137,6 +139,7 @@ def get_up_block(
|
||||
attn_num_head_channels,
|
||||
resnet_groups=None,
|
||||
cross_attention_dim=None,
|
||||
dual_cross_attention=False,
|
||||
):
|
||||
up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
|
||||
if up_block_type == "UpBlock2D":
|
||||
@@ -322,6 +325,7 @@ class UNetMidBlock2DCrossAttn(nn.Module):
|
||||
attention_type="default",
|
||||
output_scale_factor=1.0,
|
||||
cross_attention_dim=1280,
|
||||
dual_cross_attention=False,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
@@ -505,6 +509,7 @@ class CrossAttnDownBlock2D(nn.Module):
|
||||
output_scale_factor=1.0,
|
||||
downsample_padding=1,
|
||||
add_downsample=True,
|
||||
dual_cross_attention=False,
|
||||
):
|
||||
super().__init__()
|
||||
resnets = []
|
||||
@@ -529,16 +534,28 @@ class CrossAttnDownBlock2D(nn.Module):
|
||||
pre_norm=resnet_pre_norm,
|
||||
)
|
||||
)
|
||||
attentions.append(
|
||||
Transformer2DModel(
|
||||
attn_num_head_channels,
|
||||
out_channels // attn_num_head_channels,
|
||||
in_channels=out_channels,
|
||||
num_layers=1,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
norm_num_groups=resnet_groups,
|
||||
if dual_cross_attention is False:
|
||||
attentions.append(
|
||||
Transformer2DModel(
|
||||
attn_num_head_channels,
|
||||
out_channels // attn_num_head_channels,
|
||||
in_channels=out_channels,
|
||||
num_layers=1,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
norm_num_groups=resnet_groups,
|
||||
)
|
||||
)
|
||||
else:
|
||||
attentions.append(
|
||||
DualTransformer2DModel(
|
||||
attn_num_head_channels,
|
||||
out_channels // attn_num_head_channels,
|
||||
in_channels=out_channels,
|
||||
num_layers=1,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
norm_num_groups=resnet_groups,
|
||||
)
|
||||
)
|
||||
)
|
||||
self.attentions = nn.ModuleList(attentions)
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
|
||||
@@ -106,6 +106,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
|
||||
norm_eps: float = 1e-5,
|
||||
cross_attention_dim: int = 1280,
|
||||
attention_head_dim: int = 8,
|
||||
dual_cross_attention: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -145,6 +146,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
attn_num_head_channels=attention_head_dim,
|
||||
downsample_padding=downsample_padding,
|
||||
dual_cross_attention=dual_cross_attention,
|
||||
)
|
||||
self.down_blocks.append(down_block)
|
||||
|
||||
@@ -159,6 +161,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
attn_num_head_channels=attention_head_dim,
|
||||
resnet_groups=norm_num_groups,
|
||||
dual_cross_attention=dual_cross_attention,
|
||||
)
|
||||
|
||||
# count how many layers upsample the images
|
||||
@@ -194,6 +197,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
|
||||
resnet_groups=norm_num_groups,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
attn_num_head_channels=attention_head_dim,
|
||||
dual_cross_attention=dual_cross_attention,
|
||||
)
|
||||
self.up_blocks.append(up_block)
|
||||
prev_output_channel = output_channel
|
||||
|
||||
@@ -40,7 +40,7 @@ class VersatileDiffusionImageToTextPipelineIntegrationTests(unittest.TestCase):
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
image_prompt = load_image(
|
||||
"https://raw.githubusercontent.com/SHI-Labs/Versatile-Diffusion/master/assets/benz.jpg"
|
||||
"https://raw.githubusercontent.com/SHI-Labs/Versatile-Diffusion/master/assets/boy_and_girl.jpg"
|
||||
)
|
||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||
text = pipe(
|
||||
|
||||
Reference in New Issue
Block a user