From aa216a96d4ec2ede40dda8dfea81330669dca150 Mon Sep 17 00:00:00 2001 From: Brad Warren Date: Tue, 22 Sep 2015 18:24:22 -0700 Subject: [PATCH] Finished error_handler --- letsencrypt/error_handler.py | 51 +++++++++++++++---------- letsencrypt/tests/error_handler_test.py | 25 ++++++++++++ 2 files changed, 56 insertions(+), 20 deletions(-) create mode 100644 letsencrypt/tests/error_handler_test.py diff --git a/letsencrypt/error_handler.py b/letsencrypt/error_handler.py index 884c73927..b82f49b5a 100644 --- a/letsencrypt/error_handler.py +++ b/letsencrypt/error_handler.py @@ -3,44 +3,55 @@ import os import signal -_SIGNALS = [signal.SIGTERM] if os.name == "nt" else - [signal.SIGTERM, signal.SIGHUP, signal.SIGQUIT, - signal.SIGXCPU, signal.SIGXFSZ, signal.SIGPWR,] +_SIGNALS = ([signal.SIGTERM] if os.name == "nt" else + [signal.SIGTERM, signal.SIGHUP, signal.SIGQUIT, + signal.SIGXCPU, signal.SIGXFSZ, signal.SIGPWR]) -class ErrorHandler(): +class ErrorHandler(object): """Registers and calls cleanup functions in case of an error.""" def __init__(self, func=None): - self.funcs = [] - if func: - self.funcs.append(func) + self.funcs = [func] if func else [] + self.prev_handlers = {} def __enter__(self): self.set_signal_handlers() def __exit__(self, exec_type, exec_value, traceback): if exec_value is not None: - self.cleanup() + self.call_registered() self.reset_signal_handlers() def register(self, func): """Registers func to be called if an error occurs.""" self.funcs.append(func) - - def cleanup(self): - """Calls all registered functions.""" - while self.funcs: - self.funcs.pop()() + + def call_registered(self): + """Calls all functions in the order they were registered.""" + for func in self.funcs: + func() def set_signal_handlers(self): - for signal_type in _SIGNALS: - signal.signal(signal_type, self._signal_handler) + """Sets signal handlers for signals in _SIGNALS.""" + for signum in _SIGNALS: + prev_handler = signal.getsignal(signum) + # If prev_handler is None, the handler was set outside of Python + if prev_handler is not None: + self.prev_handlers[signum] = prev_handler + signal.signal(signum, self._signal_handler) def reset_signal_handlers(self): - for signal_type in _SIGNALS: - signal.signal(signal_type, signal.SIG_DFL) + """Resets signal handlers for signals in _SIGNALS.""" + for signum in self.prev_handlers: + signal.signal(signum, self.prev_handlers[signum]) + self.prev_handlers.clear() - def _signal_handler(self, signum, frame): - self.cleanup() - signal.signal(signal_type, signal.SIG_DFL) + def _signal_handler(self, signum, _): + """Calls registered functions and the previous signal handler. + + :param int signum: number of current signal + + """ + self.call_registered() + signal.signal(signum, self.prev_handlers[signum]) os.kill(os.getpid(), signum) diff --git a/letsencrypt/tests/error_handler_test.py b/letsencrypt/tests/error_handler_test.py new file mode 100644 index 000000000..6c6d02ec3 --- /dev/null +++ b/letsencrypt/tests/error_handler_test.py @@ -0,0 +1,25 @@ +"""Tests for letsencrypt.error_handler.""" +import unittest + +import mock + + +class ErrorHandlerTest(unittest.TestCase): + """Tests for letsencrypt.error_handler.""" + + def setUp(self): + from letsencrypt import error_handler + self.init_func = mock.MagicMock() + self.error_handler = error_handler.ErrorHandler(self.init_func) + + def test_context_manager(self): + try: + with self.error_handler: + raise ValueError + except ValueError: + pass + self.init_func.assert_called_once_with() + + +if __name__ == "__main__": + unittest.main() # pragma: no cover