diff --git a/library/ssl_tls.c b/library/ssl_tls.c index ff56dcb3b2..44b9c85af2 100644 --- a/library/ssl_tls.c +++ b/library/ssl_tls.c @@ -56,6 +56,9 @@ #include "mbedtls/oid.h" #endif +/* Convert key bits to byte size */ +#define KEY_BYTES( bits ) ( ( (size_t) bits + 7 ) / 8 ) + #if defined(MBEDTLS_SSL_PROTO_DTLS) #if defined(MBEDTLS_SSL_DTLS_CONNECTION_ID) @@ -720,6 +723,14 @@ static int ssl_tls12_populate_transform( mbedtls_ssl_transform *transform, const mbedtls_cipher_info_t *cipher_info; const mbedtls_md_info_t *md_info; +#if defined(MBEDTLS_USE_PSA_CRYPTO) + psa_key_type_t key_type; + psa_key_attributes_t attributes = PSA_KEY_ATTRIBUTES_INIT; + psa_algorithm_t alg; + size_t key_bits; + psa_status_t status; +#endif + #if !defined(MBEDTLS_DEBUG_C) && \ !defined(MBEDTLS_SSL_DTLS_CONNECTION_ID) if( ssl->f_export_keys == NULL ) @@ -1077,6 +1088,40 @@ static int ssl_tls12_populate_transform( mbedtls_ssl_transform *transform, goto end; } +#if defined(MBEDTLS_USE_PSA_CRYPTO) + if( ( status = mbedtls_cipher_to_psa( cipher_info->type, + transform->taglen, + &alg, + &key_type, + &key_bits ) ) != PSA_SUCCESS ) + { + MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_cipher_to_psa", status ); + goto end; + } + + psa_set_key_usage_flags( &attributes, PSA_KEY_USAGE_ENCRYPT | PSA_KEY_USAGE_DECRYPT ); + psa_set_key_algorithm( &attributes, alg ); + + transform->psa_alg = alg; + + if( ( status = psa_import_key( &attributes, + key1, + KEY_BYTES( key_bits ), + &transform->psa_key_enc ) ) != PSA_SUCCESS ) + { + MBEDTLS_SSL_DEBUG_RET( 1, "psa_import_key", status ); + goto end; + } + if( ( status = psa_import_key( &attributes, + key2, + KEY_BYTES( key_bits ), + &transform->psa_key_dec ) ) != PSA_SUCCESS ) + { + MBEDTLS_SSL_DEBUG_RET( 1, "psa_import_key", status ); + goto end; + } +#endif /* MBEDTLS_USE_PSA_CRYPTO */ + #if defined(MBEDTLS_CIPHER_MODE_CBC) if( mbedtls_cipher_info_get_mode( cipher_info ) == MBEDTLS_MODE_CBC ) {