diff --git a/certbot/tests/error_handler_test.py b/certbot/tests/error_handler_test.py index 2eb1506be..2e2ffe2d9 100644 --- a/certbot/tests/error_handler_test.py +++ b/certbot/tests/error_handler_test.py @@ -7,18 +7,25 @@ import unittest import mock +def get_signals(signums): + """Get the handlers for an iterable of signums.""" + return dict((s, signal.getsignal(s)) for s in signums) + + +def set_signals(sig_handler_dict): + """Set the signal (keys) with the handler (values) from the input dict.""" + tuple(signal.signal(s, h) for (s, h) in sig_handler_dict.items()) + @contextlib.contextmanager def signal_receiver(signums): """Context manager to catch signals""" signals = [] prev_handlers = {} - for signum in signums: - prev_handlers[signum] = signal.getsignal(signum) - signal.signal(signum, lambda signum, _: signals.append(signum)) + prev_handlers = get_signals(signums) + set_signals(dict((s, lambda s, _: signals.append(s)) for s in signums)) yield signals - for signum in signums: - signal.signal(signum, prev_handlers[signum]) + set_signals(dict((s, prev_handlers[s]) for s in signums)) def send_signal(signum): @@ -54,6 +61,7 @@ class ErrorHandlerTest(unittest.TestCase): **self.init_kwargs) def test_context_manager_with_signal(self): + init_signals = get_signals(self.signals) with signal_receiver(self.signals) as signals_received: with self.handler: should_be_42 = 42 @@ -68,8 +76,7 @@ class ErrorHandlerTest(unittest.TestCase): self.init_func.assert_called_once_with(*self.init_args, **self.init_kwargs) for signum in self.signals: - sig = signal.getsignal(signum) - self.assertTrue((sig == signal.SIG_DFL) or (sig == signal.SIG_IGN)) + self.assertEqual(init_signals[signum], signal.getsignal(signum)) def test_bad_recovery(self): bad_func = mock.MagicMock(side_effect=[ValueError])