1
0
mirror of https://github.com/Mbed-TLS/mbedtls.git synced 2025-08-05 19:35:48 +03:00

cert_audit: Use dictionary to store parsed AuditData

Signed-off-by: Pengyu Lv <pengyu.lv@arm.com>
This commit is contained in:
Pengyu Lv
2023-05-05 17:29:12 +08:00
parent 31e3d12be9
commit e09d27e723

View File

@@ -209,13 +209,11 @@ class Auditor:
X.509 data(DER/PEM format) to an X.509 object. X.509 data(DER/PEM format) to an X.509 object.
- walk_all: Defaultly, it iterates over all the files in the provided - walk_all: Defaultly, it iterates over all the files in the provided
file name list, calls `parse_file` for each file and stores the results file name list, calls `parse_file` for each file and stores the results
by extending Auditor.audit_data. by extending the `results` passed to the function.
""" """
def __init__(self, logger): def __init__(self, logger):
self.logger = logger self.logger = logger
self.default_files = self.collect_default_files() self.default_files = self.collect_default_files()
# A list to store the parsed audit_data.
self.audit_data = [] # type: typing.List[AuditData]
self.parser = X509Parser({ self.parser = X509Parser({
DataType.CRT: { DataType.CRT: {
DataFormat.PEM: x509.load_pem_x509_certificate, DataFormat.PEM: x509.load_pem_x509_certificate,
@@ -257,15 +255,27 @@ class Auditor:
return audit_data return audit_data
return None return None
def walk_all(self, file_list: typing.Optional[typing.List[str]] = None): def walk_all(self,
results: typing.Dict[str, AuditData],
file_list: typing.Optional[typing.List[str]] = None) \
-> None:
""" """
Iterate over all the files in the list and get audit data. Iterate over all the files in the list and get audit data. The
results will be written to `results` passed to this function.
:param results: The dictionary used to store the parsed
AuditData. The keys of this dictionary should
be the identifier of the AuditData.
""" """
if file_list is None: if file_list is None:
file_list = self.default_files file_list = self.default_files
for filename in file_list: for filename in file_list:
data_list = self.parse_file(filename) data_list = self.parse_file(filename)
self.audit_data.extend(data_list) for d in data_list:
if d.identifier in results:
results[d.identifier].locations.extend(d.locations)
else:
results[d.identifier] = d
@staticmethod @staticmethod
def find_test_dir(): def find_test_dir():
@@ -485,11 +495,9 @@ def main():
end_date = start_date end_date = start_date
# go through all the files # go through all the files
td_auditor.walk_all(data_files) audit_results = {}
sd_auditor.walk_all(suite_data_files) td_auditor.walk_all(audit_results, data_files)
audit_results = td_auditor.audit_data + sd_auditor.audit_data sd_auditor.walk_all(audit_results, suite_data_files)
audit_results = merge_auditdata(audit_results)
logger.info("Total: {} objects found!".format(len(audit_results))) logger.info("Total: {} objects found!".format(len(audit_results)))
@@ -504,7 +512,7 @@ def main():
filter_func = None filter_func = None
# filter and output the results # filter and output the results
for d in sorted(filter(filter_func, audit_results), key=sortby_end): for d in sorted(filter(filter_func, audit_results.values()), key=sortby_end):
list_all(d) list_all(d)
logger.debug("Done!") logger.debug("Done!")