mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[UNet2DConditionModel, UNet2DModel] pass norm_num_groups to all the blocks (#442)
* pass norm_num_groups to unet blocs and attention * fix UNet2DConditionModel * add norm_num_groups arg in vae * add tests * remove comment * Apply suggestions from code review
This commit is contained in:
@@ -113,6 +113,7 @@ class SpatialTransformer(nn.Module):
|
||||
d_head: int,
|
||||
depth: int = 1,
|
||||
dropout: float = 0.0,
|
||||
num_groups: int = 32,
|
||||
context_dim: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
@@ -120,7 +121,7 @@ class SpatialTransformer(nn.Module):
|
||||
self.d_head = d_head
|
||||
self.in_channels = in_channels
|
||||
inner_dim = n_heads * d_head
|
||||
self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
self.norm = torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
|
||||
self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
|
||||
|
||||
|
||||
@@ -114,6 +114,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
|
||||
add_downsample=not is_final_block,
|
||||
resnet_eps=norm_eps,
|
||||
resnet_act_fn=act_fn,
|
||||
resnet_groups=norm_num_groups,
|
||||
attn_num_head_channels=attention_head_dim,
|
||||
downsample_padding=downsample_padding,
|
||||
)
|
||||
@@ -151,6 +152,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
|
||||
add_upsample=not is_final_block,
|
||||
resnet_eps=norm_eps,
|
||||
resnet_act_fn=act_fn,
|
||||
resnet_groups=norm_num_groups,
|
||||
attn_num_head_channels=attention_head_dim,
|
||||
)
|
||||
self.up_blocks.append(up_block)
|
||||
|
||||
@@ -114,6 +114,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
|
||||
add_downsample=not is_final_block,
|
||||
resnet_eps=norm_eps,
|
||||
resnet_act_fn=act_fn,
|
||||
resnet_groups=norm_num_groups,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
attn_num_head_channels=attention_head_dim,
|
||||
downsample_padding=downsample_padding,
|
||||
@@ -153,6 +154,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
|
||||
add_upsample=not is_final_block,
|
||||
resnet_eps=norm_eps,
|
||||
resnet_act_fn=act_fn,
|
||||
resnet_groups=norm_num_groups,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
attn_num_head_channels=attention_head_dim,
|
||||
)
|
||||
|
||||
@@ -31,6 +31,7 @@ def get_down_block(
|
||||
resnet_eps,
|
||||
resnet_act_fn,
|
||||
attn_num_head_channels,
|
||||
resnet_groups=None,
|
||||
cross_attention_dim=None,
|
||||
downsample_padding=None,
|
||||
):
|
||||
@@ -44,6 +45,7 @@ def get_down_block(
|
||||
add_downsample=add_downsample,
|
||||
resnet_eps=resnet_eps,
|
||||
resnet_act_fn=resnet_act_fn,
|
||||
resnet_groups=resnet_groups,
|
||||
downsample_padding=downsample_padding,
|
||||
)
|
||||
elif down_block_type == "AttnDownBlock2D":
|
||||
@@ -55,6 +57,7 @@ def get_down_block(
|
||||
add_downsample=add_downsample,
|
||||
resnet_eps=resnet_eps,
|
||||
resnet_act_fn=resnet_act_fn,
|
||||
resnet_groups=resnet_groups,
|
||||
downsample_padding=downsample_padding,
|
||||
attn_num_head_channels=attn_num_head_channels,
|
||||
)
|
||||
@@ -69,6 +72,7 @@ def get_down_block(
|
||||
add_downsample=add_downsample,
|
||||
resnet_eps=resnet_eps,
|
||||
resnet_act_fn=resnet_act_fn,
|
||||
resnet_groups=resnet_groups,
|
||||
downsample_padding=downsample_padding,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
attn_num_head_channels=attn_num_head_channels,
|
||||
@@ -104,6 +108,7 @@ def get_down_block(
|
||||
add_downsample=add_downsample,
|
||||
resnet_eps=resnet_eps,
|
||||
resnet_act_fn=resnet_act_fn,
|
||||
resnet_groups=resnet_groups,
|
||||
downsample_padding=downsample_padding,
|
||||
)
|
||||
|
||||
@@ -119,6 +124,7 @@ def get_up_block(
|
||||
resnet_eps,
|
||||
resnet_act_fn,
|
||||
attn_num_head_channels,
|
||||
resnet_groups=None,
|
||||
cross_attention_dim=None,
|
||||
):
|
||||
up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
|
||||
@@ -132,6 +138,7 @@ def get_up_block(
|
||||
add_upsample=add_upsample,
|
||||
resnet_eps=resnet_eps,
|
||||
resnet_act_fn=resnet_act_fn,
|
||||
resnet_groups=resnet_groups,
|
||||
)
|
||||
elif up_block_type == "CrossAttnUpBlock2D":
|
||||
if cross_attention_dim is None:
|
||||
@@ -145,6 +152,7 @@ def get_up_block(
|
||||
add_upsample=add_upsample,
|
||||
resnet_eps=resnet_eps,
|
||||
resnet_act_fn=resnet_act_fn,
|
||||
resnet_groups=resnet_groups,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
attn_num_head_channels=attn_num_head_channels,
|
||||
)
|
||||
@@ -158,6 +166,7 @@ def get_up_block(
|
||||
add_upsample=add_upsample,
|
||||
resnet_eps=resnet_eps,
|
||||
resnet_act_fn=resnet_act_fn,
|
||||
resnet_groups=resnet_groups,
|
||||
attn_num_head_channels=attn_num_head_channels,
|
||||
)
|
||||
elif up_block_type == "SkipUpBlock2D":
|
||||
@@ -191,6 +200,7 @@ def get_up_block(
|
||||
add_upsample=add_upsample,
|
||||
resnet_eps=resnet_eps,
|
||||
resnet_act_fn=resnet_act_fn,
|
||||
resnet_groups=resnet_groups,
|
||||
)
|
||||
raise ValueError(f"{up_block_type} does not exist.")
|
||||
|
||||
@@ -323,6 +333,7 @@ class UNetMidBlock2DCrossAttn(nn.Module):
|
||||
in_channels // attn_num_head_channels,
|
||||
depth=1,
|
||||
context_dim=cross_attention_dim,
|
||||
num_groups=resnet_groups,
|
||||
)
|
||||
)
|
||||
resnets.append(
|
||||
@@ -414,6 +425,7 @@ class AttnDownBlock2D(nn.Module):
|
||||
num_head_channels=attn_num_head_channels,
|
||||
rescale_output_factor=output_scale_factor,
|
||||
eps=resnet_eps,
|
||||
num_groups=resnet_groups,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -498,6 +510,7 @@ class CrossAttnDownBlock2D(nn.Module):
|
||||
out_channels // attn_num_head_channels,
|
||||
depth=1,
|
||||
context_dim=cross_attention_dim,
|
||||
num_groups=resnet_groups,
|
||||
)
|
||||
)
|
||||
self.attentions = nn.ModuleList(attentions)
|
||||
@@ -966,6 +979,7 @@ class AttnUpBlock2D(nn.Module):
|
||||
num_head_channels=attn_num_head_channels,
|
||||
rescale_output_factor=output_scale_factor,
|
||||
eps=resnet_eps,
|
||||
num_groups=resnet_groups,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -1047,6 +1061,7 @@ class CrossAttnUpBlock2D(nn.Module):
|
||||
out_channels // attn_num_head_channels,
|
||||
depth=1,
|
||||
context_dim=cross_attention_dim,
|
||||
num_groups=resnet_groups,
|
||||
)
|
||||
)
|
||||
self.attentions = nn.ModuleList(attentions)
|
||||
|
||||
@@ -59,6 +59,7 @@ class Encoder(nn.Module):
|
||||
down_block_types=("DownEncoderBlock2D",),
|
||||
block_out_channels=(64,),
|
||||
layers_per_block=2,
|
||||
norm_num_groups=32,
|
||||
act_fn="silu",
|
||||
double_z=True,
|
||||
):
|
||||
@@ -86,6 +87,7 @@ class Encoder(nn.Module):
|
||||
resnet_eps=1e-6,
|
||||
downsample_padding=0,
|
||||
resnet_act_fn=act_fn,
|
||||
resnet_groups=norm_num_groups,
|
||||
attn_num_head_channels=None,
|
||||
temb_channels=None,
|
||||
)
|
||||
@@ -99,13 +101,12 @@ class Encoder(nn.Module):
|
||||
output_scale_factor=1,
|
||||
resnet_time_scale_shift="default",
|
||||
attn_num_head_channels=None,
|
||||
resnet_groups=32,
|
||||
resnet_groups=norm_num_groups,
|
||||
temb_channels=None,
|
||||
)
|
||||
|
||||
# out
|
||||
num_groups_out = 32
|
||||
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=num_groups_out, eps=1e-6)
|
||||
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6)
|
||||
self.conv_act = nn.SiLU()
|
||||
|
||||
conv_out_channels = 2 * out_channels if double_z else out_channels
|
||||
@@ -138,6 +139,7 @@ class Decoder(nn.Module):
|
||||
up_block_types=("UpDecoderBlock2D",),
|
||||
block_out_channels=(64,),
|
||||
layers_per_block=2,
|
||||
norm_num_groups=32,
|
||||
act_fn="silu",
|
||||
):
|
||||
super().__init__()
|
||||
@@ -156,7 +158,7 @@ class Decoder(nn.Module):
|
||||
output_scale_factor=1,
|
||||
resnet_time_scale_shift="default",
|
||||
attn_num_head_channels=None,
|
||||
resnet_groups=32,
|
||||
resnet_groups=norm_num_groups,
|
||||
temb_channels=None,
|
||||
)
|
||||
|
||||
@@ -178,6 +180,7 @@ class Decoder(nn.Module):
|
||||
add_upsample=not is_final_block,
|
||||
resnet_eps=1e-6,
|
||||
resnet_act_fn=act_fn,
|
||||
resnet_groups=norm_num_groups,
|
||||
attn_num_head_channels=None,
|
||||
temb_channels=None,
|
||||
)
|
||||
@@ -185,8 +188,7 @@ class Decoder(nn.Module):
|
||||
prev_output_channel = output_channel
|
||||
|
||||
# out
|
||||
num_groups_out = 32
|
||||
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=num_groups_out, eps=1e-6)
|
||||
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6)
|
||||
self.conv_act = nn.SiLU()
|
||||
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)
|
||||
|
||||
@@ -405,6 +407,7 @@ class VQModel(ModelMixin, ConfigMixin):
|
||||
latent_channels: int = 3,
|
||||
sample_size: int = 32,
|
||||
num_vq_embeddings: int = 256,
|
||||
norm_num_groups: int = 32,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -416,6 +419,7 @@ class VQModel(ModelMixin, ConfigMixin):
|
||||
block_out_channels=block_out_channels,
|
||||
layers_per_block=layers_per_block,
|
||||
act_fn=act_fn,
|
||||
norm_num_groups=norm_num_groups,
|
||||
double_z=False,
|
||||
)
|
||||
|
||||
@@ -433,6 +437,7 @@ class VQModel(ModelMixin, ConfigMixin):
|
||||
block_out_channels=block_out_channels,
|
||||
layers_per_block=layers_per_block,
|
||||
act_fn=act_fn,
|
||||
norm_num_groups=norm_num_groups,
|
||||
)
|
||||
|
||||
def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> VQEncoderOutput:
|
||||
@@ -509,6 +514,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
|
||||
layers_per_block: int = 1,
|
||||
act_fn: str = "silu",
|
||||
latent_channels: int = 4,
|
||||
norm_num_groups: int = 32,
|
||||
sample_size: int = 32,
|
||||
):
|
||||
super().__init__()
|
||||
@@ -521,6 +527,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
|
||||
block_out_channels=block_out_channels,
|
||||
layers_per_block=layers_per_block,
|
||||
act_fn=act_fn,
|
||||
norm_num_groups=norm_num_groups,
|
||||
double_z=True,
|
||||
)
|
||||
|
||||
@@ -531,6 +538,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
|
||||
up_block_types=up_block_types,
|
||||
block_out_channels=block_out_channels,
|
||||
layers_per_block=layers_per_block,
|
||||
norm_num_groups=norm_num_groups,
|
||||
act_fn=act_fn,
|
||||
)
|
||||
|
||||
|
||||
@@ -99,6 +99,26 @@ class ModelTesterMixin:
|
||||
expected_shape = inputs_dict["sample"].shape
|
||||
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
|
||||
|
||||
def test_forward_with_norm_groups(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
init_dict["norm_num_groups"] = 16
|
||||
init_dict["block_out_channels"] = (16, 32)
|
||||
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(**inputs_dict)
|
||||
|
||||
if isinstance(output, dict):
|
||||
output = output.sample
|
||||
|
||||
self.assertIsNotNone(output)
|
||||
expected_shape = inputs_dict["sample"].shape
|
||||
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
|
||||
|
||||
def test_forward_signature(self):
|
||||
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
|
||||
@@ -293,3 +293,7 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
# fmt: on
|
||||
|
||||
self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-2))
|
||||
|
||||
def test_forward_with_norm_groups(self):
|
||||
# not required for this model
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user