diff --git a/include/mbedtls/ssl.h b/include/mbedtls/ssl.h index a6ee9a4487..b05bfe1b72 100644 --- a/include/mbedtls/ssl.h +++ b/include/mbedtls/ssl.h @@ -1304,6 +1304,10 @@ struct mbedtls_ssl_session { char *MBEDTLS_PRIVATE(hostname); /*!< host name binded with tickets */ #endif /* MBEDTLS_SSL_SERVER_NAME_INDICATION && MBEDTLS_SSL_CLI_C */ +#if defined(MBEDTLS_SSL_EARLY_DATA) && defined(MBEDTLS_SSL_ALPN) && defined(MBEDTLS_SSL_SRV_C) + char *ticket_alpn; /*!< ALPN negotiated in the session */ +#endif + #if defined(MBEDTLS_HAVE_TIME) && defined(MBEDTLS_SSL_CLI_C) /*! Time in milliseconds when the last ticket was received. */ mbedtls_ms_time_t MBEDTLS_PRIVATE(ticket_reception_time); @@ -1312,9 +1316,6 @@ struct mbedtls_ssl_session { #if defined(MBEDTLS_SSL_EARLY_DATA) uint32_t MBEDTLS_PRIVATE(max_early_data_size); /*!< maximum amount of early data in tickets */ -#if defined(MBEDTLS_SSL_ALPN) && defined(MBEDTLS_SSL_SRV_C) - char *alpn; /*!< ALPN negotiated in the session */ -#endif #endif #if defined(MBEDTLS_SSL_ENCRYPT_THEN_MAC) diff --git a/library/ssl_misc.h b/library/ssl_misc.h index 2ec898b453..948c802299 100644 --- a/library/ssl_misc.h +++ b/library/ssl_misc.h @@ -2852,6 +2852,13 @@ int mbedtls_ssl_session_set_hostname(mbedtls_ssl_session *session, const char *hostname); #endif +#if defined(MBEDTLS_SSL_SRV_C) && defined(MBEDTLS_SSL_EARLY_DATA) && \ + defined(MBEDTLS_SSL_ALPN) +MBEDTLS_CHECK_RETURN_CRITICAL +int mbedtls_ssl_session_set_alpn(mbedtls_ssl_session *session, + const char *alpn); +#endif + #if defined(MBEDTLS_SSL_PROTO_TLS1_3) && defined(MBEDTLS_SSL_SESSION_TICKETS) #define MBEDTLS_SSL_TLS1_3_MAX_ALLOWED_TICKET_LIFETIME (604800) diff --git a/library/ssl_tls.c b/library/ssl_tls.c index d7d26ab063..f78b97d444 100644 --- a/library/ssl_tls.c +++ b/library/ssl_tls.c @@ -2450,6 +2450,7 @@ mbedtls_ssl_mode_t mbedtls_ssl_get_mode_from_ciphersuite( #if defined(MBEDTLS_USE_PSA_CRYPTO) || defined(MBEDTLS_SSL_PROTO_TLS1_3) + psa_status_t mbedtls_ssl_cipher_to_psa(mbedtls_cipher_type_t mbedtls_cipher_type, size_t taglen, psa_algorithm_t *alg, @@ -3771,8 +3772,8 @@ static int ssl_tls13_session_save(const mbedtls_ssl_session *session, #if defined(MBEDTLS_SSL_SRV_C) && \ defined(MBEDTLS_SSL_EARLY_DATA) && defined(MBEDTLS_SSL_ALPN) - const uint8_t alpn_len = (session->alpn == NULL) ? - 0 : (uint8_t) strlen(session->alpn) + 1; + const uint8_t alpn_len = (session->ticket_alpn == NULL) ? + 0 : (uint8_t) strlen(session->ticket_alpn) + 1; #endif size_t needed = 4 /* ticket_age_add */ + 1 /* ticket_flags */ @@ -3858,7 +3859,7 @@ static int ssl_tls13_session_save(const mbedtls_ssl_session *session, *p++ = alpn_len; if (alpn_len > 0) { /* save chosen alpn */ - memcpy(p, session->alpn, alpn_len); + memcpy(p, session->ticket_alpn, alpn_len); p += alpn_len; } #endif /* MBEDTLS_SSL_EARLY_DATA && MBEDTLS_SSL_ALPN */ @@ -3951,6 +3952,7 @@ static int ssl_tls13_session_load(mbedtls_ssl_session *session, #if defined(MBEDTLS_SSL_EARLY_DATA) && defined(MBEDTLS_SSL_ALPN) uint8_t alpn_len; + int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED; if (end - p < 1) { return MBEDTLS_ERR_SSL_BAD_INPUT_DATA; @@ -3960,12 +3962,12 @@ static int ssl_tls13_session_load(mbedtls_ssl_session *session, if (end - p < alpn_len) { return MBEDTLS_ERR_SSL_BAD_INPUT_DATA; } + if (alpn_len > 0) { - session->alpn = mbedtls_calloc(alpn_len, sizeof(char)); - if (session->alpn == NULL) { - return MBEDTLS_ERR_SSL_ALLOC_FAILED; + ret = mbedtls_ssl_session_set_alpn(session, (const char *) p); + if (ret != 0) { + return ret; } - memcpy(session->alpn, p, alpn_len); p += alpn_len; } #endif /* MBEDTLS_SSL_EARLY_DATA && MBEDTLS_SSL_ALPN */ @@ -4917,11 +4919,12 @@ void mbedtls_ssl_session_free(mbedtls_ssl_session *session) defined(MBEDTLS_SSL_SERVER_NAME_INDICATION) mbedtls_free(session->hostname); #endif + mbedtls_free(session->ticket); +#endif + #if defined(MBEDTLS_SSL_EARLY_DATA) && defined(MBEDTLS_SSL_ALPN) && \ defined(MBEDTLS_SSL_SRV_C) - mbedtls_free(session->alpn); -#endif - mbedtls_free(session->ticket); + mbedtls_free(session->ticket_alpn); #endif mbedtls_platform_zeroize(session, sizeof(mbedtls_ssl_session)); @@ -9870,4 +9873,37 @@ int mbedtls_ssl_session_set_hostname(mbedtls_ssl_session *session, MBEDTLS_SSL_SERVER_NAME_INDICATION && MBEDTLS_SSL_CLI_C */ +#if defined(MBEDTLS_SSL_SRV_C) && defined(MBEDTLS_SSL_EARLY_DATA) && \ + defined(MBEDTLS_SSL_ALPN) +int mbedtls_ssl_session_set_alpn(mbedtls_ssl_session *session, + const char *alpn) +{ + size_t alpn_len = 0; + + if (alpn != NULL) { + alpn_len = strlen(alpn); + + if (alpn_len > MBEDTLS_SSL_MAX_ALPN_NAME_LEN) { + return MBEDTLS_ERR_SSL_BAD_INPUT_DATA; + } + } + + if (session->ticket_alpn != NULL) { + mbedtls_zeroize_and_free(session->ticket_alpn, + strlen(session->ticket_alpn)); + } + + if (alpn == NULL) { + session->ticket_alpn = NULL; + } else { + session->ticket_alpn = mbedtls_calloc(strlen(alpn) + 1, sizeof(char)); + if (session->ticket_alpn == NULL) { + return MBEDTLS_ERR_SSL_ALLOC_FAILED; + } + memcpy(session->ticket_alpn, alpn, strlen(alpn) + 1); + } + + return 0; +} +#endif /* MBEDTLS_SSL_SRV_C && MBEDTLS_SSL_EARLY_DATA && MBEDTLS_SSL_ALPN */ #endif /* MBEDTLS_SSL_TLS_C */ diff --git a/library/ssl_tls13_server.c b/library/ssl_tls13_server.c index 291d64500d..9c73c7a1a5 100644 --- a/library/ssl_tls13_server.c +++ b/library/ssl_tls13_server.c @@ -469,12 +469,10 @@ static int ssl_tls13_session_copy_ticket(mbedtls_ssl_session *dst, dst->max_early_data_size = src->max_early_data_size; #if defined(MBEDTLS_SSL_ALPN) - if (src->alpn != NULL) { - dst->alpn = mbedtls_calloc(strlen(src->alpn) + 1, sizeof(char)); - if (dst->alpn == NULL) { - return MBEDTLS_ERR_SSL_ALLOC_FAILED; - } - memcpy(dst->alpn, src->alpn, strlen(src->alpn) + 1); + int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED; + ret = mbedtls_ssl_session_set_alpn(dst, src->ticket_alpn); + if (ret != 0) { + return ret; } #endif /* MBEDTLS_SSL_ALPN */ #endif /* MBEDTLS_SSL_EARLY_DATA*/ @@ -3148,12 +3146,9 @@ static int ssl_tls13_prepare_new_session_ticket(mbedtls_ssl_context *ssl, MBEDTLS_SSL_PRINT_TICKET_FLAGS(4, session->ticket_flags); #if defined(MBEDTLS_SSL_EARLY_DATA) && defined(MBEDTLS_SSL_ALPN) - if (ssl->alpn_chosen != NULL) { - session->alpn = mbedtls_calloc(strlen(ssl->alpn_chosen) + 1, sizeof(char)); - if (session->alpn == NULL) { - return MBEDTLS_ERR_SSL_ALLOC_FAILED; - } - memcpy(session->alpn, ssl->alpn_chosen, strlen(ssl->alpn_chosen) + 1); + ret = mbedtls_ssl_session_set_alpn(session, ssl->alpn_chosen); + if (ret != 0) { + return ret; } #endif diff --git a/tests/src/test_helpers/ssl_helpers.c b/tests/src/test_helpers/ssl_helpers.c index 89c1bbf522..9c1676fc63 100644 --- a/tests/src/test_helpers/ssl_helpers.c +++ b/tests/src/test_helpers/ssl_helpers.c @@ -1794,11 +1794,11 @@ int mbedtls_test_ssl_tls13_populate_session(mbedtls_ssl_session *session, #if defined(MBEDTLS_SSL_EARLY_DATA) session->max_early_data_size = 0x87654321; #if defined(MBEDTLS_SSL_ALPN) && defined(MBEDTLS_SSL_SRV_C) - session->alpn = mbedtls_calloc(strlen("ALPNExample")+1, sizeof(char)); - if (session->alpn == NULL) { + int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED; + ret = mbedtls_ssl_session_set_alpn(session, "ALPNExample"); + if (ret != 0) { return -1; } - strcpy(session->alpn, "ALPNExample"); #endif /* MBEDTLS_SSL_ALPN && MBEDTLS_SSL_SRV_C */ #endif /* MBEDTLS_SSL_EARLY_DATA */ diff --git a/tests/suites/test_suite_ssl.function b/tests/suites/test_suite_ssl.function index da07f2c62f..e29667d0bc 100644 --- a/tests/suites/test_suite_ssl.function +++ b/tests/suites/test_suite_ssl.function @@ -2106,11 +2106,10 @@ void ssl_serialize_session_save_load(int ticket_len, char *crt_file, original.max_early_data_size == restored.max_early_data_size); #if defined(MBEDTLS_SSL_ALPN) && defined(MBEDTLS_SSL_SRV_C) if (endpoint_type == MBEDTLS_SSL_IS_SERVER) { - TEST_ASSERT(original.alpn != NULL); - TEST_ASSERT(restored.alpn != NULL); - TEST_ASSERT(memcmp(original.alpn, - restored.alpn, - strlen(original.alpn)) == 0); + TEST_ASSERT(original.ticket_alpn != NULL); + TEST_ASSERT(restored.ticket_alpn != NULL); + TEST_MEMORY_COMPARE(original.ticket_alpn, strlen(original.ticket_alpn), + restored.ticket_alpn, strlen(restored.ticket_alpn)); } #endif #endif