diff --git a/library/psa_crypto.c b/library/psa_crypto.c index 55adfa7698..458affd2e8 100644 --- a/library/psa_crypto.c +++ b/library/psa_crypto.c @@ -73,10 +73,6 @@ #include "mbedtls/sha512.h" #include "mbedtls/xtea.h" -#if defined(MBEDTLS_TEST_HOOKS) -#include "test/memory.h" -#endif - #define ARRAY_LENGTH(array) (sizeof(array) / sizeof(*(array))) /****************************************************************/ @@ -5532,6 +5528,14 @@ exit: return status; } +/* Memory copying test hooks */ +#if defined(MBEDTLS_TEST_HOOKS) +void (*psa_input_pre_copy_hook)(const uint8_t *input, size_t input_len) = NULL; +void (*psa_input_post_copy_hook)(const uint8_t *input, size_t input_len) = NULL; +void (*psa_output_pre_copy_hook)(const uint8_t *output, size_t output_len) = NULL; +void (*psa_output_post_copy_hook)(const uint8_t *output, size_t output_len) = NULL; +#endif + /** Copy from an input buffer to a local copy. * * \param[in] input Pointer to input buffer. @@ -5553,7 +5557,9 @@ psa_status_t psa_crypto_copy_input(const uint8_t *input, size_t input_len, } #if defined(MBEDTLS_TEST_HOOKS) - MBEDTLS_TEST_MEMORY_UNPOISON(input, input_len); + if (psa_input_pre_copy_hook != NULL) { + psa_input_pre_copy_hook(input, input_len); + } #endif if (input_len > 0) { @@ -5561,7 +5567,9 @@ psa_status_t psa_crypto_copy_input(const uint8_t *input, size_t input_len, } #if defined(MBEDTLS_TEST_HOOKS) - MBEDTLS_TEST_MEMORY_POISON(input, input_len); + if (psa_input_post_copy_hook != NULL) { + psa_input_post_copy_hook(input, input_len); + } #endif return PSA_SUCCESS; @@ -5588,7 +5596,9 @@ psa_status_t psa_crypto_copy_output(const uint8_t *output_copy, size_t output_co } #if defined(MBEDTLS_TEST_HOOKS) - MBEDTLS_TEST_MEMORY_UNPOISON(output, output_len); + if (psa_output_pre_copy_hook != NULL) { + psa_output_pre_copy_hook(output, output_len); + } #endif if (output_copy_len > 0) { @@ -5596,7 +5606,9 @@ psa_status_t psa_crypto_copy_output(const uint8_t *output_copy, size_t output_co } #if defined(MBEDTLS_TEST_HOOKS) - MBEDTLS_TEST_MEMORY_POISON(output, output_len); + if (psa_output_post_copy_hook != NULL) { + psa_output_post_copy_hook(output, output_len); + } #endif return PSA_SUCCESS; diff --git a/library/psa_crypto_invasive.h b/library/psa_crypto_invasive.h index e7ab9b3133..a1281d14fd 100644 --- a/library/psa_crypto_invasive.h +++ b/library/psa_crypto_invasive.h @@ -76,6 +76,14 @@ psa_status_t psa_crypto_copy_input(const uint8_t *input, size_t input_len, psa_status_t psa_crypto_copy_output(const uint8_t *output_copy, size_t output_copy_len, uint8_t *output, size_t output_len); +/* + * Test hooks to use for memory unpoisoning/poisoning in copy functions. + */ +extern void (*psa_input_pre_copy_hook)(const uint8_t *input, size_t input_len); +extern void (*psa_input_post_copy_hook)(const uint8_t *input, size_t input_len); +extern void (*psa_output_pre_copy_hook)(const uint8_t *output, size_t output_len); +extern void (*psa_output_post_copy_hook)(const uint8_t *output, size_t output_len); + #endif /* MBEDTLS_TEST_HOOKS && MBEDTLS_PSA_CRYPTO_C */ #endif /* PSA_CRYPTO_INVASIVE_H */ diff --git a/tests/include/test/psa_memory_poisoning_wrappers.h b/tests/include/test/psa_memory_poisoning_wrappers.h index 08234b4948..e1642d2c17 100644 --- a/tests/include/test/psa_memory_poisoning_wrappers.h +++ b/tests/include/test/psa_memory_poisoning_wrappers.h @@ -2,6 +2,26 @@ #include "test/memory.h" +#include "psa_crypto_invasive.h" + +#if defined(MBEDTLS_TEST_MEMORY_CAN_POISON) + +static void setup_test_hooks(void) +{ + psa_input_pre_copy_hook = mbedtls_test_memory_unpoison; + psa_input_post_copy_hook = mbedtls_test_memory_poison; + psa_output_pre_copy_hook = mbedtls_test_memory_unpoison; + psa_output_post_copy_hook = mbedtls_test_memory_poison; +} + +static void teardown_test_hooks(void) +{ + psa_input_pre_copy_hook = NULL; + psa_input_post_copy_hook = NULL; + psa_output_pre_copy_hook = NULL; + psa_output_post_copy_hook = NULL; +} + psa_status_t wrap_psa_cipher_encrypt(mbedtls_svc_key_id_t key, psa_algorithm_t alg, const uint8_t *input, @@ -10,6 +30,7 @@ psa_status_t wrap_psa_cipher_encrypt(mbedtls_svc_key_id_t key, size_t output_size, size_t *output_length) { + setup_test_hooks(); MBEDTLS_TEST_MEMORY_POISON(input, input_length); MBEDTLS_TEST_MEMORY_POISON(output, output_size); psa_status_t status = psa_cipher_encrypt(key, @@ -21,7 +42,10 @@ psa_status_t wrap_psa_cipher_encrypt(mbedtls_svc_key_id_t key, output_length); MBEDTLS_TEST_MEMORY_UNPOISON(input, input_length); MBEDTLS_TEST_MEMORY_UNPOISON(output, output_size); + teardown_test_hooks(); return status; } #define psa_cipher_encrypt(...) wrap_psa_cipher_encrypt(__VA_ARGS__) + +#endif /* MBEDTLS_TEST_MEMORY_CAN_POISON */