1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Adapt UNet2D for supre-resolution (#1385)

* allow disabling self attention

* add class_embedding

* fix copies

* fix condition

* fix copies

* do_self_attention -> only_cross_attention

* fix copies

* num_classes -> num_class_embeds

* fix default value
This commit is contained in:
Suraj Patil
2022-11-24 14:49:03 +01:00
committed by GitHub
parent 30f6f44104
commit cecdd8bdd1
4 changed files with 60 additions and 1 deletions

View File

@@ -100,6 +100,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
activation_fn: str = "geglu",
num_embeds_ada_norm: Optional[int] = None,
use_linear_projection: bool = False,
only_cross_attention: bool = False,
):
super().__init__()
self.use_linear_projection = use_linear_projection
@@ -157,6 +158,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
activation_fn=activation_fn,
num_embeds_ada_norm=num_embeds_ada_norm,
attention_bias=attention_bias,
only_cross_attention=only_cross_attention,
)
for d in range(num_layers)
]
@@ -387,14 +389,17 @@ class BasicTransformerBlock(nn.Module):
activation_fn: str = "geglu",
num_embeds_ada_norm: Optional[int] = None,
attention_bias: bool = False,
only_cross_attention: bool = False,
):
super().__init__()
self.only_cross_attention = only_cross_attention
self.attn1 = CrossAttention(
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
cross_attention_dim=cross_attention_dim if only_cross_attention else None,
) # is a self-attention
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
self.attn2 = CrossAttention(
@@ -461,7 +466,11 @@ class BasicTransformerBlock(nn.Module):
norm_hidden_states = (
self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states)
)
hidden_states = self.attn1(norm_hidden_states) + hidden_states
if self.only_cross_attention:
hidden_states = self.attn1(norm_hidden_states, context) + hidden_states
else:
hidden_states = self.attn1(norm_hidden_states) + hidden_states
# 2. Cross-Attention
norm_hidden_states = (

View File

@@ -34,6 +34,7 @@ def get_down_block(
downsample_padding=None,
dual_cross_attention=False,
use_linear_projection=False,
only_cross_attention=False,
):
down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
if down_block_type == "DownBlock2D":
@@ -78,6 +79,7 @@ def get_down_block(
attn_num_head_channels=attn_num_head_channels,
dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention,
)
elif down_block_type == "SkipDownBlock2D":
return SkipDownBlock2D(
@@ -143,6 +145,7 @@ def get_up_block(
cross_attention_dim=None,
dual_cross_attention=False,
use_linear_projection=False,
only_cross_attention=False,
):
up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
if up_block_type == "UpBlock2D":
@@ -174,6 +177,7 @@ def get_up_block(
attn_num_head_channels=attn_num_head_channels,
dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention,
)
elif up_block_type == "AttnUpBlock2D":
return AttnUpBlock2D(
@@ -530,6 +534,7 @@ class CrossAttnDownBlock2D(nn.Module):
add_downsample=True,
dual_cross_attention=False,
use_linear_projection=False,
only_cross_attention=False,
):
super().__init__()
resnets = []
@@ -564,6 +569,7 @@ class CrossAttnDownBlock2D(nn.Module):
cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention,
)
)
else:
@@ -1129,6 +1135,7 @@ class CrossAttnUpBlock2D(nn.Module):
add_upsample=True,
dual_cross_attention=False,
use_linear_projection=False,
only_cross_attention=False,
):
super().__init__()
resnets = []
@@ -1165,6 +1172,7 @@ class CrossAttnUpBlock2D(nn.Module):
cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention,
)
)
else:

View File

@@ -98,6 +98,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
"DownBlock2D",
),
up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
only_cross_attention: Union[bool, Tuple[bool]] = False,
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
layers_per_block: int = 2,
downsample_padding: int = 1,
@@ -109,6 +110,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
attention_head_dim: Union[int, Tuple[int]] = 8,
dual_cross_attention: bool = False,
use_linear_projection: bool = False,
num_class_embeds: Optional[int] = None,
):
super().__init__()
@@ -124,10 +126,17 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
# class embedding
if num_class_embeds is not None:
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
self.down_blocks = nn.ModuleList([])
self.mid_block = None
self.up_blocks = nn.ModuleList([])
if isinstance(only_cross_attention, bool):
only_cross_attention = [only_cross_attention] * len(down_block_types)
if isinstance(attention_head_dim, int):
attention_head_dim = (attention_head_dim,) * len(down_block_types)
@@ -153,6 +162,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
downsample_padding=downsample_padding,
dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention[i],
)
self.down_blocks.append(down_block)
@@ -177,6 +187,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
# up
reversed_block_out_channels = list(reversed(block_out_channels))
reversed_attention_head_dim = list(reversed(attention_head_dim))
only_cross_attention = list(reversed(only_cross_attention))
output_channel = reversed_block_out_channels[0]
for i, up_block_type in enumerate(up_block_types):
is_final_block = i == len(block_out_channels) - 1
@@ -207,6 +218,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
attn_num_head_channels=reversed_attention_head_dim[i],
dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention[i],
)
self.up_blocks.append(up_block)
prev_output_channel = output_channel
@@ -258,6 +270,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
sample: torch.FloatTensor,
timestep: Union[torch.Tensor, float, int],
encoder_hidden_states: torch.Tensor,
class_labels: Optional[torch.Tensor] = None,
return_dict: bool = True,
) -> Union[UNet2DConditionOutput, Tuple]:
r"""
@@ -310,6 +323,12 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
t_emb = t_emb.to(dtype=self.dtype)
emb = self.time_embedding(t_emb)
if self.config.num_class_embeds is not None:
if class_labels is None:
raise ValueError("class_labels should be provided when num_class_embeds > 0")
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
emb = emb + class_emb
# 2. pre-process
sample = self.conv_in(sample)

View File

@@ -166,6 +166,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
"CrossAttnUpBlockFlat",
"CrossAttnUpBlockFlat",
),
only_cross_attention: Union[bool, Tuple[bool]] = False,
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
layers_per_block: int = 2,
downsample_padding: int = 1,
@@ -177,6 +178,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
attention_head_dim: Union[int, Tuple[int]] = 8,
dual_cross_attention: bool = False,
use_linear_projection: bool = False,
num_class_embeds: Optional[int] = None,
):
super().__init__()
@@ -192,10 +194,17 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
# class embedding
if num_class_embeds is not None:
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
self.down_blocks = nn.ModuleList([])
self.mid_block = None
self.up_blocks = nn.ModuleList([])
if isinstance(only_cross_attention, bool):
only_cross_attention = [only_cross_attention] * len(down_block_types)
if isinstance(attention_head_dim, int):
attention_head_dim = (attention_head_dim,) * len(down_block_types)
@@ -221,6 +230,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
downsample_padding=downsample_padding,
dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention[i],
)
self.down_blocks.append(down_block)
@@ -245,6 +255,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
# up
reversed_block_out_channels = list(reversed(block_out_channels))
reversed_attention_head_dim = list(reversed(attention_head_dim))
only_cross_attention = list(reversed(only_cross_attention))
output_channel = reversed_block_out_channels[0]
for i, up_block_type in enumerate(up_block_types):
is_final_block = i == len(block_out_channels) - 1
@@ -275,6 +286,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
attn_num_head_channels=reversed_attention_head_dim[i],
dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention[i],
)
self.up_blocks.append(up_block)
prev_output_channel = output_channel
@@ -326,6 +338,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
sample: torch.FloatTensor,
timestep: Union[torch.Tensor, float, int],
encoder_hidden_states: torch.Tensor,
class_labels: Optional[torch.Tensor] = None,
return_dict: bool = True,
) -> Union[UNet2DConditionOutput, Tuple]:
r"""
@@ -378,6 +391,12 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
t_emb = t_emb.to(dtype=self.dtype)
emb = self.time_embedding(t_emb)
if self.config.num_class_embeds is not None:
if class_labels is None:
raise ValueError("class_labels should be provided when num_class_embeds > 0")
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
emb = emb + class_emb
# 2. pre-process
sample = self.conv_in(sample)
@@ -648,6 +667,7 @@ class CrossAttnDownBlockFlat(nn.Module):
add_downsample=True,
dual_cross_attention=False,
use_linear_projection=False,
only_cross_attention=False,
):
super().__init__()
resnets = []
@@ -682,6 +702,7 @@ class CrossAttnDownBlockFlat(nn.Module):
cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention,
)
)
else:
@@ -861,6 +882,7 @@ class CrossAttnUpBlockFlat(nn.Module):
add_upsample=True,
dual_cross_attention=False,
use_linear_projection=False,
only_cross_attention=False,
):
super().__init__()
resnets = []
@@ -897,6 +919,7 @@ class CrossAttnUpBlockFlat(nn.Module):
cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention,
)
)
else: