1
0
mirror of https://github.com/Mbed-TLS/mbedtls.git synced 2025-07-29 11:41:15 +03:00

Separate common test generation classes/functions

Signed-off-by: Werner Lewis <werner.lewis@arm.com>
This commit is contained in:
Werner Lewis
2022-08-24 11:30:03 +01:00
parent 383461c92f
commit fbb75e3fc5
3 changed files with 167 additions and 192 deletions

View File

@ -20,22 +20,17 @@ generate only the specified files.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import enum
import os
import posixpath
import re
import sys
from typing import Callable, Dict, FrozenSet, Iterable, Iterator, List, Optional, TypeVar
from typing import Callable, Dict, FrozenSet, Iterable, Iterator, List, Optional
import scripts_path # pylint: disable=unused-import
from mbedtls_dev import build_tree
from mbedtls_dev import crypto_knowledge
from mbedtls_dev import macro_collector
from mbedtls_dev import psa_storage
from mbedtls_dev import test_case
T = TypeVar('T') #pylint: disable=invalid-name
from mbedtls_dev import test_generation
def psa_want_symbol(name: str) -> str:
@ -897,32 +892,8 @@ class StorageFormatV0(StorageFormat):
yield from super().generate_all_keys()
yield from self.all_keys_for_implicit_usage()
class TestGenerator:
"""Generate test data."""
def __init__(self, options) -> None:
self.test_suite_directory = self.get_option(options, 'directory',
'tests/suites')
self.info = Information()
@staticmethod
def get_option(options, name: str, default: T) -> T:
value = getattr(options, name, None)
return default if value is None else value
def filename_for(self, basename: str) -> str:
"""The location of the data file with the specified base name."""
return posixpath.join(self.test_suite_directory, basename + '.data')
def write_test_data_file(self, basename: str,
test_cases: Iterable[test_case.TestCase]) -> None:
"""Write the test cases to a .data file.
The output file is ``basename + '.data'`` in the test suite directory.
"""
filename = self.filename_for(basename)
test_case.write_data_file(filename, test_cases)
class PSATestGenerator(test_generation.TestGenerator):
"""Test generator subclass including PSA targets and info."""
# Note that targets whose names contain 'test_format' have their content
# validated by `abi_check.py`.
TARGETS = {
@ -938,44 +909,12 @@ class TestGenerator:
lambda info: StorageFormatV0(info).all_test_cases(),
} #type: Dict[str, Callable[[Information], Iterable[test_case.TestCase]]]
def generate_target(self, name: str) -> None:
test_cases = self.TARGETS[name](self.info)
self.write_test_data_file(name, test_cases)
def __init__(self, options):
super().__init__(options)
self.info = Information()
def main(args):
"""Command line entry point."""
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument('--list', action='store_true',
help='List available targets and exit')
parser.add_argument('--list-for-cmake', action='store_true',
help='Print \';\'-separated list of available targets and exit')
parser.add_argument('--directory', metavar='DIR',
help='Output directory (default: tests/suites)')
parser.add_argument('targets', nargs='*', metavar='TARGET',
help='Target file to generate (default: all; "-": none)')
options = parser.parse_args(args)
build_tree.chdir_to_root()
generator = TestGenerator(options)
if options.list:
for name in sorted(generator.TARGETS):
print(generator.filename_for(name))
return
# List in a cmake list format (i.e. ';'-separated)
if options.list_for_cmake:
print(';'.join(generator.filename_for(name)
for name in sorted(generator.TARGETS)), end='')
return
if options.targets:
# Allow "-" as a special case so you can run
# ``generate_psa_tests.py - $targets`` and it works uniformly whether
# ``$targets`` is empty or not.
options.targets = [os.path.basename(re.sub(r'\.data\Z', r'', target))
for target in options.targets
if target != '-']
else:
options.targets = sorted(generator.TARGETS)
for target in options.targets:
generator.generate_target(target)
def generate_target(self, name: str, *target_args) -> None:
super().generate_target(name, self.info)
if __name__ == '__main__':
main(sys.argv[1:])
test_generation.main(sys.argv[1:], PSATestGenerator)