diff --git a/certbot/auth_handler.py b/certbot/auth_handler.py index 4b8160ef9..68389d1f8 100644 --- a/certbot/auth_handler.py +++ b/certbot/auth_handler.py @@ -118,7 +118,7 @@ class AuthHandler(object): """Get Responses for challenges from authenticators.""" resp = [] all_achalls = self._get_all_achalls(aauthzrs) - with error_handler.ErrorHandler(self._cleanup_challenges, all_achalls): + with error_handler.ErrorHandler(self._cleanup_challenges, aauthzrs, all_achalls): try: if all_achalls: resp = self.auth.perform(all_achalls) @@ -294,19 +294,18 @@ class AuthHandler(object): chall_prefs.extend(plugin_pref) return chall_prefs - def _cleanup_challenges(self, aauthzrs, achall_list=None): + def _cleanup_challenges(self, aauthzrs, achalls): """Cleanup challenges. - If achall_list is not provided, cleanup all achallenges. + :param aauthzrs: authorizations and their selected annotated + challenges + :type aauthzrs: `list` of `AnnotatedAuthzr` + :param achalls: annotated challenges to cleanup + :type achalls: `list` of :class:`certbot.achallenges.AnnotatedChallenge` """ logger.info("Cleaning up challenges") - if achall_list is None: - achalls = self._get_all_achalls(aauthzrs) - else: - achalls = achall_list - if achalls: self.auth.cleanup(achalls) for achall in achalls: diff --git a/certbot/tests/auth_handler_test.py b/certbot/tests/auth_handler_test.py index 54e284d9e..a4ac9eb73 100644 --- a/certbot/tests/auth_handler_test.py +++ b/certbot/tests/auth_handler_test.py @@ -278,6 +278,17 @@ class HandleAuthorizationsTest(unittest.TestCase): self.assertRaises( errors.AuthorizationError, self.handler.handle_authorizations, mock_order) + def test_perform_error(self): + self.mock_auth.perform.side_effect = errors.AuthorizationError + + authzr = gen_dom_authzr(domain="0", challs=acme_util.CHALLENGES, combos=True) + mock_order = mock.MagicMock(authorizations=[authzr]) + self.assertRaises(errors.AuthorizationError, self.handler.handle_authorizations, mock_order) + + self.assertEqual(self.mock_auth.cleanup.call_count, 1) + self.assertEqual( + self.mock_auth.cleanup.call_args[0][0][0].typ, "tls-sni-01") + def _validate_all(self, aauthzrs, unused_1, unused_2): for i, aauthzr in enumerate(aauthzrs): azr = aauthzr.authzr