1
0
mirror of https://github.com/certbot/certbot.git synced 2026-01-21 19:01:07 +03:00

Don't expose threads from ServerManager.

This commit is contained in:
Jakub Warmuz
2015-10-07 19:20:47 +00:00
parent f0214ddf9a
commit 7102f9ef4b
2 changed files with 30 additions and 29 deletions

View File

@@ -39,8 +39,10 @@ class ServerManager(object):
will serve the same URLs!
"""
_Instance = collections.namedtuple("_Instance", "server thread")
def __init__(self, certs, simple_http_resources):
self._servers = {}
self._instances = {}
self.certs = certs
self.simple_http_resources = simple_http_resources
@@ -53,13 +55,12 @@ class ServerManager(object):
:param int port: Port to run the server on.
:param bool tls: TLS or non-TLS?
:returns: Server instance (`ACMEServerMixin`) and the
corresponding (already started) thread (`threading.Thread`).
:rtype: tuple
:returns: Server instance.
:rtype: ACMEServerMixin
"""
if port in self._servers:
return self._servers[port]
if port in self._instances:
return self._instances[port].server
logger.debug("Starting new server at %s (tls=%s)", port, tls)
handler = acme_standalone.ACMERequestHandler.partial_init(
@@ -84,8 +85,8 @@ class ServerManager(object):
logger.debug("Starting server at %s:%d", host, real_port)
thread.start()
self._servers[real_port] = (server, thread)
return self._servers[real_port]
self._instances[real_port] = self._Instance(server, thread)
return server
def stop(self, port):
"""Stop ACME server running on the specified ``port``.
@@ -93,10 +94,10 @@ class ServerManager(object):
:param int port:
"""
server, thread = self._servers[port]
server.shutdown2()
thread.join()
del self._servers[port]
instance = self._instances[port]
instance.server.shutdown2()
instance.thread.join()
del self._instances[port]
def running(self):
"""Return all running instances.
@@ -104,11 +105,12 @@ class ServerManager(object):
Once the server is stopped using `stop`, it will not be
returned.
:returns: Mapping from port to ``(server, thread)``.
:returns: Mapping from ``port`` to ``server``.
:rtype: tuple
"""
return self._servers.copy()
return dict((port, instance.server) for port, instance
in six.iteritems(self._instances))
SUPPORTED_CHALLENGES = set([challenges.DVSNI, challenges.SimpleHTTP])
@@ -233,7 +235,7 @@ class Authenticator(common.Plugin):
for achall in achalls:
if isinstance(achall, achallenges.SimpleHTTP):
server, _ = self.servers.run(self.config.simple_http_port, tls=tls)
server = self.servers.run(self.config.simple_http_port, tls=tls)
response, validation = achall.gen_response_and_validation(tls=tls)
self.simple_http_resources.add(
acme_standalone.SimpleHTTPRequestHandler.SimpleHTTPResource(
@@ -242,7 +244,7 @@ class Authenticator(common.Plugin):
cert = self.simple_http_cert
domain = achall.domain
else: # DVSNI
server, _ = self.servers.run(self.config.dvsni_port, tls=True)
server = self.servers.run(self.config.dvsni_port, tls=True)
response, cert, _ = achall.gen_cert_and_response(self.key)
domain = response.z_domain
self.certs[domain] = (self.key, cert)
@@ -257,6 +259,6 @@ class Authenticator(common.Plugin):
for achall in achalls:
if achall in server_achalls:
server_achalls.remove(achall)
for port, (server, _) in six.iteritems(self.servers.running()):
for port, server in six.iteritems(self.servers.running()):
if not self.served[server]:
self.servers.stop(port)

View File

@@ -33,9 +33,9 @@ class ServerManagerTest(unittest.TestCase):
self.mgr.simple_http_resources is self.simple_http_resources)
def _test_run_stop(self, tls):
server, _ = self.mgr.run(port=0, tls=tls)
port = server.socket.getsockname()[1]
self.assertEqual(self.mgr.running(), {port: (server, mock.ANY)})
server = self.mgr.run(port=0, tls=tls)
port = server.socket.getsockname()[1] # pylint: disable=no-member
self.assertEqual(self.mgr.running(), {port: server})
self.mgr.stop(port=port)
self.assertEqual(self.mgr.running(), {})
@@ -46,12 +46,11 @@ class ServerManagerTest(unittest.TestCase):
self._test_run_stop(tls=False)
def test_run_idempotent(self):
server, thread = self.mgr.run(port=0, tls=False)
port = server.socket.getsockname()[1]
server2, thread2 = self.mgr.run(port=port, tls=False)
self.assertEqual(self.mgr.running(), {port: (server, thread)})
server = self.mgr.run(port=0, tls=False)
port = server.socket.getsockname()[1] # pylint: disable=no-member
server2 = self.mgr.run(port=port, tls=False)
self.assertEqual(self.mgr.running(), {port: server})
self.assertTrue(server is server2)
self.assertTrue(thread is thread2)
self.mgr.stop(port)
self.assertEqual(self.mgr.running(), {})
@@ -166,7 +165,7 @@ class AuthenticatorTest(unittest.TestCase):
self.auth.servers = mock.MagicMock()
def _run(port, tls): # pylint: disable=unused-argument
return "server{0}".format(port), "thread{0}".format(port)
return "server{0}".format(port)
self.auth.servers.run.side_effect = _run
responses = self.auth.perform2([simple_http, dvsni])
@@ -191,8 +190,8 @@ class AuthenticatorTest(unittest.TestCase):
def test_cleanup(self):
self.auth.servers = mock.Mock()
self.auth.servers.running.return_value = {
1: ("server1", "thread1"),
2: ("server2", "thread2"),
1: "server1",
2: "server2",
}
self.auth.served["server1"].add("chall1")
self.auth.served["server2"].update(["chall2", "chall3"])
@@ -203,7 +202,7 @@ class AuthenticatorTest(unittest.TestCase):
self.auth.servers.stop.assert_called_once_with(1)
self.auth.servers.running.return_value = {
2: ("server2", "thread2"),
2: "server2",
}
self.auth.cleanup(["chall2"])
self.assertEqual(self.auth.served, {