From 0d76341eace52010cbf1d8f6e2b5054455614eb9 Mon Sep 17 00:00:00 2001 From: Neil Armstrong Date: Thu, 11 Aug 2022 10:32:22 +0200 Subject: [PATCH] Remove md_info by md_type in ecjpake context, use mbedtls_hash_info_get_size() to get hash length Signed-off-by: Neil Armstrong --- include/mbedtls/ecjpake.h | 2 +- library/ecjpake.c | 60 +++++++++++++----------- tests/suites/test_suite_ecjpake.function | 6 +-- 3 files changed, 37 insertions(+), 31 deletions(-) diff --git a/include/mbedtls/ecjpake.h b/include/mbedtls/ecjpake.h index 7853a6a837..ffdea05bcf 100644 --- a/include/mbedtls/ecjpake.h +++ b/include/mbedtls/ecjpake.h @@ -70,7 +70,7 @@ typedef enum { */ typedef struct mbedtls_ecjpake_context { - const mbedtls_md_info_t *MBEDTLS_PRIVATE(md_info); /**< Hash to use */ + mbedtls_md_type_t MBEDTLS_PRIVATE(md_type); /**< Hash to use */ mbedtls_ecp_group MBEDTLS_PRIVATE(grp); /**< Elliptic curve */ mbedtls_ecjpake_role MBEDTLS_PRIVATE(role); /**< Are we client or server? */ int MBEDTLS_PRIVATE(point_format); /**< Format for point export */ diff --git a/library/ecjpake.c b/library/ecjpake.c index c591924b77..10286c27d7 100644 --- a/library/ecjpake.c +++ b/library/ecjpake.c @@ -30,6 +30,8 @@ #include "mbedtls/platform_util.h" #include "mbedtls/error.h" +#include "hash_info.h" + #include #if !defined(MBEDTLS_ECJPAKE_ALT) @@ -50,7 +52,7 @@ static const char * const ecjpake_id[] = { */ void mbedtls_ecjpake_init( mbedtls_ecjpake_context *ctx ) { - ctx->md_info = NULL; + ctx->md_type = MBEDTLS_MD_NONE; mbedtls_ecp_group_init( &ctx->grp ); ctx->point_format = MBEDTLS_ECP_PF_UNCOMPRESSED; @@ -73,7 +75,7 @@ void mbedtls_ecjpake_free( mbedtls_ecjpake_context *ctx ) if( ctx == NULL ) return; - ctx->md_info = NULL; + ctx->md_type = MBEDTLS_MD_NONE; mbedtls_ecp_group_free( &ctx->grp ); mbedtls_ecp_point_free( &ctx->Xm1 ); @@ -104,9 +106,11 @@ int mbedtls_ecjpake_setup( mbedtls_ecjpake_context *ctx, ctx->role = role; - if( ( ctx->md_info = mbedtls_md_info_from_type( hash ) ) == NULL ) + if( ( mbedtls_md_info_from_type( hash ) ) == NULL ) return( MBEDTLS_ERR_MD_FEATURE_UNAVAILABLE ); + ctx->md_type = hash; + MBEDTLS_MPI_CHK( mbedtls_ecp_group_load( &ctx->grp, curve ) ); MBEDTLS_MPI_CHK( mbedtls_mpi_read_binary( &ctx->s, secret, len ) ); @@ -137,7 +141,7 @@ int mbedtls_ecjpake_set_point_format( mbedtls_ecjpake_context *ctx, */ int mbedtls_ecjpake_check( const mbedtls_ecjpake_context *ctx ) { - if( ctx->md_info == NULL || + if( ctx->md_type == MBEDTLS_MD_NONE || ctx->grp.id == MBEDTLS_ECP_DP_NONE || ctx->s.p == NULL ) { @@ -184,7 +188,7 @@ static int ecjpake_write_len_point( unsigned char **p, /* * Compute hash for ZKP (7.4.2.2.2.1) */ -static int ecjpake_hash( const mbedtls_md_info_t *md_info, +static int ecjpake_hash( const mbedtls_md_type_t md_type, const mbedtls_ecp_group *grp, const int pf, const mbedtls_ecp_point *G, @@ -218,11 +222,12 @@ static int ecjpake_hash( const mbedtls_md_info_t *md_info, p += id_len; /* Compute hash */ - MBEDTLS_MPI_CHK( mbedtls_md( md_info, buf, p - buf, hash ) ); + MBEDTLS_MPI_CHK( mbedtls_md( mbedtls_md_info_from_type( md_type ), + buf, p - buf, hash ) ); /* Turn it into an integer mod n */ MBEDTLS_MPI_CHK( mbedtls_mpi_read_binary( h, hash, - mbedtls_md_get_size( md_info ) ) ); + mbedtls_hash_info_get_size( md_type ) ) ); MBEDTLS_MPI_CHK( mbedtls_mpi_mod_mpi( h, h, &grp->N ) ); cleanup: @@ -232,7 +237,7 @@ cleanup: /* * Parse a ECShnorrZKP (7.4.2.2.2) and verify it (7.4.2.3.3) */ -static int ecjpake_zkp_read( const mbedtls_md_info_t *md_info, +static int ecjpake_zkp_read( const mbedtls_md_type_t md_type, const mbedtls_ecp_group *grp, const int pf, const mbedtls_ecp_point *G, @@ -282,7 +287,7 @@ static int ecjpake_zkp_read( const mbedtls_md_info_t *md_info, /* * Verification */ - MBEDTLS_MPI_CHK( ecjpake_hash( md_info, grp, pf, G, &V, X, id, &h ) ); + MBEDTLS_MPI_CHK( ecjpake_hash( md_type, grp, pf, G, &V, X, id, &h ) ); MBEDTLS_MPI_CHK( mbedtls_ecp_muladd( (mbedtls_ecp_group *) grp, &VV, &h, X, &r, G ) ); @@ -304,7 +309,7 @@ cleanup: /* * Generate ZKP (7.4.2.3.2) and write it as ECSchnorrZKP (7.4.2.2.2) */ -static int ecjpake_zkp_write( const mbedtls_md_info_t *md_info, +static int ecjpake_zkp_write( const mbedtls_md_type_t md_type, const mbedtls_ecp_group *grp, const int pf, const mbedtls_ecp_point *G, @@ -332,7 +337,7 @@ static int ecjpake_zkp_write( const mbedtls_md_info_t *md_info, /* Compute signature */ MBEDTLS_MPI_CHK( mbedtls_ecp_gen_keypair_base( (mbedtls_ecp_group *) grp, G, &v, &V, f_rng, p_rng ) ); - MBEDTLS_MPI_CHK( ecjpake_hash( md_info, grp, pf, G, &V, X, id, &h ) ); + MBEDTLS_MPI_CHK( ecjpake_hash( md_type, grp, pf, G, &V, X, id, &h ) ); MBEDTLS_MPI_CHK( mbedtls_mpi_mul_mpi( &h, &h, x ) ); /* x*h */ MBEDTLS_MPI_CHK( mbedtls_mpi_sub_mpi( &h, &v, &h ) ); /* v - x*h */ MBEDTLS_MPI_CHK( mbedtls_mpi_mod_mpi( &h, &h, &grp->N ) ); /* r */ @@ -365,7 +370,7 @@ cleanup: * Parse a ECJPAKEKeyKP (7.4.2.2.1) and check proof * Output: verified public key X */ -static int ecjpake_kkp_read( const mbedtls_md_info_t *md_info, +static int ecjpake_kkp_read( const mbedtls_md_type_t md_type, const mbedtls_ecp_group *grp, const int pf, const mbedtls_ecp_point *G, @@ -392,7 +397,7 @@ static int ecjpake_kkp_read( const mbedtls_md_info_t *md_info, goto cleanup; } - MBEDTLS_MPI_CHK( ecjpake_zkp_read( md_info, grp, pf, G, X, id, p, end ) ); + MBEDTLS_MPI_CHK( ecjpake_zkp_read( md_type, grp, pf, G, X, id, p, end ) ); cleanup: return( ret ); @@ -402,7 +407,7 @@ cleanup: * Generate an ECJPAKEKeyKP * Output: the serialized structure, plus private/public key pair */ -static int ecjpake_kkp_write( const mbedtls_md_info_t *md_info, +static int ecjpake_kkp_write( const mbedtls_md_type_t md_type, const mbedtls_ecp_group *grp, const int pf, const mbedtls_ecp_point *G, @@ -428,7 +433,7 @@ static int ecjpake_kkp_write( const mbedtls_md_info_t *md_info, *p += len; /* Generate and write proof */ - MBEDTLS_MPI_CHK( ecjpake_zkp_write( md_info, grp, pf, G, x, X, id, + MBEDTLS_MPI_CHK( ecjpake_zkp_write( md_type, grp, pf, G, x, X, id, p, end, f_rng, p_rng ) ); cleanup: @@ -439,7 +444,7 @@ cleanup: * Read a ECJPAKEKeyKPPairList (7.4.2.3) and check proofs * Outputs: verified peer public keys Xa, Xb */ -static int ecjpake_kkpp_read( const mbedtls_md_info_t *md_info, +static int ecjpake_kkpp_read( const mbedtls_md_type_t md_type, const mbedtls_ecp_group *grp, const int pf, const mbedtls_ecp_point *G, @@ -458,8 +463,8 @@ static int ecjpake_kkpp_read( const mbedtls_md_info_t *md_info, * ECJPAKEKeyKP ecjpake_key_kp_pair_list[2]; * } ECJPAKEKeyKPPairList; */ - MBEDTLS_MPI_CHK( ecjpake_kkp_read( md_info, grp, pf, G, Xa, id, &p, end ) ); - MBEDTLS_MPI_CHK( ecjpake_kkp_read( md_info, grp, pf, G, Xb, id, &p, end ) ); + MBEDTLS_MPI_CHK( ecjpake_kkp_read( md_type, grp, pf, G, Xa, id, &p, end ) ); + MBEDTLS_MPI_CHK( ecjpake_kkp_read( md_type, grp, pf, G, Xb, id, &p, end ) ); if( p != end ) ret = MBEDTLS_ERR_ECP_BAD_INPUT_DATA; @@ -472,7 +477,7 @@ cleanup: * Generate a ECJPAKEKeyKPPairList * Outputs: the serialized structure, plus two private/public key pairs */ -static int ecjpake_kkpp_write( const mbedtls_md_info_t *md_info, +static int ecjpake_kkpp_write( const mbedtls_md_type_t md_type, const mbedtls_ecp_group *grp, const int pf, const mbedtls_ecp_point *G, @@ -491,9 +496,9 @@ static int ecjpake_kkpp_write( const mbedtls_md_info_t *md_info, unsigned char *p = buf; const unsigned char *end = buf + len; - MBEDTLS_MPI_CHK( ecjpake_kkp_write( md_info, grp, pf, G, xm1, Xa, id, + MBEDTLS_MPI_CHK( ecjpake_kkp_write( md_type, grp, pf, G, xm1, Xa, id, &p, end, f_rng, p_rng ) ); - MBEDTLS_MPI_CHK( ecjpake_kkp_write( md_info, grp, pf, G, xm2, Xb, id, + MBEDTLS_MPI_CHK( ecjpake_kkp_write( md_type, grp, pf, G, xm2, Xb, id, &p, end, f_rng, p_rng ) ); *olen = p - buf; @@ -509,7 +514,7 @@ int mbedtls_ecjpake_read_round_one( mbedtls_ecjpake_context *ctx, const unsigned char *buf, size_t len ) { - return( ecjpake_kkpp_read( ctx->md_info, &ctx->grp, ctx->point_format, + return( ecjpake_kkpp_read( ctx->md_type, &ctx->grp, ctx->point_format, &ctx->grp.G, &ctx->Xp1, &ctx->Xp2, ID_PEER, buf, len ) ); @@ -523,7 +528,7 @@ int mbedtls_ecjpake_write_round_one( mbedtls_ecjpake_context *ctx, int (*f_rng)(void *, unsigned char *, size_t), void *p_rng ) { - return( ecjpake_kkpp_write( ctx->md_info, &ctx->grp, ctx->point_format, + return( ecjpake_kkpp_write( ctx->md_type, &ctx->grp, ctx->point_format, &ctx->grp.G, &ctx->xm1, &ctx->Xm1, &ctx->xm2, &ctx->Xm2, ID_MINE, buf, len, olen, f_rng, p_rng ) ); @@ -593,7 +598,7 @@ int mbedtls_ecjpake_read_round_two( mbedtls_ecjpake_context *ctx, } } - MBEDTLS_MPI_CHK( ecjpake_kkp_read( ctx->md_info, &ctx->grp, + MBEDTLS_MPI_CHK( ecjpake_kkp_read( ctx->md_type, &ctx->grp, ctx->point_format, &G, &ctx->Xp, ID_PEER, &p, end ) ); @@ -703,7 +708,7 @@ int mbedtls_ecjpake_write_round_two( mbedtls_ecjpake_context *ctx, ctx->point_format, &ec_len, p, end - p ) ); p += ec_len; - MBEDTLS_MPI_CHK( ecjpake_zkp_write( ctx->md_info, &ctx->grp, + MBEDTLS_MPI_CHK( ecjpake_zkp_write( ctx->md_type, &ctx->grp, ctx->point_format, &G, &xm, &Xm, ID_MINE, &p, end, f_rng, p_rng ) ); @@ -732,7 +737,7 @@ int mbedtls_ecjpake_derive_secret( mbedtls_ecjpake_context *ctx, unsigned char kx[MBEDTLS_ECP_MAX_BYTES]; size_t x_bytes; - *olen = mbedtls_md_get_size( ctx->md_info ); + *olen = mbedtls_hash_info_get_size( ctx->md_type ); if( len < *olen ) return( MBEDTLS_ERR_ECP_BUFFER_TOO_SMALL ); @@ -758,7 +763,8 @@ int mbedtls_ecjpake_derive_secret( mbedtls_ecjpake_context *ctx, /* PMS = SHA-256( K.X ) */ x_bytes = ( ctx->grp.pbits + 7 ) / 8; MBEDTLS_MPI_CHK( mbedtls_mpi_write_binary( &K.X, kx, x_bytes ) ); - MBEDTLS_MPI_CHK( mbedtls_md( ctx->md_info, kx, x_bytes, buf ) ); + MBEDTLS_MPI_CHK( mbedtls_md( mbedtls_md_info_from_type( ctx->md_type ), + kx, x_bytes, buf ) ); cleanup: mbedtls_ecp_point_free( &K ); diff --git a/tests/suites/test_suite_ecjpake.function b/tests/suites/test_suite_ecjpake.function index e8aaa6cd6b..ab6737f82e 100644 --- a/tests/suites/test_suite_ecjpake.function +++ b/tests/suites/test_suite_ecjpake.function @@ -137,10 +137,10 @@ void read_bad_md( data_t *msg ) mbedtls_ecjpake_init( &corrupt_ctx ); TEST_ASSERT( mbedtls_ecjpake_setup( &corrupt_ctx, any_role, MBEDTLS_MD_SHA256, MBEDTLS_ECP_DP_SECP256R1, pw, pw_len ) == 0 ); - corrupt_ctx.md_info = NULL; + corrupt_ctx.md_type = MBEDTLS_MD_NONE; - TEST_ASSERT( mbedtls_ecjpake_read_round_one( &corrupt_ctx, msg->x, - msg->len ) == MBEDTLS_ERR_MD_BAD_INPUT_DATA ); + TEST_EQUAL( mbedtls_ecjpake_read_round_one( &corrupt_ctx, msg->x, + msg->len ), MBEDTLS_ERR_MD_BAD_INPUT_DATA ); exit: mbedtls_ecjpake_free( &corrupt_ctx );