diff --git a/library/ssl_tls13_server.c b/library/ssl_tls13_server.c index 36a8119478..91e6f4ef43 100644 --- a/library/ssl_tls13_server.c +++ b/library/ssl_tls13_server.c @@ -136,7 +136,8 @@ static int ssl_tls13_offered_psks_check_identity_match( MBEDTLS_CHECK_RETURN_CRITICAL static int ssl_tls13_offered_psks_check_binder_match( mbedtls_ssl_context *ssl, const unsigned char *binder, - size_t binder_len ) + size_t binder_len, + mbedtls_md_type_t *psk_alg ) { int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED; int psk_type; @@ -149,6 +150,7 @@ static int ssl_tls13_offered_psks_check_binder_match( mbedtls_ssl_context *ssl, size_t psk_len; unsigned char server_computed_binder[PSA_HASH_MAX_SIZE]; + *psk_alg = MBEDTLS_MD_NONE; psk_type = MBEDTLS_SSL_TLS1_3_PSK_EXTERNAL; switch( binder_len ) { @@ -192,6 +194,7 @@ static int ssl_tls13_offered_psks_check_binder_match( mbedtls_ssl_context *ssl, if( mbedtls_ct_memcmp( server_computed_binder, binder, binder_len ) == 0 ) { + *psk_alg = md_alg; return( SSL_TLS1_3_OFFERED_PSK_MATCH ); } @@ -223,7 +226,8 @@ static int ssl_tls13_offered_psks_check_binder_match( mbedtls_ssl_context *ssl, MBEDTLS_CHECK_RETURN_CRITICAL static int ssl_tls13_parse_pre_shared_key_ext( mbedtls_ssl_context *ssl, const unsigned char *buf, - const unsigned char *end ) + const unsigned char *end, + mbedtls_md_type_t *psk_alg ) { const unsigned char *identities = buf; const unsigned char *p_identity_len; @@ -236,6 +240,8 @@ static int ssl_tls13_parse_pre_shared_key_ext( mbedtls_ssl_context *ssl, int matched_identity = -1; int identity_id = -1; + *psk_alg = MBEDTLS_MD_NONE; + MBEDTLS_SSL_DEBUG_BUF( 3, "pre_shared_key extension", buf, end - buf ); /* identities_len 2 bytes @@ -266,6 +272,7 @@ static int ssl_tls13_parse_pre_shared_key_ext( mbedtls_ssl_context *ssl, const unsigned char *binder; size_t binder_len; int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED; + mbedtls_md_type_t alg; MBEDTLS_SSL_CHK_BUF_READ_PTR( p_identity_len, identities_end, 2 + 1 + 4 ); identity_len = MBEDTLS_GET_UINT16_BE( p_identity_len, 0 ); @@ -286,11 +293,11 @@ static int ssl_tls13_parse_pre_shared_key_ext( mbedtls_ssl_context *ssl, ret = ssl_tls13_offered_psks_check_identity_match( ssl, identity, identity_len ); - if( SSL_TLS1_3_OFFERED_PSK_NOT_MATCH == ret ) + if( ret != SSL_TLS1_3_OFFERED_PSK_MATCH ) continue; ret = ssl_tls13_offered_psks_check_binder_match( - ssl, binder, binder_len ); + ssl, binder, binder_len, &alg ); if( ret != SSL_TLS1_3_OFFERED_PSK_MATCH ) { MBEDTLS_SSL_DEBUG_RET( 1, @@ -300,10 +307,9 @@ static int ssl_tls13_parse_pre_shared_key_ext( mbedtls_ssl_context *ssl, MBEDTLS_ERR_SSL_HANDSHAKE_FAILURE ); return( ret ); } - if( SSL_TLS1_3_OFFERED_PSK_NOT_MATCH == ret ) - continue; matched_identity = identity_id; + *psk_alg = alg; } if( p_identity_len != identities_end || p_binder_len != binders_end ) @@ -914,10 +920,10 @@ static int ssl_tls13_parse_client_hello( mbedtls_ssl_context *ssl, const unsigned char *extensions_end; int hrr_required = 0; - const mbedtls_ssl_ciphersuite_t* ciphersuite_info; #if defined(MBEDTLS_KEY_EXCHANGE_SOME_PSK_ENABLED) const unsigned char *pre_shared_key_ext_start = NULL; const unsigned char *pre_shared_key_ext_end = NULL; + mbedtls_md_type_t psk_alg = MBEDTLS_MD_NONE; #endif /* MBEDTLS_KEY_EXCHANGE_SOME_PSK_ENABLED */ ssl->handshake->extensions_present = MBEDTLS_SSL_EXT_NONE; @@ -1000,7 +1006,7 @@ static int ssl_tls13_parse_client_hello( mbedtls_ssl_context *ssl, p, legacy_session_id_len ); /* * Check we have enough data for the legacy session identifier - * and the ciphersuite list length. + * and the ciphersuite list length. */ MBEDTLS_SSL_CHK_BUF_READ_PTR( p, end, legacy_session_id_len + 2 ); @@ -1012,59 +1018,42 @@ static int ssl_tls13_parse_client_hello( mbedtls_ssl_context *ssl, /* Check we have enough data for the ciphersuite list, the legacy * compression methods and the length of the extensions. + * + * cipher_suites cipher_suites_len bytes + * legacy_compression_methods 2 bytes + * extensions_len 2 bytes */ MBEDTLS_SSL_CHK_BUF_READ_PTR( p, end, cipher_suites_len + 2 + 2 ); - /* ... - * CipherSuite cipher_suites<2..2^16-2>; - * ... - * with CipherSuite defined as: - * uint8 CipherSuite[2]; + /* + * uint8 CipherSuite[2]; // Cryptographic suite selector + * + * struct { + * ... + * CipherSuite cipher_suites<2..2^16-2>; + * ... + * } ClientHello; */ cipher_suites = p; cipher_suites_end = p + cipher_suites_len; MBEDTLS_SSL_DEBUG_BUF( 3, "client hello, ciphersuitelist", p, cipher_suites_len ); - /* - * Search for a matching ciphersuite - */ - int ciphersuite_match = 0; +#if defined(MBEDTLS_DEBUG_C) for ( ; p < cipher_suites_end; p += 2 ) { uint16_t cipher_suite; + const mbedtls_ssl_ciphersuite_t* ciphersuite_info; MBEDTLS_SSL_CHK_BUF_READ_PTR( p, cipher_suites_end, 2 ); cipher_suite = MBEDTLS_GET_UINT16_BE( p, 0 ); ciphersuite_info = mbedtls_ssl_ciphersuite_from_id( cipher_suite ); - /* - * Check whether this ciphersuite is valid and offered. - */ - if( ( mbedtls_ssl_validate_ciphersuite( - ssl, ciphersuite_info, ssl->tls_version, - ssl->tls_version ) != 0 ) || - ! mbedtls_ssl_tls13_cipher_suite_is_offered( ssl, cipher_suite ) ) - { - continue; - } - - ssl->session_negotiate->ciphersuite = cipher_suite; - ssl->handshake->ciphersuite_info = ciphersuite_info; - ciphersuite_match = 1; - - break; - + MBEDTLS_SSL_DEBUG_MSG( 2, ( "client hello, received ciphersuite: %04x - %s", + cipher_suite, + ciphersuite_info == NULL ? + "Unkown": ciphersuite_info->name ) ); } - - if( ! ciphersuite_match ) - { - MBEDTLS_SSL_PEND_FATAL_ALERT( MBEDTLS_SSL_ALERT_MSG_HANDSHAKE_FAILURE, - MBEDTLS_ERR_SSL_HANDSHAKE_FAILURE ); - return ( MBEDTLS_ERR_SSL_HANDSHAKE_FAILURE ); - } - - MBEDTLS_SSL_DEBUG_MSG( 2, ( "selected ciphersuite: %s", - ciphersuite_info->name ) ); - - p = cipher_suites + cipher_suites_len; +#else + p = cipher_suites_end; +#endif /* MBEDTLS_DEBUG_C */ /* ... * opaque legacy_compression_methods<1..2^8-1>; @@ -1298,6 +1287,7 @@ static int ssl_tls13_parse_client_hello( mbedtls_ssl_context *ssl, MBEDTLS_SSL_HS_CLIENT_HELLO, p - buf ); +/* TODO: move later */ #if defined(MBEDTLS_KEY_EXCHANGE_SOME_PSK_ENABLED) /* Update checksum with either * - The entire content of the CH message, if no PSK extension is present @@ -1311,7 +1301,8 @@ static int ssl_tls13_parse_client_hello( mbedtls_ssl_context *ssl, pre_shared_key_ext_start - buf ); ret = ssl_tls13_parse_pre_shared_key_ext( ssl, pre_shared_key_ext_start, - pre_shared_key_ext_end ); + pre_shared_key_ext_end, + &psk_alg ); if( ret == MBEDTLS_ERR_SSL_UNKNOWN_IDENTITY) { ssl->handshake->extensions_present &= ~MBEDTLS_SSL_EXT_PRE_SHARED_KEY; @@ -1329,6 +1320,51 @@ static int ssl_tls13_parse_client_hello( mbedtls_ssl_context *ssl, ssl->handshake->update_checksum( ssl, buf, p - buf ); } + /* + * Search for a matching ciphersuite + */ + for ( const unsigned char * p_chiper_suite = cipher_suites ; + p_chiper_suite < cipher_suites_end; p_chiper_suite += 2 ) + { + uint16_t cipher_suite; + const mbedtls_ssl_ciphersuite_t* ciphersuite_info; + + MBEDTLS_SSL_CHK_BUF_READ_PTR( p_chiper_suite, cipher_suites_end, 2 ); + + cipher_suite = MBEDTLS_GET_UINT16_BE( p_chiper_suite, 0 ); + if( ! mbedtls_ssl_tls13_cipher_suite_is_offered( ssl, cipher_suite ) ) + continue; + + ciphersuite_info = mbedtls_ssl_ciphersuite_from_id( cipher_suite ); + if( ( mbedtls_ssl_validate_ciphersuite( + ssl, ciphersuite_info, ssl->tls_version, + ssl->tls_version ) != 0 ) ) + { + continue; + } + +#if defined(MBEDTLS_KEY_EXCHANGE_SOME_PSK_ENABLED) + /* MAC of selected ciphersuite MUST be same with PSK binder if exist. + * Otherwise, client should reject. + */ + if( psk_alg != MBEDTLS_MD_NONE && psk_alg != ciphersuite_info->mac ) + continue; +#endif /* MBEDTLS_KEY_EXCHANGE_SOME_PSK_ENABLED */ + + ssl->session_negotiate->ciphersuite = cipher_suite; + ssl->handshake->ciphersuite_info = ciphersuite_info; + MBEDTLS_SSL_DEBUG_MSG( 2, ( "selected ciphersuite: %04x - %s", + cipher_suite, + ciphersuite_info->name ) ); + } + + if( ssl->handshake->ciphersuite_info == NULL ) + { + MBEDTLS_SSL_PEND_FATAL_ALERT( MBEDTLS_SSL_ALERT_MSG_HANDSHAKE_FAILURE, + MBEDTLS_ERR_SSL_HANDSHAKE_FAILURE ); + return ( MBEDTLS_ERR_SSL_HANDSHAKE_FAILURE ); + } + ret = ssl_tls13_determine_key_exchange_mode( ssl ); if( ret < 0 ) return( ret );