diff --git a/scripts/mbedtls_dev/psa_information.py b/scripts/mbedtls_dev/psa_information.py index c3a3d487e7..5f62606457 100644 --- a/scripts/mbedtls_dev/psa_information.py +++ b/scripts/mbedtls_dev/psa_information.py @@ -9,6 +9,7 @@ import re from typing import Dict, FrozenSet, List, Optional from . import macro_collector +from . import test_case def psa_want_symbol(name: str) -> str: @@ -105,3 +106,29 @@ class Information: self.remove_unwanted_macros(constructors) constructors.gather_arguments() return constructors + + +class TestCase(test_case.TestCase): + """A PSA test case with automatically inferred dependencies.""" + + def __init__(self) -> None: + super().__init__() + self.key_bits = None #type: Optional[int] + + def set_key_bits(self, key_bits: Optional[int]) -> None: + """Use the given key size for automatic dependency generation. + + Call this function before set_arguments() if relevant. + + This is only relevant for ECC and DH keys. For other key types, + this information is ignored. + """ + self.key_bits = key_bits + + def set_arguments(self, arguments: List[str]) -> None: + """Set test case arguments and automatically infer dependencies.""" + super().set_arguments(arguments) + dependencies = automatic_dependencies(*arguments) + if self.key_bits is not None: + dependencies = finish_family_dependencies(dependencies, self.key_bits) + self.dependencies += dependencies diff --git a/tests/scripts/generate_psa_tests.py b/tests/scripts/generate_psa_tests.py index 44a2cb54ea..ca2a6fb9f6 100755 --- a/tests/scripts/generate_psa_tests.py +++ b/tests/scripts/generate_psa_tests.py @@ -461,14 +461,9 @@ class StorageFormat: correctly. """ verb = 'save' if self.forward else 'read' - tc = test_case.TestCase() + tc = psa_information.TestCase() tc.set_description(verb + ' ' + key.description) - dependencies = psa_information.automatic_dependencies( - key.lifetime.string, key.type.string, - key.alg.string, key.alg2.string, - ) - dependencies = psa_information.finish_family_dependencies(dependencies, key.bits) - tc.set_dependencies(sorted(dependencies)) + tc.set_key_bits(key.bits) tc.set_function('key_storage_' + verb) if self.forward: extra_arguments = []