diff --git a/include/mbedtls/ssl.h b/include/mbedtls/ssl.h index 1c98a5e5a8..7fec65e1dd 100644 --- a/include/mbedtls/ssl.h +++ b/include/mbedtls/ssl.h @@ -2682,6 +2682,9 @@ int mbedtls_ssl_conf_own_cert( mbedtls_ssl_config *conf, * \note This is mainly useful for clients. Servers will usually * want to use \c mbedtls_ssl_conf_psk_cb() instead. * + * \note A PSK set by \c mbedtls_ssl_set_hs_psk() in the PSK callback + * takes precedence over a PSK configured by this function. + * * \warning Currently, clients can only register a single pre-shared key. * Calling this function or mbedtls_ssl_conf_psk_opaque() more * than once will overwrite values configured in previous calls. @@ -2715,6 +2718,10 @@ int mbedtls_ssl_conf_psk( mbedtls_ssl_config *conf, * \note This is mainly useful for clients. Servers will usually * want to use \c mbedtls_ssl_conf_psk_cb() instead. * + * \note An opaque PSK set by \c mbedtls_ssl_set_hs_psk_opaque() in + * the PSK callback takes precedence over an opaque PSK + * configured by this function. + * * \warning Currently, clients can only register a single pre-shared key. * Calling this function or mbedtls_ssl_conf_psk() more than * once will overwrite values configured in previous calls. @@ -2752,6 +2759,9 @@ int mbedtls_ssl_conf_psk_opaque( mbedtls_ssl_config *conf, * \note This should only be called inside the PSK callback, * i.e. the function passed to \c mbedtls_ssl_conf_psk_cb(). * + * \note A PSK set by this function takes precedence over a PSK + * configured by \c mbedtls_ssl_conf_psk(). + * * \param ssl The SSL context to configure a PSK for. * \param psk The pointer to the pre-shared key. * \param psk_len The length of the pre-shared key in bytes. @@ -2769,6 +2779,9 @@ int mbedtls_ssl_set_hs_psk( mbedtls_ssl_context *ssl, * \note This should only be called inside the PSK callback, * i.e. the function passed to \c mbedtls_ssl_conf_psk_cb(). * + * \note An opaque PSK set by this function takes precedence over an + * opaque PSK configured by \c mbedtls_ssl_conf_psk_opaque(). + * * \param ssl The SSL context to configure a PSK for. * \param psk The identifier of the key slot holding the PSK. * For the duration of the current handshake, the key slot @@ -2807,9 +2820,14 @@ int mbedtls_ssl_set_hs_psk_opaque( mbedtls_ssl_context *ssl, * on the SSL context to set the correct PSK and return \c 0. * Any other return value will result in a denied PSK identity. * - * \note If you set a PSK callback using this function, then you - * don't need to set a PSK key and identity using - * \c mbedtls_ssl_conf_psk(). + * \note A dynamic PSK (i.e. set by the PSK callback) takes + * precedence over a static PSK (i.e. set by + * \c mbedtls_ssl_conf_psk() or + * \c mbedtls_ssl_conf_psk_opaque()). + * This means that if you set a PSK callback using this + * function, you don't need to set a PSK using + * \c mbedtls_ssl_conf_psk() or + * \c mbedtls_ssl_conf_psk_opaque()). * * \param conf The SSL configuration to register the callback with. * \param f_psk The callback for selecting and setting the PSK based diff --git a/include/mbedtls/ssl_internal.h b/include/mbedtls/ssl_internal.h index f83d01454f..e92381c33d 100644 --- a/include/mbedtls/ssl_internal.h +++ b/include/mbedtls/ssl_internal.h @@ -921,7 +921,60 @@ void mbedtls_ssl_optimize_checksum( mbedtls_ssl_context *ssl, #if defined(MBEDTLS_KEY_EXCHANGE_SOME_PSK_ENABLED) int mbedtls_ssl_psk_derive_premaster( mbedtls_ssl_context *ssl, mbedtls_key_exchange_type_t key_ex ); -#endif + +/** + * Get the first defined PSK by order of precedence: + * 1. handshake PSK set by \c mbedtls_ssl_set_hs_psk() in the PSK callback + * 2. static PSK configured by \c mbedtls_ssl_conf_psk() + * Return a code and update the pair (PSK, PSK length) passed to this function + */ +static inline int mbedtls_ssl_get_psk( const mbedtls_ssl_context *ssl, + const unsigned char **psk, size_t *psk_len ) +{ + if( ssl->handshake->psk != NULL && ssl->handshake->psk_len > 0 ) + { + *psk = ssl->handshake->psk; + *psk_len = ssl->handshake->psk_len; + } + + else if( ssl->conf->psk != NULL && ssl->conf->psk_len > 0 ) + { + *psk = ssl->conf->psk; + *psk_len = ssl->conf->psk_len; + } + + else + { + *psk = NULL; + *psk_len = 0; + return( MBEDTLS_ERR_SSL_PRIVATE_KEY_REQUIRED ); + } + + return( 0 ); +} + +#if defined(MBEDTLS_USE_PSA_CRYPTO) +/** + * Get the first defined opaque PSK by order of precedence: + * 1. handshake PSK set by \c mbedtls_ssl_set_hs_psk_opaque() in the PSK + * callback + * 2. static PSK configured by \c mbedtls_ssl_conf_psk_opaque() + * Return an opaque PSK + */ +static inline psa_key_handle_t mbedtls_ssl_get_opaque_psk( + const mbedtls_ssl_context *ssl ) +{ + if( ssl->handshake->psk_opaque != 0 ) + return( ssl->handshake->psk_opaque ); + + if( ssl->conf->psk_opaque != 0 ) + return( ssl->conf->psk_opaque ); + + return( 0 ); +} +#endif /* MBEDTLS_USE_PSA_CRYPTO */ + +#endif /* MBEDTLS_KEY_EXCHANGE_SOME_PSK_ENABLED */ #if defined(MBEDTLS_PK_C) unsigned char mbedtls_ssl_sig_from_pk( mbedtls_pk_context *pk ); diff --git a/library/ssl_tls.c b/library/ssl_tls.c index d0adfd9515..dbc5a3e880 100644 --- a/library/ssl_tls.c +++ b/library/ssl_tls.c @@ -1514,9 +1514,7 @@ static int ssl_compute_master( mbedtls_ssl_handshake_params *handshake, MBEDTLS_SSL_DEBUG_MSG( 2, ( "perform PSA-based PSK-to-MS expansion" ) ); - psk = ssl->conf->psk_opaque; - if( handshake->psk_opaque != 0 ) - psk = handshake->psk_opaque; + psk = mbedtls_ssl_get_opaque_psk( ssl ); if( hash_alg == MBEDTLS_MD_SHA384 ) alg = PSA_ALG_TLS12_PSK_TO_MS(PSA_ALG_SHA_384); @@ -1850,14 +1848,18 @@ int mbedtls_ssl_psk_derive_premaster( mbedtls_ssl_context *ssl, mbedtls_key_exch { unsigned char *p = ssl->handshake->premaster; unsigned char *end = p + sizeof( ssl->handshake->premaster ); - const unsigned char *psk = ssl->conf->psk; - size_t psk_len = ssl->conf->psk_len; + const unsigned char *psk = NULL; + size_t psk_len = 0; - /* If the psk callback was called, use its result */ - if( ssl->handshake->psk != NULL ) + if( mbedtls_ssl_get_psk( ssl, &psk, &psk_len ) + == MBEDTLS_ERR_SSL_PRIVATE_KEY_REQUIRED ) { - psk = ssl->handshake->psk; - psk_len = ssl->handshake->psk_len; + /* + * This should never happen because the existence of a PSK is always + * checked before calling this function + */ + MBEDTLS_SSL_DEBUG_MSG( 1, ( "should never happen" ) ); + return( MBEDTLS_ERR_SSL_INTERNAL_ERROR ); } /*