diff --git a/library/psa_crypto.c b/library/psa_crypto.c index ec23830a2e..acb39a1bcf 100644 --- a/library/psa_crypto.c +++ b/library/psa_crypto.c @@ -634,6 +634,23 @@ psa_status_t psa_import_key_into_slot( return PSA_SUCCESS; } else if (PSA_KEY_TYPE_IS_ASYMMETRIC(type)) { +#if defined(MBEDTLS_PSA_BUILTIN_KEY_TYPE_DH_KEY_PAIR) || \ + defined(MBEDTLS_PSA_BUILTIN_KEY_TYPE_DH_PUBLIC_KEY) + if (PSA_KEY_TYPE_IS_DH(type)) { + if (psa_is_dh_key_size_valid(PSA_BYTES_TO_BITS(data_length)) == 0) { + return PSA_ERROR_INVALID_ARGUMENT; + } + + /* Copy the key material. */ + memcpy(key_buffer, data, data_length); + *key_buffer_length = data_length; + *bits = PSA_BYTES_TO_BITS(data_length); + (void) key_buffer_size; + + return PSA_SUCCESS; + } +#endif /* defined(MBEDTLS_PSA_BUILTIN_KEY_TYPE_DH_KEY_PAIR) || + * defined(MBEDTLS_PSA_BUILTIN_KEY_TYPE_DH_PUBLIC_KEY) */ #if defined(MBEDTLS_PSA_BUILTIN_KEY_TYPE_ECC_KEY_PAIR) || \ defined(MBEDTLS_PSA_BUILTIN_KEY_TYPE_ECC_PUBLIC_KEY) if (PSA_KEY_TYPE_IS_ECC(type)) { @@ -1403,7 +1420,14 @@ psa_status_t psa_export_public_key_internal( { psa_key_type_t type = attributes->core.type; - if (PSA_KEY_TYPE_IS_RSA(type)) { + if (PSA_KEY_TYPE_IS_PUBLIC_KEY(type) && + (PSA_KEY_TYPE_IS_RSA(type) || PSA_KEY_TYPE_IS_ECC(type) || + PSA_KEY_TYPE_IS_DH(type))) { + /* Exporting public -> public */ + return psa_export_key_buffer_internal( + key_buffer, key_buffer_size, + data, data_size, data_length); + } else if (PSA_KEY_TYPE_IS_RSA(type)) { #if defined(MBEDTLS_PSA_BUILTIN_KEY_TYPE_RSA_KEY_PAIR) || \ defined(MBEDTLS_PSA_BUILTIN_KEY_TYPE_RSA_PUBLIC_KEY) return mbedtls_psa_rsa_export_public_key(attributes, @@ -1489,23 +1513,9 @@ psa_status_t psa_export_public_key(mbedtls_svc_key_id_t key, psa_key_attributes_t attributes = { .core = slot->attr }; - - psa_key_location_t location = PSA_KEY_LIFETIME_GET_LOCATION( - psa_get_key_lifetime(&attributes)); - - if (location == PSA_KEY_LOCATION_LOCAL_STORAGE && - PSA_KEY_TYPE_IS_PUBLIC_KEY(slot->attr.type) && - (PSA_KEY_TYPE_IS_RSA(slot->attr.type) || PSA_KEY_TYPE_IS_ECC(slot->attr.type) || - PSA_KEY_TYPE_IS_DH(slot->attr.type))) { - /* Exporting public -> public */ - status = psa_export_key_buffer_internal( - slot->key.data, slot->key.bytes, - data, data_size, data_length); - } else { - status = psa_driver_wrapper_export_public_key( - &attributes, slot->key.data, slot->key.bytes, - data, data_size, data_length); - } + status = psa_driver_wrapper_export_public_key( + &attributes, slot->key.data, slot->key.bytes, + data, data_size, data_length); exit: unlock_status = psa_unlock_key_slot(slot); @@ -2000,27 +2010,12 @@ psa_status_t psa_import_key(const psa_key_attributes_t *attributes, } } - if (PSA_KEY_TYPE_IS_ASYMMETRIC(attributes->core.type) && - PSA_KEY_TYPE_IS_DH(attributes->core.type)) { - if (psa_is_dh_key_size_valid(PSA_BYTES_TO_BITS(data_length)) == 0) { - status = PSA_ERROR_INVALID_ARGUMENT; - goto exit; - } - - /* Copy the key material. */ - memcpy(slot->key.data, data, data_length); - bits = PSA_BYTES_TO_BITS(data_length); - - status = PSA_SUCCESS; - } else { - bits = slot->attr.bits; - status = psa_driver_wrapper_import_key(attributes, - data, data_length, - slot->key.data, - slot->key.bytes, - &slot->key.bytes, &bits); - } - + bits = slot->attr.bits; + status = psa_driver_wrapper_import_key(attributes, + data, data_length, + slot->key.data, + slot->key.bytes, + &slot->key.bytes, &bits); if (status != PSA_SUCCESS) { goto exit; } @@ -5835,25 +5830,11 @@ static psa_status_t psa_generate_derived_key_internal( goto exit; } - if (PSA_KEY_TYPE_IS_ASYMMETRIC(attributes.core.type) && - PSA_KEY_TYPE_IS_DH(attributes.core.type)) { - if (psa_is_dh_key_size_valid(PSA_BYTES_TO_BITS(bytes)) == 0) { - status = PSA_ERROR_INVALID_ARGUMENT; - goto exit; - } - - /* Copy the key material. */ - memcpy(slot->key.data, data, bytes); - bits = PSA_BYTES_TO_BITS(bytes); - - status = PSA_SUCCESS; - } else { - status = psa_driver_wrapper_import_key(&attributes, - data, bytes, - slot->key.data, - slot->key.bytes, - &slot->key.bytes, &bits); - } + status = psa_driver_wrapper_import_key(&attributes, + data, bytes, + slot->key.data, + slot->key.bytes, + &slot->key.bytes, &bits); if (bits != slot->attr.bits) { status = PSA_ERROR_INVALID_ARGUMENT; }