1
0
mirror of https://github.com/Mbed-TLS/mbedtls.git synced 2025-07-29 11:41:15 +03:00

Merge pull request #6141 from mpg/driver-hashes-rsa-v21

Driver hashes rsa v21
This commit is contained in:
Dave Rodgman
2022-08-16 09:52:39 +01:00
committed by GitHub
12 changed files with 607 additions and 247 deletions

View File

@ -22,6 +22,7 @@
#include "hash_info.h"
#include "legacy_or_psa.h"
#include "mbedtls/error.h"
typedef struct
{
@ -107,3 +108,20 @@ mbedtls_md_type_t mbedtls_hash_info_md_from_psa( psa_algorithm_t psa_alg )
return entry->md_type;
}
int mbedtls_md_error_from_psa( psa_status_t status )
{
switch( status )
{
case PSA_SUCCESS:
return( 0 );
case PSA_ERROR_NOT_SUPPORTED:
return( MBEDTLS_ERR_MD_FEATURE_UNAVAILABLE );
case PSA_ERROR_INVALID_ARGUMENT:
return( MBEDTLS_ERR_MD_BAD_INPUT_DATA );
case PSA_ERROR_INSUFFICIENT_MEMORY:
return( MBEDTLS_ERR_MD_ALLOC_FAILED );
default:
return( MBEDTLS_ERR_PLATFORM_HW_ACCEL_FAILED );
}
}

View File

@ -74,4 +74,12 @@ psa_algorithm_t mbedtls_hash_info_psa_from_md( mbedtls_md_type_t md_type );
*/
mbedtls_md_type_t mbedtls_hash_info_md_from_psa( psa_algorithm_t psa_alg );
/** Convert PSA status to MD error code.
*
* \param status PSA status.
*
* \return The corresponding MD error code,
*/
int mbedtls_md_error_from_psa( psa_status_t status );
#endif /* MBEDTLS_HASH_INFO_H */

View File

@ -27,7 +27,9 @@
* - low-level module API (aes.h, sha256.h), or
* - an abstraction layer (md.h, cipher.h);
* - <condition> will be either:
* - depending on what's available in the build, or
* - depending on what's available in the build:
* legacy API used if available, PSA otherwise
* (this is done to ensure backwards compatibility); or
* - depending on whether MBEDTLS_USE_PSA_CRYPTO is defined.
*
* Examples:
@ -125,31 +127,38 @@
/* Hashes using MD or PSA based on availability */
#if ( defined(MBEDTLS_MD_C) && defined(MBEDTLS_MD5_C) ) || \
( defined(MBEDTLS_PSA_CRYPTO_C) && defined(PSA_WANT_ALG_MD5) )
( !defined(MBEDTLS_MD_C) && \
defined(MBEDTLS_PSA_CRYPTO_C) && defined(PSA_WANT_ALG_MD5) )
#define MBEDTLS_HAS_ALG_MD5_VIA_MD_OR_PSA
#endif
#if ( defined(MBEDTLS_MD_C) && defined(MBEDTLS_RIPEMD160_C) ) || \
( defined(MBEDTLS_PSA_CRYPTO_C) && defined(PSA_WANT_ALG_RIPEMD160) )
( !defined(MBEDTLS_MD_C) && \
defined(MBEDTLS_PSA_CRYPTO_C) && defined(PSA_WANT_ALG_RIPEMD160) )
#define MBEDTLS_HAS_ALG_RIPEMD160_VIA_MD_OR_PSA
#endif
#if ( defined(MBEDTLS_MD_C) && defined(MBEDTLS_SHA1_C) ) || \
( defined(MBEDTLS_PSA_CRYPTO_C) && defined(PSA_WANT_ALG_SHA_1) )
( !defined(MBEDTLS_MD_C) && \
defined(MBEDTLS_PSA_CRYPTO_C) && defined(PSA_WANT_ALG_SHA_1) )
#define MBEDTLS_HAS_ALG_SHA_1_VIA_MD_OR_PSA
#endif
#if ( defined(MBEDTLS_MD_C) && defined(MBEDTLS_SHA224_C) ) || \
( defined(MBEDTLS_PSA_CRYPTO_C) && defined(PSA_WANT_ALG_SHA_224) )
( !defined(MBEDTLS_MD_C) && \
defined(MBEDTLS_PSA_CRYPTO_C) && defined(PSA_WANT_ALG_SHA_224) )
#define MBEDTLS_HAS_ALG_SHA_224_VIA_MD_OR_PSA
#endif
#if ( defined(MBEDTLS_MD_C) && defined(MBEDTLS_SHA256_C) ) || \
( defined(MBEDTLS_PSA_CRYPTO_C) && defined(PSA_WANT_ALG_SHA_256) )
( !defined(MBEDTLS_MD_C) && \
defined(MBEDTLS_PSA_CRYPTO_C) && defined(PSA_WANT_ALG_SHA_256) )
#define MBEDTLS_HAS_ALG_SHA_256_VIA_MD_OR_PSA
#endif
#if ( defined(MBEDTLS_MD_C) && defined(MBEDTLS_SHA384_C) ) || \
( defined(MBEDTLS_PSA_CRYPTO_C) && defined(PSA_WANT_ALG_SHA_384) )
( !defined(MBEDTLS_MD_C) && \
defined(MBEDTLS_PSA_CRYPTO_C) && defined(PSA_WANT_ALG_SHA_384) )
#define MBEDTLS_HAS_ALG_SHA_384_VIA_MD_OR_PSA
#endif
#if ( defined(MBEDTLS_MD_C) && defined(MBEDTLS_SHA512_C) ) || \
( defined(MBEDTLS_PSA_CRYPTO_C) && defined(PSA_WANT_ALG_SHA_512) )
( !defined(MBEDTLS_MD_C) && \
defined(MBEDTLS_PSA_CRYPTO_C) && defined(PSA_WANT_ALG_SHA_512) )
#define MBEDTLS_HAS_ALG_SHA_512_VIA_MD_OR_PSA
#endif

View File

@ -54,6 +54,18 @@
#include <stdlib.h>
#endif
/* We use MD first if it's available (for compatibility reasons)
* and "fall back" to PSA otherwise (which needs psa_crypto_init()). */
#if defined(MBEDTLS_PKCS1_V21)
#if defined(MBEDTLS_MD_C)
#define HASH_MAX_SIZE MBEDTLS_MD_MAX_SIZE
#else /* MBEDTLS_MD_C */
#include "psa/crypto.h"
#include "mbedtls/psa_util.h"
#define HASH_MAX_SIZE PSA_HASH_MAX_SIZE
#endif /* MBEDTLS_MD_C */
#endif /* MBEDTLS_PKCS1_V21 */
#if defined(MBEDTLS_PLATFORM_C)
#include "mbedtls/platform.h"
#else
@ -502,10 +514,8 @@ int mbedtls_rsa_set_padding( mbedtls_rsa_context *ctx, int padding,
if( ( padding == MBEDTLS_RSA_PKCS_V21 ) &&
( hash_id != MBEDTLS_MD_NONE ) )
{
const mbedtls_md_info_t *md_info;
md_info = mbedtls_md_info_from_type( hash_id );
if( md_info == NULL )
/* Just make sure this hash is supported in this build. */
if( mbedtls_hash_info_psa_from_md( hash_id ) == PSA_ALG_NONE )
return( MBEDTLS_ERR_RSA_INVALID_PADDING );
}
#endif /* MBEDTLS_PKCS1_V21 */
@ -1095,23 +1105,43 @@ cleanup:
* \param dlen length of destination buffer
* \param src source of the mask generation
* \param slen length of the source buffer
* \param md_ctx message digest context to use
* \param md_alg message digest to use
*/
static int mgf_mask( unsigned char *dst, size_t dlen, unsigned char *src,
size_t slen, mbedtls_md_context_t *md_ctx )
size_t slen, mbedtls_md_type_t md_alg )
{
unsigned char mask[MBEDTLS_MD_MAX_SIZE];
unsigned char counter[4];
unsigned char *p;
unsigned int hlen;
size_t i, use_len;
unsigned char mask[HASH_MAX_SIZE];
#if defined(MBEDTLS_MD_C)
int ret = 0;
const mbedtls_md_info_t *md_info;
mbedtls_md_context_t md_ctx;
memset( mask, 0, MBEDTLS_MD_MAX_SIZE );
mbedtls_md_init( &md_ctx );
md_info = mbedtls_md_info_from_type( md_alg );
if( md_info == NULL )
return( MBEDTLS_ERR_RSA_BAD_INPUT_DATA );
mbedtls_md_init( &md_ctx );
if( ( ret = mbedtls_md_setup( &md_ctx, md_info, 0 ) ) != 0 )
goto exit;
hlen = mbedtls_md_get_size( md_info );
#else
psa_hash_operation_t op = PSA_HASH_OPERATION_INIT;
psa_algorithm_t alg = mbedtls_psa_translate_md( md_alg );
psa_status_t status = PSA_SUCCESS;
size_t out_len;
hlen = PSA_HASH_LENGTH( alg );
#endif
memset( mask, 0, sizeof( mask ) );
memset( counter, 0, 4 );
hlen = mbedtls_md_get_size( md_ctx->md_info );
/* Generate and apply dbMask */
p = dst;
@ -1121,14 +1151,26 @@ static int mgf_mask( unsigned char *dst, size_t dlen, unsigned char *src,
if( dlen < hlen )
use_len = dlen;
if( ( ret = mbedtls_md_starts( md_ctx ) ) != 0 )
#if defined(MBEDTLS_MD_C)
if( ( ret = mbedtls_md_starts( &md_ctx ) ) != 0 )
goto exit;
if( ( ret = mbedtls_md_update( md_ctx, src, slen ) ) != 0 )
if( ( ret = mbedtls_md_update( &md_ctx, src, slen ) ) != 0 )
goto exit;
if( ( ret = mbedtls_md_update( md_ctx, counter, 4 ) ) != 0 )
if( ( ret = mbedtls_md_update( &md_ctx, counter, 4 ) ) != 0 )
goto exit;
if( ( ret = mbedtls_md_finish( md_ctx, mask ) ) != 0 )
if( ( ret = mbedtls_md_finish( &md_ctx, mask ) ) != 0 )
goto exit;
#else
if( ( status = psa_hash_setup( &op, alg ) ) != PSA_SUCCESS )
goto exit;
if( ( status = psa_hash_update( &op, src, slen ) ) != PSA_SUCCESS )
goto exit;
if( ( status = psa_hash_update( &op, counter, 4 ) ) != PSA_SUCCESS )
goto exit;
status = psa_hash_finish( &op, mask, sizeof( mask ), &out_len );
if( status != PSA_SUCCESS )
goto exit;
#endif
for( i = 0; i < use_len; ++i )
*p++ ^= mask[i];
@ -1140,8 +1182,115 @@ static int mgf_mask( unsigned char *dst, size_t dlen, unsigned char *src,
exit:
mbedtls_platform_zeroize( mask, sizeof( mask ) );
#if defined(MBEDTLS_MD_C)
mbedtls_md_free( &md_ctx );
return( ret );
#else
psa_hash_abort( &op );
return( mbedtls_md_error_from_psa( status ) );
#endif
}
/**
* Generate Hash(M') as in RFC 8017 page 43 points 5 and 6.
*
* \param hash the input hash
* \param hlen length of the input hash
* \param salt the input salt
* \param slen length of the input salt
* \param out the output buffer - must be large enough for \p md_alg
* \param md_alg message digest to use
*/
static int hash_mprime( const unsigned char *hash, size_t hlen,
const unsigned char *salt, size_t slen,
unsigned char *out, mbedtls_md_type_t md_alg )
{
const unsigned char zeros[8] = { 0, 0, 0, 0, 0, 0, 0, 0 };
#if defined(MBEDTLS_MD_C)
mbedtls_md_context_t md_ctx;
int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
const mbedtls_md_info_t *md_info = mbedtls_md_info_from_type( md_alg );
if( md_info == NULL )
return( MBEDTLS_ERR_RSA_BAD_INPUT_DATA );
mbedtls_md_init( &md_ctx );
if( ( ret = mbedtls_md_setup( &md_ctx, md_info, 0 ) ) != 0 )
goto exit;
if( ( ret = mbedtls_md_starts( &md_ctx ) ) != 0 )
goto exit;
if( ( ret = mbedtls_md_update( &md_ctx, zeros, sizeof( zeros ) ) ) != 0 )
goto exit;
if( ( ret = mbedtls_md_update( &md_ctx, hash, hlen ) ) != 0 )
goto exit;
if( ( ret = mbedtls_md_update( &md_ctx, salt, slen ) ) != 0 )
goto exit;
if( ( ret = mbedtls_md_finish( &md_ctx, out ) ) != 0 )
goto exit;
exit:
mbedtls_md_free( &md_ctx );
return( ret );
#else
psa_hash_operation_t op = PSA_HASH_OPERATION_INIT;
psa_algorithm_t alg = mbedtls_psa_translate_md( md_alg );
psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED ;
size_t out_size = PSA_HASH_LENGTH( alg );
size_t out_len;
if( ( status = psa_hash_setup( &op, alg ) ) != PSA_SUCCESS )
goto exit;
if( ( status = psa_hash_update( &op, zeros, sizeof( zeros ) ) ) != PSA_SUCCESS )
goto exit;
if( ( status = psa_hash_update( &op, hash, hlen ) ) != PSA_SUCCESS )
goto exit;
if( ( status = psa_hash_update( &op, salt, slen ) ) != PSA_SUCCESS )
goto exit;
status = psa_hash_finish( &op, out, out_size, &out_len );
if( status != PSA_SUCCESS )
goto exit;
exit:
psa_hash_abort( &op );
return( mbedtls_md_error_from_psa( status ) );
#endif /* !MBEDTLS_MD_C */
}
/**
* Compute a hash.
*
* \param md_alg algorithm to use
* \param input input message to hash
* \param ilen input length
* \param output the output buffer - must be large enough for \p md_alg
*/
static int compute_hash( mbedtls_md_type_t md_alg,
const unsigned char *input, size_t ilen,
unsigned char *output )
{
#if defined(MBEDTLS_MD_C)
const mbedtls_md_info_t *md_info;
md_info = mbedtls_md_info_from_type( md_alg );
if( md_info == NULL )
return( MBEDTLS_ERR_RSA_BAD_INPUT_DATA );
return( mbedtls_md( md_info, input, ilen, output ) );
#else
psa_algorithm_t alg = mbedtls_psa_translate_md( md_alg );
psa_status_t status;
size_t out_size = PSA_HASH_LENGTH( alg );
size_t out_len;
status = psa_hash_compute( alg, input, ilen, output, out_size, &out_len );
return( mbedtls_md_error_from_psa( status ) );
#endif /* !MBEDTLS_MD_C */
}
#endif /* MBEDTLS_PKCS1_V21 */
@ -1161,8 +1310,6 @@ int mbedtls_rsa_rsaes_oaep_encrypt( mbedtls_rsa_context *ctx,
int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
unsigned char *p = output;
unsigned int hlen;
const mbedtls_md_info_t *md_info;
mbedtls_md_context_t md_ctx;
RSA_VALIDATE_RET( ctx != NULL );
RSA_VALIDATE_RET( output != NULL );
@ -1172,12 +1319,11 @@ int mbedtls_rsa_rsaes_oaep_encrypt( mbedtls_rsa_context *ctx,
if( f_rng == NULL )
return( MBEDTLS_ERR_RSA_BAD_INPUT_DATA );
md_info = mbedtls_md_info_from_type( (mbedtls_md_type_t) ctx->hash_id );
if( md_info == NULL )
hlen = mbedtls_hash_info_get_size( (mbedtls_md_type_t) ctx->hash_id );
if( hlen == 0 )
return( MBEDTLS_ERR_RSA_BAD_INPUT_DATA );
olen = ctx->len;
hlen = mbedtls_md_get_size( md_info );
/* first comparison checks for overflow */
if( ilen + 2 * hlen + 2 < ilen || olen < ilen + 2 * hlen + 2 )
@ -1194,7 +1340,8 @@ int mbedtls_rsa_rsaes_oaep_encrypt( mbedtls_rsa_context *ctx,
p += hlen;
/* Construct DB */
if( ( ret = mbedtls_md( md_info, label, label_len, p ) ) != 0 )
ret = compute_hash( (mbedtls_md_type_t) ctx->hash_id, label, label_len, p );
if( ret != 0 )
return( ret );
p += hlen;
p += olen - 2 * hlen - 2 - ilen;
@ -1202,24 +1349,14 @@ int mbedtls_rsa_rsaes_oaep_encrypt( mbedtls_rsa_context *ctx,
if( ilen != 0 )
memcpy( p, input, ilen );
mbedtls_md_init( &md_ctx );
if( ( ret = mbedtls_md_setup( &md_ctx, md_info, 0 ) ) != 0 )
goto exit;
/* maskedDB: Apply dbMask to DB */
if( ( ret = mgf_mask( output + hlen + 1, olen - hlen - 1, output + 1, hlen,
&md_ctx ) ) != 0 )
goto exit;
ctx->hash_id ) ) != 0 )
return( ret );
/* maskedSeed: Apply seedMask to seed */
if( ( ret = mgf_mask( output + 1, hlen, output + hlen + 1, olen - hlen - 1,
&md_ctx ) ) != 0 )
goto exit;
exit:
mbedtls_md_free( &md_ctx );
if( ret != 0 )
ctx->hash_id ) ) != 0 )
return( ret );
return( mbedtls_rsa_public( ctx, output, output ) );
@ -1332,10 +1469,8 @@ int mbedtls_rsa_rsaes_oaep_decrypt( mbedtls_rsa_context *ctx,
size_t ilen, i, pad_len;
unsigned char *p, bad, pad_done;
unsigned char buf[MBEDTLS_MPI_MAX_SIZE];
unsigned char lhash[MBEDTLS_MD_MAX_SIZE];
unsigned char lhash[HASH_MAX_SIZE];
unsigned int hlen;
const mbedtls_md_info_t *md_info;
mbedtls_md_context_t md_ctx;
RSA_VALIDATE_RET( ctx != NULL );
RSA_VALIDATE_RET( output_max_len == 0 || output != NULL );
@ -1354,12 +1489,10 @@ int mbedtls_rsa_rsaes_oaep_decrypt( mbedtls_rsa_context *ctx,
if( ilen < 16 || ilen > sizeof( buf ) )
return( MBEDTLS_ERR_RSA_BAD_INPUT_DATA );
md_info = mbedtls_md_info_from_type( (mbedtls_md_type_t) ctx->hash_id );
if( md_info == NULL )
hlen = mbedtls_hash_info_get_size( (mbedtls_md_type_t) ctx->hash_id );
if( hlen == 0 )
return( MBEDTLS_ERR_RSA_BAD_INPUT_DATA );
hlen = mbedtls_md_get_size( md_info );
// checking for integer underflow
if( 2 * hlen + 2 > ilen )
return( MBEDTLS_ERR_RSA_BAD_INPUT_DATA );
@ -1375,28 +1508,20 @@ int mbedtls_rsa_rsaes_oaep_decrypt( mbedtls_rsa_context *ctx,
/*
* Unmask data and generate lHash
*/
mbedtls_md_init( &md_ctx );
if( ( ret = mbedtls_md_setup( &md_ctx, md_info, 0 ) ) != 0 )
{
mbedtls_md_free( &md_ctx );
goto cleanup;
}
/* seed: Apply seedMask to maskedSeed */
if( ( ret = mgf_mask( buf + 1, hlen, buf + hlen + 1, ilen - hlen - 1,
&md_ctx ) ) != 0 ||
ctx->hash_id ) ) != 0 ||
/* DB: Apply dbMask to maskedDB */
( ret = mgf_mask( buf + hlen + 1, ilen - hlen - 1, buf + 1, hlen,
&md_ctx ) ) != 0 )
ctx->hash_id ) ) != 0 )
{
mbedtls_md_free( &md_ctx );
goto cleanup;
}
mbedtls_md_free( &md_ctx );
/* Generate lHash */
if( ( ret = mbedtls_md( md_info, label, label_len, lhash ) ) != 0 )
ret = compute_hash( (mbedtls_md_type_t) ctx->hash_id,
label, label_len, lhash );
if( ret != 0 )
goto cleanup;
/*
@ -1553,8 +1678,7 @@ static int rsa_rsassa_pss_sign( mbedtls_rsa_context *ctx,
size_t slen, min_slen, hlen, offset = 0;
int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
size_t msb;
const mbedtls_md_info_t *md_info;
mbedtls_md_context_t md_ctx;
RSA_VALIDATE_RET( ctx != NULL );
RSA_VALIDATE_RET( ( md_alg == MBEDTLS_MD_NONE &&
hashlen == 0 ) ||
@ -1572,20 +1696,18 @@ static int rsa_rsassa_pss_sign( mbedtls_rsa_context *ctx,
if( md_alg != MBEDTLS_MD_NONE )
{
/* Gather length of hash to sign */
md_info = mbedtls_md_info_from_type( md_alg );
if( md_info == NULL )
size_t exp_hashlen = mbedtls_hash_info_get_size( md_alg );
if( exp_hashlen == 0 )
return( MBEDTLS_ERR_RSA_BAD_INPUT_DATA );
if( hashlen != mbedtls_md_get_size( md_info ) )
if( hashlen != exp_hashlen )
return( MBEDTLS_ERR_RSA_BAD_INPUT_DATA );
}
md_info = mbedtls_md_info_from_type( (mbedtls_md_type_t) ctx->hash_id );
if( md_info == NULL )
hlen = mbedtls_hash_info_get_size( (mbedtls_md_type_t) ctx->hash_id );
if( hlen == 0 )
return( MBEDTLS_ERR_RSA_BAD_INPUT_DATA );
hlen = mbedtls_md_get_size( md_info );
if (saltlen == MBEDTLS_RSA_SALT_LEN_ANY)
{
/* Calculate the largest possible salt length, up to the hash size.
@ -1626,30 +1748,20 @@ static int rsa_rsassa_pss_sign( mbedtls_rsa_context *ctx,
p += slen;
mbedtls_md_init( &md_ctx );
if( ( ret = mbedtls_md_setup( &md_ctx, md_info, 0 ) ) != 0 )
goto exit;
/* Generate H = Hash( M' ) */
if( ( ret = mbedtls_md_starts( &md_ctx ) ) != 0 )
goto exit;
if( ( ret = mbedtls_md_update( &md_ctx, p, 8 ) ) != 0 )
goto exit;
if( ( ret = mbedtls_md_update( &md_ctx, hash, hashlen ) ) != 0 )
goto exit;
if( ( ret = mbedtls_md_update( &md_ctx, salt, slen ) ) != 0 )
goto exit;
if( ( ret = mbedtls_md_finish( &md_ctx, p ) ) != 0 )
goto exit;
ret = hash_mprime( hash, hashlen, salt, slen, p, ctx->hash_id );
if( ret != 0 )
return( ret );
/* Compensate for boundary condition when applying mask */
if( msb % 8 == 0 )
offset = 1;
/* maskedDB: Apply dbMask to DB */
if( ( ret = mgf_mask( sig + offset, olen - hlen - 1 - offset, p, hlen,
&md_ctx ) ) != 0 )
goto exit;
ret = mgf_mask( sig + offset, olen - hlen - 1 - offset, p, hlen,
ctx->hash_id );
if( ret != 0 )
return( ret );
msb = mbedtls_mpi_bitlen( &ctx->N ) - 1;
sig[0] &= 0xFF >> ( olen * 8 - msb );
@ -1657,12 +1769,6 @@ static int rsa_rsassa_pss_sign( mbedtls_rsa_context *ctx,
p += hlen;
*p++ = 0xBC;
exit:
mbedtls_md_free( &md_ctx );
if( ret != 0 )
return( ret );
return mbedtls_rsa_private( ctx, f_rng, p_rng, sig, sig );
}
@ -1958,12 +2064,9 @@ int mbedtls_rsa_rsassa_pss_verify_ext( mbedtls_rsa_context *ctx,
size_t siglen;
unsigned char *p;
unsigned char *hash_start;
unsigned char result[MBEDTLS_MD_MAX_SIZE];
unsigned char zeros[8];
unsigned char result[HASH_MAX_SIZE];
unsigned int hlen;
size_t observed_salt_len, msb;
const mbedtls_md_info_t *md_info;
mbedtls_md_context_t md_ctx;
unsigned char buf[MBEDTLS_MPI_MAX_SIZE] = {0};
RSA_VALIDATE_RET( ctx != NULL );
@ -1990,22 +2093,18 @@ int mbedtls_rsa_rsassa_pss_verify_ext( mbedtls_rsa_context *ctx,
if( md_alg != MBEDTLS_MD_NONE )
{
/* Gather length of hash to sign */
md_info = mbedtls_md_info_from_type( md_alg );
if( md_info == NULL )
size_t exp_hashlen = mbedtls_hash_info_get_size( md_alg );
if( exp_hashlen == 0 )
return( MBEDTLS_ERR_RSA_BAD_INPUT_DATA );
if( hashlen != mbedtls_md_get_size( md_info ) )
if( hashlen != exp_hashlen )
return( MBEDTLS_ERR_RSA_BAD_INPUT_DATA );
}
md_info = mbedtls_md_info_from_type( mgf1_hash_id );
if( md_info == NULL )
hlen = mbedtls_hash_info_get_size( mgf1_hash_id );
if( hlen == 0 )
return( MBEDTLS_ERR_RSA_BAD_INPUT_DATA );
hlen = mbedtls_md_get_size( md_info );
memset( zeros, 0, 8 );
/*
* Note: EMSA-PSS verification is over the length of N - 1 bits
*/
@ -2025,13 +2124,9 @@ int mbedtls_rsa_rsassa_pss_verify_ext( mbedtls_rsa_context *ctx,
return( MBEDTLS_ERR_RSA_BAD_INPUT_DATA );
hash_start = p + siglen - hlen - 1;
mbedtls_md_init( &md_ctx );
if( ( ret = mbedtls_md_setup( &md_ctx, md_info, 0 ) ) != 0 )
goto exit;
ret = mgf_mask( p, siglen - hlen - 1, hash_start, hlen, &md_ctx );
ret = mgf_mask( p, siglen - hlen - 1, hash_start, hlen, mgf1_hash_id );
if( ret != 0 )
goto exit;
return( ret );
buf[0] &= 0xFF >> ( siglen * 8 - msb );
@ -2039,49 +2134,28 @@ int mbedtls_rsa_rsassa_pss_verify_ext( mbedtls_rsa_context *ctx,
p++;
if( *p++ != 0x01 )
{
ret = MBEDTLS_ERR_RSA_INVALID_PADDING;
goto exit;
}
return( MBEDTLS_ERR_RSA_INVALID_PADDING );
observed_salt_len = hash_start - p;
if( expected_salt_len != MBEDTLS_RSA_SALT_LEN_ANY &&
observed_salt_len != (size_t) expected_salt_len )
{
ret = MBEDTLS_ERR_RSA_INVALID_PADDING;
goto exit;
return( MBEDTLS_ERR_RSA_INVALID_PADDING );
}
/*
* Generate H = Hash( M' )
*/
ret = mbedtls_md_starts( &md_ctx );
if ( ret != 0 )
goto exit;
ret = mbedtls_md_update( &md_ctx, zeros, 8 );
if ( ret != 0 )
goto exit;
ret = mbedtls_md_update( &md_ctx, hash, hashlen );
if ( ret != 0 )
goto exit;
ret = mbedtls_md_update( &md_ctx, p, observed_salt_len );
if ( ret != 0 )
goto exit;
ret = mbedtls_md_finish( &md_ctx, result );
if ( ret != 0 )
goto exit;
ret = hash_mprime( hash, hashlen, p, observed_salt_len,
result, mgf1_hash_id );
if( ret != 0 )
return( ret );
if( memcmp( hash_start, result, hlen ) != 0 )
{
ret = MBEDTLS_ERR_RSA_VERIFY_FAILED;
goto exit;
}
return( MBEDTLS_ERR_RSA_VERIFY_FAILED );
exit:
mbedtls_md_free( &md_ctx );
return( ret );
return( 0 );
}
/*