1
0
mirror of https://github.com/Mbed-TLS/mbedtls.git synced 2025-07-30 22:43:08 +03:00

Return and propagate errors in calc_finished()

Allow calc_finished to return an error code and propagate that back to
the original function. If an error is returned by a PSA function,
propagate it upwards instead of continuing, so that we do not fail to
properly check the finished message.

Signed-off-by: David Horstmann <david.horstmann@arm.com>
This commit is contained in:
David Horstmann
2025-03-10 14:10:11 +00:00
parent c43a9d5576
commit 68014b2b80
2 changed files with 27 additions and 14 deletions

View File

@ -467,7 +467,7 @@ struct mbedtls_ssl_handshake_params {
void (*update_checksum)(mbedtls_ssl_context *, const unsigned char *, size_t); void (*update_checksum)(mbedtls_ssl_context *, const unsigned char *, size_t);
void (*calc_verify)(const mbedtls_ssl_context *, unsigned char *, size_t *); void (*calc_verify)(const mbedtls_ssl_context *, unsigned char *, size_t *);
void (*calc_finished)(mbedtls_ssl_context *, unsigned char *, int); int (*calc_finished)(mbedtls_ssl_context *, unsigned char *, int);
mbedtls_ssl_tls_prf_cb *tls_prf; mbedtls_ssl_tls_prf_cb *tls_prf;
#if defined(MBEDTLS_DHM_C) #if defined(MBEDTLS_DHM_C)

View File

@ -892,25 +892,25 @@ static void ssl_update_checksum_md5sha1(mbedtls_ssl_context *, const unsigned ch
#if defined(MBEDTLS_SSL_PROTO_SSL3) #if defined(MBEDTLS_SSL_PROTO_SSL3)
static void ssl_calc_verify_ssl(const mbedtls_ssl_context *, unsigned char *, size_t *); static void ssl_calc_verify_ssl(const mbedtls_ssl_context *, unsigned char *, size_t *);
static void ssl_calc_finished_ssl(mbedtls_ssl_context *, unsigned char *, int); static int ssl_calc_finished_ssl(mbedtls_ssl_context *, unsigned char *, int);
#endif #endif
#if defined(MBEDTLS_SSL_PROTO_TLS1) || defined(MBEDTLS_SSL_PROTO_TLS1_1) #if defined(MBEDTLS_SSL_PROTO_TLS1) || defined(MBEDTLS_SSL_PROTO_TLS1_1)
static void ssl_calc_verify_tls(const mbedtls_ssl_context *, unsigned char *, size_t *); static void ssl_calc_verify_tls(const mbedtls_ssl_context *, unsigned char *, size_t *);
static void ssl_calc_finished_tls(mbedtls_ssl_context *, unsigned char *, int); static int ssl_calc_finished_tls(mbedtls_ssl_context *, unsigned char *, int);
#endif #endif
#if defined(MBEDTLS_SSL_PROTO_TLS1_2) #if defined(MBEDTLS_SSL_PROTO_TLS1_2)
#if defined(MBEDTLS_SHA256_C) #if defined(MBEDTLS_SHA256_C)
static void ssl_update_checksum_sha256(mbedtls_ssl_context *, const unsigned char *, size_t); static void ssl_update_checksum_sha256(mbedtls_ssl_context *, const unsigned char *, size_t);
static void ssl_calc_verify_tls_sha256(const mbedtls_ssl_context *, unsigned char *, size_t *); static void ssl_calc_verify_tls_sha256(const mbedtls_ssl_context *, unsigned char *, size_t *);
static void ssl_calc_finished_tls_sha256(mbedtls_ssl_context *, unsigned char *, int); static int ssl_calc_finished_tls_sha256(mbedtls_ssl_context *, unsigned char *, int);
#endif #endif
#if defined(MBEDTLS_SHA512_C) && !defined(MBEDTLS_SHA512_NO_SHA384) #if defined(MBEDTLS_SHA512_C) && !defined(MBEDTLS_SHA512_NO_SHA384)
static void ssl_update_checksum_sha384(mbedtls_ssl_context *, const unsigned char *, size_t); static void ssl_update_checksum_sha384(mbedtls_ssl_context *, const unsigned char *, size_t);
static void ssl_calc_verify_tls_sha384(const mbedtls_ssl_context *, unsigned char *, size_t *); static void ssl_calc_verify_tls_sha384(const mbedtls_ssl_context *, unsigned char *, size_t *);
static void ssl_calc_finished_tls_sha384(mbedtls_ssl_context *, unsigned char *, int); static int ssl_calc_finished_tls_sha384(mbedtls_ssl_context *, unsigned char *, int);
#endif #endif
#endif /* MBEDTLS_SSL_PROTO_TLS1_2 */ #endif /* MBEDTLS_SSL_PROTO_TLS1_2 */
@ -3136,7 +3136,7 @@ static void ssl_update_checksum_sha384(mbedtls_ssl_context *ssl,
#endif /* MBEDTLS_SSL_PROTO_TLS1_2 */ #endif /* MBEDTLS_SSL_PROTO_TLS1_2 */
#if defined(MBEDTLS_SSL_PROTO_SSL3) #if defined(MBEDTLS_SSL_PROTO_SSL3)
static void ssl_calc_finished_ssl( static int ssl_calc_finished_ssl(
mbedtls_ssl_context *ssl, unsigned char *buf, int from) mbedtls_ssl_context *ssl, unsigned char *buf, int from)
{ {
const char *sender; const char *sender;
@ -3218,11 +3218,13 @@ static void ssl_calc_finished_ssl(
mbedtls_platform_zeroize(sha1sum, sizeof(sha1sum)); mbedtls_platform_zeroize(sha1sum, sizeof(sha1sum));
MBEDTLS_SSL_DEBUG_MSG(2, ("<= calc finished")); MBEDTLS_SSL_DEBUG_MSG(2, ("<= calc finished"));
return 0;
} }
#endif /* MBEDTLS_SSL_PROTO_SSL3 */ #endif /* MBEDTLS_SSL_PROTO_SSL3 */
#if defined(MBEDTLS_SSL_PROTO_TLS1) || defined(MBEDTLS_SSL_PROTO_TLS1_1) #if defined(MBEDTLS_SSL_PROTO_TLS1) || defined(MBEDTLS_SSL_PROTO_TLS1_1)
static void ssl_calc_finished_tls( static int ssl_calc_finished_tls(
mbedtls_ssl_context *ssl, unsigned char *buf, int from) mbedtls_ssl_context *ssl, unsigned char *buf, int from)
{ {
int len = 12; int len = 12;
@ -3278,12 +3280,14 @@ static void ssl_calc_finished_tls(
mbedtls_platform_zeroize(padbuf, sizeof(padbuf)); mbedtls_platform_zeroize(padbuf, sizeof(padbuf));
MBEDTLS_SSL_DEBUG_MSG(2, ("<= calc finished")); MBEDTLS_SSL_DEBUG_MSG(2, ("<= calc finished"));
return 0;
} }
#endif /* MBEDTLS_SSL_PROTO_TLS1 || MBEDTLS_SSL_PROTO_TLS1_1 */ #endif /* MBEDTLS_SSL_PROTO_TLS1 || MBEDTLS_SSL_PROTO_TLS1_1 */
#if defined(MBEDTLS_SSL_PROTO_TLS1_2) #if defined(MBEDTLS_SSL_PROTO_TLS1_2)
#if defined(MBEDTLS_SHA256_C) #if defined(MBEDTLS_SHA256_C)
static void ssl_calc_finished_tls_sha256( static int ssl_calc_finished_tls_sha256(
mbedtls_ssl_context *ssl, unsigned char *buf, int from) mbedtls_ssl_context *ssl, unsigned char *buf, int from)
{ {
int len = 12; int len = 12;
@ -3314,13 +3318,13 @@ static void ssl_calc_finished_tls_sha256(
status = psa_hash_clone(&ssl->handshake->fin_sha256_psa, &sha256_psa); status = psa_hash_clone(&ssl->handshake->fin_sha256_psa, &sha256_psa);
if (status != PSA_SUCCESS) { if (status != PSA_SUCCESS) {
MBEDTLS_SSL_DEBUG_MSG(2, ("PSA hash clone failed")); MBEDTLS_SSL_DEBUG_MSG(2, ("PSA hash clone failed"));
return; return status;
} }
status = psa_hash_finish(&sha256_psa, padbuf, sizeof(padbuf), &hash_size); status = psa_hash_finish(&sha256_psa, padbuf, sizeof(padbuf), &hash_size);
if (status != PSA_SUCCESS) { if (status != PSA_SUCCESS) {
MBEDTLS_SSL_DEBUG_MSG(2, ("PSA hash finish failed")); MBEDTLS_SSL_DEBUG_MSG(2, ("PSA hash finish failed"));
return; return status;
} }
MBEDTLS_SSL_DEBUG_BUF(3, "PSA calculated padbuf", padbuf, 32); MBEDTLS_SSL_DEBUG_BUF(3, "PSA calculated padbuf", padbuf, 32);
#else #else
@ -3354,12 +3358,14 @@ static void ssl_calc_finished_tls_sha256(
mbedtls_platform_zeroize(padbuf, sizeof(padbuf)); mbedtls_platform_zeroize(padbuf, sizeof(padbuf));
MBEDTLS_SSL_DEBUG_MSG(2, ("<= calc finished")); MBEDTLS_SSL_DEBUG_MSG(2, ("<= calc finished"));
return 0;
} }
#endif /* MBEDTLS_SHA256_C */ #endif /* MBEDTLS_SHA256_C */
#if defined(MBEDTLS_SHA512_C) && !defined(MBEDTLS_SHA512_NO_SHA384) #if defined(MBEDTLS_SHA512_C) && !defined(MBEDTLS_SHA512_NO_SHA384)
static void ssl_calc_finished_tls_sha384( static int ssl_calc_finished_tls_sha384(
mbedtls_ssl_context *ssl, unsigned char *buf, int from) mbedtls_ssl_context *ssl, unsigned char *buf, int from)
{ {
int len = 12; int len = 12;
@ -3390,13 +3396,13 @@ static void ssl_calc_finished_tls_sha384(
status = psa_hash_clone(&ssl->handshake->fin_sha384_psa, &sha384_psa); status = psa_hash_clone(&ssl->handshake->fin_sha384_psa, &sha384_psa);
if (status != PSA_SUCCESS) { if (status != PSA_SUCCESS) {
MBEDTLS_SSL_DEBUG_MSG(2, ("PSA hash clone failed")); MBEDTLS_SSL_DEBUG_MSG(2, ("PSA hash clone failed"));
return; return status;
} }
status = psa_hash_finish(&sha384_psa, padbuf, sizeof(padbuf), &hash_size); status = psa_hash_finish(&sha384_psa, padbuf, sizeof(padbuf), &hash_size);
if (status != PSA_SUCCESS) { if (status != PSA_SUCCESS) {
MBEDTLS_SSL_DEBUG_MSG(2, ("PSA hash finish failed")); MBEDTLS_SSL_DEBUG_MSG(2, ("PSA hash finish failed"));
return; return status;
} }
MBEDTLS_SSL_DEBUG_BUF(3, "PSA calculated padbuf", padbuf, 48); MBEDTLS_SSL_DEBUG_BUF(3, "PSA calculated padbuf", padbuf, 48);
#else #else
@ -3441,6 +3447,8 @@ static void ssl_calc_finished_tls_sha384(
mbedtls_platform_zeroize(padbuf, sizeof(padbuf)); mbedtls_platform_zeroize(padbuf, sizeof(padbuf));
MBEDTLS_SSL_DEBUG_MSG(2, ("<= calc finished")); MBEDTLS_SSL_DEBUG_MSG(2, ("<= calc finished"));
return 0;
} }
#endif /* MBEDTLS_SHA512_C && !MBEDTLS_SHA512_NO_SHA384 */ #endif /* MBEDTLS_SHA512_C && !MBEDTLS_SHA512_NO_SHA384 */
#endif /* MBEDTLS_SSL_PROTO_TLS1_2 */ #endif /* MBEDTLS_SSL_PROTO_TLS1_2 */
@ -3535,7 +3543,12 @@ int mbedtls_ssl_write_finished(mbedtls_ssl_context *ssl)
mbedtls_ssl_update_out_pointers(ssl, ssl->transform_negotiate); mbedtls_ssl_update_out_pointers(ssl, ssl->transform_negotiate);
ssl->handshake->calc_finished(ssl, ssl->out_msg + 4, ssl->conf->endpoint); ret = ssl->handshake->calc_finished(ssl, ssl->out_msg + 4,
ssl->conf->endpoint);
if (ret != 0) {
MBEDTLS_SSL_DEBUG_RET(1, "calc_finished", ret);
return ret;
}
/* /*
* RFC 5246 7.4.9 (Page 63) says 12 is the default length and ciphersuites * RFC 5246 7.4.9 (Page 63) says 12 is the default length and ciphersuites