diff --git a/letsencrypt/crypto_util.py b/letsencrypt/crypto_util.py index 94617eef6..db4b629d2 100644 --- a/letsencrypt/crypto_util.py +++ b/letsencrypt/crypto_util.py @@ -234,42 +234,78 @@ def make_ss_cert(key_str, domains, not_before=None, return cert.as_pem() -def _request_san(req): # TODO: implement directly in PyOpenSSL! +def _pyopenssl_cert_or_req_san(cert_or_req): + """Get Subject Alternative Names from certificate or CSR using pyOpenSSL. + + .. todo:: Implement directly in PyOpenSSL! + + :param cert_or_req: Certificate or CSR. + :type cert_or_req: `OpenSSL.crypto.X509` or `OpenSSL.crypto.X509Req`. + + :returns: A list of Subject Alternative Names. + :rtype: list + + """ # constants based on implementation of # OpenSSL.crypto.X509Error._subjectAltNameString parts_separator = ", " part_separator = ":" extension_short_name = "subjectAltName" + if hasattr(cert_or_req, 'get_extensions'): # X509Req + extensions = cert_or_req.get_extensions() + else: # X509 + extensions = [cert_or_req.get_extension(i) + for i in xrange(cert_or_req.get_extension_count())] + # pylint: disable=protected-access,no-member label = OpenSSL.crypto.X509Extension._prefixes[OpenSSL.crypto._lib.GEN_DNS] assert parts_separator not in label prefix = label + part_separator - extensions = [ext._subjectAltNameString().split(parts_separator) - for ext in req.get_extensions() - if ext.get_short_name() == extension_short_name] + san_extensions = [ + ext._subjectAltNameString().split(parts_separator) + for ext in extensions if ext.get_short_name() == extension_short_name] # WARNING: this function assumes that no SAN can include # parts_separator, hence the split! - return [part.split(part_separator)[1] for parts in extensions + return [part.split(part_separator)[1] for parts in san_extensions for part in parts if part.startswith(prefix)] -def get_sans_from_csr(csr, typ=OpenSSL.crypto.FILETYPE_PEM): - """Get list of Subject Alternative Names from signing request. - - :param str csr: Certificate Signing Request in PEM format (must contain - one or more subjectAlternativeNames, or the function will fail, - raising ValueError) - - :returns: List of referenced subject alternative names - :rtype: list - - """ +def _get_sans_from_cert_or_req( + cert_or_req_str, load_func, typ=OpenSSL.crypto.FILETYPE_PEM): try: - request = OpenSSL.crypto.load_certificate_request(typ, csr) + cert_or_req = load_func(typ, cert_or_req_str) except OpenSSL.crypto.Error as error: logging.exception(error) raise - return _request_san(request) + return _pyopenssl_cert_or_req_san(cert_or_req) + + +def get_sans_from_cert(cert, typ=OpenSSL.crypto.FILETYPE_PEM): + """Get a list of Subject Alternative Names from a certificate. + + :param str csr: Certificate (encoded). + :param typ: `OpenSSL.crypto.FILETYPE_PEM` or `OpenSSL.crypto.FILETYPE_ASN1` + + :returns: A list of Subject Alternative Names. + :rtype: list + + """ + return _get_sans_from_cert_or_req( + cert, OpenSSL.crypto.load_certificate, typ) + + +def get_sans_from_csr(csr, typ=OpenSSL.crypto.FILETYPE_PEM): + """Get a list of Subject Alternative Names from a CSR. + + :param str csr: CSR (encoded). + :param typ: `OpenSSL.crypto.FILETYPE_PEM` or `OpenSSL.crypto.FILETYPE_ASN1` + + :returns: A list of Subject Alternative Names. + :rtype: list + + """ + return _get_sans_from_cert_or_req( + csr, OpenSSL.crypto.load_certificate_request, typ) diff --git a/letsencrypt/tests/crypto_util_test.py b/letsencrypt/tests/crypto_util_test.py index 92cb4014b..a9f9da012 100644 --- a/letsencrypt/tests/crypto_util_test.py +++ b/letsencrypt/tests/crypto_util_test.py @@ -68,6 +68,27 @@ class InitSaveCSRTest(unittest.TestCase): self.assertEqual(csr.data, 'csr_der') self.assertTrue('csr-letsencrypt.pem' in csr.file) + +class MakeCSRTest(unittest.TestCase): + """Tests for letsencrypt.crypto_util.make_csr.""" + + @classmethod + def _call(cls, *args, **kwargs): + from letsencrypt.crypto_util import make_csr + return make_csr(*args, **kwargs) + + def test_san(self): + from letsencrypt.crypto_util import get_sans_from_csr + # TODO: Fails for RSA256_KEY + csr_pem, csr_der = self._call( + RSA512_KEY, ['example.com', 'www.example.com']) + self.assertEqual( + ['example.com', 'www.example.com'], get_sans_from_csr(csr_pem)) + self.assertEqual( + ['example.com', 'www.example.com'], get_sans_from_csr( + csr_der, OpenSSL.crypto.FILETYPE_ASN1)) + + class ValidCSRTest(unittest.TestCase): """Tests for letsencrypt.crypto_util.valid_csr.""" @@ -151,7 +172,26 @@ class MakeSSCertTest(unittest.TestCase): make_ss_cert(RSA512_KEY, ['example.com', 'www.example.com']) -class GetSansFromCsrTest(unittest.TestCase): +class GetSANsFromCertTest(unittest.TestCase): + """Tests for letsencrypt.crypto_util.get_sans_from_cert.""" + + @classmethod + def _call(cls, *args, **kwargs): + from letsencrypt.crypto_util import get_sans_from_cert + return get_sans_from_cert(*args, **kwargs) + + def test_single(self): + self.assertEqual([], self._call(pkg_resources.resource_string( + __name__, os.path.join('testdata', 'cert.pem')))) + + def test_san(self): + self.assertEqual( + ['example.com', 'www.example.com'], + self._call(pkg_resources.resource_string( + __name__, os.path.join('testdata', 'cert-san.pem')))) + + +class GetSANsFromCSRTest(unittest.TestCase): """Tests for letsencrypt.crypto_util.get_sans_from_csr.""" def test_extract_one_san(self): from letsencrypt.crypto_util import get_sans_from_csr