From 470f51cd26c75974ef88c697c0a94412a20f2264 Mon Sep 17 00:00:00 2001 From: Saurav Maheshkar Date: Wed, 19 Jul 2023 18:14:44 +0530 Subject: [PATCH] feat: add `act_fn` param to `OutValueFunctionBlock` (#3994) * feat: add act_fn param to OutValueFunctionBlock * feat: update unet1d tests to not use mish * feat: add `mish` as the default activation function Co-authored-by: Patrick von Platen * feat: drop mish tests from unet1d --------- Co-authored-by: Patrick von Platen --- src/diffusers/models/unet_1d_blocks.py | 6 +++--- tests/models/test_models_unet_1d.py | 18 +----------------- 2 files changed, 4 insertions(+), 20 deletions(-) diff --git a/src/diffusers/models/unet_1d_blocks.py b/src/diffusers/models/unet_1d_blocks.py index 3c04bffeea..84ae48e0f8 100644 --- a/src/diffusers/models/unet_1d_blocks.py +++ b/src/diffusers/models/unet_1d_blocks.py @@ -235,12 +235,12 @@ class OutConv1DBlock(nn.Module): class OutValueFunctionBlock(nn.Module): - def __init__(self, fc_dim, embed_dim): + def __init__(self, fc_dim, embed_dim, act_fn="mish"): super().__init__() self.final_block = nn.ModuleList( [ nn.Linear(fc_dim + embed_dim, fc_dim // 2), - nn.Mish(), + get_activation(act_fn), nn.Linear(fc_dim // 2, 1), ] ) @@ -652,5 +652,5 @@ def get_out_block(*, out_block_type, num_groups_out, embed_dim, out_channels, ac if out_block_type == "OutConv1DBlock": return OutConv1DBlock(num_groups_out, out_channels, embed_dim, act_fn) elif out_block_type == "ValueFunction": - return OutValueFunctionBlock(fc_dim, embed_dim) + return OutValueFunctionBlock(fc_dim, embed_dim, act_fn) return None diff --git a/tests/models/test_models_unet_1d.py b/tests/models/test_models_unet_1d.py index 99a243e911..1b58f9e616 100644 --- a/tests/models/test_models_unet_1d.py +++ b/tests/models/test_models_unet_1d.py @@ -52,27 +52,21 @@ class UNet1DModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): def test_training(self): pass - @unittest.skipIf(torch_device == "mps", "mish op not supported in MPS") def test_determinism(self): super().test_determinism() - @unittest.skipIf(torch_device == "mps", "mish op not supported in MPS") def test_outputs_equivalence(self): super().test_outputs_equivalence() - @unittest.skipIf(torch_device == "mps", "mish op not supported in MPS") def test_from_save_pretrained(self): super().test_from_save_pretrained() - @unittest.skipIf(torch_device == "mps", "mish op not supported in MPS") def test_from_save_pretrained_variant(self): super().test_from_save_pretrained_variant() - @unittest.skipIf(torch_device == "mps", "mish op not supported in MPS") def test_model_from_pretrained(self): super().test_model_from_pretrained() - @unittest.skipIf(torch_device == "mps", "mish op not supported in MPS") def test_output(self): super().test_output() @@ -89,12 +83,11 @@ class UNet1DModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): "mid_block_type": "MidResTemporalBlock1D", "down_block_types": ("DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D"), "up_block_types": ("UpResnetBlock1D", "UpResnetBlock1D", "UpResnetBlock1D"), - "act_fn": "mish", + "act_fn": "swish", } inputs_dict = self.dummy_input return init_dict, inputs_dict - @unittest.skipIf(torch_device == "mps", "mish op not supported in MPS") def test_from_pretrained_hub(self): model, loading_info = UNet1DModel.from_pretrained( "bglick13/hopper-medium-v2-value-function-hor32", output_loading_info=True, subfolder="unet" @@ -107,7 +100,6 @@ class UNet1DModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): assert image is not None, "Make sure output is not None" - @unittest.skipIf(torch_device == "mps", "mish op not supported in MPS") def test_output_pretrained(self): model = UNet1DModel.from_pretrained("bglick13/hopper-medium-v2-value-function-hor32", subfolder="unet") torch.manual_seed(0) @@ -177,27 +169,21 @@ class UNetRLModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): def output_shape(self): return (4, 14, 1) - @unittest.skipIf(torch_device == "mps", "mish op not supported in MPS") def test_determinism(self): super().test_determinism() - @unittest.skipIf(torch_device == "mps", "mish op not supported in MPS") def test_outputs_equivalence(self): super().test_outputs_equivalence() - @unittest.skipIf(torch_device == "mps", "mish op not supported in MPS") def test_from_save_pretrained(self): super().test_from_save_pretrained() - @unittest.skipIf(torch_device == "mps", "mish op not supported in MPS") def test_from_save_pretrained_variant(self): super().test_from_save_pretrained_variant() - @unittest.skipIf(torch_device == "mps", "mish op not supported in MPS") def test_model_from_pretrained(self): super().test_model_from_pretrained() - @unittest.skipIf(torch_device == "mps", "mish op not supported in MPS") def test_output(self): # UNetRL is a value-function is different output shape init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() @@ -241,7 +227,6 @@ class UNetRLModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): inputs_dict = self.dummy_input return init_dict, inputs_dict - @unittest.skipIf(torch_device == "mps", "mish op not supported in MPS") def test_from_pretrained_hub(self): value_function, vf_loading_info = UNet1DModel.from_pretrained( "bglick13/hopper-medium-v2-value-function-hor32", output_loading_info=True, subfolder="value_function" @@ -254,7 +239,6 @@ class UNetRLModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): assert image is not None, "Make sure output is not None" - @unittest.skipIf(torch_device == "mps", "mish op not supported in MPS") def test_output_pretrained(self): value_function, vf_loading_info = UNet1DModel.from_pretrained( "bglick13/hopper-medium-v2-value-function-hor32", output_loading_info=True, subfolder="value_function"