mirror of
https://github.com/certbot/certbot.git
synced 2026-01-21 19:01:07 +03:00
acme: _serve_sni -> SSLSocket
This commit is contained in:
@@ -26,47 +26,95 @@ logger = logging.getLogger(__name__)
|
||||
_DEFAULT_DVSNI_SSL_METHOD = OpenSSL.SSL.SSLv23_METHOD
|
||||
|
||||
|
||||
def _serve_sni(certs, sock, reuseaddr=True, method=_DEFAULT_DVSNI_SSL_METHOD,
|
||||
accept=None):
|
||||
"""Start SNI-enabled server, that drops connection after handshake.
|
||||
class SSLSocket(object): # pylint: disable=too-few-public-methods
|
||||
"""SSL wrapper for sockets."""
|
||||
|
||||
:param certs: Mapping from SNI name to ``(key, cert)`` `tuple`.
|
||||
:param sock: Already bound socket.
|
||||
:param bool reuseaddr: Should `socket.SO_REUSEADDR` be set?
|
||||
:param method: See `OpenSSL.SSL.Context` for allowed values.
|
||||
:param accept: Callable that doesn't take any arguments and
|
||||
returns ``True`` if more connections should be served.
|
||||
def __init__(self, sock, certs, method=_DEFAULT_DVSNI_SSL_METHOD):
|
||||
self.sock = sock
|
||||
self.certs = certs
|
||||
self.method = method
|
||||
|
||||
"""
|
||||
def _pick_certificate(connection):
|
||||
def __getattr__(self, name):
|
||||
return getattr(self.sock, name)
|
||||
|
||||
def _pick_certificate_cb(self, connection):
|
||||
"""SNI certificate callback.
|
||||
|
||||
This method will set a new OpenSSL context object for this
|
||||
connection when an incoming connection provides an SNI name
|
||||
(in order to serve the appropriate certificate, if any).
|
||||
|
||||
:param connection: The TLS connection object on which the SNI
|
||||
extension was received.
|
||||
:type connection: :class:`OpenSSL.Connection`
|
||||
|
||||
"""
|
||||
server_name = connection.get_servername()
|
||||
try:
|
||||
key, cert = certs[connection.get_servername()]
|
||||
key, cert = self.certs[server_name]
|
||||
except KeyError:
|
||||
logger.debug("Server name (%s) not recognized, dropping SSL",
|
||||
server_name)
|
||||
return
|
||||
new_context = OpenSSL.SSL.Context(method)
|
||||
new_context = OpenSSL.SSL.Context(self.method)
|
||||
new_context.use_privatekey(key)
|
||||
new_context.use_certificate(cert)
|
||||
connection.set_context(new_context)
|
||||
|
||||
if reuseaddr:
|
||||
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
sock.listen(1) # TODO: add func arg?
|
||||
class FakeConnection(object):
|
||||
"""Fake OpenSSL.SSL.Connection."""
|
||||
|
||||
while accept is None or accept():
|
||||
server, addr = sock.accept()
|
||||
logger.debug('Received connection from %s', addr)
|
||||
# pylint: disable=missing-docstring
|
||||
|
||||
with contextlib.closing(server):
|
||||
context = OpenSSL.SSL.Context(method)
|
||||
context.set_tlsext_servername_callback(_pick_certificate)
|
||||
def __init__(self, connection):
|
||||
self._wrapped = connection
|
||||
self._makefile_refs = 0
|
||||
|
||||
server_ssl = OpenSSL.SSL.Connection(context, server)
|
||||
server_ssl.set_accept_state()
|
||||
try:
|
||||
server_ssl.do_handshake()
|
||||
server_ssl.shutdown()
|
||||
except OpenSSL.SSL.Error as error:
|
||||
raise errors.Error(error)
|
||||
def __getattr__(self, name):
|
||||
return getattr(self._wrapped, name)
|
||||
|
||||
def shutdown(self, *unused_args):
|
||||
# OpenSSL.SSL.Connection.shutdown doesn't accept any args
|
||||
return self._wrapped.shutdown()
|
||||
|
||||
# stuff below ripped off from
|
||||
# https://hg.python.org/cpython/file/2.7/Lib/ssl.py
|
||||
# XXX: this uses Python's internal API
|
||||
|
||||
def makefile(self, mode='r', bufsize=-1):
|
||||
self._makefile_refs += 1
|
||||
# SocketServer.StreamRequesthandler.finish will try to
|
||||
# close the wfile/rfile. close=True causes curl: (56)
|
||||
# GnuTLS recv error (-110): The TLS connection was
|
||||
# non-properly terminated.
|
||||
# TODO: doesn't work in Python3
|
||||
# pylint: disable=protected-access
|
||||
return socket._fileobject(self._wrapped, mode, bufsize, close=False)
|
||||
|
||||
def close(self):
|
||||
if self._makefile_refs < 1:
|
||||
self._wrapped.close()
|
||||
else:
|
||||
self._makefile_refs -= 1
|
||||
|
||||
def accept(self): # pylint: disable=missing-docstring
|
||||
sock, addr = self.sock.accept()
|
||||
|
||||
context = OpenSSL.SSL.Context(self.method)
|
||||
context.set_tlsext_servername_callback(self._pick_certificate_cb)
|
||||
|
||||
ssl_sock = self.FakeConnection(OpenSSL.SSL.Connection(context, sock))
|
||||
ssl_sock.set_accept_state()
|
||||
|
||||
logger.debug("Performing handshake with %s", addr)
|
||||
try:
|
||||
ssl_sock.do_handshake()
|
||||
except OpenSSL.SSL.Error as error:
|
||||
# _pick_certificate_cb might have returned without
|
||||
# creating SSL context (wrong server name)
|
||||
raise socket.error(error)
|
||||
|
||||
return ssl_sock, addr
|
||||
|
||||
|
||||
def probe_sni(name, host, port=443, timeout=300,
|
||||
|
||||
@@ -4,45 +4,43 @@ import threading
|
||||
import time
|
||||
import unittest
|
||||
|
||||
import mock
|
||||
import OpenSSL
|
||||
from six.moves import socketserver # pylint: disable=import-error
|
||||
|
||||
from acme import errors
|
||||
from acme import jose
|
||||
from acme import test_util
|
||||
|
||||
|
||||
class ServeProbeSNITest(unittest.TestCase):
|
||||
"""Tests for acme.crypto_util._serve_sni/probe_sni."""
|
||||
class SSLSocketAndProbeSNITest(unittest.TestCase):
|
||||
"""Tests for acme.crypto_util.SSLSocket/probe_sni."""
|
||||
|
||||
def setUp(self):
|
||||
self.cert = test_util.load_cert('cert.pem')
|
||||
key = OpenSSL.crypto.load_privatekey(
|
||||
OpenSSL.crypto.FILETYPE_PEM,
|
||||
test_util.load_vector('rsa512_key.pem'))
|
||||
key = test_util.load_pyopenssl_private_key('rsa512_key.pem')
|
||||
# pylint: disable=protected-access
|
||||
certs = {b'foo': (key, self.cert._wrapped)}
|
||||
|
||||
sock = socket.socket()
|
||||
sock.bind(('', 0)) # pick random port
|
||||
self.port = sock.getsockname()[1]
|
||||
from acme.crypto_util import SSLSocket
|
||||
|
||||
self.server = threading.Thread(target=self._run_server, args=(certs, sock))
|
||||
self.server.start()
|
||||
class _TestServer(socketserver.TCPServer):
|
||||
|
||||
# pylint: disable=too-few-public-methods
|
||||
# six.moves.* | pylint: disable=attribute-defined-outside-init,no-init
|
||||
|
||||
def server_bind(self): # pylint: disable=missing-docstring
|
||||
self.socket = SSLSocket(socket.socket(), certs=certs)
|
||||
socketserver.TCPServer.server_bind(self)
|
||||
|
||||
self.server = _TestServer(('', 0), socketserver.BaseRequestHandler)
|
||||
self.port = self.server.socket.getsockname()[1]
|
||||
self.server_thread = threading.Thread(
|
||||
# pylint: disable=no-member
|
||||
target=self.server.handle_request)
|
||||
self.server_thread.start()
|
||||
time.sleep(1) # TODO: avoid race conditions in other way
|
||||
|
||||
@classmethod
|
||||
def _run_server(cls, certs, sock):
|
||||
from acme.crypto_util import _serve_sni
|
||||
# TODO: improve testing of server errors and their conditions
|
||||
try:
|
||||
return _serve_sni(
|
||||
certs, sock, accept=mock.Mock(side_effect=[True, False]))
|
||||
except errors.Error:
|
||||
pass
|
||||
|
||||
def tearDown(self):
|
||||
self.server.join()
|
||||
self.server_thread.join()
|
||||
|
||||
def _probe(self, name):
|
||||
from acme.crypto_util import probe_sni
|
||||
|
||||
Reference in New Issue
Block a user