diff --git a/scripts/mbedtls_dev/bignum_common.py b/scripts/mbedtls_dev/bignum_common.py index ed321d7c3e..6fd42d1e7f 100644 --- a/scripts/mbedtls_dev/bignum_common.py +++ b/scripts/mbedtls_dev/bignum_common.py @@ -273,6 +273,15 @@ class ModOperationCommon(OperationCommon): return False return True + @classmethod + def input_cases_args(cls) -> Iterator[Tuple[Any, Any, Any]]: + if cls.arity == 1: + yield from ((n, a, "0") for a, n in cls.input_cases) + elif cls.arity == 2: + yield from ((n, a, b) for a, b, n in cls.input_cases) + else: + raise ValueError("Unsupported number of operands!") + @classmethod def generate_function_tests(cls) -> Iterator[test_case.TestCase]: if cls.input_style not in cls.input_styles: @@ -284,14 +293,18 @@ class ModOperationCommon(OperationCommon): for n in cls.moduli for a, b in cls.get_value_pairs() for bil in cls.limb_sizes) + special_cases = (cls(*args, bits_in_limb=bil) + for args in cls.input_cases_args() + for bil in cls.limb_sizes) else: test_objects = (cls(n, a, b) for n in cls.moduli for a, b in cls.get_value_pairs()) + special_cases = (cls(*args) for args in cls.input_cases_args()) yield from (valid_test_object.create_test_case() for valid_test_object in filter( lambda test_object: test_object.is_valid, - test_objects + chain(test_objects, special_cases) )) # BEGIN MERGE SLOT 1