1
0
mirror of https://github.com/Mbed-TLS/mbedtls.git synced 2025-08-08 17:42:09 +03:00

Merge pull request #7512 from lpy4105/issue/7014/cert_audit-improvement

cert_audit: Improvements of audit script
This commit is contained in:
Gilles Peskine
2023-05-24 20:24:48 +02:00
committed by GitHub

View File

@@ -31,6 +31,7 @@ import argparse
import datetime import datetime
import glob import glob
import logging import logging
import hashlib
from enum import Enum from enum import Enum
# The script requires cryptography >= 35.0.0 which is only available # The script requires cryptography >= 35.0.0 which is only available
@@ -45,7 +46,7 @@ from mbedtls_dev import build_tree
def check_cryptography_version(): def check_cryptography_version():
match = re.match(r'^[0-9]+', cryptography.__version__) match = re.match(r'^[0-9]+', cryptography.__version__)
if match is None or int(match[0]) < 35: if match is None or int(match.group(0)) < 35:
raise Exception("audit-validity-dates requires cryptography >= 35.0.0" raise Exception("audit-validity-dates requires cryptography >= 35.0.0"
+ "({} is too old)".format(cryptography.__version__)) + "({} is too old)".format(cryptography.__version__))
@@ -65,8 +66,20 @@ class AuditData:
#pylint: disable=too-few-public-methods #pylint: disable=too-few-public-methods
def __init__(self, data_type: DataType, x509_obj): def __init__(self, data_type: DataType, x509_obj):
self.data_type = data_type self.data_type = data_type
self.location = "" # the locations that the x509 object could be found
self.locations = [] # type: typing.List[str]
self.fill_validity_duration(x509_obj) self.fill_validity_duration(x509_obj)
self._obj = x509_obj
encoding = cryptography.hazmat.primitives.serialization.Encoding.DER
self._identifier = hashlib.sha1(self._obj.public_bytes(encoding)).hexdigest()
@property
def identifier(self):
"""
Identifier of the underlying X.509 object, which is consistent across
different runs.
"""
return self._identifier
def fill_validity_duration(self, x509_obj): def fill_validity_duration(self, x509_obj):
"""Read validity period from an X.509 object.""" """Read validity period from an X.509 object."""
@@ -90,7 +103,7 @@ class AuditData:
class X509Parser: class X509Parser:
"""A parser class to parse crt/crl/csr file or data in PEM/DER format.""" """A parser class to parse crt/crl/csr file or data in PEM/DER format."""
PEM_REGEX = br'-{5}BEGIN (?P<type>.*?)-{5}\n(?P<data>.*?)-{5}END (?P=type)-{5}\n' PEM_REGEX = br'-{5}BEGIN (?P<type>.*?)-{5}(?P<data>.*?)-{5}END (?P=type)-{5}'
PEM_TAG_REGEX = br'-{5}BEGIN (?P<type>.*?)-{5}\n' PEM_TAG_REGEX = br'-{5}BEGIN (?P<type>.*?)-{5}\n'
PEM_TAGS = { PEM_TAGS = {
DataType.CRT: 'CERTIFICATE', DataType.CRT: 'CERTIFICATE',
@@ -193,13 +206,11 @@ class Auditor:
X.509 data(DER/PEM format) to an X.509 object. X.509 data(DER/PEM format) to an X.509 object.
- walk_all: Defaultly, it iterates over all the files in the provided - walk_all: Defaultly, it iterates over all the files in the provided
file name list, calls `parse_file` for each file and stores the results file name list, calls `parse_file` for each file and stores the results
by extending Auditor.audit_data. by extending the `results` passed to the function.
""" """
def __init__(self, logger): def __init__(self, logger):
self.logger = logger self.logger = logger
self.default_files = self.collect_default_files() self.default_files = self.collect_default_files()
# A list to store the parsed audit_data.
self.audit_data = [] # type: typing.List[AuditData]
self.parser = X509Parser({ self.parser = X509Parser({
DataType.CRT: { DataType.CRT: {
DataFormat.PEM: x509.load_pem_x509_certificate, DataFormat.PEM: x509.load_pem_x509_certificate,
@@ -241,15 +252,27 @@ class Auditor:
return audit_data return audit_data
return None return None
def walk_all(self, file_list: typing.Optional[typing.List[str]] = None): def walk_all(self,
results: typing.Dict[str, AuditData],
file_list: typing.Optional[typing.List[str]] = None) \
-> None:
""" """
Iterate over all the files in the list and get audit data. Iterate over all the files in the list and get audit data. The
results will be written to `results` passed to this function.
:param results: The dictionary used to store the parsed
AuditData. The keys of this dictionary should
be the identifier of the AuditData.
""" """
if file_list is None: if file_list is None:
file_list = self.default_files file_list = self.default_files
for filename in file_list: for filename in file_list:
data_list = self.parse_file(filename) data_list = self.parse_file(filename)
self.audit_data.extend(data_list) for d in data_list:
if d.identifier in results:
results[d.identifier].locations.extend(d.locations)
else:
results[d.identifier] = d
@staticmethod @staticmethod
def find_test_dir(): def find_test_dir():
@@ -277,12 +300,25 @@ class TestDataAuditor(Auditor):
""" """
with open(filename, 'rb') as f: with open(filename, 'rb') as f:
data = f.read() data = f.read()
result = self.parse_bytes(data)
if result is not None: results = []
result.location = filename # Try to parse all PEM blocks.
return [result] is_pem = False
else: for idx, m in enumerate(re.finditer(X509Parser.PEM_REGEX, data, flags=re.S), 1):
return [] is_pem = True
result = self.parse_bytes(data[m.start():m.end()])
if result is not None:
result.locations.append("{}#{}".format(filename, idx))
results.append(result)
# Might be DER format.
if not is_pem:
result = self.parse_bytes(data)
if result is not None:
result.locations.append("{}".format(filename))
results.append(result)
return results
def parse_suite_data(data_f): def parse_suite_data(data_f):
@@ -339,20 +375,22 @@ class SuiteDataAuditor(Auditor):
audit_data = self.parse_bytes(bytes.fromhex(match.group('data'))) audit_data = self.parse_bytes(bytes.fromhex(match.group('data')))
if audit_data is None: if audit_data is None:
continue continue
audit_data.location = "{}:{}:#{}".format(filename, audit_data.locations.append("{}:{}:#{}".format(filename,
data_f.line_no, data_f.line_no,
idx + 1) idx + 1))
audit_data_list.append(audit_data) audit_data_list.append(audit_data)
return audit_data_list return audit_data_list
def list_all(audit_data: AuditData): def list_all(audit_data: AuditData):
print("{}\t{}\t{}\t{}".format( for loc in audit_data.locations:
audit_data.not_valid_before.isoformat(timespec='seconds'), print("{}\t{:20}\t{:20}\t{:3}\t{}".format(
audit_data.not_valid_after.isoformat(timespec='seconds'), audit_data.identifier,
audit_data.data_type.name, audit_data.not_valid_before.isoformat(timespec='seconds'),
audit_data.location)) audit_data.not_valid_after.isoformat(timespec='seconds'),
audit_data.data_type.name,
loc))
def configure_logger(logger: logging.Logger) -> None: def configure_logger(logger: logging.Logger) -> None:
@@ -448,20 +486,24 @@ def main():
end_date = start_date end_date = start_date
# go through all the files # go through all the files
td_auditor.walk_all(data_files) audit_results = {}
sd_auditor.walk_all(suite_data_files) td_auditor.walk_all(audit_results, data_files)
audit_results = td_auditor.audit_data + sd_auditor.audit_data sd_auditor.walk_all(audit_results, suite_data_files)
logger.info("Total: {} objects found!".format(len(audit_results)))
# we filter out the files whose validity duration covers the provided # we filter out the files whose validity duration covers the provided
# duration. # duration.
filter_func = lambda d: (start_date < d.not_valid_before) or \ filter_func = lambda d: (start_date < d.not_valid_before) or \
(d.not_valid_after < end_date) (d.not_valid_after < end_date)
sortby_end = lambda d: d.not_valid_after
if args.all: if args.all:
filter_func = None filter_func = None
# filter and output the results # filter and output the results
for d in filter(filter_func, audit_results): for d in sorted(filter(filter_func, audit_results.values()), key=sortby_end):
list_all(d) list_all(d)
logger.debug("Done!") logger.debug("Done!")