From 8e6794ad56066cfcbd4168ca3d2b92a1cedf2367 Mon Sep 17 00:00:00 2001 From: Pengyu Lv Date: Tue, 18 Apr 2023 17:00:47 +0800 Subject: [PATCH] cert_audit: Code refinement This commit is a collection of code refinements from review comments. Signed-off-by: Pengyu Lv --- tests/scripts/audit-validity-dates.py | 30 +++++++++++++++------------ 1 file changed, 17 insertions(+), 13 deletions(-) diff --git a/tests/scripts/audit-validity-dates.py b/tests/scripts/audit-validity-dates.py index 9ab8806d64..575da12d0c 100755 --- a/tests/scripts/audit-validity-dates.py +++ b/tests/scripts/audit-validity-dates.py @@ -86,7 +86,12 @@ class X509Parser: DataType.CSR: 'CERTIFICATE REQUEST' } - def __init__(self, backends: dict): + def __init__(self, + backends: + typing.Dict[DataType, + typing.Dict[DataFormat, + typing.Callable[[bytes], object]]]) \ + -> None: self.backends = backends self.__generate_parsers() @@ -122,7 +127,7 @@ class X509Parser: return self.parsers[item] @staticmethod - def pem_data_type(data: bytes) -> str: + def pem_data_type(data: bytes) -> typing.Optional[str]: """Get the tag from the data in PEM format :param data: data to be checked in binary mode. @@ -132,7 +137,7 @@ class X509Parser: if m is not None: return m.group('type').decode('UTF-8') else: - return "" + return None @staticmethod def check_hex_string(hex_str: str) -> bool: @@ -165,6 +170,7 @@ class Auditor: def __init__(self, verbose): self.verbose = verbose self.default_files = [] + # A list to store the parsed audit_data. self.audit_data = [] self.parser = X509Parser({ DataType.CRT: { @@ -198,12 +204,12 @@ class Auditor: """ with open(filename, 'rb') as f: data = f.read() - result_list = [] result = self.parse_bytes(data) if result is not None: result.location = filename - result_list.append(result) - return result_list + return [result] + else: + return [] def parse_bytes(self, data: bytes): """Parse AuditData from bytes.""" @@ -218,11 +224,11 @@ class Auditor: return audit_data return None - def walk_all(self, file_list): + def walk_all(self, file_list: typing.Optional[typing.List[str]] = None): """ Iterate over all the files in the list and get audit data. """ - if not file_list: + if file_list is None: file_list = self.default_files for filename in file_list: data_list = self.parse_file(filename) @@ -250,11 +256,9 @@ class TestDataAuditor(Auditor): def collect_default_files(self): """Collect all files in tests/data_files/""" test_dir = self.find_test_dir() - test_data_folder = os.path.join(test_dir, 'data_files') - data_files = [] - for (dir_path, _, file_names) in os.walk(test_data_folder): - data_files.extend(os.path.join(dir_path, file_name) - for file_name in file_names) + test_data_glob = os.path.join(test_dir, 'data_files/**') + data_files = [f for f in glob.glob(test_data_glob, recursive=True) + if os.path.isfile(f)] return data_files class FileWrapper():