1
0
mirror of https://github.com/Mbed-TLS/mbedtls.git synced 2025-07-28 00:21:48 +03:00

Merge pull request #4214 from gilles-peskine-arm/psa-storage-format-test-types

PSA storage format test case generator
This commit is contained in:
Gilles Peskine
2021-03-22 12:16:17 +01:00
committed by GitHub
12 changed files with 2255 additions and 163 deletions

View File

@ -304,7 +304,7 @@ class CaseBuilder(macro_collector.PSAMacroCollector):
def _make_key_usage_code(self):
return '\n'.join([self._make_bit_test('usage', bit)
for bit in sorted(self.key_usages)])
for bit in sorted(self.key_usage_flags)])
def write_file(self, output_file):
"""Generate the pretty-printer function code from the gathered

View File

@ -19,14 +19,14 @@ This module is entirely based on the PSA API.
# limitations under the License.
import re
from typing import List, Optional, Tuple
from typing import Iterable, Optional, Tuple
from mbedtls_dev.asymmetric_key_data import ASYMMETRIC_KEY_DATA
class KeyType:
"""Knowledge about a PSA key type."""
def __init__(self, name: str, params: Optional[List[str]] = None):
def __init__(self, name: str, params: Optional[Iterable[str]] = None):
"""Analyze a key type.
The key type must be specified in PSA syntax. In its simplest form,

View File

@ -16,13 +16,121 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import itertools
import re
from typing import Dict, Iterable, Iterator, List, Set
class PSAMacroCollector:
class PSAMacroEnumerator:
"""Information about constructors of various PSA Crypto types.
This includes macro names as well as information about their arguments
when applicable.
This class only provides ways to enumerate expressions that evaluate to
values of the covered types. Derived classes are expected to populate
the set of known constructors of each kind, as well as populate
`self.arguments_for` for arguments that are not of a kind that is
enumerated here.
"""
def __init__(self) -> None:
"""Set up an empty set of known constructor macros.
"""
self.statuses = set() #type: Set[str]
self.algorithms = set() #type: Set[str]
self.ecc_curves = set() #type: Set[str]
self.dh_groups = set() #type: Set[str]
self.key_types = set() #type: Set[str]
self.key_usage_flags = set() #type: Set[str]
self.hash_algorithms = set() #type: Set[str]
self.mac_algorithms = set() #type: Set[str]
self.ka_algorithms = set() #type: Set[str]
self.kdf_algorithms = set() #type: Set[str]
self.aead_algorithms = set() #type: Set[str]
# macro name -> list of argument names
self.argspecs = {} #type: Dict[str, List[str]]
# argument name -> list of values
self.arguments_for = {
'mac_length': [],
'min_mac_length': [],
'tag_length': [],
'min_tag_length': [],
} #type: Dict[str, List[str]]
def gather_arguments(self) -> None:
"""Populate the list of values for macro arguments.
Call this after parsing all the inputs.
"""
self.arguments_for['hash_alg'] = sorted(self.hash_algorithms)
self.arguments_for['mac_alg'] = sorted(self.mac_algorithms)
self.arguments_for['ka_alg'] = sorted(self.ka_algorithms)
self.arguments_for['kdf_alg'] = sorted(self.kdf_algorithms)
self.arguments_for['aead_alg'] = sorted(self.aead_algorithms)
self.arguments_for['curve'] = sorted(self.ecc_curves)
self.arguments_for['group'] = sorted(self.dh_groups)
@staticmethod
def _format_arguments(name: str, arguments: Iterable[str]) -> str:
"""Format a macro call with arguments.."""
return name + '(' + ', '.join(arguments) + ')'
_argument_split_re = re.compile(r' *, *')
@classmethod
def _argument_split(cls, arguments: str) -> List[str]:
return re.split(cls._argument_split_re, arguments)
def distribute_arguments(self, name: str) -> Iterator[str]:
"""Generate macro calls with each tested argument set.
If name is a macro without arguments, just yield "name".
If name is a macro with arguments, yield a series of
"name(arg1,...,argN)" where each argument takes each possible
value at least once.
"""
try:
if name not in self.argspecs:
yield name
return
argspec = self.argspecs[name]
if argspec == []:
yield name + '()'
return
argument_lists = [self.arguments_for[arg] for arg in argspec]
arguments = [values[0] for values in argument_lists]
yield self._format_arguments(name, arguments)
# Dear Pylint, enumerate won't work here since we're modifying
# the array.
# pylint: disable=consider-using-enumerate
for i in range(len(arguments)):
for value in argument_lists[i][1:]:
arguments[i] = value
yield self._format_arguments(name, arguments)
arguments[i] = argument_lists[0][0]
except BaseException as e:
raise Exception('distribute_arguments({})'.format(name)) from e
def generate_expressions(self, names: Iterable[str]) -> Iterator[str]:
"""Generate expressions covering values constructed from the given names.
`names` can be any iterable collection of macro names.
For example:
* ``generate_expressions(['PSA_ALG_CMAC', 'PSA_ALG_HMAC'])``
generates ``'PSA_ALG_CMAC'`` as well as ``'PSA_ALG_HMAC(h)'`` for
every known hash algorithm ``h``.
* ``macros.generate_expressions(macros.key_types)`` generates all
key types.
"""
return itertools.chain(*map(self.distribute_arguments, names))
class PSAMacroCollector(PSAMacroEnumerator):
"""Collect PSA crypto macro definitions from C header files.
"""
def __init__(self, include_intermediate=False):
def __init__(self, include_intermediate: bool = False) -> None:
"""Set up an object to collect PSA macro definitions.
Call the read_file method of the constructed object on each header file.
@ -30,20 +138,13 @@ class PSAMacroCollector:
* include_intermediate: if true, include intermediate macros such as
PSA_XXX_BASE that do not designate semantic values.
"""
super().__init__()
self.include_intermediate = include_intermediate
self.statuses = set()
self.key_types = set()
self.key_types_from_curve = {}
self.key_types_from_group = {}
self.ecc_curves = set()
self.dh_groups = set()
self.algorithms = set()
self.hash_algorithms = set()
self.ka_algorithms = set()
self.algorithms_from_hash = {}
self.key_usages = set()
self.key_types_from_curve = {} #type: Dict[str, str]
self.key_types_from_group = {} #type: Dict[str, str]
self.algorithms_from_hash = {} #type: Dict[str, str]
def is_internal_name(self, name):
def is_internal_name(self, name: str) -> bool:
"""Whether this is an internal macro. Internal macros will be skipped."""
if not self.include_intermediate:
if name.endswith('_BASE') or name.endswith('_NONE'):
@ -52,6 +153,30 @@ class PSAMacroCollector:
return True
return name.endswith('_FLAG') or name.endswith('_MASK')
def record_algorithm_subtype(self, name: str, expansion: str) -> None:
"""Record the subtype of an algorithm constructor.
Given a ``PSA_ALG_xxx`` macro name and its expansion, if the algorithm
is of a subtype that is tracked in its own set, add it to the relevant
set.
"""
# This code is very ad hoc and fragile. It should be replaced by
# something more robust.
if re.match(r'MAC(?:_|\Z)', name):
self.mac_algorithms.add(name)
elif re.match(r'KDF(?:_|\Z)', name):
self.kdf_algorithms.add(name)
elif re.search(r'0x020000[0-9A-Fa-f]{2}', expansion):
self.hash_algorithms.add(name)
elif re.search(r'0x03[0-9A-Fa-f]{6}', expansion):
self.mac_algorithms.add(name)
elif re.search(r'0x05[0-9A-Fa-f]{6}', expansion):
self.aead_algorithms.add(name)
elif re.search(r'0x09[0-9A-Fa-f]{2}0000', expansion):
self.ka_algorithms.add(name)
elif re.search(r'0x08[0-9A-Fa-f]{6}', expansion):
self.kdf_algorithms.add(name)
# "#define" followed by a macro name with either no parameters
# or a single parameter and a non-empty expansion.
# Grab the macro name in group 1, the parameter name if any in group 2
@ -72,6 +197,8 @@ class PSAMacroCollector:
return
name, parameter, expansion = m.groups()
expansion = re.sub(r'/\*.*?\*/|//.*', r' ', expansion)
if parameter:
self.argspecs[name] = [parameter]
if re.match(self._deprecated_definition_re, expansion):
# Skip deprecated values, which are assumed to be
# backward compatibility aliases that share
@ -99,12 +226,7 @@ class PSAMacroCollector:
# Ad hoc skipping of duplicate names for some numerical values
return
self.algorithms.add(name)
# Ad hoc detection of hash algorithms
if re.search(r'0x020000[0-9A-Fa-f]{2}', expansion):
self.hash_algorithms.add(name)
# Ad hoc detection of key agreement algorithms
if re.search(r'0x09[0-9A-Fa-f]{2}0000', expansion):
self.ka_algorithms.add(name)
self.record_algorithm_subtype(name, expansion)
elif name.startswith('PSA_ALG_') and parameter == 'hash_alg':
if name in ['PSA_ALG_DSA', 'PSA_ALG_ECDSA']:
# A naming irregularity
@ -113,7 +235,7 @@ class PSAMacroCollector:
tester = name[:8] + 'IS_' + name[8:]
self.algorithms_from_hash[name] = tester
elif name.startswith('PSA_KEY_USAGE_') and not parameter:
self.key_usages.add(name)
self.key_usage_flags.add(name)
else:
# Other macro without parameter
return

View File

@ -0,0 +1,199 @@
"""Knowledge about the PSA key store as implemented in Mbed TLS.
"""
# Copyright The Mbed TLS Contributors
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
import struct
from typing import Dict, List, Optional, Set, Union
import unittest
from mbedtls_dev import c_build_helper
class Expr:
"""Representation of a C expression with a known or knowable numerical value."""
def __init__(self, content: Union[int, str]):
if isinstance(content, int):
digits = 8 if content > 0xffff else 4
self.string = '{0:#0{1}x}'.format(content, digits + 2)
self.value_if_known = content #type: Optional[int]
else:
self.string = content
self.unknown_values.add(self.normalize(content))
self.value_if_known = None
value_cache = {} #type: Dict[str, int]
"""Cache of known values of expressions."""
unknown_values = set() #type: Set[str]
"""Expressions whose values are not present in `value_cache` yet."""
def update_cache(self) -> None:
"""Update `value_cache` for expressions registered in `unknown_values`."""
expressions = sorted(self.unknown_values)
values = c_build_helper.get_c_expression_values(
'unsigned long', '%lu',
expressions,
header="""
#include <psa/crypto.h>
""",
include_path=['include']) #type: List[str]
for e, v in zip(expressions, values):
self.value_cache[e] = int(v, 0)
self.unknown_values.clear()
@staticmethod
def normalize(string: str) -> str:
"""Put the given C expression in a canonical form.
This function is only intended to give correct results for the
relatively simple kind of C expression typically used with this
module.
"""
return re.sub(r'\s+', r'', string)
def value(self) -> int:
"""Return the numerical value of the expression."""
if self.value_if_known is None:
if re.match(r'([0-9]+|0x[0-9a-f]+)\Z', self.string, re.I):
return int(self.string, 0)
normalized = self.normalize(self.string)
if normalized not in self.value_cache:
self.update_cache()
self.value_if_known = self.value_cache[normalized]
return self.value_if_known
Exprable = Union[str, int, Expr]
"""Something that can be converted to a C expression with a known numerical value."""
def as_expr(thing: Exprable) -> Expr:
"""Return an `Expr` object for `thing`.
If `thing` is already an `Expr` object, return it. Otherwise build a new
`Expr` object from `thing`. `thing` can be an integer or a string that
contains a C expression.
"""
if isinstance(thing, Expr):
return thing
else:
return Expr(thing)
class Key:
"""Representation of a PSA crypto key object and its storage encoding.
"""
LATEST_VERSION = 0
"""The latest version of the storage format."""
def __init__(self, *,
version: Optional[int] = None,
id: Optional[int] = None, #pylint: disable=redefined-builtin
lifetime: Exprable = 'PSA_KEY_LIFETIME_PERSISTENT',
type: Exprable, #pylint: disable=redefined-builtin
bits: int,
usage: Exprable, alg: Exprable, alg2: Exprable,
material: bytes #pylint: disable=used-before-assignment
) -> None:
self.version = self.LATEST_VERSION if version is None else version
self.id = id #pylint: disable=invalid-name #type: Optional[int]
self.lifetime = as_expr(lifetime) #type: Expr
self.type = as_expr(type) #type: Expr
self.bits = bits #type: int
self.usage = as_expr(usage) #type: Expr
self.alg = as_expr(alg) #type: Expr
self.alg2 = as_expr(alg2) #type: Expr
self.material = material #type: bytes
MAGIC = b'PSA\000KEY\000'
@staticmethod
def pack(
fmt: str,
*args: Union[int, Expr]
) -> bytes: #pylint: disable=used-before-assignment
"""Pack the given arguments into a byte string according to the given format.
This function is similar to `struct.pack`, but with the following differences:
* All integer values are encoded with standard sizes and in
little-endian representation. `fmt` must not include an endianness
prefix.
* Arguments can be `Expr` objects instead of integers.
* Only integer-valued elements are supported.
"""
return struct.pack('<' + fmt, # little-endian, standard sizes
*[arg.value() if isinstance(arg, Expr) else arg
for arg in args])
def bytes(self) -> bytes:
"""Return the representation of the key in storage as a byte array.
This is the content of the PSA storage file. When PSA storage is
implemented over stdio files, this does not include any wrapping made
by the PSA-storage-over-stdio-file implementation.
"""
header = self.MAGIC + self.pack('L', self.version)
if self.version == 0:
attributes = self.pack('LHHLLL',
self.lifetime, self.type, self.bits,
self.usage, self.alg, self.alg2)
material = self.pack('L', len(self.material)) + self.material
else:
raise NotImplementedError
return header + attributes + material
def hex(self) -> str:
"""Return the representation of the key as a hexadecimal string.
This is the hexadecimal representation of `self.bytes`.
"""
return self.bytes().hex()
class TestKey(unittest.TestCase):
# pylint: disable=line-too-long
"""A few smoke tests for the functionality of the `Key` class."""
def test_numerical(self):
key = Key(version=0,
id=1, lifetime=0x00000001,
type=0x2400, bits=128,
usage=0x00000300, alg=0x05500200, alg2=0x04c01000,
material=b'@ABCDEFGHIJKLMNO')
expected_hex = '505341004b45590000000000010000000024800000030000000250050010c00410000000404142434445464748494a4b4c4d4e4f'
self.assertEqual(key.bytes(), bytes.fromhex(expected_hex))
self.assertEqual(key.hex(), expected_hex)
def test_names(self):
length = 0xfff8 // 8 # PSA_MAX_KEY_BITS in bytes
key = Key(version=0,
id=1, lifetime='PSA_KEY_LIFETIME_PERSISTENT',
type='PSA_KEY_TYPE_RAW_DATA', bits=length*8,
usage=0, alg=0, alg2=0,
material=b'\x00' * length)
expected_hex = '505341004b45590000000000010000000110f8ff000000000000000000000000ff1f0000' + '00' * length
self.assertEqual(key.bytes(), bytes.fromhex(expected_hex))
self.assertEqual(key.hex(), expected_hex)
def test_defaults(self):
key = Key(type=0x1001, bits=8,
usage=0, alg=0, alg2=0,
material=b'\x2a')
expected_hex = '505341004b455900000000000100000001100800000000000000000000000000010000002a'
self.assertEqual(key.bytes(), bytes.fromhex(expected_hex))
self.assertEqual(key.hex(), expected_hex)