diff --git a/docs/architecture/tls13-support.md b/docs/architecture/tls13-support.md index f30590bd47..85482ba9ed 100644 --- a/docs/architecture/tls13-support.md +++ b/docs/architecture/tls13-support.md @@ -478,3 +478,175 @@ outbound message on server side as well. * state change: the state change is done in the main state handler to ease the navigation of the state machine transitions. + + +Writing and reading early or 0-RTT data +--------------------------------------- + +An application function to write and send a buffer of data to a server through +TLS may plausibly look like: + +``` +int write_data( mbedtls_ssl_context *ssl, + const unsigned char *data_to_write, + size_t data_to_write_len, + size_t *data_written ) +{ + *data_written = 0; + + while( *data_written < data_to_write_len ) + { + ret = mbedtls_ssl_write( ssl, data_to_write + *data_written, + data_to_write_len - *data_written ); + + if( ret < 0 && + ret != MBEDTLS_ERR_SSL_WANT_READ && + ret != MBEDTLS_ERR_SSL_WANT_WRITE ) + { + return( ret ); + } + + *data_written += ret; + } + + return( 0 ); +} +``` +where ssl is the SSL context to use, data_to_write the address of the data +buffer and data_to_write_len the number of data bytes. The handshake may +not be completed, not even started for the SSL context ssl when the function is +called and in that case the mbedtls_ssl_write() API takes care transparently of +completing the handshake before to write and send data to the server. The +mbedtls_ssl_write() may not been able to write and send all data in one go thus +the need for a loop calling it as long as there are still data to write and +send. + +An application function to write and send early data and only early data, +data sent during the first flight of client messages while the handshake is in +its initial phase, would look completely similar but the call to +mbedtls_ssl_write_early_data() instead of mbedtls_ssl_write(). +``` +int write_early_data( mbedtls_ssl_context *ssl, + const unsigned char *data_to_write, + size_t data_to_write_len, + size_t *data_written ) +{ + *data_written = 0; + + while( *data_written < data_to_write_len ) + { + ret = mbedtls_ssl_write_early_data( ssl, data_to_write + *data_written, + data_to_write_len - *data_written ); + + if( ret < 0 && + ret != MBEDTLS_ERR_SSL_WANT_READ && + ret != MBEDTLS_ERR_SSL_WANT_WRITE ) + { + return( ret ); + } + + *data_written += ret; + } + + return( 0 ); +} +``` +Note that compared to write_data(), write_early_data() can also return +MBEDTLS_ERR_SSL_CANNOT_WRITE_EARLY_DATA and that should be handled +specifically by the user of write_early_data(). A fresh SSL context (typically +just after a call to mbedtls_ssl_setup() or mbedtls_ssl_session_reset()) would +be expected when calling `write_early_data`. + +All together, code to write and send a buffer of data as long as possible as +early data and then as standard post-handshake application data could +plausibly look like: + +``` +ret = write_early_data( ssl, data_to_write, data_to_write_len, + &early_data_written ); +if( ret < 0 && + ret != MBEDTLS_ERR_SSL_CANNOT_WRITE_EARLY_DATA ) +{ + goto error; +} + +ret = write_data( ssl, data_to_write + early_data_written, + data_to_write_len - early_data_written, &data_written ); +if( ret < 0 ) + goto error; + +data_written += early_data_written; +``` + +Finally, taking into account that the server may reject early data, application +code to write and send a buffer of data could plausibly look like: +``` +ret = write_early_data( ssl, data_to_write, data_to_write_len, + &early_data_written ); +if( ret < 0 && + ret != MBEDTLS_ERR_SSL_CANNOT_WRITE_EARLY_DATA ) +{ + goto error; +} + +/* + * Make sure the handshake is completed as it is a requisite to + * mbedtls_ssl_get_early_data_status(). + */ +while( !mbedtls_ssl_is_handshake_over( ssl ) ) +{ + ret = mbedtls_ssl_handshake( ssl ); + if( ret < 0 && + ret != MBEDTLS_ERR_SSL_WANT_READ && + ret != MBEDTLS_ERR_SSL_WANT_WRITE ) + { + goto error; + } +} + +ret = mbedtls_ssl_get_early_data_status( ssl ); +if( ret < 0 ) + goto error; + +if( ret == MBEDTLS_SSL_EARLY_DATA_STATUS_REJECTED ) + early_data_written = 0; + +ret = write_data( ssl, data_to_write + early_data_written, + data_to_write_len - early_data_written, &data_written ); +if( ret < 0 ) + goto error; + +data_written += early_data_written; +``` + +Basically, the same holds for reading early data on the server side without the +complication of possible rejection. An application function to read early data +into a given buffer could plausibly look like: +``` +int read_early_data( mbedtls_ssl_context *ssl, + unsigned char *buffer, + size_t buffer_size, + size_t *data_len ) +{ + *data_len = 0; + + while( *data_len < buffer_size ) + { + ret = mbedtls_ssl_read_early_data( ssl, buffer + *data_len, + buffer_size - *data_len ); + + if( ret < 0 && + ret != MBEDTLS_ERR_SSL_WANT_READ && + ret != MBEDTLS_ERR_SSL_WANT_WRITE ) + { + return( ret ); + } + + *data_len += ret; + } + + return( 0 ); +} +``` +with again calls to read_early_data() expected to be done with a fresh SSL +context. diff --git a/docs/use-psa-crypto.md b/docs/use-psa-crypto.md index b22d37f65f..11442ed66d 100644 --- a/docs/use-psa-crypto.md +++ b/docs/use-psa-crypto.md @@ -86,7 +86,6 @@ is enabled, no change required on the application side. Current exceptions: -- EC J-PAKE (when `MBEDTLS_KEY_EXCHANGE_ECJPAKE_ENABLED` is defined) - finite-field (non-EC) Diffie-Hellman (used in key exchanges: DHE-RSA, DHE-PSK) diff --git a/include/mbedtls/ecjpake.h b/include/mbedtls/ecjpake.h index e7ca1b2354..3dd3361a1b 100644 --- a/include/mbedtls/ecjpake.h +++ b/include/mbedtls/ecjpake.h @@ -113,7 +113,7 @@ void mbedtls_ecjpake_init( mbedtls_ecjpake_context *ctx ); * \param curve The identifier of the elliptic curve to use, * for example #MBEDTLS_ECP_DP_SECP256R1. * \param secret The pre-shared secret (passphrase). This must be - * a readable buffer of length \p len Bytes. It need + * a readable not empty buffer of length \p len Bytes. It need * only be valid for the duration of this call. * \param len The length of the pre-shared secret \p secret. * diff --git a/include/mbedtls/mbedtls_config.h b/include/mbedtls/mbedtls_config.h index 3137395206..bd3872f6a7 100644 --- a/include/mbedtls/mbedtls_config.h +++ b/include/mbedtls/mbedtls_config.h @@ -1648,7 +1648,7 @@ * production. * */ -#define MBEDTLS_SSL_EARLY_DATA +//#define MBEDTLS_SSL_EARLY_DATA /** * \def MBEDTLS_SSL_PROTO_DTLS diff --git a/include/mbedtls/ssl.h b/include/mbedtls/ssl.h index d0558511a8..ea58661088 100644 --- a/include/mbedtls/ssl.h +++ b/include/mbedtls/ssl.h @@ -96,12 +96,16 @@ /* Error space gap */ /** Processing of the Certificate handshake message failed. */ #define MBEDTLS_ERR_SSL_BAD_CERTIFICATE -0x7A00 -/** Received NewSessionTicket Post Handshake Message */ +/* Error space gap */ +/** + * Received NewSessionTicket Post Handshake Message. + * This error code is experimental and may be changed or removed without notice. + */ #define MBEDTLS_ERR_SSL_RECEIVED_NEW_SESSION_TICKET -0x7B00 -/* Error space gap */ -/* Error space gap */ -/* Error space gap */ -/* Error space gap */ +/** Not possible to read early data */ +#define MBEDTLS_ERR_SSL_CANNOT_READ_EARLY_DATA -0x7B80 +/** Not possible to write early data */ +#define MBEDTLS_ERR_SSL_CANNOT_WRITE_EARLY_DATA -0x7C00 /* Error space gap */ /* Error space gap */ /* Error space gap */ @@ -661,7 +665,7 @@ typedef enum MBEDTLS_SSL_SERVER_FINISHED, MBEDTLS_SSL_FLUSH_BUFFERS, MBEDTLS_SSL_HANDSHAKE_WRAPUP, - MBEDTLS_SSL_HANDSHAKE_OVER, + MBEDTLS_SSL_NEW_SESSION_TICKET, MBEDTLS_SSL_SERVER_HELLO_VERIFY_REQUEST_SENT, MBEDTLS_SSL_HELLO_RETRY_REQUEST, @@ -671,7 +675,9 @@ typedef enum MBEDTLS_SSL_CLIENT_CCS_BEFORE_2ND_CLIENT_HELLO, MBEDTLS_SSL_SERVER_CCS_AFTER_SERVER_HELLO, MBEDTLS_SSL_SERVER_CCS_AFTER_HELLO_RETRY_REQUEST, - MBEDTLS_SSL_NEW_SESSION_TICKET_FLUSH, + MBEDTLS_SSL_HANDSHAKE_OVER, + MBEDTLS_SSL_TLS1_3_NEW_SESSION_TICKET, + MBEDTLS_SSL_TLS1_3_NEW_SESSION_TICKET_FLUSH, } mbedtls_ssl_states; @@ -801,14 +807,6 @@ typedef struct mbedtls_ssl_key_cert mbedtls_ssl_key_cert; typedef struct mbedtls_ssl_flight_item mbedtls_ssl_flight_item; #endif -#if defined(MBEDTLS_SSL_EARLY_DATA) && defined(MBEDTLS_SSL_CLI_C) -#define MBEDTLS_SSL_EARLY_DATA_STATUS_UNKNOWN 0 -#define MBEDTLS_SSL_EARLY_DATA_STATUS_NOT_SENT 1 -#define MBEDTLS_SSL_EARLY_DATA_STATUS_INDICATION_SENT 2 -#define MBEDTLS_SSL_EARLY_DATA_STATUS_REJECTED 3 -#define MBEDTLS_SSL_EARLY_DATA_STATUS_ACCEPTED 4 -#endif - #if defined(MBEDTLS_SSL_PROTO_TLS1_3) && defined(MBEDTLS_SSL_SESSION_TICKETS) typedef uint8_t mbedtls_ssl_tls13_ticket_flags; @@ -3851,9 +3849,10 @@ void mbedtls_ssl_conf_sni( mbedtls_ssl_config *conf, * \note The SSL context needs to be already set up. The right place * to call this function is between \c mbedtls_ssl_setup() or * \c mbedtls_ssl_reset() and \c mbedtls_ssl_handshake(). + * Password cannot be empty (see RFC 8236). * * \param ssl SSL context - * \param pw EC J-PAKE password (pre-shared secret) + * \param pw EC J-PAKE password (pre-shared secret). It cannot be empty * \param pw_len length of pw in bytes * * \return 0 on success, or a negative error code. @@ -4651,7 +4650,7 @@ int mbedtls_ssl_handshake( mbedtls_ssl_context *ssl ); */ static inline int mbedtls_ssl_is_handshake_over( mbedtls_ssl_context *ssl ) { - return( ssl->MBEDTLS_PRIVATE( state ) == MBEDTLS_SSL_HANDSHAKE_OVER ); + return( ssl->MBEDTLS_PRIVATE( state ) >= MBEDTLS_SSL_HANDSHAKE_OVER ); } /** @@ -4891,6 +4890,151 @@ int mbedtls_ssl_send_alert_message( mbedtls_ssl_context *ssl, */ int mbedtls_ssl_close_notify( mbedtls_ssl_context *ssl ); +#if defined(MBEDTLS_SSL_EARLY_DATA) + +#if defined(MBEDTLS_SSL_SRV_C) +/** + * \brief Read at most 'len' application data bytes while performing + * the handshake (early data). + * + * \note This function behaves mainly as mbedtls_ssl_read(). The + * specification of mbedtls_ssl_read() relevant to TLS 1.3 + * (thus not the parts specific to (D)TLS 1.2) applies to this + * function and the present documentation is restricted to the + * differences with mbedtls_ssl_read(). + * + * \param ssl SSL context + * \param buf buffer that will hold the data + * \param len maximum number of bytes to read + * + * \return One additional specific return value: + * #MBEDTLS_ERR_SSL_CANNOT_READ_EARLY_DATA. + * + * #MBEDTLS_ERR_SSL_CANNOT_READ_EARLY_DATA is returned when it + * is not possible to read early data for the SSL context + * \p ssl. + * + * It may have been possible and it is not possible + * anymore because the server received the End of Early Data + * message or the maximum number of allowed early data for the + * PSK in use has been reached. + * + * It may never have been possible and will never be possible + * for the SSL context \p ssl because the use of early data + * is disabled for that context or more generally the context + * is not suitably configured to enable early data or the + * client does not use early data or the first call to the + * function was done while the handshake was already too + * advanced to gather and accept early data. + * + * It is not possible to read early data for the SSL context + * \p ssl but this does not preclude for using it with + * mbedtls_ssl_write(), mbedtls_ssl_read() or + * mbedtls_ssl_handshake(). + * + * \note When a server wants to retrieve early data, it is expected + * that this function starts the handshake for the SSL context + * \p ssl. But this is not mandatory. + * + */ +int mbedtls_ssl_read_early_data( mbedtls_ssl_context *ssl, + unsigned char *buf, size_t len ); +#endif /* MBEDTLS_SSL_SRV_C */ + +#if defined(MBEDTLS_SSL_CLI_C) +/** + * \brief Try to write exactly 'len' application data bytes while + * performing the handshake (early data). + * + * \note This function behaves mainly as mbedtls_ssl_write(). The + * specification of mbedtls_ssl_write() relevant to TLS 1.3 + * (thus not the parts specific to (D)TLS1.2) applies to this + * function and the present documentation is restricted to the + * differences with mbedtls_ssl_write(). + * + * \param ssl SSL context + * \param buf buffer holding the data + * \param len how many bytes must be written + * + * \return One additional specific return value: + * #MBEDTLS_ERR_SSL_CANNOT_WRITE_EARLY_DATA. + * + * #MBEDTLS_ERR_SSL_CANNOT_WRITE_EARLY_DATA is returned when it + * is not possible to write early data for the SSL context + * \p ssl. + * + * It may have been possible and it is not possible + * anymore because the client received the server Finished + * message, the server rejected early data or the maximum + * number of allowed early data for the PSK in use has been + * reached. + * + * It may never have been possible and will never be possible + * for the SSL context \p ssl because the use of early data + * is disabled for that context or more generally the context + * is not suitably configured to enable early data or the first + * call to the function was done while the handshake was + * already completed. + * + * It is not possible to write early data for the SSL context + * \p ssl but this does not preclude for using it with + * mbedtls_ssl_write(), mbedtls_ssl_read() or + * mbedtls_ssl_handshake(). + * + * \note This function may write early data only if the SSL context + * has been configured for the handshake with a PSK for which + * early data is allowed. + * + * \note To maximize the number of early data that can be written in + * the course of the handshake, it is expected that this + * function starts the handshake for the SSL context \p ssl. + * But this is not mandatory. + * + * \note This function does not provide any information on whether + * the server has accepted or will accept early data or not. + * When it returns a positive value, it just means that it + * has written early data to the server. To know whether the + * server has accepted early data or not, you should call + * mbedtls_ssl_get_early_data_status() with the handshake + * completed. + */ +int mbedtls_ssl_write_early_data( mbedtls_ssl_context *ssl, + const unsigned char *buf, size_t len ); + +#define MBEDTLS_SSL_EARLY_DATA_STATUS_NOT_SENT 0 +#define MBEDTLS_SSL_EARLY_DATA_STATUS_ACCEPTED 1 +#define MBEDTLS_SSL_EARLY_DATA_STATUS_REJECTED 2 +/** + * \brief Get the status of the negotiation of the use of early data. + * + * \param ssl The SSL context to query + * + * \return #MBEDTLS_ERR_SSL_BAD_INPUT_DATA if this function is called + * from the server-side. + * + * \return #MBEDTLS_ERR_SSL_BAD_INPUT_DATA if this function is called + * prior to completion of the handshake. + * + * \return #MBEDTLS_SSL_EARLY_DATA_STATUS_NOT_SENT if the client has + * not indicated the use of early data to the server. + * + * \return #MBEDTLS_SSL_EARLY_DATA_STATUS_ACCEPTED if the client has + * indicated the use of early data and the server has accepted + * it. + * + * \return #MBEDTLS_SSL_EARLY_DATA_STATUS_REJECTED if the client has + * indicated the use of early data but the server has rejected + * it. In this situation, the client may want to re-send the + * early data it may have tried to send by calling + * mbedtls_ssl_write_early_data() as ordinary post-handshake + * application data by calling mbedtls_ssl_write(). + * + */ +int mbedtls_ssl_get_early_data_status( mbedtls_ssl_context *ssl ); +#endif /* MBEDTLS_SSL_CLI_C */ + +#endif /* MBEDTLS_SSL_EARLY_DATA */ + /** * \brief Free referenced items in an SSL context and clear memory * diff --git a/library/bignum.c b/library/bignum.c index ba03988254..a68957a534 100644 --- a/library/bignum.c +++ b/library/bignum.c @@ -968,17 +968,15 @@ int mbedtls_mpi_sub_abs( mbedtls_mpi *X, const mbedtls_mpi *A, const mbedtls_mpi carry = mbedtls_mpi_core_sub( X->p, A->p, B->p, n ); if( carry != 0 ) { - /* Propagate the carry to the first nonzero limb of X. */ - for( ; n < X->n && X->p[n] == 0; n++ ) - --X->p[n]; - /* If we ran out of space for the carry, it means that the result - * is negative. */ - if( n == X->n ) + /* Propagate the carry through the rest of X. */ + carry = mbedtls_mpi_core_sub_int( X->p + n, X->p + n, carry, X->n - n ); + + /* If we have further carry/borrow, the result is negative. */ + if( carry != 0 ) { ret = MBEDTLS_ERR_MPI_NEGATIVE_VALUE; goto cleanup; } - --X->p[n]; } /* X should always be positive as a result of unsigned subtractions. */ diff --git a/library/bignum_core.c b/library/bignum_core.c index 34aecda501..41d3239688 100644 --- a/library/bignum_core.c +++ b/library/bignum_core.c @@ -590,6 +590,22 @@ cleanup: /* BEGIN MERGE SLOT 3 */ +mbedtls_mpi_uint mbedtls_mpi_core_sub_int( mbedtls_mpi_uint *X, + const mbedtls_mpi_uint *A, + mbedtls_mpi_uint c, /* doubles as carry */ + size_t limbs ) +{ + for( size_t i = 0; i < limbs; i++ ) + { + mbedtls_mpi_uint s = A[i]; + mbedtls_mpi_uint t = s - c; + c = ( t > s ); + X[i] = t; + } + + return( c ); +} + /* END MERGE SLOT 3 */ /* BEGIN MERGE SLOT 4 */ diff --git a/library/bignum_core.h b/library/bignum_core.h index ad04e08283..d48e7053bb 100644 --- a/library/bignum_core.h +++ b/library/bignum_core.h @@ -504,6 +504,23 @@ int mbedtls_mpi_core_fill_random( mbedtls_mpi_uint *X, size_t X_limbs, /* BEGIN MERGE SLOT 3 */ +/** + * \brief Subtract unsigned integer from known-size large unsigned integers. + * Return the borrow. + * + * \param[out] X The result of the subtraction. + * \param[in] A The left operand. + * \param b The unsigned scalar to subtract. + * \param limbs Number of limbs of \p X and \p A. + * + * \return 1 if `A < b`. + * 0 if `A >= b`. + */ +mbedtls_mpi_uint mbedtls_mpi_core_sub_int( mbedtls_mpi_uint *X, + const mbedtls_mpi_uint *A, + mbedtls_mpi_uint b, + size_t limbs ); + /* END MERGE SLOT 3 */ /* BEGIN MERGE SLOT 4 */ diff --git a/library/ssl_misc.h b/library/ssl_misc.h index 4d7f63547d..9998e5b910 100644 --- a/library/ssl_misc.h +++ b/library/ssl_misc.h @@ -50,7 +50,8 @@ #include "mbedtls/sha512.h" #endif -#if defined(MBEDTLS_KEY_EXCHANGE_ECJPAKE_ENABLED) +#if defined(MBEDTLS_KEY_EXCHANGE_ECJPAKE_ENABLED) && \ + !defined(MBEDTLS_USE_PSA_CRYPTO) #include "mbedtls/ecjpake.h" #endif @@ -776,7 +777,13 @@ struct mbedtls_ssl_handshake_params #endif /* MBEDTLS_ECDH_C || MBEDTLS_ECDSA_C */ #if defined(MBEDTLS_KEY_EXCHANGE_ECJPAKE_ENABLED) +#if defined(MBEDTLS_USE_PSA_CRYPTO) + psa_pake_operation_t psa_pake_ctx; /*!< EC J-PAKE key exchange */ + mbedtls_svc_key_id_t psa_pake_password; + uint8_t psa_pake_ctx_is_ok; +#else mbedtls_ecjpake_context ecjpake_ctx; /*!< EC J-PAKE key exchange */ +#endif /* MBEDTLS_USE_PSA_CRYPTO */ #if defined(MBEDTLS_SSL_CLI_C) unsigned char *ecjpake_cache; /*!< Cache for ClientHello ext */ size_t ecjpake_cache_len; /*!< Length of cached data */ @@ -2493,6 +2500,52 @@ static inline int psa_ssl_status_to_mbedtls( psa_status_t status ) } #endif /* MBEDTLS_USE_PSA_CRYPTO || MBEDTLS_SSL_PROTO_TLS1_3 */ +#if defined(MBEDTLS_KEY_EXCHANGE_ECJPAKE_ENABLED) && \ + defined(MBEDTLS_USE_PSA_CRYPTO) + +typedef enum { + MBEDTLS_ECJPAKE_ROUND_ONE, + MBEDTLS_ECJPAKE_ROUND_TWO +} mbedtls_ecjpake_rounds_t; + +/** + * \brief Parse the provided input buffer for getting the first round + * of key exchange. This code is common between server and client + * + * \param pake_ctx [in] the PAKE's operation/context structure + * \param buf [in] input buffer to parse + * \param len [in] length of the input buffer + * \param round [in] either MBEDTLS_ECJPAKE_ROUND_ONE or + * MBEDTLS_ECJPAKE_ROUND_TWO + * + * \return 0 on success or a negative error code in case of failure + */ +int mbedtls_psa_ecjpake_read_round( + psa_pake_operation_t *pake_ctx, + const unsigned char *buf, + size_t len, mbedtls_ecjpake_rounds_t round ); + +/** + * \brief Write the first round of key exchange into the provided output + * buffer. This code is common between server and client + * + * \param pake_ctx [in] the PAKE's operation/context structure + * \param buf [out] the output buffer in which data will be written to + * \param len [in] length of the output buffer + * \param olen [out] the length of the data really written on the buffer + * \param round [in] either MBEDTLS_ECJPAKE_ROUND_ONE or + * MBEDTLS_ECJPAKE_ROUND_TWO + * + * \return 0 on success or a negative error code in case of failure + */ +int mbedtls_psa_ecjpake_write_round( + psa_pake_operation_t *pake_ctx, + unsigned char *buf, + size_t len, size_t *olen, + mbedtls_ecjpake_rounds_t round ); + +#endif //MBEDTLS_KEY_EXCHANGE_ECJPAKE_ENABLED && MBEDTLS_USE_PSA_CRYPTO + /** * \brief TLS record protection modes */ diff --git a/library/ssl_msg.c b/library/ssl_msg.c index dbc6391885..80471d4c5d 100644 --- a/library/ssl_msg.c +++ b/library/ssl_msg.c @@ -1907,7 +1907,7 @@ int mbedtls_ssl_fetch_input( mbedtls_ssl_context *ssl, size_t nb_want ) MBEDTLS_SSL_DEBUG_MSG( 2, ( "timeout" ) ); mbedtls_ssl_set_timer( ssl, 0 ); - if( mbedtls_ssl_is_handshake_over( ssl ) == 0 ) + if( ssl->state != MBEDTLS_SSL_HANDSHAKE_OVER ) { if( ssl_double_retransmit_timeout( ssl ) != 0 ) { @@ -5299,7 +5299,7 @@ static int ssl_tls13_check_new_session_ticket( mbedtls_ssl_context *ssl ) MBEDTLS_SSL_DEBUG_MSG( 3, ( "NewSessionTicket received" ) ); mbedtls_ssl_handshake_set_state( ssl, - MBEDTLS_SSL_NEW_SESSION_TICKET ); + MBEDTLS_SSL_TLS1_3_NEW_SESSION_TICKET ); return( MBEDTLS_ERR_SSL_WANT_READ ); } @@ -5502,7 +5502,7 @@ int mbedtls_ssl_read( mbedtls_ssl_context *ssl, unsigned char *buf, size_t len ) } #endif - if( mbedtls_ssl_is_handshake_over( ssl ) == 0 ) + if( ssl->state != MBEDTLS_SSL_HANDSHAKE_OVER ) { ret = mbedtls_ssl_handshake( ssl ); if( ret != MBEDTLS_ERR_SSL_WAITING_SERVER_HELLO_RENEGO && @@ -5758,7 +5758,7 @@ int mbedtls_ssl_write( mbedtls_ssl_context *ssl, const unsigned char *buf, size_ } #endif - if( mbedtls_ssl_is_handshake_over( ssl ) == 0 ) + if( ssl->state != MBEDTLS_SSL_HANDSHAKE_OVER ) { if( ( ret = mbedtls_ssl_handshake( ssl ) ) != 0 ) { diff --git a/library/ssl_tls.c b/library/ssl_tls.c index da90b2350f..3d3491bc6c 100644 --- a/library/ssl_tls.c +++ b/library/ssl_tls.c @@ -907,7 +907,12 @@ static void ssl_handshake_params_init( mbedtls_ssl_handshake_params *handshake ) mbedtls_ecdh_init( &handshake->ecdh_ctx ); #endif #if defined(MBEDTLS_KEY_EXCHANGE_ECJPAKE_ENABLED) +#if defined(MBEDTLS_USE_PSA_CRYPTO) + handshake->psa_pake_ctx = psa_pake_operation_init(); + handshake->psa_pake_password = MBEDTLS_SVC_KEY_ID_INIT; +#else mbedtls_ecjpake_init( &handshake->ecjpake_ctx ); +#endif /* MBEDTLS_USE_PSA_CRYPTO */ #if defined(MBEDTLS_SSL_CLI_C) handshake->ecjpake_cache = NULL; handshake->ecjpake_cache_len = 0; @@ -1850,6 +1855,73 @@ void mbedtls_ssl_set_verify( mbedtls_ssl_context *ssl, /* * Set EC J-PAKE password for current handshake */ +#if defined(MBEDTLS_USE_PSA_CRYPTO) +int mbedtls_ssl_set_hs_ecjpake_password( mbedtls_ssl_context *ssl, + const unsigned char *pw, + size_t pw_len ) +{ + psa_pake_cipher_suite_t cipher_suite = psa_pake_cipher_suite_init(); + psa_key_attributes_t attributes = PSA_KEY_ATTRIBUTES_INIT; + psa_pake_role_t psa_role; + psa_status_t status; + + if( ssl->handshake == NULL || ssl->conf == NULL ) + return( MBEDTLS_ERR_SSL_BAD_INPUT_DATA ); + + if( ssl->conf->endpoint == MBEDTLS_SSL_IS_SERVER ) + psa_role = PSA_PAKE_ROLE_SERVER; + else + psa_role = PSA_PAKE_ROLE_CLIENT; + + /* Empty password is not valid */ + if( ( pw == NULL) || ( pw_len == 0 ) ) + return( MBEDTLS_ERR_SSL_BAD_INPUT_DATA ); + + psa_set_key_usage_flags( &attributes, PSA_KEY_USAGE_DERIVE ); + psa_set_key_algorithm( &attributes, PSA_ALG_JPAKE ); + psa_set_key_type( &attributes, PSA_KEY_TYPE_PASSWORD ); + + status = psa_import_key( &attributes, pw, pw_len, + &ssl->handshake->psa_pake_password ); + if( status != PSA_SUCCESS ) + return( MBEDTLS_ERR_SSL_HW_ACCEL_FAILED ); + + psa_pake_cs_set_algorithm( &cipher_suite, PSA_ALG_JPAKE ); + psa_pake_cs_set_primitive( &cipher_suite, + PSA_PAKE_PRIMITIVE( PSA_PAKE_PRIMITIVE_TYPE_ECC, + PSA_ECC_FAMILY_SECP_R1, + 256) ); + psa_pake_cs_set_hash( &cipher_suite, PSA_ALG_SHA_256 ); + + status = psa_pake_setup( &ssl->handshake->psa_pake_ctx, &cipher_suite ); + if( status != PSA_SUCCESS ) + { + psa_destroy_key( ssl->handshake->psa_pake_password ); + return( MBEDTLS_ERR_SSL_HW_ACCEL_FAILED ); + } + + status = psa_pake_set_role( &ssl->handshake->psa_pake_ctx, psa_role ); + if( status != PSA_SUCCESS ) + { + psa_destroy_key( ssl->handshake->psa_pake_password ); + psa_pake_abort( &ssl->handshake->psa_pake_ctx ); + return( MBEDTLS_ERR_SSL_HW_ACCEL_FAILED ); + } + + psa_pake_set_password_key( &ssl->handshake->psa_pake_ctx, + ssl->handshake->psa_pake_password ); + if( status != PSA_SUCCESS ) + { + psa_destroy_key( ssl->handshake->psa_pake_password ); + psa_pake_abort( &ssl->handshake->psa_pake_ctx ); + return( MBEDTLS_ERR_SSL_HW_ACCEL_FAILED ); + } + + ssl->handshake->psa_pake_ctx_is_ok = 1; + + return( 0 ); +} +#else /* MBEDTLS_USE_PSA_CRYPTO */ int mbedtls_ssl_set_hs_ecjpake_password( mbedtls_ssl_context *ssl, const unsigned char *pw, size_t pw_len ) @@ -1870,6 +1942,7 @@ int mbedtls_ssl_set_hs_ecjpake_password( mbedtls_ssl_context *ssl, MBEDTLS_ECP_DP_SECP256R1, pw, pw_len ) ); } +#endif /* MBEDTLS_USE_PSA_CRYPTO */ #endif /* MBEDTLS_KEY_EXCHANGE_ECJPAKE_ENABLED */ #if defined(MBEDTLS_SSL_HANDSHAKE_WITH_PSK_ENABLED) @@ -3602,7 +3675,7 @@ int mbedtls_ssl_handshake_step( mbedtls_ssl_context *ssl ) if( ssl == NULL || ssl->conf == NULL || ssl->handshake == NULL || - mbedtls_ssl_is_handshake_over( ssl ) == 1 ) + ssl->state == MBEDTLS_SSL_HANDSHAKE_OVER ) { return( MBEDTLS_ERR_SSL_BAD_INPUT_DATA ); } @@ -3706,7 +3779,7 @@ int mbedtls_ssl_handshake( mbedtls_ssl_context *ssl ) MBEDTLS_SSL_DEBUG_MSG( 2, ( "=> handshake" ) ); /* Main handshake loop */ - while( mbedtls_ssl_is_handshake_over( ssl ) == 0 ) + while( ssl->state != MBEDTLS_SSL_HANDSHAKE_OVER ) { ret = mbedtls_ssl_handshake_step( ssl ); @@ -3908,8 +3981,15 @@ void mbedtls_ssl_handshake_free( mbedtls_ssl_context *ssl ) #if !defined(MBEDTLS_USE_PSA_CRYPTO) && defined(MBEDTLS_ECDH_C) mbedtls_ecdh_free( &handshake->ecdh_ctx ); #endif + #if defined(MBEDTLS_KEY_EXCHANGE_ECJPAKE_ENABLED) +#if defined(MBEDTLS_USE_PSA_CRYPTO) + psa_pake_abort( &handshake->psa_pake_ctx ); + psa_destroy_key( handshake->psa_pake_password ); + handshake->psa_pake_password = MBEDTLS_SVC_KEY_ID_INIT; +#else mbedtls_ecjpake_free( &handshake->ecjpake_ctx ); +#endif /* MBEDTLS_USE_PSA_CRYPTO */ #if defined(MBEDTLS_SSL_CLI_C) mbedtls_free( handshake->ecjpake_cache ); handshake->ecjpake_cache = NULL; @@ -6123,6 +6203,55 @@ static int ssl_compute_master( mbedtls_ssl_handshake_params *handshake, else #endif { +#if defined(MBEDTLS_USE_PSA_CRYPTO) && \ + defined(MBEDTLS_KEY_EXCHANGE_ECJPAKE_ENABLED) + if( handshake->ciphersuite_info->key_exchange == MBEDTLS_KEY_EXCHANGE_ECJPAKE ) + { + psa_status_t status; + psa_algorithm_t alg = PSA_ALG_TLS12_ECJPAKE_TO_PMS; + psa_key_derivation_operation_t derivation = + PSA_KEY_DERIVATION_OPERATION_INIT; + + MBEDTLS_SSL_DEBUG_MSG( 2, ( "perform PSA-based PMS KDF for ECJPAKE" ) ); + + handshake->pmslen = PSA_TLS12_ECJPAKE_TO_PMS_DATA_SIZE; + + status = psa_key_derivation_setup( &derivation, alg ); + if( status != PSA_SUCCESS ) + return( MBEDTLS_ERR_SSL_HW_ACCEL_FAILED ); + + status = psa_key_derivation_set_capacity( &derivation, + PSA_TLS12_ECJPAKE_TO_PMS_DATA_SIZE ); + if( status != PSA_SUCCESS ) + { + psa_key_derivation_abort( &derivation ); + return( MBEDTLS_ERR_SSL_HW_ACCEL_FAILED ); + } + + status = psa_pake_get_implicit_key( &handshake->psa_pake_ctx, + &derivation ); + if( status != PSA_SUCCESS ) + { + psa_key_derivation_abort( &derivation ); + return( MBEDTLS_ERR_SSL_HW_ACCEL_FAILED ); + } + + status = psa_key_derivation_output_bytes( &derivation, + handshake->premaster, + handshake->pmslen ); + if( status != PSA_SUCCESS ) + { + psa_key_derivation_abort( &derivation ); + return( MBEDTLS_ERR_SSL_HW_ACCEL_FAILED ); + } + + status = psa_key_derivation_abort( &derivation ); + if( status != PSA_SUCCESS ) + { + return( MBEDTLS_ERR_SSL_HW_ACCEL_FAILED ); + } + } +#endif ret = handshake->tls_prf( handshake->premaster, handshake->pmslen, lbl, seed, seed_len, master, @@ -7544,7 +7673,7 @@ void mbedtls_ssl_handshake_wrapup( mbedtls_ssl_context *ssl ) #endif mbedtls_ssl_handshake_wrapup_free_hs_transform( ssl ); - ssl->state++; + ssl->state = MBEDTLS_SSL_HANDSHAKE_OVER; MBEDTLS_SSL_DEBUG_MSG( 3, ( "<= handshake wrapup" ) ); } @@ -8306,6 +8435,99 @@ end: return( ret ); } +#if defined(MBEDTLS_KEY_EXCHANGE_ECJPAKE_ENABLED) && \ + defined(MBEDTLS_USE_PSA_CRYPTO) +int mbedtls_psa_ecjpake_read_round( + psa_pake_operation_t *pake_ctx, + const unsigned char *buf, + size_t len, mbedtls_ecjpake_rounds_t round ) +{ + psa_status_t status; + size_t input_offset = 0; + /* + * At round one repeat the KEY_SHARE, ZK_PUBLIC & ZF_PROOF twice + * At round two perform a single cycle + */ + unsigned int remaining_steps = ( round == MBEDTLS_ECJPAKE_ROUND_ONE) ? 2 : 1; + + for( ; remaining_steps > 0; remaining_steps-- ) + { + for( psa_pake_step_t step = PSA_PAKE_STEP_KEY_SHARE; + step <= PSA_PAKE_STEP_ZK_PROOF; + ++step ) + { + /* Length is stored at the first byte */ + size_t length = buf[input_offset]; + input_offset += 1; + + if( input_offset + length > len ) + { + return MBEDTLS_ERR_SSL_HANDSHAKE_FAILURE; + } + + status = psa_pake_input( pake_ctx, step, + buf + input_offset, length ); + if( status != PSA_SUCCESS) + { + return psa_ssl_status_to_mbedtls( status ); + } + + input_offset += length; + } + } + + if( input_offset != len ) + return MBEDTLS_ERR_SSL_HANDSHAKE_FAILURE; + + return( 0 ); +} + +int mbedtls_psa_ecjpake_write_round( + psa_pake_operation_t *pake_ctx, + unsigned char *buf, + size_t len, size_t *olen, + mbedtls_ecjpake_rounds_t round ) +{ + psa_status_t status; + size_t output_offset = 0; + size_t output_len; + /* + * At round one repeat the KEY_SHARE, ZK_PUBLIC & ZF_PROOF twice + * At round two perform a single cycle + */ + unsigned int remaining_steps = ( round == MBEDTLS_ECJPAKE_ROUND_ONE) ? 2 : 1; + + for( ; remaining_steps > 0; remaining_steps-- ) + { + for( psa_pake_step_t step = PSA_PAKE_STEP_KEY_SHARE; + step <= PSA_PAKE_STEP_ZK_PROOF; + ++step ) + { + /* + * For each step, prepend 1 byte with the length of the data as + * given by psa_pake_output(). + */ + status = psa_pake_output( pake_ctx, step, + buf + output_offset + 1, + len - output_offset - 1, + &output_len ); + if( status != PSA_SUCCESS ) + { + return( psa_ssl_status_to_mbedtls( status ) ); + } + + *(buf + output_offset) = (uint8_t) output_len; + + output_offset += output_len + 1; + } + } + + *olen = output_offset; + + return( 0 ); +} +#endif //MBEDTLS_KEY_EXCHANGE_ECJPAKE_ENABLED && MBEDTLS_USE_PSA_CRYPTO + #if defined(MBEDTLS_USE_PSA_CRYPTO) int mbedtls_ssl_get_key_exchange_md_tls1_2( mbedtls_ssl_context *ssl, unsigned char *hash, size_t *hashlen, @@ -8864,8 +9086,13 @@ int mbedtls_ssl_validate_ciphersuite( #if defined(MBEDTLS_SSL_PROTO_TLS1_2) && defined(MBEDTLS_SSL_CLI_C) #if defined(MBEDTLS_KEY_EXCHANGE_ECJPAKE_ENABLED) +#if defined(MBEDTLS_USE_PSA_CRYPTO) + if( suite_info->key_exchange == MBEDTLS_KEY_EXCHANGE_ECJPAKE && + ssl->handshake->psa_pake_ctx_is_ok != 1 ) +#else if( suite_info->key_exchange == MBEDTLS_KEY_EXCHANGE_ECJPAKE && mbedtls_ecjpake_check( &ssl->handshake->ecjpake_ctx ) != 0 ) +#endif /* MBEDTLS_USE_PSA_CRYPTO */ { return( -1 ); } diff --git a/library/ssl_tls12_client.c b/library/ssl_tls12_client.c index 21b3ba6216..79c884b18c 100644 --- a/library/ssl_tls12_client.c +++ b/library/ssl_tls12_client.c @@ -132,13 +132,18 @@ static int ssl_write_ecjpake_kkpp_ext( mbedtls_ssl_context *ssl, { int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED; unsigned char *p = buf; - size_t kkpp_len; + size_t kkpp_len = 0; *olen = 0; /* Skip costly extension if we can't use EC J-PAKE anyway */ +#if defined(MBEDTLS_USE_PSA_CRYPTO) + if( ssl->handshake->psa_pake_ctx_is_ok != 1 ) + return( 0 ); +#else if( mbedtls_ecjpake_check( &ssl->handshake->ecjpake_ctx ) != 0 ) return( 0 ); +#endif /* MBEDTLS_USE_PSA_CRYPTO */ MBEDTLS_SSL_DEBUG_MSG( 3, ( "client hello, adding ecjpake_kkpp extension" ) ); @@ -158,6 +163,18 @@ static int ssl_write_ecjpake_kkpp_ext( mbedtls_ssl_context *ssl, { MBEDTLS_SSL_DEBUG_MSG( 3, ( "generating new ecjpake parameters" ) ); +#if defined(MBEDTLS_USE_PSA_CRYPTO) + ret = mbedtls_psa_ecjpake_write_round(&ssl->handshake->psa_pake_ctx, + p + 2, end - p - 2, &kkpp_len, + MBEDTLS_ECJPAKE_ROUND_ONE ); + if ( ret != 0 ) + { + psa_destroy_key( ssl->handshake->psa_pake_password ); + psa_pake_abort( &ssl->handshake->psa_pake_ctx ); + MBEDTLS_SSL_DEBUG_RET( 1 , "psa_pake_output", ret ); + return( ret ); + } +#else ret = mbedtls_ecjpake_write_round_one( &ssl->handshake->ecjpake_ctx, p + 2, end - p - 2, &kkpp_len, ssl->conf->f_rng, ssl->conf->p_rng ); @@ -167,6 +184,7 @@ static int ssl_write_ecjpake_kkpp_ext( mbedtls_ssl_context *ssl, "mbedtls_ecjpake_write_round_one", ret ); return( ret ); } +#endif /* MBEDTLS_USE_PSA_CRYPTO */ ssl->handshake->ecjpake_cache = mbedtls_calloc( 1, kkpp_len ); if( ssl->handshake->ecjpake_cache == NULL ) @@ -849,10 +867,11 @@ static int ssl_parse_supported_point_formats_ext( mbedtls_ssl_context *ssl, ssl->handshake->ecdh_ctx.point_format = p[0]; #endif /* !MBEDTLS_USE_PSA_CRYPTO && ( MBEDTLS_ECDH_C || MBEDTLS_ECDSA_C ) */ -#if defined(MBEDTLS_KEY_EXCHANGE_ECJPAKE_ENABLED) +#if !defined(MBEDTLS_USE_PSA_CRYPTO) && \ + defined(MBEDTLS_KEY_EXCHANGE_ECJPAKE_ENABLED) mbedtls_ecjpake_set_point_format( &ssl->handshake->ecjpake_ctx, p[0] ); -#endif +#endif /* !MBEDTLS_USE_PSA_CRYPTO && MBEDTLS_KEY_EXCHANGE_ECJPAKE_ENABLED */ MBEDTLS_SSL_DEBUG_MSG( 4, ( "point format selected: %d", p[0] ) ); return( 0 ); } @@ -889,6 +908,24 @@ static int ssl_parse_ecjpake_kkpp( mbedtls_ssl_context *ssl, ssl->handshake->ecjpake_cache = NULL; ssl->handshake->ecjpake_cache_len = 0; +#if defined(MBEDTLS_USE_PSA_CRYPTO) + if( ( ret = mbedtls_psa_ecjpake_read_round( + &ssl->handshake->psa_pake_ctx, buf, len, + MBEDTLS_ECJPAKE_ROUND_ONE ) ) != 0 ) + { + psa_destroy_key( ssl->handshake->psa_pake_password ); + psa_pake_abort( &ssl->handshake->psa_pake_ctx ); + + MBEDTLS_SSL_DEBUG_RET( 1, "psa_pake_input round one", ret ); + mbedtls_ssl_send_alert_message( + ssl, + MBEDTLS_SSL_ALERT_LEVEL_FATAL, + MBEDTLS_SSL_ALERT_MSG_HANDSHAKE_FAILURE ); + return( ret ); + } + + return( 0 ); +#else if( ( ret = mbedtls_ecjpake_read_round_one( &ssl->handshake->ecjpake_ctx, buf, len ) ) != 0 ) { @@ -901,6 +938,7 @@ static int ssl_parse_ecjpake_kkpp( mbedtls_ssl_context *ssl, } return( 0 ); +#endif /* MBEDTLS_USE_PSA_CRYPTO */ } #endif /* MBEDTLS_KEY_EXCHANGE_ECJPAKE_ENABLED */ @@ -2296,6 +2334,47 @@ start_processing: #if defined(MBEDTLS_KEY_EXCHANGE_ECJPAKE_ENABLED) if( ciphersuite_info->key_exchange == MBEDTLS_KEY_EXCHANGE_ECJPAKE ) { +#if defined(MBEDTLS_USE_PSA_CRYPTO) + /* + * The first 3 bytes are: + * [0] MBEDTLS_ECP_TLS_NAMED_CURVE + * [1, 2] elliptic curve's TLS ID + * + * However since we only support secp256r1 for now, we check only + * that TLS ID here + */ + uint16_t read_tls_id = MBEDTLS_GET_UINT16_BE( p, 1 ); + const mbedtls_ecp_curve_info *curve_info; + + if( ( curve_info = mbedtls_ecp_curve_info_from_grp_id( + MBEDTLS_ECP_DP_SECP256R1 ) ) == NULL ) + { + return( MBEDTLS_ERR_SSL_FEATURE_UNAVAILABLE ); + } + + if( ( *p != MBEDTLS_ECP_TLS_NAMED_CURVE ) || + ( read_tls_id != curve_info->tls_id ) ) + { + return( MBEDTLS_ERR_SSL_ILLEGAL_PARAMETER ); + } + + p += 3; + + if( ( ret = mbedtls_psa_ecjpake_read_round( + &ssl->handshake->psa_pake_ctx, p, end - p, + MBEDTLS_ECJPAKE_ROUND_TWO ) ) != 0 ) + { + psa_destroy_key( ssl->handshake->psa_pake_password ); + psa_pake_abort( &ssl->handshake->psa_pake_ctx ); + + MBEDTLS_SSL_DEBUG_RET( 1, "psa_pake_input round two", ret ); + mbedtls_ssl_send_alert_message( + ssl, + MBEDTLS_SSL_ALERT_LEVEL_FATAL, + MBEDTLS_SSL_ALERT_MSG_HANDSHAKE_FAILURE ); + return( MBEDTLS_ERR_SSL_HANDSHAKE_FAILURE ); + } +#else ret = mbedtls_ecjpake_read_round_two( &ssl->handshake->ecjpake_ctx, p, end - p ); if( ret != 0 ) @@ -2307,6 +2386,7 @@ start_processing: MBEDTLS_SSL_ALERT_MSG_HANDSHAKE_FAILURE ); return( MBEDTLS_ERR_SSL_HANDSHAKE_FAILURE ); } +#endif /* MBEDTLS_USE_PSA_CRYPTO */ } else #endif /* MBEDTLS_KEY_EXCHANGE_ECJPAKE_ENABLED */ @@ -3227,6 +3307,21 @@ ecdh_calc_secret: { header_len = 4; +#if defined(MBEDTLS_USE_PSA_CRYPTO) + unsigned char *out_p = ssl->out_msg + header_len; + unsigned char *end_p = ssl->out_msg + MBEDTLS_SSL_OUT_CONTENT_LEN - + header_len; + ret = mbedtls_psa_ecjpake_write_round( &ssl->handshake->psa_pake_ctx, + out_p, end_p - out_p, &content_len, + MBEDTLS_ECJPAKE_ROUND_TWO ); + if ( ret != 0 ) + { + psa_destroy_key( ssl->handshake->psa_pake_password ); + psa_pake_abort( &ssl->handshake->psa_pake_ctx ); + MBEDTLS_SSL_DEBUG_RET( 1 , "psa_pake_output", ret ); + return( ret ); + } +#else ret = mbedtls_ecjpake_write_round_two( &ssl->handshake->ecjpake_ctx, ssl->out_msg + header_len, MBEDTLS_SSL_OUT_CONTENT_LEN - header_len, @@ -3246,6 +3341,7 @@ ecdh_calc_secret: MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_ecjpake_derive_secret", ret ); return( ret ); } +#endif /* MBEDTLS_USE_PSA_CRYPTO */ } else #endif /* MBEDTLS_KEY_EXCHANGE_RSA_ENABLED */ diff --git a/library/ssl_tls12_server.c b/library/ssl_tls12_server.c index 3dab2467c6..8aa89c67e5 100644 --- a/library/ssl_tls12_server.c +++ b/library/ssl_tls12_server.c @@ -268,10 +268,11 @@ static int ssl_parse_supported_point_formats( mbedtls_ssl_context *ssl, ssl->handshake->ecdh_ctx.point_format = p[0]; #endif /* !MBEDTLS_USE_PSA_CRYPTO && ( MBEDTLS_ECDH_C || MBEDTLS_ECDSA_C ) */ -#if defined(MBEDTLS_KEY_EXCHANGE_ECJPAKE_ENABLED) +#if !defined(MBEDTLS_USE_PSA_CRYPTO) && \ + defined(MBEDTLS_KEY_EXCHANGE_ECJPAKE_ENABLED) mbedtls_ecjpake_set_point_format( &ssl->handshake->ecjpake_ctx, p[0] ); -#endif +#endif /* !MBEDTLS_USE_PSA_CRYPTO && MBEDTLS_KEY_EXCHANGE_ECJPAKE_ENABLED */ MBEDTLS_SSL_DEBUG_MSG( 4, ( "point format selected: %d", p[0] ) ); return( 0 ); } @@ -289,16 +290,37 @@ static int ssl_parse_supported_point_formats( mbedtls_ssl_context *ssl, MBEDTLS_CHECK_RETURN_CRITICAL static int ssl_parse_ecjpake_kkpp( mbedtls_ssl_context *ssl, const unsigned char *buf, - size_t len ) + size_t len) { int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED; +#if defined(MBEDTLS_USE_PSA_CRYPTO) + if( ssl->handshake->psa_pake_ctx_is_ok != 1 ) +#else if( mbedtls_ecjpake_check( &ssl->handshake->ecjpake_ctx ) != 0 ) +#endif /* MBEDTLS_USE_PSA_CRYPTO */ { MBEDTLS_SSL_DEBUG_MSG( 3, ( "skip ecjpake kkpp extension" ) ); return( 0 ); } +#if defined(MBEDTLS_USE_PSA_CRYPTO) + if( ( ret = mbedtls_psa_ecjpake_read_round( + &ssl->handshake->psa_pake_ctx, buf, len, + MBEDTLS_ECJPAKE_ROUND_ONE ) ) != 0 ) + { + psa_destroy_key( ssl->handshake->psa_pake_password ); + psa_pake_abort( &ssl->handshake->psa_pake_ctx ); + + MBEDTLS_SSL_DEBUG_RET( 1, "psa_pake_input round one", ret ); + mbedtls_ssl_send_alert_message( + ssl, + MBEDTLS_SSL_ALERT_LEVEL_FATAL, + MBEDTLS_SSL_ALERT_MSG_HANDSHAKE_FAILURE ); + + return( ret ); + } +#else if( ( ret = mbedtls_ecjpake_read_round_one( &ssl->handshake->ecjpake_ctx, buf, len ) ) != 0 ) { @@ -307,6 +329,7 @@ static int ssl_parse_ecjpake_kkpp( mbedtls_ssl_context *ssl, MBEDTLS_SSL_ALERT_MSG_ILLEGAL_PARAMETER ); return( ret ); } +#endif /* MBEDTLS_USE_PSA_CRYPTO */ /* Only mark the extension as OK when we're sure it is */ ssl->handshake->cli_exts |= MBEDTLS_TLS_EXT_ECJPAKE_KKPP_OK; @@ -1996,6 +2019,18 @@ static void ssl_write_ecjpake_kkpp_ext( mbedtls_ssl_context *ssl, MBEDTLS_PUT_UINT16_BE( MBEDTLS_TLS_EXT_ECJPAKE_KKPP, p, 0 ); p += 2; +#if defined(MBEDTLS_USE_PSA_CRYPTO) + ret = mbedtls_psa_ecjpake_write_round( &ssl->handshake->psa_pake_ctx, + p + 2, end - p - 2, &kkpp_len, + MBEDTLS_ECJPAKE_ROUND_ONE ); + if ( ret != 0 ) + { + psa_destroy_key( ssl->handshake->psa_pake_password ); + psa_pake_abort( &ssl->handshake->psa_pake_ctx ); + MBEDTLS_SSL_DEBUG_RET( 1 , "psa_pake_output", ret ); + return; + } +#else ret = mbedtls_ecjpake_write_round_one( &ssl->handshake->ecjpake_ctx, p + 2, end - p - 2, &kkpp_len, ssl->conf->f_rng, ssl->conf->p_rng ); @@ -2004,6 +2039,7 @@ static void ssl_write_ecjpake_kkpp_ext( mbedtls_ssl_context *ssl, MBEDTLS_SSL_DEBUG_RET( 1 , "mbedtls_ecjpake_write_round_one", ret ); return; } +#endif /* MBEDTLS_USE_PSA_CRYPTO */ MBEDTLS_PUT_UINT16_BE( kkpp_len, p, 0 ); p += 2; @@ -2813,6 +2849,46 @@ static int ssl_prepare_server_key_exchange( mbedtls_ssl_context *ssl, if( ciphersuite_info->key_exchange == MBEDTLS_KEY_EXCHANGE_ECJPAKE ) { int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED; +#if defined(MBEDTLS_USE_PSA_CRYPTO) + unsigned char *out_p = ssl->out_msg + ssl->out_msglen; + unsigned char *end_p = ssl->out_msg + MBEDTLS_SSL_OUT_CONTENT_LEN - + ssl->out_msglen; + size_t output_offset = 0; + size_t output_len = 0; + const mbedtls_ecp_curve_info *curve_info; + + /* + * The first 3 bytes are: + * [0] MBEDTLS_ECP_TLS_NAMED_CURVE + * [1, 2] elliptic curve's TLS ID + * + * However since we only support secp256r1 for now, we hardcode its + * TLS ID here + */ + if( ( curve_info = mbedtls_ecp_curve_info_from_grp_id( + MBEDTLS_ECP_DP_SECP256R1 ) ) == NULL ) + { + return( MBEDTLS_ERR_SSL_FEATURE_UNAVAILABLE ); + } + *out_p = MBEDTLS_ECP_TLS_NAMED_CURVE; + MBEDTLS_PUT_UINT16_BE( curve_info->tls_id, out_p, 1 ); + output_offset += 3; + + ret = mbedtls_psa_ecjpake_write_round( &ssl->handshake->psa_pake_ctx, + out_p + output_offset, + end_p - out_p - output_offset, &output_len, + MBEDTLS_ECJPAKE_ROUND_TWO ); + if( ret != 0 ) + { + psa_destroy_key( ssl->handshake->psa_pake_password ); + psa_pake_abort( &ssl->handshake->psa_pake_ctx ); + MBEDTLS_SSL_DEBUG_RET( 1 , "psa_pake_output", ret ); + return( ret ); + } + + output_offset += output_len; + ssl->out_msglen += output_offset; +#else size_t len = 0; ret = mbedtls_ecjpake_write_round_two( @@ -2827,6 +2903,7 @@ static int ssl_prepare_server_key_exchange( mbedtls_ssl_context *ssl, } ssl->out_msglen += len; +#endif /* MBEDTLS_USE_PSA_CRYPTO */ } #endif /* MBEDTLS_KEY_EXCHANGE_ECJPAKE_ENABLED */ @@ -4044,6 +4121,18 @@ static int ssl_parse_client_key_exchange( mbedtls_ssl_context *ssl ) #if defined(MBEDTLS_KEY_EXCHANGE_ECJPAKE_ENABLED) if( ciphersuite_info->key_exchange == MBEDTLS_KEY_EXCHANGE_ECJPAKE ) { +#if defined(MBEDTLS_USE_PSA_CRYPTO) + if( ( ret = mbedtls_psa_ecjpake_read_round( + &ssl->handshake->psa_pake_ctx, p, end - p, + MBEDTLS_ECJPAKE_ROUND_TWO ) ) != 0 ) + { + psa_destroy_key( ssl->handshake->psa_pake_password ); + psa_pake_abort( &ssl->handshake->psa_pake_ctx ); + + MBEDTLS_SSL_DEBUG_RET( 1, "psa_pake_input round two", ret ); + return( ret ); + } +#else ret = mbedtls_ecjpake_read_round_two( &ssl->handshake->ecjpake_ctx, p, end - p ); if( ret != 0 ) @@ -4060,6 +4149,7 @@ static int ssl_parse_client_key_exchange( mbedtls_ssl_context *ssl ) MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_ecjpake_derive_secret", ret ); return( ret ); } +#endif /* MBEDTLS_USE_PSA_CRYPTO */ } else #endif /* MBEDTLS_KEY_EXCHANGE_ECJPAKE_ENABLED */ diff --git a/library/ssl_tls13_client.c b/library/ssl_tls13_client.c index 0372f2d98d..0109f776c0 100644 --- a/library/ssl_tls13_client.c +++ b/library/ssl_tls13_client.c @@ -1183,11 +1183,11 @@ int mbedtls_ssl_tls13_write_client_hello_exts( mbedtls_ssl_context *ssl, return( ret ); p += ext_len; - /* Initializes the status to `indication sent`. It will be updated to - * `accepted` or `rejected` depending on whether the EncryptedExtension - * message will contain an early data indication extension or not. + /* Initializes the status to `rejected`. It will be updated to + * `accepted` if the EncryptedExtension message contain an early data + * indication extension. */ - ssl->early_data_status = MBEDTLS_SSL_EARLY_DATA_STATUS_INDICATION_SENT; + ssl->early_data_status = MBEDTLS_SSL_EARLY_DATA_STATUS_REJECTED; } else { @@ -2060,6 +2060,21 @@ static int ssl_tls13_parse_encrypted_extensions( mbedtls_ssl_context *ssl, break; #endif /* MBEDTLS_SSL_ALPN */ + +#if defined(MBEDTLS_SSL_EARLY_DATA) + case MBEDTLS_TLS_EXT_EARLY_DATA: + + if( extension_data_len != 0 ) + { + /* The message must be empty. */ + MBEDTLS_SSL_PEND_FATAL_ALERT( MBEDTLS_SSL_ALERT_MSG_DECODE_ERROR, + MBEDTLS_ERR_SSL_DECODE_ERROR ); + return( MBEDTLS_ERR_SSL_DECODE_ERROR ); + } + + break; +#endif /* MBEDTLS_SSL_EARLY_DATA */ + default: MBEDTLS_SSL_PRINT_EXT( 3, MBEDTLS_SSL_HS_ENCRYPTED_EXTENSIONS, @@ -2102,6 +2117,14 @@ static int ssl_tls13_process_encrypted_extensions( mbedtls_ssl_context *ssl ) MBEDTLS_SSL_PROC_CHK( ssl_tls13_parse_encrypted_extensions( ssl, buf, buf + buf_len ) ); +#if defined(MBEDTLS_SSL_EARLY_DATA) + if( ssl->handshake->received_extensions & + MBEDTLS_SSL_EXT_MASK( EARLY_DATA ) ) + { + ssl->early_data_status = MBEDTLS_SSL_EARLY_DATA_STATUS_ACCEPTED; + } +#endif + mbedtls_ssl_add_hs_msg_to_checksum( ssl, MBEDTLS_SSL_HS_ENCRYPTED_EXTENSIONS, buf, buf_len ); @@ -2743,7 +2766,7 @@ static int ssl_tls13_postprocess_new_session_ticket( mbedtls_ssl_context *ssl, } /* - * Handler for MBEDTLS_SSL_NEW_SESSION_TICKET + * Handler for MBEDTLS_SSL_TLS1_3_NEW_SESSION_TICKET */ MBEDTLS_CHECK_RETURN_CRITICAL static int ssl_tls13_process_new_session_ticket( mbedtls_ssl_context *ssl ) @@ -2857,7 +2880,7 @@ int mbedtls_ssl_tls13_handshake_client_step( mbedtls_ssl_context *ssl ) #endif /* MBEDTLS_SSL_TLS1_3_COMPATIBILITY_MODE */ #if defined(MBEDTLS_SSL_SESSION_TICKETS) - case MBEDTLS_SSL_NEW_SESSION_TICKET: + case MBEDTLS_SSL_TLS1_3_NEW_SESSION_TICKET: ret = ssl_tls13_process_new_session_ticket( ssl ); if( ret != 0 ) break; diff --git a/library/ssl_tls13_server.c b/library/ssl_tls13_server.c index 3cd03108f6..ce8767c5fd 100644 --- a/library/ssl_tls13_server.c +++ b/library/ssl_tls13_server.c @@ -2628,7 +2628,7 @@ static int ssl_tls13_handshake_wrapup( mbedtls_ssl_context *ssl ) mbedtls_ssl_tls13_handshake_wrapup( ssl ); #if defined(MBEDTLS_SSL_SESSION_TICKETS) - mbedtls_ssl_handshake_set_state( ssl, MBEDTLS_SSL_NEW_SESSION_TICKET ); + mbedtls_ssl_handshake_set_state( ssl, MBEDTLS_SSL_TLS1_3_NEW_SESSION_TICKET ); #else mbedtls_ssl_handshake_set_state( ssl, MBEDTLS_SSL_HANDSHAKE_OVER ); #endif @@ -2636,7 +2636,7 @@ static int ssl_tls13_handshake_wrapup( mbedtls_ssl_context *ssl ) } /* - * Handler for MBEDTLS_SSL_NEW_SESSION_TICKET + * Handler for MBEDTLS_SSL_TLS1_3_NEW_SESSION_TICKET */ #define SSL_NEW_SESSION_TICKET_SKIP 0 #define SSL_NEW_SESSION_TICKET_WRITE 1 @@ -2872,7 +2872,7 @@ static int ssl_tls13_write_new_session_ticket_body( mbedtls_ssl_context *ssl, } /* - * Handler for MBEDTLS_SSL_NEW_SESSION_TICKET + * Handler for MBEDTLS_SSL_TLS1_3_NEW_SESSION_TICKET */ static int ssl_tls13_write_new_session_ticket( mbedtls_ssl_context *ssl ) { @@ -2908,8 +2908,8 @@ static int ssl_tls13_write_new_session_ticket( mbedtls_ssl_context *ssl ) else ssl->handshake->new_session_tickets_count--; - mbedtls_ssl_handshake_set_state( ssl, - MBEDTLS_SSL_NEW_SESSION_TICKET_FLUSH ); + mbedtls_ssl_handshake_set_state( + ssl, MBEDTLS_SSL_TLS1_3_NEW_SESSION_TICKET_FLUSH ); } else { @@ -3045,7 +3045,7 @@ int mbedtls_ssl_tls13_handshake_server_step( mbedtls_ssl_context *ssl ) #endif /* MBEDTLS_SSL_TLS1_3_KEY_EXCHANGE_MODE_EPHEMERAL_ENABLED */ #if defined(MBEDTLS_SSL_SESSION_TICKETS) - case MBEDTLS_SSL_NEW_SESSION_TICKET: + case MBEDTLS_SSL_TLS1_3_NEW_SESSION_TICKET: ret = ssl_tls13_write_new_session_ticket( ssl ); if( ret != 0 ) { @@ -3054,9 +3054,9 @@ int mbedtls_ssl_tls13_handshake_server_step( mbedtls_ssl_context *ssl ) ret ); } break; - case MBEDTLS_SSL_NEW_SESSION_TICKET_FLUSH: + case MBEDTLS_SSL_TLS1_3_NEW_SESSION_TICKET_FLUSH: /* This state is necessary to do the flush of the New Session - * Ticket message written in MBEDTLS_SSL_NEW_SESSION_TICKET + * Ticket message written in MBEDTLS_SSL_TLS1_3_NEW_SESSION_TICKET * as part of ssl_prepare_handshake_step. */ ret = 0; @@ -3064,7 +3064,7 @@ int mbedtls_ssl_tls13_handshake_server_step( mbedtls_ssl_context *ssl ) if( ssl->handshake->new_session_tickets_count == 0 ) mbedtls_ssl_handshake_set_state( ssl, MBEDTLS_SSL_HANDSHAKE_OVER ); else - mbedtls_ssl_handshake_set_state( ssl, MBEDTLS_SSL_NEW_SESSION_TICKET ); + mbedtls_ssl_handshake_set_state( ssl, MBEDTLS_SSL_TLS1_3_NEW_SESSION_TICKET ); break; #endif /* MBEDTLS_SSL_SESSION_TICKETS */ diff --git a/scripts/mbedtls_dev/bignum_common.py b/scripts/mbedtls_dev/bignum_common.py index 8b11bc283c..67ea78db46 100644 --- a/scripts/mbedtls_dev/bignum_common.py +++ b/scripts/mbedtls_dev/bignum_common.py @@ -15,7 +15,12 @@ # limitations under the License. from abc import abstractmethod -from typing import Iterator, List, Tuple, TypeVar +from typing import Iterator, List, Tuple, TypeVar, Any +from itertools import chain + +from . import test_case +from . import test_data_generation +from .bignum_data import INPUTS_DEFAULT, MODULI_DEFAULT T = TypeVar('T') #pylint: disable=invalid-name @@ -63,8 +68,7 @@ def combination_pairs(values: List[T]) -> List[Tuple[T, T]]: """Return all pair combinations from input values.""" return [(x, y) for x in values for y in values] - -class OperationCommon: +class OperationCommon(test_data_generation.BaseTest): """Common features for bignum binary operations. This adds functionality common in binary operation tests. @@ -78,22 +82,106 @@ class OperationCommon: unique_combinations_only: Boolean to select if test case combinations must be unique. If True, only A,B or B,A would be included as a test case. If False, both A,B and B,A would be included. + input_style: Controls the way how test data is passed to the functions + in the generated test cases. "variable" passes them as they are + defined in the python source. "arch_split" pads the values with + zeroes depending on the architecture/limb size. If this is set, + test cases are generated for all architectures. + arity: the number of operands for the operation. Currently supported + values are 1 and 2. """ symbol = "" - input_values = [] # type: List[str] - input_cases = [] # type: List[Tuple[str, str]] - unique_combinations_only = True + input_values = INPUTS_DEFAULT # type: List[str] + input_cases = [] # type: List[Any] + unique_combinations_only = False + input_styles = ["variable", "fixed", "arch_split"] # type: List[str] + input_style = "variable" # type: str + limb_sizes = [32, 64] # type: List[int] + arities = [1, 2] + arity = 2 - def __init__(self, val_a: str, val_b: str) -> None: - self.arg_a = val_a - self.arg_b = val_b + def __init__(self, val_a: str, val_b: str = "0", bits_in_limb: int = 32) -> None: + self.val_a = val_a + self.val_b = val_b + # Setting the int versions here as opposed to making them @properties + # provides earlier/more robust input validation. self.int_a = hex_to_int(val_a) self.int_b = hex_to_int(val_b) + if bits_in_limb not in self.limb_sizes: + raise ValueError("Invalid number of bits in limb!") + if self.input_style == "arch_split": + self.dependencies = ["MBEDTLS_HAVE_INT{:d}".format(bits_in_limb)] + self.bits_in_limb = bits_in_limb + + @property + def boundary(self) -> int: + if self.arity == 1: + return self.int_a + elif self.arity == 2: + return max(self.int_a, self.int_b) + raise ValueError("Unsupported number of operands!") + + @property + def limb_boundary(self) -> int: + return bound_mpi(self.boundary, self.bits_in_limb) + + @property + def limbs(self) -> int: + return limbs_mpi(self.boundary, self.bits_in_limb) + + @property + def hex_digits(self) -> int: + return 2 * (self.limbs * self.bits_in_limb // 8) + + def format_arg(self, val) -> str: + if self.input_style not in self.input_styles: + raise ValueError("Unknown input style!") + if self.input_style == "variable": + return val + else: + return val.zfill(self.hex_digits) + + def format_result(self, res) -> str: + res_str = '{:x}'.format(res) + return quote_str(self.format_arg(res_str)) + + @property + def arg_a(self) -> str: + return self.format_arg(self.val_a) + + @property + def arg_b(self) -> str: + if self.arity == 1: + raise AttributeError("Operation is unary and doesn't have arg_b!") + return self.format_arg(self.val_b) def arguments(self) -> List[str]: - return [ - quote_str(self.arg_a), quote_str(self.arg_b) - ] + self.result() + args = [quote_str(self.arg_a)] + if self.arity == 2: + args.append(quote_str(self.arg_b)) + return args + self.result() + + def description(self) -> str: + """Generate a description for the test case. + + If not set, case_description uses the form A `symbol` B, where symbol + is used to represent the operation. Descriptions of each value are + generated to provide some context to the test case. + """ + if not self.case_description: + if self.arity == 1: + self.case_description = "{} {:x}".format( + self.symbol, self.int_a + ) + elif self.arity == 2: + self.case_description = "{:x} {} {:x}".format( + self.int_a, self.symbol, self.int_b + ) + return super().description() + + @property + def is_valid(self) -> bool: + return True @abstractmethod def result(self) -> List[str]: @@ -111,15 +199,134 @@ class OperationCommon: Combinations are first generated from all input values, and then specific cases provided. """ - if cls.unique_combinations_only: - yield from combination_pairs(cls.input_values) + if cls.arity == 1: + yield from ((a, "0") for a in cls.input_values) + elif cls.arity == 2: + if cls.unique_combinations_only: + yield from combination_pairs(cls.input_values) + else: + yield from ( + (a, b) + for a in cls.input_values + for b in cls.input_values + ) else: - yield from ( - (a, b) - for a in cls.input_values - for b in cls.input_values - ) - yield from cls.input_cases + raise ValueError("Unsupported number of operands!") + + @classmethod + def generate_function_tests(cls) -> Iterator[test_case.TestCase]: + if cls.input_style not in cls.input_styles: + raise ValueError("Unknown input style!") + if cls.arity not in cls.arities: + raise ValueError("Unsupported number of operands!") + if cls.input_style == "arch_split": + test_objects = (cls(a, b, bits_in_limb=bil) + for a, b in cls.get_value_pairs() + for bil in cls.limb_sizes) + special_cases = (cls(*args, bits_in_limb=bil) # type: ignore + for args in cls.input_cases + for bil in cls.limb_sizes) + else: + test_objects = (cls(a, b) + for a, b in cls.get_value_pairs()) + special_cases = (cls(*args) for args in cls.input_cases) + yield from (valid_test_object.create_test_case() + for valid_test_object in filter( + lambda test_object: test_object.is_valid, + chain(test_objects, special_cases) + ) + ) + + +class ModOperationCommon(OperationCommon): + #pylint: disable=abstract-method + """Target for bignum mod_raw test case generation.""" + moduli = MODULI_DEFAULT # type: List[str] + + def __init__(self, val_n: str, val_a: str, val_b: str = "0", + bits_in_limb: int = 64) -> None: + super().__init__(val_a=val_a, val_b=val_b, bits_in_limb=bits_in_limb) + self.val_n = val_n + # Setting the int versions here as opposed to making them @properties + # provides earlier/more robust input validation. + self.int_n = hex_to_int(val_n) + + @property + def boundary(self) -> int: + return self.int_n + + @property + def arg_n(self) -> str: + return self.format_arg(self.val_n) + + def arguments(self) -> List[str]: + return [quote_str(self.arg_n)] + super().arguments() + + @property + def r(self) -> int: # pylint: disable=invalid-name + l = limbs_mpi(self.int_n, self.bits_in_limb) + return bound_mpi_limbs(l, self.bits_in_limb) + + @property + def r_inv(self) -> int: + return invmod(self.r, self.int_n) + + @property + def r2(self) -> int: # pylint: disable=invalid-name + return pow(self.r, 2) + + @property + def is_valid(self) -> bool: + if self.int_a >= self.int_n: + return False + if self.arity == 2 and self.int_b >= self.int_n: + return False + return True + + def description(self) -> str: + """Generate a description for the test case. + + It uses the form A `symbol` B mod N, where symbol is used to represent + the operation. + """ + + if not self.case_description: + return super().description() + " mod {:x}".format(self.int_n) + return super().description() + + @classmethod + def input_cases_args(cls) -> Iterator[Tuple[Any, Any, Any]]: + if cls.arity == 1: + yield from ((n, a, "0") for a, n in cls.input_cases) + elif cls.arity == 2: + yield from ((n, a, b) for a, b, n in cls.input_cases) + else: + raise ValueError("Unsupported number of operands!") + + @classmethod + def generate_function_tests(cls) -> Iterator[test_case.TestCase]: + if cls.input_style not in cls.input_styles: + raise ValueError("Unknown input style!") + if cls.arity not in cls.arities: + raise ValueError("Unsupported number of operands!") + if cls.input_style == "arch_split": + test_objects = (cls(n, a, b, bits_in_limb=bil) + for n in cls.moduli + for a, b in cls.get_value_pairs() + for bil in cls.limb_sizes) + special_cases = (cls(*args, bits_in_limb=bil) + for args in cls.input_cases_args() + for bil in cls.limb_sizes) + else: + test_objects = (cls(n, a, b) + for n in cls.moduli + for a, b in cls.get_value_pairs()) + special_cases = (cls(*args) for args in cls.input_cases_args()) + yield from (valid_test_object.create_test_case() + for valid_test_object in filter( + lambda test_object: test_object.is_valid, + chain(test_objects, special_cases) + )) # BEGIN MERGE SLOT 1 diff --git a/scripts/mbedtls_dev/bignum_core.py b/scripts/mbedtls_dev/bignum_core.py index 0cc86b8096..b8e2a31239 100644 --- a/scripts/mbedtls_dev/bignum_core.py +++ b/scripts/mbedtls_dev/bignum_core.py @@ -16,20 +16,19 @@ import random -from abc import ABCMeta from typing import Dict, Iterator, List, Tuple from . import test_case from . import test_data_generation from . import bignum_common -class BignumCoreTarget(test_data_generation.BaseTarget, metaclass=ABCMeta): - #pylint: disable=abstract-method +class BignumCoreTarget(test_data_generation.BaseTarget): + #pylint: disable=abstract-method, too-few-public-methods """Target for bignum core test case generation.""" target_basename = 'test_suite_bignum_core.generated' -class BignumCoreShiftR(BignumCoreTarget, metaclass=ABCMeta): +class BignumCoreShiftR(BignumCoreTarget, test_data_generation.BaseTest): """Test cases for mbedtls_bignum_core_shift_r().""" count = 0 test_function = "mpi_core_shift_r" @@ -69,7 +68,7 @@ class BignumCoreShiftR(BignumCoreTarget, metaclass=ABCMeta): for count in counts: yield cls(input_hex, descr, count).create_test_case() -class BignumCoreCTLookup(BignumCoreTarget, metaclass=ABCMeta): +class BignumCoreCTLookup(BignumCoreTarget, test_data_generation.BaseTest): """Test cases for mbedtls_mpi_core_ct_uint_table_lookup().""" test_function = "mpi_core_ct_uint_table_lookup" test_name = "Constant time MPI table lookup" @@ -107,104 +106,33 @@ class BignumCoreCTLookup(BignumCoreTarget, metaclass=ABCMeta): yield (cls(bitsize, bitsize_description, window_size) .create_test_case()) -class BignumCoreOperation(bignum_common.OperationCommon, BignumCoreTarget, metaclass=ABCMeta): - #pylint: disable=abstract-method - """Common features for bignum core operations.""" - input_values = [ - "0", "1", "3", "f", "fe", "ff", "100", "ff00", "fffe", "ffff", "10000", - "fffffffe", "ffffffff", "100000000", "1f7f7f7f7f7f7f", - "8000000000000000", "fefefefefefefefe", "fffffffffffffffe", - "ffffffffffffffff", "10000000000000000", "1234567890abcdef0", - "fffffffffffffffffefefefefefefefe", "fffffffffffffffffffffffffffffffe", - "ffffffffffffffffffffffffffffffff", "100000000000000000000000000000000", - "1234567890abcdef01234567890abcdef0", - "fffffffffffffffffffffffffffffffffffffffffffffffffefefefefefefefe", - "fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffe", - "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", - "10000000000000000000000000000000000000000000000000000000000000000", - "1234567890abcdef01234567890abcdef01234567890abcdef01234567890abcdef0", - ( - "4df72d07b4b71c8dacb6cffa954f8d88254b6277099308baf003fab73227f34029" - "643b5a263f66e0d3c3fa297ef71755efd53b8fb6cb812c6bbf7bcf179298bd9947" - "c4c8b14324140a2c0f5fad7958a69050a987a6096e9f055fb38edf0c5889eca4a0" - "cfa99b45fbdeee4c696b328ddceae4723945901ec025076b12b" - ) - ] - def description(self) -> str: - """Generate a description for the test case. - - If not set, case_description uses the form A `symbol` B, where symbol - is used to represent the operation. Descriptions of each value are - generated to provide some context to the test case. - """ - if not self.case_description: - self.case_description = "{:x} {} {:x}".format( - self.int_a, self.symbol, self.int_b - ) - return super().description() - - @classmethod - def generate_function_tests(cls) -> Iterator[test_case.TestCase]: - for a_value, b_value in cls.get_value_pairs(): - yield cls(a_value, b_value).create_test_case() - - -class BignumCoreOperationArchSplit(BignumCoreOperation): - #pylint: disable=abstract-method - """Common features for bignum core operations where the result depends on - the limb size.""" - - def __init__(self, val_a: str, val_b: str, bits_in_limb: int) -> None: - super().__init__(val_a, val_b) - bound_val = max(self.int_a, self.int_b) - self.bits_in_limb = bits_in_limb - self.bound = bignum_common.bound_mpi(bound_val, self.bits_in_limb) - limbs = bignum_common.limbs_mpi(bound_val, self.bits_in_limb) - byte_len = limbs * self.bits_in_limb // 8 - self.hex_digits = 2 * byte_len - if self.bits_in_limb == 32: - self.dependencies = ["MBEDTLS_HAVE_INT32"] - elif self.bits_in_limb == 64: - self.dependencies = ["MBEDTLS_HAVE_INT64"] - else: - raise ValueError("Invalid number of bits in limb!") - self.arg_a = self.arg_a.zfill(self.hex_digits) - self.arg_b = self.arg_b.zfill(self.hex_digits) - - def pad_to_limbs(self, val) -> str: - return "{:x}".format(val).zfill(self.hex_digits) - - @classmethod - def generate_function_tests(cls) -> Iterator[test_case.TestCase]: - for a_value, b_value in cls.get_value_pairs(): - yield cls(a_value, b_value, 32).create_test_case() - yield cls(a_value, b_value, 64).create_test_case() - -class BignumCoreAddAndAddIf(BignumCoreOperationArchSplit): +class BignumCoreAddAndAddIf(BignumCoreTarget, bignum_common.OperationCommon): """Test cases for bignum core add and add-if.""" count = 0 symbol = "+" test_function = "mpi_core_add_and_add_if" test_name = "mpi_core_add_and_add_if" + input_style = "arch_split" + unique_combinations_only = True def result(self) -> List[str]: result = self.int_a + self.int_b - carry, result = divmod(result, self.bound) + carry, result = divmod(result, self.limb_boundary) return [ - bignum_common.quote_str(self.pad_to_limbs(result)), + self.format_result(result), str(carry) ] -class BignumCoreSub(BignumCoreOperation): + +class BignumCoreSub(BignumCoreTarget, bignum_common.OperationCommon): """Test cases for bignum core sub.""" count = 0 symbol = "-" test_function = "mpi_core_sub" test_name = "mbedtls_mpi_core_sub" - unique_combinations_only = False def result(self) -> List[str]: if self.int_a >= self.int_b: @@ -224,12 +152,11 @@ class BignumCoreSub(BignumCoreOperation): ] -class BignumCoreMLA(BignumCoreOperation): +class BignumCoreMLA(BignumCoreTarget, bignum_common.OperationCommon): """Test cases for fixed-size multiply accumulate.""" count = 0 test_function = "mpi_core_mla" test_name = "mbedtls_mpi_core_mla" - unique_combinations_only = False input_values = [ "0", "1", "fffe", "ffffffff", "100000000", "20000000000000", @@ -288,6 +215,16 @@ class BignumCoreMLA(BignumCoreOperation): "\"{:x}\"".format(carry_8) ] + @classmethod + def get_value_pairs(cls) -> Iterator[Tuple[str, str]]: + """Generator to yield pairs of inputs. + + Combinations are first generated from all input values, and then + specific cases provided. + """ + yield from super().get_value_pairs() + yield from cls.input_cases + @classmethod def generate_function_tests(cls) -> Iterator[test_case.TestCase]: """Override for additional scalar input.""" @@ -297,7 +234,7 @@ class BignumCoreMLA(BignumCoreOperation): yield cur_op.create_test_case() -class BignumCoreMontmul(BignumCoreTarget): +class BignumCoreMontmul(BignumCoreTarget, test_data_generation.BaseTest): """Test cases for Montgomery multiplication.""" count = 0 test_function = "mpi_core_montmul" @@ -826,6 +763,37 @@ def mpi_modmul_case_generate() -> None: # BEGIN MERGE SLOT 3 +class BignumCoreSubInt(BignumCoreTarget, bignum_common.OperationCommon): + """Test cases for bignum core sub int.""" + count = 0 + symbol = "-" + test_function = "mpi_core_sub_int" + test_name = "mpi_core_sub_int" + input_style = "arch_split" + + @property + def is_valid(self) -> bool: + # This is "sub int", so b is only one limb + if bignum_common.limbs_mpi(self.int_b, self.bits_in_limb) > 1: + return False + return True + + # Overriding because we don't want leading zeros on b + @property + def arg_b(self) -> str: + return self.val_b + + def result(self) -> List[str]: + result = self.int_a - self.int_b + + borrow, result = divmod(result, self.limb_boundary) + + # Borrow will be -1 if non-zero, but we want it to be 1 in the test data + return [ + self.format_result(result), + str(-borrow) + ] + # END MERGE SLOT 3 # BEGIN MERGE SLOT 4 diff --git a/scripts/mbedtls_dev/bignum_data.py b/scripts/mbedtls_dev/bignum_data.py new file mode 100644 index 0000000000..74d21d0ca5 --- /dev/null +++ b/scripts/mbedtls_dev/bignum_data.py @@ -0,0 +1,136 @@ +"""Base values and datasets for bignum generated tests and helper functions that +produced them.""" +# Copyright The Mbed TLS Contributors +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random + +# Functions calling these were used to produce test data and are here only for +# reproducability, they are not used by the test generation framework/classes +try: + from Cryptodome.Util.number import isPrime, getPrime #type: ignore #pylint: disable=import-error +except ImportError: + pass + +# Generated by bignum_common.gen_safe_prime(192,1) +SAFE_PRIME_192_BIT_SEED_1 = "d1c127a667786703830500038ebaef20e5a3e2dc378fb75b" + +# First number generated by random.getrandbits(192) - seed(2,2), not a prime +RANDOM_192_BIT_SEED_2_NO1 = "177219d30e7a269fd95bafc8f2a4d27bdcf4bb99f4bea973" + +# Second number generated by random.getrandbits(192) - seed(2,2), not a prime +RANDOM_192_BIT_SEED_2_NO2 = "cf1822ffbc6887782b491044d5e341245c6e433715ba2bdd" + +# Third number generated by random.getrandbits(192) - seed(2,2), not a prime +RANDOM_192_BIT_SEED_2_NO3 = "3653f8dd9b1f282e4067c3584ee207f8da94e3e8ab73738f" + +# Fourth number generated by random.getrandbits(192) - seed(2,2), not a prime +RANDOM_192_BIT_SEED_2_NO4 = "ffed9235288bc781ae66267594c9c9500925e4749b575bd1" + +# Ninth number generated by random.getrandbits(192) - seed(2,2), not a prime +RANDOM_192_BIT_SEED_2_NO9 = "2a1be9cd8697bbd0e2520e33e44c50556c71c4a66148a86f" + +# Generated by bignum_common.gen_safe_prime(1024,3) +SAFE_PRIME_1024_BIT_SEED_3 = ("c93ba7ec74d96f411ba008bdb78e63ff11bb5df46a51e16b" + "2c9d156f8e4e18abf5e052cb01f47d0d1925a77f60991577" + "e128fb6f52f34a27950a594baadd3d8057abeb222cf3cca9" + "62db16abf79f2ada5bd29ab2f51244bf295eff9f6aaba130" + "2efc449b128be75eeaca04bc3c1a155d11d14e8be32a2c82" + "87b3996cf6ad5223") + +# First number generated by random.getrandbits(1024) - seed(4,2), not a prime +RANDOM_1024_BIT_SEED_4_NO1 = ("6905269ed6f0b09f165c8ce36e2f24b43000de01b2ed40ed" + "3addccb2c33be0ac79d679346d4ac7a5c3902b38963dc6e8" + "534f45738d048ec0f1099c6c3e1b258fd724452ccea71ff4" + "a14876aeaff1a098ca5996666ceab360512bd13110722311" + "710cf5327ac435a7a97c643656412a9b8a1abcd1a6916c74" + "da4f9fc3c6da5d7") + +# Second number generated by random.getrandbits(1024) - seed(4,2), not a prime +RANDOM_1024_BIT_SEED_4_NO2 = ("f1cfd99216df648647adec26793d0e453f5082492d83a823" + "3fb62d2c81862fc9634f806fabf4a07c566002249b191bf4" + "d8441b5616332aca5f552773e14b0190d93936e1daca3c06" + "f5ff0c03bb5d7385de08caa1a08179104a25e4664f5253a0" + "2a3187853184ff27459142deccea264542a00403ce80c4b0" + "a4042bb3d4341aad") + +# Third number generated by random.getrandbits(1024) - seed(4,2), not a prime +RANDOM_1024_BIT_SEED_4_NO3 = ("14c15c910b11ad28cc21ce88d0060cc54278c2614e1bcb38" + "3bb4a570294c4ea3738d243a6e58d5ca49c7b59b995253fd" + "6c79a3de69f85e3131f3b9238224b122c3e4a892d9196ada" + "4fcfa583e1df8af9b474c7e89286a1754abcb06ae8abb93f" + "01d89a024cdce7a6d7288ff68c320f89f1347e0cdd905ecf" + "d160c5d0ef412ed6") + +# Fourth number generated by random.getrandbits(1024) - seed(4,2), not a prime +RANDOM_1024_BIT_SEED_4_NO4 = ("32decd6b8efbc170a26a25c852175b7a96b98b5fbf37a2be" + "6f98bca35b17b9662f0733c846bbe9e870ef55b1a1f65507" + "a2909cb633e238b4e9dd38b869ace91311021c9e32111ac1" + "ac7cc4a4ff4dab102522d53857c49391b36cc9aa78a330a1" + "a5e333cb88dcf94384d4cd1f47ca7883ff5a52f1a05885ac" + "7671863c0bdbc23a") + +# Fifth number generated by random.getrandbits(1024) - seed(4,2), not a prime +RANDOM_1024_BIT_SEED_4_NO5 = ("53be4721f5b9e1f5acdac615bc20f6264922b9ccf469aef8" + "f6e7d078e55b85dd1525f363b281b8885b69dc230af5ac87" + "0692b534758240df4a7a03052d733dcdef40af2e54c0ce68" + "1f44ebd13cc75f3edcb285f89d8cf4d4950b16ffc3e1ac3b" + "4708d9893a973000b54a23020fc5b043d6e4a51519d9c9cc" + "52d32377e78131c1") + +# Adding 192 bit and 1024 bit numbers because these are the shortest required +# for ECC and RSA respectively. +INPUTS_DEFAULT = [ + "0", "1", # corner cases + "2", "3", # small primes + "4", # non-prime even + "38", # small random + SAFE_PRIME_192_BIT_SEED_1, # prime + RANDOM_192_BIT_SEED_2_NO1, # not a prime + RANDOM_192_BIT_SEED_2_NO2, # not a prime + SAFE_PRIME_1024_BIT_SEED_3, # prime + RANDOM_1024_BIT_SEED_4_NO1, # not a prime + RANDOM_1024_BIT_SEED_4_NO3, # not a prime + RANDOM_1024_BIT_SEED_4_NO2, # largest (not a prime) + ] + +# Only odd moduli are present as in the new bignum code only odd moduli are +# supported for now. +MODULI_DEFAULT = [ + "53", # safe prime + "45", # non-prime + SAFE_PRIME_192_BIT_SEED_1, # safe prime + RANDOM_192_BIT_SEED_2_NO4, # not a prime + SAFE_PRIME_1024_BIT_SEED_3, # safe prime + RANDOM_1024_BIT_SEED_4_NO5, # not a prime + ] + +def __gen_safe_prime(bits, seed): + ''' + Generate a safe prime. + + This function is intended for generating constants offline and shouldn't be + used in test generation classes. + + Requires pycryptodomex for getPrime and isPrime and python 3.9 or later for + randbytes. + ''' + rng = random.Random() + # We want reproducability across python versions + rng.seed(seed, version=2) + while True: + prime = 2*getPrime(bits-1, rng.randbytes)+1 #pylint: disable=no-member + if isPrime(prime, 1e-30): + return prime diff --git a/scripts/mbedtls_dev/bignum_mod.py b/scripts/mbedtls_dev/bignum_mod.py index 2bd7fbbda3..a604cc0c59 100644 --- a/scripts/mbedtls_dev/bignum_mod.py +++ b/scripts/mbedtls_dev/bignum_mod.py @@ -14,12 +14,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from abc import ABCMeta - from . import test_data_generation -class BignumModTarget(test_data_generation.BaseTarget, metaclass=ABCMeta): - #pylint: disable=abstract-method +class BignumModTarget(test_data_generation.BaseTarget): + #pylint: disable=abstract-method, too-few-public-methods """Target for bignum mod test case generation.""" target_basename = 'test_suite_bignum_mod.generated' diff --git a/scripts/mbedtls_dev/bignum_mod_raw.py b/scripts/mbedtls_dev/bignum_mod_raw.py index bd694a6084..60f2feded6 100644 --- a/scripts/mbedtls_dev/bignum_mod_raw.py +++ b/scripts/mbedtls_dev/bignum_mod_raw.py @@ -14,89 +14,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -from abc import ABCMeta -from typing import Dict, Iterator, List +from typing import Dict, List -from . import test_case from . import test_data_generation from . import bignum_common -class BignumModRawTarget(test_data_generation.BaseTarget, metaclass=ABCMeta): - #pylint: disable=abstract-method +class BignumModRawTarget(test_data_generation.BaseTarget): + #pylint: disable=abstract-method, too-few-public-methods """Target for bignum mod_raw test case generation.""" target_basename = 'test_suite_bignum_mod_raw.generated' -class BignumModRawOperation(bignum_common.OperationCommon, BignumModRawTarget, metaclass=ABCMeta): - #pylint: disable=abstract-method - """Target for bignum mod_raw test case generation.""" - - def __init__(self, val_n: str, val_a: str, val_b: str = "0", bits_in_limb: int = 64) -> None: - super().__init__(val_a=val_a, val_b=val_b) - self.val_n = val_n - self.bits_in_limb = bits_in_limb - - @property - def int_n(self) -> int: - return bignum_common.hex_to_int(self.val_n) - - @property - def boundary(self) -> int: - data_in = [self.int_a, self.int_b, self.int_n] - return max([n for n in data_in if n is not None]) - - @property - def limbs(self) -> int: - return bignum_common.limbs_mpi(self.boundary, self.bits_in_limb) - - @property - def hex_digits(self) -> int: - return 2 * (self.limbs * self.bits_in_limb // 8) - - @property - def hex_n(self) -> str: - return "{:x}".format(self.int_n).zfill(self.hex_digits) - - @property - def hex_a(self) -> str: - return "{:x}".format(self.int_a).zfill(self.hex_digits) - - @property - def hex_b(self) -> str: - return "{:x}".format(self.int_b).zfill(self.hex_digits) - - @property - def r(self) -> int: # pylint: disable=invalid-name - l = bignum_common.limbs_mpi(self.int_n, self.bits_in_limb) - return bignum_common.bound_mpi_limbs(l, self.bits_in_limb) - - @property - def r_inv(self) -> int: - return bignum_common.invmod(self.r, self.int_n) - - @property - def r2(self) -> int: # pylint: disable=invalid-name - return pow(self.r, 2) - -class BignumModRawOperationArchSplit(BignumModRawOperation): - #pylint: disable=abstract-method - """Common features for bignum mod raw operations where the result depends on - the limb size.""" - - limb_sizes = [32, 64] # type: List[int] - - def __init__(self, val_n: str, val_a: str, val_b: str = "0", bits_in_limb: int = 64) -> None: - super().__init__(val_n=val_n, val_a=val_a, val_b=val_b, bits_in_limb=bits_in_limb) - - if bits_in_limb not in self.limb_sizes: - raise ValueError("Invalid number of bits in limb!") - - self.dependencies = ["MBEDTLS_HAVE_INT{:d}".format(bits_in_limb)] - - @classmethod - def generate_function_tests(cls) -> Iterator[test_case.TestCase]: - for a_value, b_value in cls.get_value_pairs(): - for bil in cls.limb_sizes: - yield cls(a_value, b_value, bits_in_limb=bil).create_test_case() # BEGIN MERGE SLOT 1 # END MERGE SLOT 1 @@ -122,126 +49,35 @@ class BignumModRawOperationArchSplit(BignumModRawOperation): # END MERGE SLOT 6 # BEGIN MERGE SLOT 7 -class BignumModRawConvertToMont(BignumModRawOperationArchSplit): - """ Test cases for mpi_mod_raw_to_mont_rep(). """ +class BignumModRawConvertToMont(bignum_common.ModOperationCommon, + BignumModRawTarget): + """ Test cases for mpi_mod_raw_to_mont_rep(). """ test_function = "mpi_mod_raw_to_mont_rep" test_name = "Convert into Mont: " - - test_data_moduli = ["b", - "fd", - "eeff99aa37", - "eeff99aa11", - "800000000005", - "7fffffffffffffff", - "80fe000a10000001", - "25a55a46e5da99c71c7", - "1058ad82120c3a10196bb36229c1", - "7e35b84cb19ea5bc57ec37f5e431462fa962d98c1e63738d4657f" - "18ad6532e6adc3eafe67f1e5fa262af94cee8d3e7268593942a2a" - "98df75154f8c914a282f8b", - "8335616aed761f1f7f44e6bd49e807b82e3bf2bf11bfa63", - "ffcece570f2f991013f26dd5b03c4c5b65f97be5905f36cb4664f" - "2c78ff80aa8135a4aaf57ccb8a0aca2f394909a74cef1ef6758a6" - "4d11e2c149c393659d124bfc94196f0ce88f7d7d567efa5a649e2" - "deefaa6e10fdc3deac60d606bf63fc540ac95294347031aefd73d" - "6a9ee10188aaeb7a90d920894553cb196881691cadc51808715a0" - "7e8b24fcb1a63df047c7cdf084dd177ba368c806f3d51ddb5d389" - "8c863e687ecaf7d649a57a46264a582f94d3c8f2edaf59f77a7f6" - "bdaf83c991e8f06abe220ec8507386fce8c3da84c6c3903ab8f3a" - "d4630a204196a7dbcbd9bcca4e40ec5cc5c09938d49f5e1e6181d" - "b8896f33bb12e6ef73f12ec5c5ea7a8a337" - ] - - test_input_numbers = ["0", - "1", - "97", - "f5", - "6f5c3", - "745bfe50f7", - "ffa1f9924123", - "334a8b983c79bd", - "5b84f632b58f3461", - "19acd15bc38008e1", - "ffffffffffffffff", - "54ce6a6bb8247fa0427cfc75a6b0599", - "fecafe8eca052f154ce6a6bb8247fa019558bfeecce9bb9", - "a87d7a56fa4bfdc7da42ef798b9cf6843d4c54794698cb14d72" - "851dec9586a319f4bb6d5695acbd7c92e7a42a5ede6972adcbc" - "f68425265887f2d721f462b7f1b91531bac29fa648facb8e3c6" - "1bd5ae42d5a59ba1c89a95897bfe541a8ce1d633b98f379c481" - "6f25e21f6ac49286b261adb4b78274fe5f61c187581f213e84b" - "2a821e341ef956ecd5de89e6c1a35418cd74a549379d2d4594a" - "577543147f8e35b3514e62cf3e89d1156cdc91ab5f4c928fbd6" - "9148c35df5962fed381f4d8a62852a36823d5425f7487c13a12" - "523473fb823aa9d6ea5f42e794e15f2c1a8785cf6b7d51a4617" - "947fb3baf674f74a673cf1d38126983a19ed52c7439fab42c2185" - ] - - descr_tpl = '{} #{} N: \"{}\" A: \"{}\".' + symbol = "R *" + input_style = "arch_split" + arity = 1 def result(self) -> List[str]: - return [self.hex_x] + result = (self.int_a * self.r) % self.int_n + return [self.format_result(result)] - def arguments(self) -> List[str]: - return [bignum_common.quote_str(n) for n in [self.hex_n, - self.hex_a, - self.hex_x]] - def description(self) -> str: - return self.descr_tpl.format(self.test_name, - self.count, - self.int_n, - self.int_a) - - @classmethod - def generate_function_tests(cls) -> Iterator[test_case.TestCase]: - for bil in [32, 64]: - for n in cls.test_data_moduli: - for i in cls.test_input_numbers: - # Skip invalid combinations where A.limbs > N.limbs - if bignum_common.hex_to_int(i) > bignum_common.hex_to_int(n): - continue - yield cls(n, i, bits_in_limb=bil).create_test_case() - - @property - def x(self) -> int: # pylint: disable=invalid-name - return (self.int_a * self.r) % self.int_n - - @property - def hex_x(self) -> str: - return "{:x}".format(self.x).zfill(self.hex_digits) - -class BignumModRawConvertFromMont(BignumModRawConvertToMont): +class BignumModRawConvertFromMont(bignum_common.ModOperationCommon, + BignumModRawTarget): """ Test cases for mpi_mod_raw_from_mont_rep(). """ - test_function = "mpi_mod_raw_from_mont_rep" test_name = "Convert from Mont: " + symbol = "1/R *" + input_style = "arch_split" + arity = 1 + + def result(self) -> List[str]: + result = (self.int_a * self.r_inv) % self.int_n + return [self.format_result(result)] - test_input_numbers = ["0", - "1", - "3ca", - "539ed428", - "7dfe5c6beb35a2d6", - "dca8de1c2adfc6d7aafb9b48e", - "a7d17b6c4be72f3d5c16bf9c1af6fc933", - "2fec97beec546f9553142ed52f147845463f579", - "378dc83b8bc5a7b62cba495af4919578dce6d4f175cadc4f", - "b6415f2a1a8e48a518345db11f56db3829c8f2c6415ab4a395a" - "b3ac2ea4cbef4af86eb18a84eb6ded4c6ecbfc4b59c2879a675" - "487f687adea9d197a84a5242a5cf6125ce19a6ad2e7341f1c57" - "d43ea4f4c852a51cb63dabcd1c9de2b827a3146a3d175b35bea" - "41ae75d2a286a3e9d43623152ac513dcdea1d72a7da846a8ab3" - "58d9be4926c79cfb287cf1cf25b689de3b912176be5dcaf4d4c" - "6e7cb839a4a3243a6c47c1e2c99d65c59d6fa3672575c2f1ca8" - "de6a32e854ec9d8ec635c96af7679fce26d7d159e4a9da3bd74" - "e1272c376cd926d74fe3fb164a5935cff3d5cdb92b35fe2cea32" - "138a7e6bfbc319ebd1725dacb9a359cbf693f2ecb785efb9d627" - ] - @property - def x(self): # pylint: disable=invalid-name - return (self.int_a * self.r_inv) % self.int_n # END MERGE SLOT 7 # BEGIN MERGE SLOT 8 diff --git a/scripts/mbedtls_dev/test_data_generation.py b/scripts/mbedtls_dev/test_data_generation.py index eec0f9d978..02aa510518 100644 --- a/scripts/mbedtls_dev/test_data_generation.py +++ b/scripts/mbedtls_dev/test_data_generation.py @@ -25,6 +25,7 @@ import argparse import os import posixpath import re +import inspect from abc import ABCMeta, abstractmethod from typing import Callable, Dict, Iterable, Iterator, List, Type, TypeVar @@ -35,12 +36,8 @@ from . import test_case T = TypeVar('T') #pylint: disable=invalid-name -class BaseTarget(metaclass=ABCMeta): - """Base target for test case generation. - - Child classes of this class represent an output file, and can be referred - to as file targets. These indicate where test cases will be written to for - all subclasses of the file target, which is set by `target_basename`. +class BaseTest(metaclass=ABCMeta): + """Base class for test case generation. Attributes: count: Counter for test cases from this class. @@ -48,8 +45,6 @@ class BaseTarget(metaclass=ABCMeta): automatically generated using the class, or manually set. dependencies: A list of dependencies required for the test case. show_test_count: Toggle for inclusion of `count` in the test description. - target_basename: Basename of file to write generated tests to. This - should be specified in a child class of BaseTarget. test_function: Test function which the class generates cases for. test_name: A common name or description of the test function. This can be `test_function`, a clearer equivalent, or a short summary of the @@ -59,7 +54,6 @@ class BaseTarget(metaclass=ABCMeta): case_description = "" dependencies = [] # type: List[str] show_test_count = True - target_basename = "" test_function = "" test_name = "" @@ -121,6 +115,21 @@ class BaseTarget(metaclass=ABCMeta): """ raise NotImplementedError + +class BaseTarget: + #pylint: disable=too-few-public-methods + """Base target for test case generation. + + Child classes of this class represent an output file, and can be referred + to as file targets. These indicate where test cases will be written to for + all subclasses of the file target, which is set by `target_basename`. + + Attributes: + target_basename: Basename of file to write generated tests to. This + should be specified in a child class of BaseTarget. + """ + target_basename = "" + @classmethod def generate_tests(cls) -> Iterator[test_case.TestCase]: """Generate test cases for the class and its subclasses. @@ -132,7 +141,8 @@ class BaseTarget(metaclass=ABCMeta): yield from `generate_tests()` in each. Calling this method on a class X will yield test cases from all classes derived from X. """ - if cls.test_function: + if issubclass(cls, BaseTest) and not inspect.isabstract(cls): + #pylint: disable=no-member yield from cls.generate_function_tests() for subclass in sorted(cls.__subclasses__(), key=lambda c: c.__name__): yield from subclass.generate_tests() diff --git a/tests/opt-testcases/tls13-misc.sh b/tests/opt-testcases/tls13-misc.sh index edece456b3..ed428480c4 100755 --- a/tests/opt-testcases/tls13-misc.sh +++ b/tests/opt-testcases/tls13-misc.sh @@ -301,7 +301,7 @@ run_test "TLS 1.3 m->G: EarlyData: basic check, good" \ -c "NewSessionTicket: early_data(42) extension received." \ -c "ClientHello: early_data(42) extension exists." \ -c "EncryptedExtensions: early_data(42) extension received." \ - -c "EncryptedExtensions: early_data(42) extension ( ignored )." \ + -c "EncryptedExtensions: early_data(42) extension exists." \ -s "Parsing extension 'Early Data/42' (0 bytes)" \ -s "Sending extension Early Data/42 (0 bytes)" \ -s "early data accepted" @@ -322,7 +322,7 @@ run_test "TLS 1.3 m->G: EarlyData: no early_data in NewSessionTicket, good" \ -C "NewSessionTicket: early_data(42) extension received." \ -c "ClientHello: early_data(42) extension does not exist." \ -C "EncryptedExtensions: early_data(42) extension received." \ - -C "EncryptedExtensions: early_data(42) extension ( ignored )." + -C "EncryptedExtensions: early_data(42) extension exists." #TODO: OpenSSL tests don't work now. It might be openssl options issue, cause GnuTLS has worked. skip_next_test diff --git a/tests/scripts/all.sh b/tests/scripts/all.sh index a20bbde574..b43f999d80 100755 --- a/tests/scripts/all.sh +++ b/tests/scripts/all.sh @@ -1438,6 +1438,31 @@ component_test_tls1_2_default_cbc_legacy_cbc_etm_cipher_only_use_psa () { tests/ssl-opt.sh -f "TLS 1.2" } +# We're not aware of any other (open source) implementation of EC J-PAKE in TLS +# that we could use for interop testing. However, we now have sort of two +# implementations ourselves: one using PSA, the other not. At least test that +# these two interoperate with each other. +component_test_tls1_2_ecjpake_compatibility() { + msg "build: TLS1.2 server+client w/ EC-JPAKE w/o USE_PSA" + scripts/config.py set MBEDTLS_KEY_EXCHANGE_ECJPAKE_ENABLED + make -C programs ssl/ssl_server2 ssl/ssl_client2 + cp programs/ssl/ssl_server2 s2_no_use_psa + cp programs/ssl/ssl_client2 c2_no_use_psa + + msg "build: TLS1.2 server+client w/ EC-JPAKE w/ USE_PSA" + scripts/config.py set MBEDTLS_USE_PSA_CRYPTO + make clean + make -C programs ssl/ssl_server2 ssl/ssl_client2 + make -C programs test/udp_proxy test/query_compile_time_config + + msg "test: server w/o USE_PSA - client w/ USE_PSA" + P_SRV=../s2_no_use_psa tests/ssl-opt.sh -f ECJPAKE + msg "test: client w/o USE_PSA - server w/ USE_PSA" + P_CLI=../c2_no_use_psa tests/ssl-opt.sh -f ECJPAKE + + rm s2_no_use_psa c2_no_use_psa +} + component_test_psa_external_rng_use_psa_crypto () { msg "build: full + PSA_CRYPTO_EXTERNAL_RNG + USE_PSA_CRYPTO minus CTR_DRBG" scripts/config.py full @@ -3252,6 +3277,7 @@ component_build_armcc () { component_test_tls13_only () { msg "build: default config with MBEDTLS_SSL_PROTO_TLS1_3, without MBEDTLS_SSL_PROTO_TLS1_2" + scripts/config.py set MBEDTLS_SSL_EARLY_DATA make CFLAGS="'-DMBEDTLS_USER_CONFIG_FILE=\"../tests/configs/tls13-only.h\"'" msg "test: TLS 1.3 only, all key exchange modes enabled" @@ -3272,6 +3298,7 @@ component_test_tls13_only_psk () { scripts/config.py unset MBEDTLS_ECDSA_C scripts/config.py unset MBEDTLS_PKCS1_V21 scripts/config.py unset MBEDTLS_PKCS7_C + scripts/config.py set MBEDTLS_SSL_EARLY_DATA make CFLAGS="'-DMBEDTLS_USER_CONFIG_FILE=\"../tests/configs/tls13-only.h\"'" msg "test_suite_ssl: TLS 1.3 only, only PSK key exchange mode enabled" @@ -3305,6 +3332,7 @@ component_test_tls13_only_psk_ephemeral () { scripts/config.py unset MBEDTLS_ECDSA_C scripts/config.py unset MBEDTLS_PKCS1_V21 scripts/config.py unset MBEDTLS_PKCS7_C + scripts/config.py set MBEDTLS_SSL_EARLY_DATA make CFLAGS="'-DMBEDTLS_USER_CONFIG_FILE=\"../tests/configs/tls13-only.h\"'" msg "test_suite_ssl: TLS 1.3 only, only PSK ephemeral key exchange mode" @@ -3323,6 +3351,7 @@ component_test_tls13_only_psk_all () { scripts/config.py unset MBEDTLS_ECDSA_C scripts/config.py unset MBEDTLS_PKCS1_V21 scripts/config.py unset MBEDTLS_PKCS7_C + scripts/config.py set MBEDTLS_SSL_EARLY_DATA make CFLAGS="'-DMBEDTLS_USER_CONFIG_FILE=\"../tests/configs/tls13-only.h\"'" msg "test_suite_ssl: TLS 1.3 only, PSK and PSK ephemeral key exchange modes" @@ -3335,6 +3364,7 @@ component_test_tls13_only_psk_all () { component_test_tls13_only_ephemeral_all () { msg "build: TLS 1.3 only from default, without PSK key exchange mode" scripts/config.py unset MBEDTLS_SSL_TLS1_3_KEY_EXCHANGE_MODE_PSK_ENABLED + scripts/config.py set MBEDTLS_SSL_EARLY_DATA make CFLAGS="'-DMBEDTLS_USER_CONFIG_FILE=\"../tests/configs/tls13-only.h\"'" msg "test_suite_ssl: TLS 1.3 only, ephemeral and PSK ephemeral key exchange modes" @@ -3349,6 +3379,7 @@ component_test_tls13 () { scripts/config.py set MBEDTLS_SSL_PROTO_TLS1_3 scripts/config.py set MBEDTLS_SSL_TLS1_3_COMPATIBILITY_MODE scripts/config.py set MBEDTLS_SSL_CID_TLS1_3_PADDING_GRANULARITY 1 + scripts/config.py set MBEDTLS_SSL_EARLY_DATA CC=gcc cmake -D CMAKE_BUILD_TYPE:String=Asan . make msg "test: default config with MBEDTLS_SSL_PROTO_TLS1_3 enabled, without padding" @@ -3362,6 +3393,7 @@ component_test_tls13_no_compatibility_mode () { scripts/config.py set MBEDTLS_SSL_PROTO_TLS1_3 scripts/config.py unset MBEDTLS_SSL_TLS1_3_COMPATIBILITY_MODE scripts/config.py set MBEDTLS_SSL_CID_TLS1_3_PADDING_GRANULARITY 1 + scripts/config.py set MBEDTLS_SSL_EARLY_DATA CC=gcc cmake -D CMAKE_BUILD_TYPE:String=Asan . make msg "test: default config with MBEDTLS_SSL_PROTO_TLS1_3 enabled, without padding" diff --git a/tests/scripts/generate_bignum_tests.py b/tests/scripts/generate_bignum_tests.py index eee2f657ad..c3058e98a9 100755 --- a/tests/scripts/generate_bignum_tests.py +++ b/tests/scripts/generate_bignum_tests.py @@ -57,7 +57,7 @@ of BaseTarget in test_data_generation.py. import sys from abc import ABCMeta -from typing import Iterator, List +from typing import List import scripts_path # pylint: disable=unused-import from mbedtls_dev import test_case @@ -68,15 +68,17 @@ from mbedtls_dev import bignum_common # the framework from mbedtls_dev import bignum_core, bignum_mod_raw # pylint: disable=unused-import -class BignumTarget(test_data_generation.BaseTarget, metaclass=ABCMeta): - #pylint: disable=abstract-method +class BignumTarget(test_data_generation.BaseTarget): + #pylint: disable=too-few-public-methods """Target for bignum (legacy) test case generation.""" target_basename = 'test_suite_bignum.generated' -class BignumOperation(bignum_common.OperationCommon, BignumTarget, metaclass=ABCMeta): +class BignumOperation(bignum_common.OperationCommon, BignumTarget, + metaclass=ABCMeta): #pylint: disable=abstract-method """Common features for bignum operations in legacy tests.""" + unique_combinations_only = True input_values = [ "", "0", "-", "-0", "7b", "-7b", @@ -132,11 +134,6 @@ class BignumOperation(bignum_common.OperationCommon, BignumTarget, metaclass=ABC tmp = "large " + tmp return tmp - @classmethod - def generate_function_tests(cls) -> Iterator[test_case.TestCase]: - for a_value, b_value in cls.get_value_pairs(): - yield cls(a_value, b_value).create_test_case() - class BignumCmp(BignumOperation): """Test cases for bignum value comparison.""" diff --git a/tests/ssl-opt.sh b/tests/ssl-opt.sh index 62205274c7..b460c67dc1 100755 --- a/tests/ssl-opt.sh +++ b/tests/ssl-opt.sh @@ -1362,7 +1362,7 @@ do_run_test_once() { if [ -n "$PXY_CMD" ]; then kill $PXY_PID >/dev/null 2>&1 - wait $PXY_PID + wait $PXY_PID >> $PXY_OUT 2>&1 fi } @@ -12945,8 +12945,8 @@ run_test "TLS 1.3: NewSessionTicket: Basic check, O->m" \ "$O_NEXT_CLI -msg -debug -tls1_3 -reconnect" \ 0 \ -s "=> write NewSessionTicket msg" \ - -s "server state: MBEDTLS_SSL_NEW_SESSION_TICKET" \ - -s "server state: MBEDTLS_SSL_NEW_SESSION_TICKET_FLUSH" + -s "server state: MBEDTLS_SSL_TLS1_3_NEW_SESSION_TICKET" \ + -s "server state: MBEDTLS_SSL_TLS1_3_NEW_SESSION_TICKET_FLUSH" requires_gnutls_tls1_3 requires_config_enabled MBEDTLS_SSL_SESSION_TICKETS @@ -12962,8 +12962,8 @@ run_test "TLS 1.3: NewSessionTicket: Basic check, G->m" \ -c "Connecting again- trying to resume previous session" \ -c "NEW SESSION TICKET (4) was received" \ -s "=> write NewSessionTicket msg" \ - -s "server state: MBEDTLS_SSL_NEW_SESSION_TICKET" \ - -s "server state: MBEDTLS_SSL_NEW_SESSION_TICKET_FLUSH" \ + -s "server state: MBEDTLS_SSL_TLS1_3_NEW_SESSION_TICKET" \ + -s "server state: MBEDTLS_SSL_TLS1_3_NEW_SESSION_TICKET_FLUSH" \ -s "key exchange mode: ephemeral" \ -s "key exchange mode: psk_ephemeral" \ -s "found pre_shared_key extension" @@ -12985,8 +12985,8 @@ run_test "TLS 1.3: NewSessionTicket: Basic check, m->m" \ -c "Reconnecting with saved session" \ -c "HTTP/1.0 200 OK" \ -s "=> write NewSessionTicket msg" \ - -s "server state: MBEDTLS_SSL_NEW_SESSION_TICKET" \ - -s "server state: MBEDTLS_SSL_NEW_SESSION_TICKET_FLUSH" \ + -s "server state: MBEDTLS_SSL_TLS1_3_NEW_SESSION_TICKET" \ + -s "server state: MBEDTLS_SSL_TLS1_3_NEW_SESSION_TICKET_FLUSH" \ -s "key exchange mode: ephemeral" \ -s "key exchange mode: psk_ephemeral" \ -s "found pre_shared_key extension" @@ -13040,8 +13040,8 @@ run_test "TLS 1.3: NewSessionTicket: servername check, m->m" \ -c "Reconnecting with saved session" \ -c "HTTP/1.0 200 OK" \ -s "=> write NewSessionTicket msg" \ - -s "server state: MBEDTLS_SSL_NEW_SESSION_TICKET" \ - -s "server state: MBEDTLS_SSL_NEW_SESSION_TICKET_FLUSH" \ + -s "server state: MBEDTLS_SSL_TLS1_3_NEW_SESSION_TICKET" \ + -s "server state: MBEDTLS_SSL_TLS1_3_NEW_SESSION_TICKET_FLUSH" \ -s "key exchange mode: ephemeral" \ -s "key exchange mode: psk_ephemeral" \ -s "found pre_shared_key extension" @@ -13064,8 +13064,8 @@ run_test "TLS 1.3: NewSessionTicket: servername negative check, m->m" \ -c "Reconnecting with saved session" \ -c "Hostname mismatch the session ticket, disable session resumption." \ -s "=> write NewSessionTicket msg" \ - -s "server state: MBEDTLS_SSL_NEW_SESSION_TICKET" \ - -s "server state: MBEDTLS_SSL_NEW_SESSION_TICKET_FLUSH" + -s "server state: MBEDTLS_SSL_TLS1_3_NEW_SESSION_TICKET" \ + -s "server state: MBEDTLS_SSL_TLS1_3_NEW_SESSION_TICKET_FLUSH" # Test heap memory usage after handshake requires_config_enabled MBEDTLS_SSL_PROTO_TLS1_2 diff --git a/tests/suites/test_suite_bignum_core.function b/tests/suites/test_suite_bignum_core.function index 612a7c6bd4..d5bb420023 100644 --- a/tests/suites/test_suite_bignum_core.function +++ b/tests/suites/test_suite_bignum_core.function @@ -1049,6 +1049,52 @@ exit: /* BEGIN MERGE SLOT 3 */ +/* BEGIN_CASE */ +void mpi_core_sub_int( char * input_A, char * input_B, + char * input_X, int borrow ) +{ + /* We are testing A - b, where A is an MPI and b is a scalar, expecting + * result X with borrow borrow. However, for ease of handling we encode b + * as a 1-limb MPI (B) in the .data file. */ + + mbedtls_mpi_uint *A = NULL; + mbedtls_mpi_uint *B = NULL; + mbedtls_mpi_uint *X = NULL; + mbedtls_mpi_uint *R = NULL; + size_t A_limbs, B_limbs, X_limbs; + + TEST_EQUAL( 0, mbedtls_test_read_mpi_core( &A, &A_limbs, input_A ) ); + TEST_EQUAL( 0, mbedtls_test_read_mpi_core( &B, &B_limbs, input_B ) ); + TEST_EQUAL( 0, mbedtls_test_read_mpi_core( &X, &X_limbs, input_X ) ); + + /* The MPI encoding of scalar b must be only 1 limb */ + TEST_EQUAL( B_limbs, 1 ); + + /* The subtraction is fixed-width, so A and X must have the same number of limbs */ + TEST_EQUAL( A_limbs, X_limbs ); + size_t limbs = A_limbs; + + ASSERT_ALLOC( R, limbs ); + +#define TEST_COMPARE_CORE_MPIS( A, B, limbs ) \ + ASSERT_COMPARE( A, (limbs) * sizeof(mbedtls_mpi_uint), B, (limbs) * sizeof(mbedtls_mpi_uint) ) + + /* 1. R = A - b. Result and borrow should be correct */ + TEST_EQUAL( mbedtls_mpi_core_sub_int( R, A, B[0], limbs ), borrow ); + TEST_COMPARE_CORE_MPIS( R, X, limbs ); + + /* 2. A = A - b. Result and borrow should be correct */ + TEST_EQUAL( mbedtls_mpi_core_sub_int( A, A, B[0], limbs ), borrow ); + TEST_COMPARE_CORE_MPIS( A, X, limbs ); + +exit: + mbedtls_free( A ); + mbedtls_free( B ); + mbedtls_free( X ); + mbedtls_free( R ); +} +/* END_CASE */ + /* END MERGE SLOT 3 */ /* BEGIN MERGE SLOT 4 */