1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

feat: add act_fn param to OutValueFunctionBlock (#3994)

* feat: add act_fn param to OutValueFunctionBlock

* feat: update unet1d tests to not use mish

* feat: add `mish` as the default activation function

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* feat: drop mish tests from unet1d

---------

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
Saurav Maheshkar
2023-07-19 18:14:44 +05:30
committed by GitHub
parent b7e35dc782
commit 470f51cd26
2 changed files with 4 additions and 20 deletions

View File

@@ -235,12 +235,12 @@ class OutConv1DBlock(nn.Module):
class OutValueFunctionBlock(nn.Module):
def __init__(self, fc_dim, embed_dim):
def __init__(self, fc_dim, embed_dim, act_fn="mish"):
super().__init__()
self.final_block = nn.ModuleList(
[
nn.Linear(fc_dim + embed_dim, fc_dim // 2),
nn.Mish(),
get_activation(act_fn),
nn.Linear(fc_dim // 2, 1),
]
)
@@ -652,5 +652,5 @@ def get_out_block(*, out_block_type, num_groups_out, embed_dim, out_channels, ac
if out_block_type == "OutConv1DBlock":
return OutConv1DBlock(num_groups_out, out_channels, embed_dim, act_fn)
elif out_block_type == "ValueFunction":
return OutValueFunctionBlock(fc_dim, embed_dim)
return OutValueFunctionBlock(fc_dim, embed_dim, act_fn)
return None

View File

@@ -52,27 +52,21 @@ class UNet1DModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
def test_training(self):
pass
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
def test_determinism(self):
super().test_determinism()
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
def test_outputs_equivalence(self):
super().test_outputs_equivalence()
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
def test_from_save_pretrained(self):
super().test_from_save_pretrained()
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
def test_from_save_pretrained_variant(self):
super().test_from_save_pretrained_variant()
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
def test_model_from_pretrained(self):
super().test_model_from_pretrained()
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
def test_output(self):
super().test_output()
@@ -89,12 +83,11 @@ class UNet1DModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
"mid_block_type": "MidResTemporalBlock1D",
"down_block_types": ("DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D"),
"up_block_types": ("UpResnetBlock1D", "UpResnetBlock1D", "UpResnetBlock1D"),
"act_fn": "mish",
"act_fn": "swish",
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
def test_from_pretrained_hub(self):
model, loading_info = UNet1DModel.from_pretrained(
"bglick13/hopper-medium-v2-value-function-hor32", output_loading_info=True, subfolder="unet"
@@ -107,7 +100,6 @@ class UNet1DModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
assert image is not None, "Make sure output is not None"
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
def test_output_pretrained(self):
model = UNet1DModel.from_pretrained("bglick13/hopper-medium-v2-value-function-hor32", subfolder="unet")
torch.manual_seed(0)
@@ -177,27 +169,21 @@ class UNetRLModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
def output_shape(self):
return (4, 14, 1)
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
def test_determinism(self):
super().test_determinism()
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
def test_outputs_equivalence(self):
super().test_outputs_equivalence()
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
def test_from_save_pretrained(self):
super().test_from_save_pretrained()
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
def test_from_save_pretrained_variant(self):
super().test_from_save_pretrained_variant()
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
def test_model_from_pretrained(self):
super().test_model_from_pretrained()
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
def test_output(self):
# UNetRL is a value-function is different output shape
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
@@ -241,7 +227,6 @@ class UNetRLModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
inputs_dict = self.dummy_input
return init_dict, inputs_dict
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
def test_from_pretrained_hub(self):
value_function, vf_loading_info = UNet1DModel.from_pretrained(
"bglick13/hopper-medium-v2-value-function-hor32", output_loading_info=True, subfolder="value_function"
@@ -254,7 +239,6 @@ class UNetRLModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
assert image is not None, "Make sure output is not None"
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
def test_output_pretrained(self):
value_function, vf_loading_info = UNet1DModel.from_pretrained(
"bglick13/hopper-medium-v2-value-function-hor32", output_loading_info=True, subfolder="value_function"