diff --git a/docs/architecture/tls13-support.md b/docs/architecture/tls13-support.md index 39e46c4562..3154ac1077 100644 --- a/docs/architecture/tls13-support.md +++ b/docs/architecture/tls13-support.md @@ -388,10 +388,10 @@ General coding rules: Example: ``` - int mbedtls_ssl_tls13_start_handshake_msg( mbedtls_ssl_context *ssl, - unsigned hs_type, - unsigned char **buf, - size_t *buf_len ); + int mbedtls_ssl_start_handshake_msg( mbedtls_ssl_context *ssl, + unsigned hs_type, + unsigned char **buf, + size_t *buf_len ); ``` - When a function's parameters span several lines, group related parameters @@ -400,12 +400,12 @@ General coding rules: For example, prefer: ``` - mbedtls_ssl_tls13_start_handshake_msg( ssl, hs_type, - buf, buf_len ); + mbedtls_ssl_start_handshake_msg( ssl, hs_type, + buf, buf_len ); ``` over ``` - mbedtls_ssl_tls13_start_handshake_msg( ssl, hs_type, buf, - buf_len ); + mbedtls_ssl_start_handshake_msg( ssl, hs_type, buf, + buf_len ); ``` even if it fits. diff --git a/library/ssl_misc.h b/library/ssl_misc.h index 289a64a1d6..6329e0d5e5 100644 --- a/library/ssl_misc.h +++ b/library/ssl_misc.h @@ -1152,6 +1152,11 @@ int mbedtls_ssl_write_hostname_ext( mbedtls_ssl_context *ssl, int mbedtls_ssl_handshake_client_step( mbedtls_ssl_context *ssl ); int mbedtls_ssl_handshake_server_step( mbedtls_ssl_context *ssl ); void mbedtls_ssl_handshake_wrapup( mbedtls_ssl_context *ssl ); +static inline void mbedtls_ssl_handshake_set_state( mbedtls_ssl_context *ssl, + mbedtls_ssl_states state ) +{ + ssl->state = ( int ) state; +} int mbedtls_ssl_send_fatal_handshake_failure( mbedtls_ssl_context *ssl ); @@ -1245,6 +1250,12 @@ int mbedtls_ssl_read_record( mbedtls_ssl_context *ssl, unsigned update_hs_digest ); int mbedtls_ssl_fetch_input( mbedtls_ssl_context *ssl, size_t nb_want ); +/* + * Write handshake message header + */ +int mbedtls_ssl_start_handshake_msg( mbedtls_ssl_context *ssl, unsigned hs_type, + unsigned char **buf, size_t *buf_len ); + int mbedtls_ssl_write_handshake_msg_ext( mbedtls_ssl_context *ssl, int update_checksum, int force_flush ); @@ -1253,6 +1264,12 @@ static inline int mbedtls_ssl_write_handshake_msg( mbedtls_ssl_context *ssl ) return( mbedtls_ssl_write_handshake_msg_ext( ssl, 1 /* update checksum */, 1 /* force flush */ ) ); } +/* + * Write handshake message tail + */ +int mbedtls_ssl_finish_handshake_msg( mbedtls_ssl_context *ssl, + size_t buf_len, size_t msg_len ); + int mbedtls_ssl_write_record( mbedtls_ssl_context *ssl, int force_flush ); int mbedtls_ssl_flush_output( mbedtls_ssl_context *ssl ); @@ -1268,8 +1285,17 @@ int mbedtls_ssl_write_finished( mbedtls_ssl_context *ssl ); void mbedtls_ssl_optimize_checksum( mbedtls_ssl_context *ssl, const mbedtls_ssl_ciphersuite_t *ciphersuite_info ); +/* + * Update checksum of handshake messages. + */ +void mbedtls_ssl_add_hs_msg_to_checksum( mbedtls_ssl_context *ssl, + unsigned hs_type, + unsigned char const *msg, + size_t msg_len ); + #if defined(MBEDTLS_KEY_EXCHANGE_SOME_PSK_ENABLED) -int mbedtls_ssl_psk_derive_premaster( mbedtls_ssl_context *ssl, mbedtls_key_exchange_type_t key_ex ); +int mbedtls_ssl_psk_derive_premaster( mbedtls_ssl_context *ssl, + mbedtls_key_exchange_type_t key_ex ); /** * Get the first defined PSK by order of precedence: @@ -1727,13 +1753,6 @@ static inline int mbedtls_ssl_tls13_some_psk_enabled( mbedtls_ssl_context *ssl ) MBEDTLS_SSL_TLS1_3_KEY_EXCHANGE_MODE_PSK_ALL ) ); } - -static inline void mbedtls_ssl_handshake_set_state( mbedtls_ssl_context *ssl, - mbedtls_ssl_states state ) -{ - ssl->state = ( int ) state; -} - /* * Fetch TLS 1.3 handshake message header */ @@ -1742,14 +1761,6 @@ int mbedtls_ssl_tls13_fetch_handshake_msg( mbedtls_ssl_context *ssl, unsigned char **buf, size_t *buf_len ); -/* - * Write TLS 1.3 handshake message header - */ -int mbedtls_ssl_tls13_start_handshake_msg( mbedtls_ssl_context *ssl, - unsigned hs_type, - unsigned char **buf, - size_t *buf_len ); - /* * Handler of TLS 1.3 server certificate message */ @@ -1778,25 +1789,6 @@ int mbedtls_ssl_tls13_process_certificate_verify( mbedtls_ssl_context *ssl ); */ int mbedtls_ssl_tls13_write_change_cipher_spec( mbedtls_ssl_context *ssl ); -/* - * Write TLS 1.3 handshake message tail - */ -int mbedtls_ssl_tls13_finish_handshake_msg( mbedtls_ssl_context *ssl, - size_t buf_len, - size_t msg_len ); - -void mbedtls_ssl_tls13_add_hs_hdr_to_checksum( mbedtls_ssl_context *ssl, - unsigned hs_type, - size_t total_hs_len ); - -/* - * Update checksum of handshake messages. - */ -void mbedtls_ssl_tls13_add_hs_msg_to_checksum( mbedtls_ssl_context *ssl, - unsigned hs_type, - unsigned char const *msg, - size_t msg_len ); - int mbedtls_ssl_reset_transcript_for_hrr( mbedtls_ssl_context *ssl ); #endif /* MBEDTLS_SSL_PROTO_TLS1_3 */ diff --git a/library/ssl_msg.c b/library/ssl_msg.c index c2effb6d23..ae3dc13303 100644 --- a/library/ssl_msg.c +++ b/library/ssl_msg.c @@ -2445,6 +2445,24 @@ void mbedtls_ssl_send_flight_completed( mbedtls_ssl_context *ssl ) /* * Handshake layer functions */ +int mbedtls_ssl_start_handshake_msg( mbedtls_ssl_context *ssl, unsigned hs_type, + unsigned char **buf, size_t *buf_len ) +{ + /* + * Reserve 4 bytes for hanshake header. ( Section 4,RFC 8446 ) + * ... + * HandshakeType msg_type; + * uint24 length; + * ... + */ + *buf = ssl->out_msg + 4; + *buf_len = MBEDTLS_SSL_OUT_CONTENT_LEN - 4; + + ssl->out_msgtype = MBEDTLS_SSL_MSG_HANDSHAKE; + ssl->out_msg[0] = hs_type; + + return( 0 ); +} /* * Write (DTLS: or queue) current handshake (including CCS) message. @@ -2609,6 +2627,22 @@ int mbedtls_ssl_write_handshake_msg_ext( mbedtls_ssl_context *ssl, return( 0 ); } +int mbedtls_ssl_finish_handshake_msg( mbedtls_ssl_context *ssl, + size_t buf_len, size_t msg_len ) +{ + int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED; + size_t msg_with_header_len; + ((void) buf_len); + + /* Add reserved 4 bytes for handshake header */ + msg_with_header_len = msg_len + 4; + ssl->out_msglen = msg_with_header_len; + MBEDTLS_SSL_PROC_CHK( mbedtls_ssl_write_handshake_msg_ext( ssl, 0, 0 ) ); + +cleanup: + return( ret ); +} + /* * Record layer functions */ diff --git a/library/ssl_tls.c b/library/ssl_tls.c index 6a1bfa8f5a..b14848527f 100644 --- a/library/ssl_tls.c +++ b/library/ssl_tls.c @@ -475,6 +475,30 @@ void mbedtls_ssl_optimize_checksum( mbedtls_ssl_context *ssl, } } +static void mbedtls_ssl_add_hs_hdr_to_checksum( mbedtls_ssl_context *ssl, + unsigned hs_type, + size_t total_hs_len ) +{ + unsigned char hs_hdr[4]; + + /* Build HS header for checksum update. */ + hs_hdr[0] = MBEDTLS_BYTE_0( hs_type ); + hs_hdr[1] = MBEDTLS_BYTE_2( total_hs_len ); + hs_hdr[2] = MBEDTLS_BYTE_1( total_hs_len ); + hs_hdr[3] = MBEDTLS_BYTE_0( total_hs_len ); + + ssl->handshake->update_checksum( ssl, hs_hdr, sizeof( hs_hdr ) ); +} + +void mbedtls_ssl_add_hs_msg_to_checksum( mbedtls_ssl_context *ssl, + unsigned hs_type, + unsigned char const *msg, + size_t msg_len ) +{ + mbedtls_ssl_add_hs_hdr_to_checksum( ssl, hs_type, msg_len ); + ssl->handshake->update_checksum( ssl, msg, msg_len ); +} + void mbedtls_ssl_reset_checksum( mbedtls_ssl_context *ssl ) { ((void) ssl); diff --git a/library/ssl_tls13_client.c b/library/ssl_tls13_client.c index 7d3bdf2408..9f22e1dcc6 100644 --- a/library/ssl_tls13_client.c +++ b/library/ssl_tls13_client.c @@ -1060,7 +1060,7 @@ static int ssl_tls13_write_client_hello( mbedtls_ssl_context *ssl ) MBEDTLS_SSL_PROC_CHK( ssl_tls13_prepare_client_hello( ssl ) ); - MBEDTLS_SSL_PROC_CHK( mbedtls_ssl_tls13_start_handshake_msg( + MBEDTLS_SSL_PROC_CHK( mbedtls_ssl_start_handshake_msg( ssl, MBEDTLS_SSL_HS_CLIENT_HELLO, &buf, &buf_len ) ); @@ -1068,14 +1068,12 @@ static int ssl_tls13_write_client_hello( mbedtls_ssl_context *ssl ) buf + buf_len, &msg_len ) ); - mbedtls_ssl_tls13_add_hs_hdr_to_checksum( ssl, - MBEDTLS_SSL_HS_CLIENT_HELLO, - msg_len ); - ssl->handshake->update_checksum( ssl, buf, msg_len ); + mbedtls_ssl_add_hs_msg_to_checksum( ssl, MBEDTLS_SSL_HS_CLIENT_HELLO, + buf, msg_len ); - MBEDTLS_SSL_PROC_CHK( mbedtls_ssl_tls13_finish_handshake_msg( ssl, - buf_len, - msg_len ) ); + MBEDTLS_SSL_PROC_CHK( mbedtls_ssl_finish_handshake_msg( ssl, + buf_len, + msg_len ) ); mbedtls_ssl_handshake_set_state( ssl, MBEDTLS_SSL_SERVER_HELLO ); @@ -1707,9 +1705,8 @@ static int ssl_tls13_process_server_hello( mbedtls_ssl_context *ssl ) if( is_hrr ) MBEDTLS_SSL_PROC_CHK( mbedtls_ssl_reset_transcript_for_hrr( ssl ) ); - mbedtls_ssl_tls13_add_hs_msg_to_checksum( ssl, - MBEDTLS_SSL_HS_SERVER_HELLO, - buf, buf_len ); + mbedtls_ssl_add_hs_msg_to_checksum( ssl, MBEDTLS_SSL_HS_SERVER_HELLO, + buf, buf_len ); if( is_hrr ) MBEDTLS_SSL_PROC_CHK( ssl_tls13_postprocess_hrr( ssl ) ); @@ -1762,8 +1759,8 @@ static int ssl_tls13_process_encrypted_extensions( mbedtls_ssl_context *ssl ) MBEDTLS_SSL_PROC_CHK( ssl_tls13_parse_encrypted_extensions( ssl, buf, buf + buf_len ) ); - mbedtls_ssl_tls13_add_hs_msg_to_checksum( - ssl, MBEDTLS_SSL_HS_ENCRYPTED_EXTENSIONS, buf, buf_len ); + mbedtls_ssl_add_hs_msg_to_checksum( ssl, MBEDTLS_SSL_HS_ENCRYPTED_EXTENSIONS, + buf, buf_len ); MBEDTLS_SSL_PROC_CHK( ssl_tls13_postprocess_encrypted_extensions( ssl ) ); @@ -2059,8 +2056,8 @@ static int ssl_tls13_process_certificate_request( mbedtls_ssl_context *ssl ) MBEDTLS_SSL_PROC_CHK( ssl_tls13_parse_certificate_request( ssl, buf, buf + buf_len ) ); - mbedtls_ssl_tls13_add_hs_msg_to_checksum( - ssl, MBEDTLS_SSL_HS_CERTIFICATE_REQUEST, buf, buf_len ); + mbedtls_ssl_add_hs_msg_to_checksum( ssl, MBEDTLS_SSL_HS_CERTIFICATE_REQUEST, + buf, buf_len ); } else if( ret == SSL_CERTIFICATE_REQUEST_SKIP ) { diff --git a/library/ssl_tls13_generic.c b/library/ssl_tls13_generic.c index dab98a34f3..6623e7f705 100644 --- a/library/ssl_tls13_generic.c +++ b/library/ssl_tls13_generic.c @@ -72,68 +72,6 @@ cleanup: return( ret ); } -int mbedtls_ssl_tls13_start_handshake_msg( mbedtls_ssl_context *ssl, - unsigned hs_type, - unsigned char **buf, - size_t *buf_len ) -{ - /* - * Reserve 4 bytes for hanshake header. ( Section 4,RFC 8446 ) - * ... - * HandshakeType msg_type; - * uint24 length; - * ... - */ - *buf = ssl->out_msg + 4; - *buf_len = MBEDTLS_SSL_OUT_CONTENT_LEN - 4; - - ssl->out_msgtype = MBEDTLS_SSL_MSG_HANDSHAKE; - ssl->out_msg[0] = hs_type; - - return( 0 ); -} - -int mbedtls_ssl_tls13_finish_handshake_msg( mbedtls_ssl_context *ssl, - size_t buf_len, - size_t msg_len ) -{ - int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED; - size_t msg_with_header_len; - ((void) buf_len); - - /* Add reserved 4 bytes for handshake header */ - msg_with_header_len = msg_len + 4; - ssl->out_msglen = msg_with_header_len; - MBEDTLS_SSL_PROC_CHK( mbedtls_ssl_write_handshake_msg_ext( ssl, 0, 0 ) ); - -cleanup: - return( ret ); -} - -void mbedtls_ssl_tls13_add_hs_msg_to_checksum( mbedtls_ssl_context *ssl, - unsigned hs_type, - unsigned char const *msg, - size_t msg_len ) -{ - mbedtls_ssl_tls13_add_hs_hdr_to_checksum( ssl, hs_type, msg_len ); - ssl->handshake->update_checksum( ssl, msg, msg_len ); -} - -void mbedtls_ssl_tls13_add_hs_hdr_to_checksum( mbedtls_ssl_context *ssl, - unsigned hs_type, - size_t total_hs_len ) -{ - unsigned char hs_hdr[4]; - - /* Build HS header for checksum update. */ - hs_hdr[0] = MBEDTLS_BYTE_0( hs_type ); - hs_hdr[1] = MBEDTLS_BYTE_2( total_hs_len ); - hs_hdr[2] = MBEDTLS_BYTE_1( total_hs_len ); - hs_hdr[3] = MBEDTLS_BYTE_0( total_hs_len ); - - ssl->handshake->update_checksum( ssl, hs_hdr, sizeof( hs_hdr ) ); -} - #if defined(MBEDTLS_KEY_EXCHANGE_WITH_CERT_ENABLED) /* mbedtls_ssl_tls13_parse_sig_alg_ext() * @@ -479,8 +417,8 @@ int mbedtls_ssl_tls13_process_certificate_verify( mbedtls_ssl_context *ssl ) MBEDTLS_SSL_PROC_CHK( ssl_tls13_parse_certificate_verify( ssl, buf, buf + buf_len, verify_buffer, verify_buffer_len ) ); - mbedtls_ssl_tls13_add_hs_msg_to_checksum( ssl, - MBEDTLS_SSL_HS_CERTIFICATE_VERIFY, buf, buf_len ); + mbedtls_ssl_add_hs_msg_to_checksum( ssl, MBEDTLS_SSL_HS_CERTIFICATE_VERIFY, + buf, buf_len ); cleanup: @@ -796,8 +734,8 @@ int mbedtls_ssl_tls13_process_certificate( mbedtls_ssl_context *ssl ) /* Validate the certificate chain and set the verification results. */ MBEDTLS_SSL_PROC_CHK( ssl_tls13_validate_certificate( ssl ) ); - mbedtls_ssl_tls13_add_hs_msg_to_checksum( ssl, MBEDTLS_SSL_HS_CERTIFICATE, - buf, buf_len ); + mbedtls_ssl_add_hs_msg_to_checksum( ssl, MBEDTLS_SSL_HS_CERTIFICATE, + buf, buf_len ); cleanup: @@ -904,7 +842,7 @@ int mbedtls_ssl_tls13_write_certificate( mbedtls_ssl_context *ssl ) MBEDTLS_SSL_DEBUG_MSG( 2, ( "=> write certificate" ) ); - MBEDTLS_SSL_PROC_CHK( mbedtls_ssl_tls13_start_handshake_msg( ssl, + MBEDTLS_SSL_PROC_CHK( mbedtls_ssl_start_handshake_msg( ssl, MBEDTLS_SSL_HS_CERTIFICATE, &buf, &buf_len ) ); MBEDTLS_SSL_PROC_CHK( ssl_tls13_write_certificate_body( ssl, @@ -912,12 +850,10 @@ int mbedtls_ssl_tls13_write_certificate( mbedtls_ssl_context *ssl ) buf + buf_len, &msg_len ) ); - mbedtls_ssl_tls13_add_hs_msg_to_checksum( ssl, - MBEDTLS_SSL_HS_CERTIFICATE, - buf, - msg_len ); + mbedtls_ssl_add_hs_msg_to_checksum( ssl, MBEDTLS_SSL_HS_CERTIFICATE, + buf, msg_len ); - MBEDTLS_SSL_PROC_CHK( mbedtls_ssl_tls13_finish_handshake_msg( + MBEDTLS_SSL_PROC_CHK( mbedtls_ssl_finish_handshake_msg( ssl, buf_len, msg_len ) ); cleanup: @@ -1161,16 +1097,16 @@ int mbedtls_ssl_tls13_write_certificate_verify( mbedtls_ssl_context *ssl ) MBEDTLS_SSL_DEBUG_MSG( 2, ( "=> write certificate verify" ) ); - MBEDTLS_SSL_PROC_CHK( mbedtls_ssl_tls13_start_handshake_msg( ssl, + MBEDTLS_SSL_PROC_CHK( mbedtls_ssl_start_handshake_msg( ssl, MBEDTLS_SSL_HS_CERTIFICATE_VERIFY, &buf, &buf_len ) ); MBEDTLS_SSL_PROC_CHK( ssl_tls13_write_certificate_verify_body( ssl, buf, buf + buf_len, &msg_len ) ); - mbedtls_ssl_tls13_add_hs_msg_to_checksum( - ssl, MBEDTLS_SSL_HS_CERTIFICATE_VERIFY, buf, msg_len ); + mbedtls_ssl_add_hs_msg_to_checksum( ssl, MBEDTLS_SSL_HS_CERTIFICATE_VERIFY, + buf, msg_len ); - MBEDTLS_SSL_PROC_CHK( mbedtls_ssl_tls13_finish_handshake_msg( + MBEDTLS_SSL_PROC_CHK( mbedtls_ssl_finish_handshake_msg( ssl, buf_len, msg_len ) ); cleanup: @@ -1340,8 +1276,8 @@ int mbedtls_ssl_tls13_process_finished_message( mbedtls_ssl_context *ssl ) MBEDTLS_SSL_HS_FINISHED, &buf, &buf_len ) ); MBEDTLS_SSL_PROC_CHK( ssl_tls13_parse_finished_message( ssl, buf, buf + buf_len ) ); - mbedtls_ssl_tls13_add_hs_msg_to_checksum( - ssl, MBEDTLS_SSL_HS_FINISHED, buf, buf_len ); + mbedtls_ssl_add_hs_msg_to_checksum( ssl, MBEDTLS_SSL_HS_FINISHED, + buf, buf_len ); MBEDTLS_SSL_PROC_CHK( ssl_tls13_postprocess_finished_message( ssl ) ); cleanup: @@ -1418,19 +1354,18 @@ int mbedtls_ssl_tls13_write_finished_message( mbedtls_ssl_context *ssl ) MBEDTLS_SSL_PROC_CHK( ssl_tls13_prepare_finished_message( ssl ) ); - MBEDTLS_SSL_PROC_CHK( mbedtls_ssl_tls13_start_handshake_msg( ssl, + MBEDTLS_SSL_PROC_CHK( mbedtls_ssl_start_handshake_msg( ssl, MBEDTLS_SSL_HS_FINISHED, &buf, &buf_len ) ); MBEDTLS_SSL_PROC_CHK( ssl_tls13_write_finished_message_body( ssl, buf, buf + buf_len, &msg_len ) ); - mbedtls_ssl_tls13_add_hs_msg_to_checksum( ssl, MBEDTLS_SSL_HS_FINISHED, - buf, msg_len ); + mbedtls_ssl_add_hs_msg_to_checksum( ssl, MBEDTLS_SSL_HS_FINISHED, + buf, msg_len ); MBEDTLS_SSL_PROC_CHK( ssl_tls13_finalize_finished_message( ssl ) ); - MBEDTLS_SSL_PROC_CHK( mbedtls_ssl_tls13_finish_handshake_msg( ssl, - buf_len, msg_len ) ); - + MBEDTLS_SSL_PROC_CHK( mbedtls_ssl_finish_handshake_msg( + ssl, buf_len, msg_len ) ); cleanup: MBEDTLS_SSL_DEBUG_MSG( 2, ( "<= write finished message" ) );