diff --git a/library/ssl_misc.h b/library/ssl_misc.h index 82a951a581..0f43a18f42 100644 --- a/library/ssl_misc.h +++ b/library/ssl_misc.h @@ -2378,6 +2378,11 @@ static inline int psa_ssl_status_to_mbedtls( psa_status_t status ) MBEDTLS_SSL_ECJPAKE_PSA_PRIMITIVE, \ step ) +typedef enum { + MBEDTLS_ECJPAKE_ROUND_ONE, + MBEDTLS_ECJPAKE_ROUND_TWO +} mbedtls_ecjpake_rounds_t; + /** * \brief Parse the provided input buffer for getting the first round * of key exchange. This code is common between server and client @@ -2385,27 +2390,15 @@ static inline int psa_ssl_status_to_mbedtls( psa_status_t status ) * \param pake_ctx [in] the PAKE's operation/context structure * \param buf [in] input buffer to parse * \param len [in] length of the input buffer + * \param round [in] either MBEDTLS_ECJPAKE_ROUND_ONE or + * MBEDTLS_ECJPAKE_ROUND_TWO * * \return 0 on success or a negative error code in case of failure */ -int mbedtls_psa_ecjpake_read_round_one( +int mbedtls_psa_ecjpake_read_round( psa_pake_operation_t *pake_ctx, const unsigned char *buf, - size_t len ); - -/** - * \brief Parse the provided input buffer for getting the second round - * of key exchange. This code is common between server and client - * - * \param pake_ctx [in] the PAKE's operation/context structure - * \param buf [in] input buffer to parse - * \param len [in] length of the input buffer - * - * \return 0 on success or a negative error code in case of failure - */ -int mbedtls_psa_ecjpake_read_round_two( - psa_pake_operation_t *pake_ctx, - const unsigned char *buf, size_t len ); + size_t len, mbedtls_ecjpake_rounds_t round ); /** * \brief Write the first round of key exchange into the provided output @@ -2415,29 +2408,16 @@ int mbedtls_psa_ecjpake_read_round_two( * \param buf [out] the output buffer in which data will be written to * \param len [in] length of the output buffer * \param olen [out] the length of the data really written on the buffer + * \param round [in] either MBEDTLS_ECJPAKE_ROUND_ONE or + * MBEDTLS_ECJPAKE_ROUND_TWO * * \return 0 on success or a negative error code in case of failure */ -int mbedtls_psa_ecjpake_write_round_one( +int mbedtls_psa_ecjpake_write_round( psa_pake_operation_t *pake_ctx, unsigned char *buf, - size_t len, size_t *olen ); - -/** - * \brief Write the second round of key exchange into the provided output - * buffer. This code is common between server and client - * - * \param pake_ctx [in] the PAKE's operation/context structure - * \param buf [out] the output buffer in which data will be written to - * \param len [in] length of the output buffer - * \param olen [out] the length of the data really written on the buffer - * - * \return 0 on success or a negative error code in case of failure - */ -int mbedtls_psa_ecjpake_write_round_two( - psa_pake_operation_t *pake_ctx, - unsigned char *buf, - size_t len, size_t *olen ); + size_t len, size_t *olen, + mbedtls_ecjpake_rounds_t round ); #endif //MBEDTLS_KEY_EXCHANGE_ECJPAKE_ENABLED && MBEDTLS_USE_PSA_CRYPTO diff --git a/library/ssl_tls.c b/library/ssl_tls.c index ae12c7ebdf..a1fa8697b0 100644 --- a/library/ssl_tls.c +++ b/library/ssl_tls.c @@ -8196,16 +8196,20 @@ end: #if defined(MBEDTLS_KEY_EXCHANGE_ECJPAKE_ENABLED) && \ defined(MBEDTLS_USE_PSA_CRYPTO) -int mbedtls_psa_ecjpake_read_round_one( +int mbedtls_psa_ecjpake_read_round( psa_pake_operation_t *pake_ctx, const unsigned char *buf, - size_t len ) + size_t len, mbedtls_ecjpake_rounds_t round ) { psa_status_t status; size_t input_offset = 0; + /* + * At round one repeat the KEY_SHARE, ZK_PUBLIC & ZF_PROOF twice + * At round two perform a single cycle + */ + unsigned int remaining_steps = ( round == MBEDTLS_ECJPAKE_ROUND_ONE) ? 2 : 1; - /* Repeat the KEY_SHARE, ZK_PUBLIC & ZF_PROOF twice */ - for( unsigned int x = 1; x <= 2; ++x ) + for( ; remaining_steps > 0; remaining_steps-- ) { for( psa_pake_step_t step = PSA_PAKE_STEP_KEY_SHARE; step <= PSA_PAKE_STEP_ZK_PROOF; @@ -8237,59 +8241,25 @@ int mbedtls_psa_ecjpake_read_round_one( return( 0 ); } -int mbedtls_psa_ecjpake_read_round_two( - psa_pake_operation_t *pake_ctx, - const unsigned char *buf, - size_t len ) -{ - psa_status_t status; - size_t input_offset = 0; - - for( psa_pake_step_t step = PSA_PAKE_STEP_KEY_SHARE ; - step <= PSA_PAKE_STEP_ZK_PROOF ; - ++step ) - { - size_t length; - - /* Length is stored at the first byte */ - length = buf[input_offset]; - input_offset += 1; - - if( input_offset + length > len ) - { - return MBEDTLS_ERR_SSL_BAD_INPUT_DATA; - } - - status = psa_pake_input( pake_ctx, step, - buf + input_offset, length ); - if( status != PSA_SUCCESS) - { - return psa_ssl_status_to_mbedtls( status ); - } - - input_offset += length; - } - - if ( input_offset != len ) - return PSA_ERROR_INVALID_ARGUMENT; - - return( 0 ); -} - -int mbedtls_psa_ecjpake_write_round_one( +int mbedtls_psa_ecjpake_write_round( psa_pake_operation_t *pake_ctx, unsigned char *buf, - size_t len, size_t *olen ) + size_t len, size_t *olen, + mbedtls_ecjpake_rounds_t round ) { psa_status_t status; size_t output_offset = 0; size_t output_len; + /* + * At round one repeat the KEY_SHARE, ZK_PUBLIC & ZF_PROOF twice + * At round two perform a single cycle + */ + unsigned int remaining_steps = ( round == MBEDTLS_ECJPAKE_ROUND_ONE) ? 2 : 1; - /* Repeat the KEY_SHARE, ZK_PUBLIC & ZF_PROOF twice */ - for( unsigned int x = 1 ; x <= 2 ; ++x ) + for( ; remaining_steps > 0; remaining_steps-- ) { - for( psa_pake_step_t step = PSA_PAKE_STEP_KEY_SHARE ; - step <= PSA_PAKE_STEP_ZK_PROOF ; + for( psa_pake_step_t step = PSA_PAKE_STEP_KEY_SHARE; + step <= PSA_PAKE_STEP_ZK_PROOF; ++step ) { /* For each step, prepend 1 byte with the length of the data */ @@ -8313,39 +8283,6 @@ int mbedtls_psa_ecjpake_write_round_one( return( 0 ); } - -int mbedtls_psa_ecjpake_write_round_two( - psa_pake_operation_t *pake_ctx, - unsigned char *buf, - size_t len, size_t *olen ) -{ - psa_status_t status; - size_t output_offset = 0; - size_t output_len; - - for( psa_pake_step_t step = PSA_PAKE_STEP_KEY_SHARE ; - step <= PSA_PAKE_STEP_ZK_PROOF ; - ++step ) - { - /* For each step, prepend 1 byte with the length of the data */ - *(buf + output_offset) = MBEDTLS_SSL_ECJPAKE_OUTPUT_SIZE( step ); - output_offset += 1; - status = psa_pake_output( pake_ctx, - step, buf + output_offset, - len - output_offset, - &output_len ); - if( status != PSA_SUCCESS ) - { - return( psa_ssl_status_to_mbedtls( status ) ); - } - - output_offset += output_len; - } - - *olen = output_offset; - - return( 0 ); -} #endif //MBEDTLS_KEY_EXCHANGE_ECJPAKE_ENABLED && MBEDTLS_USE_PSA_CRYPTO #if defined(MBEDTLS_USE_PSA_CRYPTO) diff --git a/library/ssl_tls12_client.c b/library/ssl_tls12_client.c index 6dd8ef50fe..8fcf5a4f5e 100644 --- a/library/ssl_tls12_client.c +++ b/library/ssl_tls12_client.c @@ -164,8 +164,9 @@ static int ssl_write_ecjpake_kkpp_ext( mbedtls_ssl_context *ssl, MBEDTLS_SSL_DEBUG_MSG( 3, ( "generating new ecjpake parameters" ) ); #if defined(MBEDTLS_USE_PSA_CRYPTO) - ret = mbedtls_psa_ecjpake_write_round_one(&ssl->handshake->psa_pake_ctx, - p + 2, end - p - 2, &kkpp_len ); + ret = mbedtls_psa_ecjpake_write_round(&ssl->handshake->psa_pake_ctx, + p + 2, end - p - 2, &kkpp_len, + MBEDTLS_ECJPAKE_ROUND_ONE ); if ( ret != 0 ) { psa_destroy_key( ssl->handshake->psa_pake_password ); @@ -908,8 +909,9 @@ static int ssl_parse_ecjpake_kkpp( mbedtls_ssl_context *ssl, ssl->handshake->ecjpake_cache_len = 0; #if defined(MBEDTLS_USE_PSA_CRYPTO) - if( ( ret = mbedtls_psa_ecjpake_read_round_one( - &ssl->handshake->psa_pake_ctx, buf, len ) ) != 0 ) + if( ( ret = mbedtls_psa_ecjpake_read_round( + &ssl->handshake->psa_pake_ctx, buf, len, + MBEDTLS_ECJPAKE_ROUND_ONE ) ) != 0 ) { psa_destroy_key( ssl->handshake->psa_pake_password ); psa_pake_abort( &ssl->handshake->psa_pake_ctx ); @@ -2356,8 +2358,9 @@ start_processing: p += 3; - if( ( ret = mbedtls_psa_ecjpake_read_round_two( - &ssl->handshake->psa_pake_ctx, p, end - p ) ) != 0 ) + if( ( ret = mbedtls_psa_ecjpake_read_round( + &ssl->handshake->psa_pake_ctx, p, end - p, + MBEDTLS_ECJPAKE_ROUND_TWO ) ) != 0 ) { psa_destroy_key( ssl->handshake->psa_pake_password ); psa_pake_abort( &ssl->handshake->psa_pake_ctx ); @@ -3314,8 +3317,9 @@ ecdh_calc_secret: unsigned char *out_p = ssl->out_msg + header_len; unsigned char *end_p = ssl->out_msg + MBEDTLS_SSL_OUT_CONTENT_LEN - header_len; - ret = mbedtls_psa_ecjpake_write_round_two( &ssl->handshake->psa_pake_ctx, - out_p, end_p - out_p, &content_len ); + ret = mbedtls_psa_ecjpake_write_round( &ssl->handshake->psa_pake_ctx, + out_p, end_p - out_p, &content_len, + MBEDTLS_ECJPAKE_ROUND_TWO ); if ( ret != 0 ) { psa_destroy_key( ssl->handshake->psa_pake_password ); diff --git a/library/ssl_tls12_server.c b/library/ssl_tls12_server.c index 3bc7217b79..e6dee49c14 100644 --- a/library/ssl_tls12_server.c +++ b/library/ssl_tls12_server.c @@ -305,8 +305,9 @@ static int ssl_parse_ecjpake_kkpp( mbedtls_ssl_context *ssl, } #if defined(MBEDTLS_USE_PSA_CRYPTO) - if ( ( ret = mbedtls_psa_ecjpake_read_round_one( - &ssl->handshake->psa_pake_ctx, buf, len ) ) != 0 ) + if ( ( ret = mbedtls_psa_ecjpake_read_round( + &ssl->handshake->psa_pake_ctx, buf, len, + MBEDTLS_ECJPAKE_ROUND_ONE ) ) != 0 ) { psa_destroy_key( ssl->handshake->psa_pake_password ); psa_pake_abort( &ssl->handshake->psa_pake_ctx ); @@ -2019,8 +2020,9 @@ static void ssl_write_ecjpake_kkpp_ext( mbedtls_ssl_context *ssl, p += 2; #if defined(MBEDTLS_USE_PSA_CRYPTO) - ret = mbedtls_psa_ecjpake_write_round_one( &ssl->handshake->psa_pake_ctx, - p + 2, end - p - 2, &kkpp_len ); + ret = mbedtls_psa_ecjpake_write_round( &ssl->handshake->psa_pake_ctx, + p + 2, end - p - 2, &kkpp_len, + MBEDTLS_ECJPAKE_ROUND_ONE ); if ( ret != 0 ) { psa_destroy_key( ssl->handshake->psa_pake_password ); @@ -2867,9 +2869,10 @@ static int ssl_prepare_server_key_exchange( mbedtls_ssl_context *ssl, MBEDTLS_PUT_UINT16_BE( curve_info->tls_id, out_p, 1 ); output_offset += sizeof( uint8_t ) + sizeof( uint16_t ); - ret = mbedtls_psa_ecjpake_write_round_two( &ssl->handshake->psa_pake_ctx, + ret = mbedtls_psa_ecjpake_write_round( &ssl->handshake->psa_pake_ctx, out_p + output_offset, - end_p - out_p - output_offset, &output_len ); + end_p - out_p - output_offset, &output_len, + MBEDTLS_ECJPAKE_ROUND_TWO ); if( ret != 0 ) { psa_destroy_key( ssl->handshake->psa_pake_password ); @@ -4114,8 +4117,9 @@ static int ssl_parse_client_key_exchange( mbedtls_ssl_context *ssl ) if( ciphersuite_info->key_exchange == MBEDTLS_KEY_EXCHANGE_ECJPAKE ) { #if defined(MBEDTLS_USE_PSA_CRYPTO) - if( ( ret = mbedtls_psa_ecjpake_read_round_two( - &ssl->handshake->psa_pake_ctx, p, end - p ) ) != 0 ) + if( ( ret = mbedtls_psa_ecjpake_read_round( + &ssl->handshake->psa_pake_ctx, p, end - p, + MBEDTLS_ECJPAKE_ROUND_TWO ) ) != 0 ) { psa_destroy_key( ssl->handshake->psa_pake_password ); psa_pake_abort( &ssl->handshake->psa_pake_ctx );