mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
feat: add act_fn param to OutValueFunctionBlock (#3994)
* feat: add act_fn param to OutValueFunctionBlock * feat: update unet1d tests to not use mish * feat: add `mish` as the default activation function Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * feat: drop mish tests from unet1d --------- Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -235,12 +235,12 @@ class OutConv1DBlock(nn.Module):
|
||||
|
||||
|
||||
class OutValueFunctionBlock(nn.Module):
|
||||
def __init__(self, fc_dim, embed_dim):
|
||||
def __init__(self, fc_dim, embed_dim, act_fn="mish"):
|
||||
super().__init__()
|
||||
self.final_block = nn.ModuleList(
|
||||
[
|
||||
nn.Linear(fc_dim + embed_dim, fc_dim // 2),
|
||||
nn.Mish(),
|
||||
get_activation(act_fn),
|
||||
nn.Linear(fc_dim // 2, 1),
|
||||
]
|
||||
)
|
||||
@@ -652,5 +652,5 @@ def get_out_block(*, out_block_type, num_groups_out, embed_dim, out_channels, ac
|
||||
if out_block_type == "OutConv1DBlock":
|
||||
return OutConv1DBlock(num_groups_out, out_channels, embed_dim, act_fn)
|
||||
elif out_block_type == "ValueFunction":
|
||||
return OutValueFunctionBlock(fc_dim, embed_dim)
|
||||
return OutValueFunctionBlock(fc_dim, embed_dim, act_fn)
|
||||
return None
|
||||
|
||||
@@ -52,27 +52,21 @@ class UNet1DModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
def test_training(self):
|
||||
pass
|
||||
|
||||
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
|
||||
def test_determinism(self):
|
||||
super().test_determinism()
|
||||
|
||||
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
|
||||
def test_outputs_equivalence(self):
|
||||
super().test_outputs_equivalence()
|
||||
|
||||
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
|
||||
def test_from_save_pretrained(self):
|
||||
super().test_from_save_pretrained()
|
||||
|
||||
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
|
||||
def test_from_save_pretrained_variant(self):
|
||||
super().test_from_save_pretrained_variant()
|
||||
|
||||
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
|
||||
def test_model_from_pretrained(self):
|
||||
super().test_model_from_pretrained()
|
||||
|
||||
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
|
||||
def test_output(self):
|
||||
super().test_output()
|
||||
|
||||
@@ -89,12 +83,11 @@ class UNet1DModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
"mid_block_type": "MidResTemporalBlock1D",
|
||||
"down_block_types": ("DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D"),
|
||||
"up_block_types": ("UpResnetBlock1D", "UpResnetBlock1D", "UpResnetBlock1D"),
|
||||
"act_fn": "mish",
|
||||
"act_fn": "swish",
|
||||
}
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
|
||||
def test_from_pretrained_hub(self):
|
||||
model, loading_info = UNet1DModel.from_pretrained(
|
||||
"bglick13/hopper-medium-v2-value-function-hor32", output_loading_info=True, subfolder="unet"
|
||||
@@ -107,7 +100,6 @@ class UNet1DModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
|
||||
assert image is not None, "Make sure output is not None"
|
||||
|
||||
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
|
||||
def test_output_pretrained(self):
|
||||
model = UNet1DModel.from_pretrained("bglick13/hopper-medium-v2-value-function-hor32", subfolder="unet")
|
||||
torch.manual_seed(0)
|
||||
@@ -177,27 +169,21 @@ class UNetRLModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
def output_shape(self):
|
||||
return (4, 14, 1)
|
||||
|
||||
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
|
||||
def test_determinism(self):
|
||||
super().test_determinism()
|
||||
|
||||
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
|
||||
def test_outputs_equivalence(self):
|
||||
super().test_outputs_equivalence()
|
||||
|
||||
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
|
||||
def test_from_save_pretrained(self):
|
||||
super().test_from_save_pretrained()
|
||||
|
||||
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
|
||||
def test_from_save_pretrained_variant(self):
|
||||
super().test_from_save_pretrained_variant()
|
||||
|
||||
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
|
||||
def test_model_from_pretrained(self):
|
||||
super().test_model_from_pretrained()
|
||||
|
||||
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
|
||||
def test_output(self):
|
||||
# UNetRL is a value-function is different output shape
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
@@ -241,7 +227,6 @@ class UNetRLModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
|
||||
def test_from_pretrained_hub(self):
|
||||
value_function, vf_loading_info = UNet1DModel.from_pretrained(
|
||||
"bglick13/hopper-medium-v2-value-function-hor32", output_loading_info=True, subfolder="value_function"
|
||||
@@ -254,7 +239,6 @@ class UNetRLModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
|
||||
assert image is not None, "Make sure output is not None"
|
||||
|
||||
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
|
||||
def test_output_pretrained(self):
|
||||
value_function, vf_loading_info = UNet1DModel.from_pretrained(
|
||||
"bglick13/hopper-medium-v2-value-function-hor32", output_loading_info=True, subfolder="value_function"
|
||||
|
||||
Reference in New Issue
Block a user