mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
update unet2d (#1376)
* boom boom * remove duplicate arg * add use_linear_proj arg * fix copies * style * add fast tests * use_linear_proj -> use_linear_projection
This commit is contained in:
@@ -99,8 +99,10 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
|
||||
num_vector_embeds: Optional[int] = None,
|
||||
activation_fn: str = "geglu",
|
||||
num_embeds_ada_norm: Optional[int] = None,
|
||||
use_linear_projection: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.use_linear_projection = use_linear_projection
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.attention_head_dim = attention_head_dim
|
||||
inner_dim = num_attention_heads * attention_head_dim
|
||||
@@ -126,7 +128,10 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
|
||||
if use_linear_projection:
|
||||
self.proj_in = nn.Linear(in_channels, inner_dim)
|
||||
else:
|
||||
self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
|
||||
elif self.is_input_vectorized:
|
||||
assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
|
||||
assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed"
|
||||
@@ -159,7 +164,10 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
|
||||
|
||||
# 4. Define output layers
|
||||
if self.is_input_continuous:
|
||||
self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
if use_linear_projection:
|
||||
self.proj_out = nn.Linear(in_channels, inner_dim)
|
||||
else:
|
||||
self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
elif self.is_input_vectorized:
|
||||
self.norm_out = nn.LayerNorm(inner_dim)
|
||||
self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)
|
||||
@@ -191,10 +199,18 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
|
||||
if self.is_input_continuous:
|
||||
batch, channel, height, weight = hidden_states.shape
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
hidden_states = self.proj_in(hidden_states)
|
||||
inner_dim = hidden_states.shape[1]
|
||||
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
|
||||
|
||||
if not self.use_linear_projection:
|
||||
hidden_states = self.proj_in(hidden_states)
|
||||
inner_dim = hidden_states.shape[1]
|
||||
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
|
||||
else:
|
||||
hidden_states = self.norm(hidden_states)
|
||||
inner_dim = hidden_states.shape[1]
|
||||
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
|
||||
hidden_states = self.proj_in(hidden_states)
|
||||
elif self.is_input_vectorized:
|
||||
hidden_states = self.latent_image_embedding(hidden_states)
|
||||
|
||||
@@ -204,8 +220,13 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
|
||||
|
||||
# 3. Output
|
||||
if self.is_input_continuous:
|
||||
hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2)
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
if not self.use_linear_projection:
|
||||
hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2)
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
else:
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2)
|
||||
|
||||
output = hidden_states + residual
|
||||
elif self.is_input_vectorized:
|
||||
hidden_states = self.norm_out(hidden_states)
|
||||
|
||||
@@ -33,6 +33,7 @@ def get_down_block(
|
||||
cross_attention_dim=None,
|
||||
downsample_padding=None,
|
||||
dual_cross_attention=False,
|
||||
use_linear_projection=False,
|
||||
):
|
||||
down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
|
||||
if down_block_type == "DownBlock2D":
|
||||
@@ -76,6 +77,7 @@ def get_down_block(
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
attn_num_head_channels=attn_num_head_channels,
|
||||
dual_cross_attention=dual_cross_attention,
|
||||
use_linear_projection=use_linear_projection,
|
||||
)
|
||||
elif down_block_type == "SkipDownBlock2D":
|
||||
return SkipDownBlock2D(
|
||||
@@ -140,6 +142,7 @@ def get_up_block(
|
||||
resnet_groups=None,
|
||||
cross_attention_dim=None,
|
||||
dual_cross_attention=False,
|
||||
use_linear_projection=False,
|
||||
):
|
||||
up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
|
||||
if up_block_type == "UpBlock2D":
|
||||
@@ -170,6 +173,7 @@ def get_up_block(
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
attn_num_head_channels=attn_num_head_channels,
|
||||
dual_cross_attention=dual_cross_attention,
|
||||
use_linear_projection=use_linear_projection,
|
||||
)
|
||||
elif up_block_type == "AttnUpBlock2D":
|
||||
return AttnUpBlock2D(
|
||||
@@ -327,6 +331,7 @@ class UNetMidBlock2DCrossAttn(nn.Module):
|
||||
output_scale_factor=1.0,
|
||||
cross_attention_dim=1280,
|
||||
dual_cross_attention=False,
|
||||
use_linear_projection=False,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
@@ -362,6 +367,7 @@ class UNetMidBlock2DCrossAttn(nn.Module):
|
||||
num_layers=1,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
norm_num_groups=resnet_groups,
|
||||
use_linear_projection=use_linear_projection,
|
||||
)
|
||||
)
|
||||
else:
|
||||
@@ -523,6 +529,7 @@ class CrossAttnDownBlock2D(nn.Module):
|
||||
downsample_padding=1,
|
||||
add_downsample=True,
|
||||
dual_cross_attention=False,
|
||||
use_linear_projection=False,
|
||||
):
|
||||
super().__init__()
|
||||
resnets = []
|
||||
@@ -556,6 +563,7 @@ class CrossAttnDownBlock2D(nn.Module):
|
||||
num_layers=1,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
norm_num_groups=resnet_groups,
|
||||
use_linear_projection=use_linear_projection,
|
||||
)
|
||||
)
|
||||
else:
|
||||
@@ -1120,6 +1128,7 @@ class CrossAttnUpBlock2D(nn.Module):
|
||||
output_scale_factor=1.0,
|
||||
add_upsample=True,
|
||||
dual_cross_attention=False,
|
||||
use_linear_projection=False,
|
||||
):
|
||||
super().__init__()
|
||||
resnets = []
|
||||
@@ -1155,6 +1164,7 @@ class CrossAttnUpBlock2D(nn.Module):
|
||||
num_layers=1,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
norm_num_groups=resnet_groups,
|
||||
use_linear_projection=use_linear_projection,
|
||||
)
|
||||
)
|
||||
else:
|
||||
|
||||
@@ -61,7 +61,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
|
||||
in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample.
|
||||
out_channels (`int`, *optional*, defaults to 4): The number of channels in the output.
|
||||
center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
|
||||
flip_sin_to_cos (`bool`, *optional*, defaults to `True`):
|
||||
flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
|
||||
Whether to flip the sin to cos in the time embedding.
|
||||
freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
|
||||
down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
|
||||
@@ -106,8 +106,9 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
|
||||
norm_num_groups: int = 32,
|
||||
norm_eps: float = 1e-5,
|
||||
cross_attention_dim: int = 1280,
|
||||
attention_head_dim: int = 8,
|
||||
attention_head_dim: Union[int, Tuple[int]] = 8,
|
||||
dual_cross_attention: bool = False,
|
||||
use_linear_projection: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -127,6 +128,9 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
|
||||
self.mid_block = None
|
||||
self.up_blocks = nn.ModuleList([])
|
||||
|
||||
if isinstance(attention_head_dim, int):
|
||||
attention_head_dim = (attention_head_dim,) * len(down_block_types)
|
||||
|
||||
# down
|
||||
output_channel = block_out_channels[0]
|
||||
for i, down_block_type in enumerate(down_block_types):
|
||||
@@ -145,9 +149,10 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
|
||||
resnet_act_fn=act_fn,
|
||||
resnet_groups=norm_num_groups,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
attn_num_head_channels=attention_head_dim,
|
||||
attn_num_head_channels=attention_head_dim[i],
|
||||
downsample_padding=downsample_padding,
|
||||
dual_cross_attention=dual_cross_attention,
|
||||
use_linear_projection=use_linear_projection,
|
||||
)
|
||||
self.down_blocks.append(down_block)
|
||||
|
||||
@@ -160,9 +165,10 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
|
||||
output_scale_factor=mid_block_scale_factor,
|
||||
resnet_time_scale_shift="default",
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
attn_num_head_channels=attention_head_dim,
|
||||
attn_num_head_channels=attention_head_dim[-1],
|
||||
resnet_groups=norm_num_groups,
|
||||
dual_cross_attention=dual_cross_attention,
|
||||
use_linear_projection=use_linear_projection,
|
||||
)
|
||||
|
||||
# count how many layers upsample the images
|
||||
@@ -170,6 +176,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
|
||||
|
||||
# up
|
||||
reversed_block_out_channels = list(reversed(block_out_channels))
|
||||
reversed_attention_head_dim = list(reversed(attention_head_dim))
|
||||
output_channel = reversed_block_out_channels[0]
|
||||
for i, up_block_type in enumerate(up_block_types):
|
||||
is_final_block = i == len(block_out_channels) - 1
|
||||
@@ -197,8 +204,9 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
|
||||
resnet_act_fn=act_fn,
|
||||
resnet_groups=norm_num_groups,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
attn_num_head_channels=attention_head_dim,
|
||||
attn_num_head_channels=reversed_attention_head_dim[i],
|
||||
dual_cross_attention=dual_cross_attention,
|
||||
use_linear_projection=use_linear_projection,
|
||||
)
|
||||
self.up_blocks.append(up_block)
|
||||
prev_output_channel = output_channel
|
||||
@@ -256,8 +264,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
|
||||
Args:
|
||||
sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
|
||||
timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
|
||||
encoder_hidden_states (`torch.FloatTensor`):
|
||||
(batch_size, sequence_length, hidden_size) encoder hidden states
|
||||
encoder_hidden_states (`torch.FloatTensor`): (batch, channel, height, width) encoder hidden states
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
|
||||
|
||||
|
||||
@@ -124,7 +124,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample.
|
||||
out_channels (`int`, *optional*, defaults to 4): The number of channels in the output.
|
||||
center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
|
||||
flip_sin_to_cos (`bool`, *optional*, defaults to `True`):
|
||||
flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
|
||||
Whether to flip the sin to cos in the time embedding.
|
||||
freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
|
||||
down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlockFlat", "CrossAttnDownBlockFlat", "CrossAttnDownBlockFlat", "DownBlockFlat")`):
|
||||
@@ -174,8 +174,9 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
norm_num_groups: int = 32,
|
||||
norm_eps: float = 1e-5,
|
||||
cross_attention_dim: int = 1280,
|
||||
attention_head_dim: int = 8,
|
||||
attention_head_dim: Union[int, Tuple[int]] = 8,
|
||||
dual_cross_attention: bool = False,
|
||||
use_linear_projection: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -195,6 +196,9 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
self.mid_block = None
|
||||
self.up_blocks = nn.ModuleList([])
|
||||
|
||||
if isinstance(attention_head_dim, int):
|
||||
attention_head_dim = (attention_head_dim,) * len(down_block_types)
|
||||
|
||||
# down
|
||||
output_channel = block_out_channels[0]
|
||||
for i, down_block_type in enumerate(down_block_types):
|
||||
@@ -213,9 +217,10 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
resnet_act_fn=act_fn,
|
||||
resnet_groups=norm_num_groups,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
attn_num_head_channels=attention_head_dim,
|
||||
attn_num_head_channels=attention_head_dim[i],
|
||||
downsample_padding=downsample_padding,
|
||||
dual_cross_attention=dual_cross_attention,
|
||||
use_linear_projection=use_linear_projection,
|
||||
)
|
||||
self.down_blocks.append(down_block)
|
||||
|
||||
@@ -228,9 +233,10 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
output_scale_factor=mid_block_scale_factor,
|
||||
resnet_time_scale_shift="default",
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
attn_num_head_channels=attention_head_dim,
|
||||
attn_num_head_channels=attention_head_dim[-1],
|
||||
resnet_groups=norm_num_groups,
|
||||
dual_cross_attention=dual_cross_attention,
|
||||
use_linear_projection=use_linear_projection,
|
||||
)
|
||||
|
||||
# count how many layers upsample the images
|
||||
@@ -238,6 +244,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
|
||||
# up
|
||||
reversed_block_out_channels = list(reversed(block_out_channels))
|
||||
reversed_attention_head_dim = list(reversed(attention_head_dim))
|
||||
output_channel = reversed_block_out_channels[0]
|
||||
for i, up_block_type in enumerate(up_block_types):
|
||||
is_final_block = i == len(block_out_channels) - 1
|
||||
@@ -265,8 +272,9 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
resnet_act_fn=act_fn,
|
||||
resnet_groups=norm_num_groups,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
attn_num_head_channels=attention_head_dim,
|
||||
attn_num_head_channels=reversed_attention_head_dim[i],
|
||||
dual_cross_attention=dual_cross_attention,
|
||||
use_linear_projection=use_linear_projection,
|
||||
)
|
||||
self.up_blocks.append(up_block)
|
||||
prev_output_channel = output_channel
|
||||
@@ -324,8 +332,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
Args:
|
||||
sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
|
||||
timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
|
||||
encoder_hidden_states (`torch.FloatTensor`):
|
||||
(batch_size, sequence_length, hidden_size) encoder hidden states
|
||||
encoder_hidden_states (`torch.FloatTensor`): (batch, channel, height, width) encoder hidden states
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
|
||||
|
||||
@@ -640,6 +647,7 @@ class CrossAttnDownBlockFlat(nn.Module):
|
||||
downsample_padding=1,
|
||||
add_downsample=True,
|
||||
dual_cross_attention=False,
|
||||
use_linear_projection=False,
|
||||
):
|
||||
super().__init__()
|
||||
resnets = []
|
||||
@@ -673,6 +681,7 @@ class CrossAttnDownBlockFlat(nn.Module):
|
||||
num_layers=1,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
norm_num_groups=resnet_groups,
|
||||
use_linear_projection=use_linear_projection,
|
||||
)
|
||||
)
|
||||
else:
|
||||
@@ -851,6 +860,7 @@ class CrossAttnUpBlockFlat(nn.Module):
|
||||
output_scale_factor=1.0,
|
||||
add_upsample=True,
|
||||
dual_cross_attention=False,
|
||||
use_linear_projection=False,
|
||||
):
|
||||
super().__init__()
|
||||
resnets = []
|
||||
@@ -886,6 +896,7 @@ class CrossAttnUpBlockFlat(nn.Module):
|
||||
num_layers=1,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
norm_num_groups=resnet_groups,
|
||||
use_linear_projection=use_linear_projection,
|
||||
)
|
||||
)
|
||||
else:
|
||||
@@ -988,6 +999,7 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
|
||||
output_scale_factor=1.0,
|
||||
cross_attention_dim=1280,
|
||||
dual_cross_attention=False,
|
||||
use_linear_projection=False,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
@@ -1023,6 +1035,7 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
|
||||
num_layers=1,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
norm_num_groups=resnet_groups,
|
||||
use_linear_projection=use_linear_projection,
|
||||
)
|
||||
)
|
||||
else:
|
||||
|
||||
@@ -296,6 +296,44 @@ class UNet2DConditionModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
for name, param in named_params.items():
|
||||
self.assertTrue(torch_all_close(param.grad.data, named_params_2[name].grad.data, atol=5e-5))
|
||||
|
||||
def test_model_with_attention_head_dim_tuple(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
init_dict["attention_head_dim"] = (8, 16)
|
||||
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(**inputs_dict)
|
||||
|
||||
if isinstance(output, dict):
|
||||
output = output.sample
|
||||
|
||||
self.assertIsNotNone(output)
|
||||
expected_shape = inputs_dict["sample"].shape
|
||||
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
|
||||
|
||||
def test_model_with_use_linear_projection(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
init_dict["use_linear_projection"] = True
|
||||
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(**inputs_dict)
|
||||
|
||||
if isinstance(output, dict):
|
||||
output = output.sample
|
||||
|
||||
self.assertIsNotNone(output)
|
||||
expected_shape = inputs_dict["sample"].shape
|
||||
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
|
||||
|
||||
|
||||
class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
model_class = UNet2DModel
|
||||
|
||||
Reference in New Issue
Block a user