diff --git a/library/pk_wrap.c b/library/pk_wrap.c index c232650229..ff8eeb14cc 100644 --- a/library/pk_wrap.c +++ b/library/pk_wrap.c @@ -32,6 +32,7 @@ #if defined(MBEDTLS_RSA_C) #include "pkwrite.h" +#include "rsa_internal.h" #endif #if defined(MBEDTLS_PK_CAN_ECDSA_SOME) @@ -69,9 +70,9 @@ static int rsa_verify_wrap(mbedtls_pk_context *pk, mbedtls_md_type_t md_alg, psa_key_attributes_t attributes = PSA_KEY_ATTRIBUTES_INIT; mbedtls_svc_key_id_t key_id = MBEDTLS_SVC_KEY_ID_INIT; psa_status_t status; - mbedtls_pk_context key; int key_len; unsigned char buf[MBEDTLS_PK_RSA_PUB_DER_MAX_BYTES]; + unsigned char *p = buf + sizeof(buf); psa_algorithm_t psa_alg_md = PSA_ALG_RSA_PKCS1V15_SIGN(mbedtls_md_psa_alg_from_type(md_alg)); size_t rsa_len = mbedtls_rsa_get_len(rsa); @@ -86,11 +87,7 @@ static int rsa_verify_wrap(mbedtls_pk_context *pk, mbedtls_md_type_t md_alg, return MBEDTLS_ERR_RSA_VERIFY_FAILED; } - /* mbedtls_pk_write_pubkey_der() expects a full PK context; - * re-construct one to make it happy */ - key.pk_info = &mbedtls_rsa_info; - key.pk_ctx = rsa; - key_len = mbedtls_pk_write_pubkey_der(&key, buf, sizeof(buf)); + key_len = mbedtls_rsa_pubkey_write(rsa, buf, &p); if (key_len <= 0) { return MBEDTLS_ERR_PK_BAD_INPUT_DATA; } @@ -172,14 +169,15 @@ int mbedtls_pk_psa_rsa_sign_ext(psa_algorithm_t alg, psa_key_attributes_t attributes = PSA_KEY_ATTRIBUTES_INIT; mbedtls_svc_key_id_t key_id = MBEDTLS_SVC_KEY_ID_INIT; psa_status_t status; - mbedtls_pk_context key; int key_len; unsigned char *buf = NULL; + unsigned char *p; + buf = mbedtls_calloc(1, MBEDTLS_PK_RSA_PRV_DER_MAX_BYTES); if (buf == NULL) { return MBEDTLS_ERR_PK_ALLOC_FAILED; } - mbedtls_pk_info_t pk_info = mbedtls_rsa_info; + p = buf + MBEDTLS_PK_RSA_PRV_DER_MAX_BYTES; *sig_len = mbedtls_rsa_get_len(rsa_ctx); if (sig_size < *sig_len) { @@ -187,11 +185,7 @@ int mbedtls_pk_psa_rsa_sign_ext(psa_algorithm_t alg, return MBEDTLS_ERR_PK_BUFFER_TOO_SMALL; } - /* mbedtls_pk_write_key_der() expects a full PK context; - * re-construct one to make it happy */ - key.pk_info = &pk_info; - key.pk_ctx = rsa_ctx; - key_len = mbedtls_pk_write_key_der(&key, buf, MBEDTLS_PK_RSA_PRV_DER_MAX_BYTES); + key_len = mbedtls_rsa_key_write(rsa_ctx, buf, &p); if (key_len <= 0) { mbedtls_free(buf); return MBEDTLS_ERR_PK_BAD_INPUT_DATA; @@ -282,9 +276,9 @@ static int rsa_decrypt_wrap(mbedtls_pk_context *pk, psa_key_attributes_t attributes = PSA_KEY_ATTRIBUTES_INIT; mbedtls_svc_key_id_t key_id = MBEDTLS_SVC_KEY_ID_INIT; psa_status_t status; - mbedtls_pk_context key; int key_len; unsigned char buf[MBEDTLS_PK_RSA_PRV_DER_MAX_BYTES]; + unsigned char *p = buf + sizeof(buf); ((void) f_rng); ((void) p_rng); @@ -299,11 +293,7 @@ static int rsa_decrypt_wrap(mbedtls_pk_context *pk, return MBEDTLS_ERR_RSA_BAD_INPUT_DATA; } - /* mbedtls_pk_write_key_der() expects a full PK context; - * re-construct one to make it happy */ - key.pk_info = &mbedtls_rsa_info; - key.pk_ctx = rsa; - key_len = mbedtls_pk_write_key_der(&key, buf, sizeof(buf)); + key_len = mbedtls_rsa_key_write(rsa, buf, &p); if (key_len <= 0) { return MBEDTLS_ERR_PK_BAD_INPUT_DATA; } @@ -368,9 +358,9 @@ static int rsa_encrypt_wrap(mbedtls_pk_context *pk, psa_key_attributes_t attributes = PSA_KEY_ATTRIBUTES_INIT; mbedtls_svc_key_id_t key_id = MBEDTLS_SVC_KEY_ID_INIT; psa_status_t status; - mbedtls_pk_context key; int key_len; unsigned char buf[MBEDTLS_PK_RSA_PUB_DER_MAX_BYTES]; + unsigned char *p = buf + sizeof(buf); ((void) f_rng); ((void) p_rng); @@ -385,11 +375,7 @@ static int rsa_encrypt_wrap(mbedtls_pk_context *pk, return MBEDTLS_ERR_RSA_OUTPUT_TOO_LARGE; } - /* mbedtls_pk_write_pubkey_der() expects a full PK context; - * re-construct one to make it happy */ - key.pk_info = &mbedtls_rsa_info; - key.pk_ctx = rsa; - key_len = mbedtls_pk_write_pubkey_der(&key, buf, sizeof(buf)); + key_len = mbedtls_rsa_pubkey_write(rsa, buf, &p); if (key_len <= 0) { return MBEDTLS_ERR_PK_BAD_INPUT_DATA; }