1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[UNet2DConditionModel, UNet2DModel] pass norm_num_groups to all the blocks (#442)

* pass norm_num_groups to unet blocs and attention

* fix UNet2DConditionModel

* add norm_num_groups arg in vae

* add tests

* remove comment

* Apply suggestions from code review
This commit is contained in:
Suraj Patil
2022-09-15 16:35:14 +02:00
committed by GitHub
parent b34be039f9
commit d144c46a59
7 changed files with 59 additions and 7 deletions

View File

@@ -113,6 +113,7 @@ class SpatialTransformer(nn.Module):
d_head: int,
depth: int = 1,
dropout: float = 0.0,
num_groups: int = 32,
context_dim: Optional[int] = None,
):
super().__init__()
@@ -120,7 +121,7 @@ class SpatialTransformer(nn.Module):
self.d_head = d_head
self.in_channels = in_channels
inner_dim = n_heads * d_head
self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
self.norm = torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)

View File

@@ -114,6 +114,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
add_downsample=not is_final_block,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
attn_num_head_channels=attention_head_dim,
downsample_padding=downsample_padding,
)
@@ -151,6 +152,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
add_upsample=not is_final_block,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
attn_num_head_channels=attention_head_dim,
)
self.up_blocks.append(up_block)

View File

@@ -114,6 +114,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
add_downsample=not is_final_block,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attention_head_dim,
downsample_padding=downsample_padding,
@@ -153,6 +154,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
add_upsample=not is_final_block,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attention_head_dim,
)

View File

@@ -31,6 +31,7 @@ def get_down_block(
resnet_eps,
resnet_act_fn,
attn_num_head_channels,
resnet_groups=None,
cross_attention_dim=None,
downsample_padding=None,
):
@@ -44,6 +45,7 @@ def get_down_block(
add_downsample=add_downsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
downsample_padding=downsample_padding,
)
elif down_block_type == "AttnDownBlock2D":
@@ -55,6 +57,7 @@ def get_down_block(
add_downsample=add_downsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
downsample_padding=downsample_padding,
attn_num_head_channels=attn_num_head_channels,
)
@@ -69,6 +72,7 @@ def get_down_block(
add_downsample=add_downsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
downsample_padding=downsample_padding,
cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attn_num_head_channels,
@@ -104,6 +108,7 @@ def get_down_block(
add_downsample=add_downsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
downsample_padding=downsample_padding,
)
@@ -119,6 +124,7 @@ def get_up_block(
resnet_eps,
resnet_act_fn,
attn_num_head_channels,
resnet_groups=None,
cross_attention_dim=None,
):
up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
@@ -132,6 +138,7 @@ def get_up_block(
add_upsample=add_upsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
)
elif up_block_type == "CrossAttnUpBlock2D":
if cross_attention_dim is None:
@@ -145,6 +152,7 @@ def get_up_block(
add_upsample=add_upsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attn_num_head_channels,
)
@@ -158,6 +166,7 @@ def get_up_block(
add_upsample=add_upsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
attn_num_head_channels=attn_num_head_channels,
)
elif up_block_type == "SkipUpBlock2D":
@@ -191,6 +200,7 @@ def get_up_block(
add_upsample=add_upsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
)
raise ValueError(f"{up_block_type} does not exist.")
@@ -323,6 +333,7 @@ class UNetMidBlock2DCrossAttn(nn.Module):
in_channels // attn_num_head_channels,
depth=1,
context_dim=cross_attention_dim,
num_groups=resnet_groups,
)
)
resnets.append(
@@ -414,6 +425,7 @@ class AttnDownBlock2D(nn.Module):
num_head_channels=attn_num_head_channels,
rescale_output_factor=output_scale_factor,
eps=resnet_eps,
num_groups=resnet_groups,
)
)
@@ -498,6 +510,7 @@ class CrossAttnDownBlock2D(nn.Module):
out_channels // attn_num_head_channels,
depth=1,
context_dim=cross_attention_dim,
num_groups=resnet_groups,
)
)
self.attentions = nn.ModuleList(attentions)
@@ -966,6 +979,7 @@ class AttnUpBlock2D(nn.Module):
num_head_channels=attn_num_head_channels,
rescale_output_factor=output_scale_factor,
eps=resnet_eps,
num_groups=resnet_groups,
)
)
@@ -1047,6 +1061,7 @@ class CrossAttnUpBlock2D(nn.Module):
out_channels // attn_num_head_channels,
depth=1,
context_dim=cross_attention_dim,
num_groups=resnet_groups,
)
)
self.attentions = nn.ModuleList(attentions)

View File

@@ -59,6 +59,7 @@ class Encoder(nn.Module):
down_block_types=("DownEncoderBlock2D",),
block_out_channels=(64,),
layers_per_block=2,
norm_num_groups=32,
act_fn="silu",
double_z=True,
):
@@ -86,6 +87,7 @@ class Encoder(nn.Module):
resnet_eps=1e-6,
downsample_padding=0,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
attn_num_head_channels=None,
temb_channels=None,
)
@@ -99,13 +101,12 @@ class Encoder(nn.Module):
output_scale_factor=1,
resnet_time_scale_shift="default",
attn_num_head_channels=None,
resnet_groups=32,
resnet_groups=norm_num_groups,
temb_channels=None,
)
# out
num_groups_out = 32
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=num_groups_out, eps=1e-6)
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6)
self.conv_act = nn.SiLU()
conv_out_channels = 2 * out_channels if double_z else out_channels
@@ -138,6 +139,7 @@ class Decoder(nn.Module):
up_block_types=("UpDecoderBlock2D",),
block_out_channels=(64,),
layers_per_block=2,
norm_num_groups=32,
act_fn="silu",
):
super().__init__()
@@ -156,7 +158,7 @@ class Decoder(nn.Module):
output_scale_factor=1,
resnet_time_scale_shift="default",
attn_num_head_channels=None,
resnet_groups=32,
resnet_groups=norm_num_groups,
temb_channels=None,
)
@@ -178,6 +180,7 @@ class Decoder(nn.Module):
add_upsample=not is_final_block,
resnet_eps=1e-6,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
attn_num_head_channels=None,
temb_channels=None,
)
@@ -185,8 +188,7 @@ class Decoder(nn.Module):
prev_output_channel = output_channel
# out
num_groups_out = 32
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=num_groups_out, eps=1e-6)
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6)
self.conv_act = nn.SiLU()
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)
@@ -405,6 +407,7 @@ class VQModel(ModelMixin, ConfigMixin):
latent_channels: int = 3,
sample_size: int = 32,
num_vq_embeddings: int = 256,
norm_num_groups: int = 32,
):
super().__init__()
@@ -416,6 +419,7 @@ class VQModel(ModelMixin, ConfigMixin):
block_out_channels=block_out_channels,
layers_per_block=layers_per_block,
act_fn=act_fn,
norm_num_groups=norm_num_groups,
double_z=False,
)
@@ -433,6 +437,7 @@ class VQModel(ModelMixin, ConfigMixin):
block_out_channels=block_out_channels,
layers_per_block=layers_per_block,
act_fn=act_fn,
norm_num_groups=norm_num_groups,
)
def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> VQEncoderOutput:
@@ -509,6 +514,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
layers_per_block: int = 1,
act_fn: str = "silu",
latent_channels: int = 4,
norm_num_groups: int = 32,
sample_size: int = 32,
):
super().__init__()
@@ -521,6 +527,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
block_out_channels=block_out_channels,
layers_per_block=layers_per_block,
act_fn=act_fn,
norm_num_groups=norm_num_groups,
double_z=True,
)
@@ -531,6 +538,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
up_block_types=up_block_types,
block_out_channels=block_out_channels,
layers_per_block=layers_per_block,
norm_num_groups=norm_num_groups,
act_fn=act_fn,
)

View File

@@ -99,6 +99,26 @@ class ModelTesterMixin:
expected_shape = inputs_dict["sample"].shape
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
def test_forward_with_norm_groups(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["norm_num_groups"] = 16
init_dict["block_out_channels"] = (16, 32)
model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()
with torch.no_grad():
output = model(**inputs_dict)
if isinstance(output, dict):
output = output.sample
self.assertIsNotNone(output)
expected_shape = inputs_dict["sample"].shape
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
def test_forward_signature(self):
init_dict, _ = self.prepare_init_args_and_inputs_for_common()

View File

@@ -293,3 +293,7 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
# fmt: on
self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-2))
def test_forward_with_norm_groups(self):
# not required for this model
pass