diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 55062c322e..e4cedbff8c 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -113,6 +113,7 @@ class SpatialTransformer(nn.Module): d_head: int, depth: int = 1, dropout: float = 0.0, + num_groups: int = 32, context_dim: Optional[int] = None, ): super().__init__() @@ -120,7 +121,7 @@ class SpatialTransformer(nn.Module): self.d_head = d_head self.in_channels = in_channels inner_dim = n_heads * d_head - self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + self.norm = torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) diff --git a/src/diffusers/models/unet_2d.py b/src/diffusers/models/unet_2d.py index c3ab621a2c..89321a5503 100644 --- a/src/diffusers/models/unet_2d.py +++ b/src/diffusers/models/unet_2d.py @@ -114,6 +114,7 @@ class UNet2DModel(ModelMixin, ConfigMixin): add_downsample=not is_final_block, resnet_eps=norm_eps, resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, attn_num_head_channels=attention_head_dim, downsample_padding=downsample_padding, ) @@ -151,6 +152,7 @@ class UNet2DModel(ModelMixin, ConfigMixin): add_upsample=not is_final_block, resnet_eps=norm_eps, resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, attn_num_head_channels=attention_head_dim, ) self.up_blocks.append(up_block) diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index 92caaca92e..a9989a5ed2 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -114,6 +114,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin): add_downsample=not is_final_block, resnet_eps=norm_eps, resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, cross_attention_dim=cross_attention_dim, attn_num_head_channels=attention_head_dim, downsample_padding=downsample_padding, @@ -153,6 +154,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin): add_upsample=not is_final_block, resnet_eps=norm_eps, resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, cross_attention_dim=cross_attention_dim, attn_num_head_channels=attention_head_dim, ) diff --git a/src/diffusers/models/unet_blocks.py b/src/diffusers/models/unet_blocks.py index 88349075d2..d76c79762c 100644 --- a/src/diffusers/models/unet_blocks.py +++ b/src/diffusers/models/unet_blocks.py @@ -31,6 +31,7 @@ def get_down_block( resnet_eps, resnet_act_fn, attn_num_head_channels, + resnet_groups=None, cross_attention_dim=None, downsample_padding=None, ): @@ -44,6 +45,7 @@ def get_down_block( add_downsample=add_downsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, downsample_padding=downsample_padding, ) elif down_block_type == "AttnDownBlock2D": @@ -55,6 +57,7 @@ def get_down_block( add_downsample=add_downsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, downsample_padding=downsample_padding, attn_num_head_channels=attn_num_head_channels, ) @@ -69,6 +72,7 @@ def get_down_block( add_downsample=add_downsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, downsample_padding=downsample_padding, cross_attention_dim=cross_attention_dim, attn_num_head_channels=attn_num_head_channels, @@ -104,6 +108,7 @@ def get_down_block( add_downsample=add_downsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, downsample_padding=downsample_padding, ) @@ -119,6 +124,7 @@ def get_up_block( resnet_eps, resnet_act_fn, attn_num_head_channels, + resnet_groups=None, cross_attention_dim=None, ): up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type @@ -132,6 +138,7 @@ def get_up_block( add_upsample=add_upsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, ) elif up_block_type == "CrossAttnUpBlock2D": if cross_attention_dim is None: @@ -145,6 +152,7 @@ def get_up_block( add_upsample=add_upsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, cross_attention_dim=cross_attention_dim, attn_num_head_channels=attn_num_head_channels, ) @@ -158,6 +166,7 @@ def get_up_block( add_upsample=add_upsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, attn_num_head_channels=attn_num_head_channels, ) elif up_block_type == "SkipUpBlock2D": @@ -191,6 +200,7 @@ def get_up_block( add_upsample=add_upsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, ) raise ValueError(f"{up_block_type} does not exist.") @@ -323,6 +333,7 @@ class UNetMidBlock2DCrossAttn(nn.Module): in_channels // attn_num_head_channels, depth=1, context_dim=cross_attention_dim, + num_groups=resnet_groups, ) ) resnets.append( @@ -414,6 +425,7 @@ class AttnDownBlock2D(nn.Module): num_head_channels=attn_num_head_channels, rescale_output_factor=output_scale_factor, eps=resnet_eps, + num_groups=resnet_groups, ) ) @@ -498,6 +510,7 @@ class CrossAttnDownBlock2D(nn.Module): out_channels // attn_num_head_channels, depth=1, context_dim=cross_attention_dim, + num_groups=resnet_groups, ) ) self.attentions = nn.ModuleList(attentions) @@ -966,6 +979,7 @@ class AttnUpBlock2D(nn.Module): num_head_channels=attn_num_head_channels, rescale_output_factor=output_scale_factor, eps=resnet_eps, + num_groups=resnet_groups, ) ) @@ -1047,6 +1061,7 @@ class CrossAttnUpBlock2D(nn.Module): out_channels // attn_num_head_channels, depth=1, context_dim=cross_attention_dim, + num_groups=resnet_groups, ) ) self.attentions = nn.ModuleList(attentions) diff --git a/src/diffusers/models/vae.py b/src/diffusers/models/vae.py index 599c88feeb..fe89b41c07 100644 --- a/src/diffusers/models/vae.py +++ b/src/diffusers/models/vae.py @@ -59,6 +59,7 @@ class Encoder(nn.Module): down_block_types=("DownEncoderBlock2D",), block_out_channels=(64,), layers_per_block=2, + norm_num_groups=32, act_fn="silu", double_z=True, ): @@ -86,6 +87,7 @@ class Encoder(nn.Module): resnet_eps=1e-6, downsample_padding=0, resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, attn_num_head_channels=None, temb_channels=None, ) @@ -99,13 +101,12 @@ class Encoder(nn.Module): output_scale_factor=1, resnet_time_scale_shift="default", attn_num_head_channels=None, - resnet_groups=32, + resnet_groups=norm_num_groups, temb_channels=None, ) # out - num_groups_out = 32 - self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=num_groups_out, eps=1e-6) + self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6) self.conv_act = nn.SiLU() conv_out_channels = 2 * out_channels if double_z else out_channels @@ -138,6 +139,7 @@ class Decoder(nn.Module): up_block_types=("UpDecoderBlock2D",), block_out_channels=(64,), layers_per_block=2, + norm_num_groups=32, act_fn="silu", ): super().__init__() @@ -156,7 +158,7 @@ class Decoder(nn.Module): output_scale_factor=1, resnet_time_scale_shift="default", attn_num_head_channels=None, - resnet_groups=32, + resnet_groups=norm_num_groups, temb_channels=None, ) @@ -178,6 +180,7 @@ class Decoder(nn.Module): add_upsample=not is_final_block, resnet_eps=1e-6, resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, attn_num_head_channels=None, temb_channels=None, ) @@ -185,8 +188,7 @@ class Decoder(nn.Module): prev_output_channel = output_channel # out - num_groups_out = 32 - self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=num_groups_out, eps=1e-6) + self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6) self.conv_act = nn.SiLU() self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1) @@ -405,6 +407,7 @@ class VQModel(ModelMixin, ConfigMixin): latent_channels: int = 3, sample_size: int = 32, num_vq_embeddings: int = 256, + norm_num_groups: int = 32, ): super().__init__() @@ -416,6 +419,7 @@ class VQModel(ModelMixin, ConfigMixin): block_out_channels=block_out_channels, layers_per_block=layers_per_block, act_fn=act_fn, + norm_num_groups=norm_num_groups, double_z=False, ) @@ -433,6 +437,7 @@ class VQModel(ModelMixin, ConfigMixin): block_out_channels=block_out_channels, layers_per_block=layers_per_block, act_fn=act_fn, + norm_num_groups=norm_num_groups, ) def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> VQEncoderOutput: @@ -509,6 +514,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin): layers_per_block: int = 1, act_fn: str = "silu", latent_channels: int = 4, + norm_num_groups: int = 32, sample_size: int = 32, ): super().__init__() @@ -521,6 +527,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin): block_out_channels=block_out_channels, layers_per_block=layers_per_block, act_fn=act_fn, + norm_num_groups=norm_num_groups, double_z=True, ) @@ -531,6 +538,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin): up_block_types=up_block_types, block_out_channels=block_out_channels, layers_per_block=layers_per_block, + norm_num_groups=norm_num_groups, act_fn=act_fn, ) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 5635646108..9095e39123 100644 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -99,6 +99,26 @@ class ModelTesterMixin: expected_shape = inputs_dict["sample"].shape self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") + def test_forward_with_norm_groups(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + init_dict["norm_num_groups"] = 16 + init_dict["block_out_channels"] = (16, 32) + + model = self.model_class(**init_dict) + model.to(torch_device) + model.eval() + + with torch.no_grad(): + output = model(**inputs_dict) + + if isinstance(output, dict): + output = output.sample + + self.assertIsNotNone(output) + expected_shape = inputs_dict["sample"].shape + self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") + def test_forward_signature(self): init_dict, _ = self.prepare_init_args_and_inputs_for_common() diff --git a/tests/test_models_unet.py b/tests/test_models_unet.py index d9c4967b57..b16a4e1c44 100644 --- a/tests/test_models_unet.py +++ b/tests/test_models_unet.py @@ -293,3 +293,7 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase): # fmt: on self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-2)) + + def test_forward_with_norm_groups(self): + # not required for this model + pass