diff --git a/library/ssl_tls.c b/library/ssl_tls.c index cf9583ccb7..292e931355 100644 --- a/library/ssl_tls.c +++ b/library/ssl_tls.c @@ -1710,8 +1710,12 @@ int mbedtls_ssl_set_hs_psk( mbedtls_ssl_context *ssl, else alg = PSA_ALG_TLS12_PSK_TO_MS(PSA_ALG_SHA_256); +#if defined(MBEDTLS_SSL_PROTO_TLS1_3) psa_set_key_usage_flags( &key_attributes, PSA_KEY_USAGE_DERIVE | PSA_KEY_USAGE_EXPORT ); +#else + psa_set_key_usage_flags( &key_attributes, PSA_KEY_USAGE_DERIVE ); +#endif psa_set_key_algorithm( &key_attributes, alg ); psa_set_key_type( &key_attributes, PSA_KEY_TYPE_DERIVE ); diff --git a/library/ssl_tls13_server.c b/library/ssl_tls13_server.c index bbc8535790..c5aacc3df9 100644 --- a/library/ssl_tls13_server.c +++ b/library/ssl_tls13_server.c @@ -191,14 +191,14 @@ static int ssl_tls13_offered_psks_check_binder_match( mbedtls_ssl_context *ssl, psk_type = MBEDTLS_SSL_TLS1_3_PSK_EXTERNAL; switch( binder_len ) { - case 32: - md_alg = MBEDTLS_MD_SHA256; - break; - case 48: - md_alg = MBEDTLS_MD_SHA384; - break; - default: - return( MBEDTLS_ERR_SSL_INTERNAL_ERROR ); + case 32: + md_alg = MBEDTLS_MD_SHA256; + break; + case 48: + md_alg = MBEDTLS_MD_SHA384; + break; + default: + return( MBEDTLS_SSL_ALERT_MSG_DECRYPT_ERROR ); } psa_md_alg = mbedtls_psa_translate_md( md_alg ); /* Get current state of handshake transcript. */ @@ -264,11 +264,13 @@ static int ssl_tls13_parse_pre_shared_key_ext( mbedtls_ssl_context *ssl, const unsigned char *buf, const unsigned char *end ) { - const unsigned char *next_identity = buf; - uint16_t identities_len; + const unsigned char *identities = buf; + const unsigned char *p_identity_len; + size_t identities_len; const unsigned char *identities_end; - const unsigned char *next_binder; - uint16_t binders_len; + const unsigned char *binders; + const unsigned char *p_binder_len; + size_t binders_len; const unsigned char *binders_end; int matched_identity = -1; int identity_id = -1; @@ -278,47 +280,44 @@ static int ssl_tls13_parse_pre_shared_key_ext( mbedtls_ssl_context *ssl, /* identities_len 2 bytes * identities_data >= 7 bytes */ - MBEDTLS_SSL_CHK_BUF_READ_PTR( next_identity, end, 7 + 2 ); - identities_len = MBEDTLS_GET_UINT16_BE( next_identity, 0 ); - next_identity += 2; - MBEDTLS_SSL_CHK_BUF_READ_PTR( next_identity, end, identities_len ); - identities_end = next_identity + identities_len; + MBEDTLS_SSL_CHK_BUF_READ_PTR( identities, end, 7 + 2 ); + identities_len = MBEDTLS_GET_UINT16_BE( identities, 0 ); + p_identity_len = identities + 2; + MBEDTLS_SSL_CHK_BUF_READ_PTR( p_identity_len, end, identities_len ); + identities_end = p_identity_len + identities_len; /* binders_len 2 bytes * binders >= 33 bytes */ - next_binder = identities_end; - MBEDTLS_SSL_CHK_BUF_READ_PTR( next_binder, end, 33 ); - binders_len = MBEDTLS_GET_UINT16_BE( next_binder, 0 ); - next_binder += 2; - MBEDTLS_SSL_CHK_BUF_READ_PTR( next_binder, end, binders_len ); - binders_end = next_binder + binders_len; + binders = identities_end; + MBEDTLS_SSL_CHK_BUF_READ_PTR( binders, end, 33 ); + binders_len = MBEDTLS_GET_UINT16_BE( binders, 0 ); + p_binder_len = binders + 2; + MBEDTLS_SSL_CHK_BUF_READ_PTR( p_binder_len, end, binders_len ); + binders_end = p_binder_len + binders_len; ssl->handshake->update_checksum( ssl, buf, identities_end - buf ); - while( next_identity < identities_end && next_binder < binders_end ) + while( p_identity_len < identities_end && p_binder_len < binders_end ) { const unsigned char *identity; - uint16_t identity_len; + size_t identity_len; const unsigned char *binder; - uint16_t binder_len; + size_t binder_len; int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED; - MBEDTLS_SSL_CHK_BUF_READ_PTR( next_identity, identities_end, 2 ); - identity_len = MBEDTLS_GET_UINT16_BE( next_identity, 0 ); - next_identity += 2; - identity = next_identity; - MBEDTLS_SSL_CHK_BUF_READ_PTR( next_identity, - identities_end, - identity_len + 4 ); - next_identity += identity_len + 4; + MBEDTLS_SSL_CHK_BUF_READ_PTR( p_identity_len, identities_end, 2 + 1 + 4 ); + identity_len = MBEDTLS_GET_UINT16_BE( p_identity_len, 0 ); + identity = p_identity_len + 2; + MBEDTLS_SSL_CHK_BUF_READ_PTR( identity, identities_end, identity_len + 4 ); + p_identity_len += identity_len + 6; - MBEDTLS_SSL_CHK_BUF_READ_PTR( next_binder, binders_end, 2 ); + MBEDTLS_SSL_CHK_BUF_READ_PTR( p_binder_len, binders_end, 1 + 32 ); + binder_len = *p_binder_len; + binder = p_binder_len + 1; + MBEDTLS_SSL_CHK_BUF_READ_PTR( binder, binders_end, binder_len ); + p_binder_len += binder_len + 1; - binder_len = *next_binder++; - binder = next_binder; - MBEDTLS_SSL_CHK_BUF_READ_PTR( next_binder, binders_end, binder_len ); - next_binder += binder_len; identity_id++; if( matched_identity != -1 ) @@ -331,7 +330,7 @@ static int ssl_tls13_parse_pre_shared_key_ext( mbedtls_ssl_context *ssl, ret = ssl_tls13_offered_psks_check_binder_match( ssl, binder, binder_len ); - if( ret < 0 ) + if( ret != SSL_TLS1_3_OFFERED_PSK_MATCH ) { MBEDTLS_SSL_DEBUG_RET( 1, "ssl_tls13_offered_psks_check_binder_match" , ret ); @@ -346,7 +345,7 @@ static int ssl_tls13_parse_pre_shared_key_ext( mbedtls_ssl_context *ssl, matched_identity = identity_id; } - if( next_identity != identities_end || next_binder != binders_end ) + if( p_identity_len != identities_end || p_binder_len != binders_end ) { MBEDTLS_SSL_DEBUG_MSG( 3, ( "pre_shared_key extesion decode error" ) ); MBEDTLS_SSL_PEND_FATAL_ALERT( MBEDTLS_SSL_ALERT_MSG_DECODE_ERROR,