diff --git a/src/mbedtls.c b/src/mbedtls.c index dc76ef59..4ff751b3 100644 --- a/src/mbedtls.c +++ b/src/mbedtls.c @@ -455,28 +455,58 @@ _libssh2_mbedtls_rsa_new_private_frommemory(libssh2_rsa_ctx **rsa, } int -_libssh2_mbedtls_rsa_sha1_verify(libssh2_rsa_ctx *rsa, - const unsigned char *sig, - unsigned long sig_len, - const unsigned char *m, - unsigned long m_len) +_libssh2_mbedtls_rsa_sha2_verify(libssh2_rsa_ctx * rsactx, + size_t hash_len, + const unsigned char *sig, + unsigned long sig_len, + const unsigned char *m, unsigned long m_len) { - unsigned char hash[SHA_DIGEST_LENGTH]; int ret; + int md_type; + unsigned char *hash = malloc(hash_len); + if(hash == NULL) + return -1; - ret = _libssh2_mbedtls_hash(m, m_len, MBEDTLS_MD_SHA1, hash); - if(ret) + if(hash_len == SHA_DIGEST_LENGTH) { + md_type = MBEDTLS_MD_SHA1; + } + else if(hash_len == SHA256_DIGEST_LENGTH) { + md_type = MBEDTLS_MD_SHA256; + } + else if(hash_len == SHA512_DIGEST_LENGTH) { + md_type = MBEDTLS_MD_SHA512; + } + else{ + free(hash); + return -1; /* unsupported digest */ + } + ret = _libssh2_mbedtls_hash(m, m_len, md_type, hash); + + if(ret != 0) { + free(hash); return -1; /* failure */ + } - ret = mbedtls_rsa_pkcs1_verify(rsa, NULL, NULL, MBEDTLS_RSA_PUBLIC, - MBEDTLS_MD_SHA1, SHA_DIGEST_LENGTH, + ret = mbedtls_rsa_pkcs1_verify(rsactx, NULL, NULL, MBEDTLS_RSA_PUBLIC, + md_type, hash_len, hash, sig); + free(hash); - return (ret == 0) ? 0 : -1; + return (ret == 1) ? 0 : -1; } int -_libssh2_mbedtls_rsa_sha1_sign(LIBSSH2_SESSION *session, +_libssh2_mbedtls_rsa_sha1_verify(libssh2_rsa_ctx * rsactx, + const unsigned char *sig, + unsigned long sig_len, + const unsigned char *m, unsigned long m_len) +{ + return _libssh2_mbedtls_rsa_sha2_verify(rsactx, SHA_DIGEST_LENGTH, + sig, sig_len, m, m_len); +} + +int +_libssh2_mbedtls_rsa_sha2_sign(LIBSSH2_SESSION *session, libssh2_rsa_ctx *rsa, const unsigned char *hash, size_t hash_len, @@ -486,7 +516,7 @@ _libssh2_mbedtls_rsa_sha1_sign(LIBSSH2_SESSION *session, int ret; unsigned char *sig; unsigned int sig_len; - + int md_type; (void)hash_len; sig_len = rsa->len; @@ -494,9 +524,22 @@ _libssh2_mbedtls_rsa_sha1_sign(LIBSSH2_SESSION *session, if(!sig) { return -1; } - + if(hash_len == SHA_DIGEST_LENGTH) { + md_type = MBEDTLS_MD_SHA1; + } + else if(hash_len == SHA256_DIGEST_LENGTH) { + md_type = MBEDTLS_MD_SHA256; + } + else if(hash_len == SHA512_DIGEST_LENGTH) { + md_type = MBEDTLS_MD_SHA512; + } + else { + _libssh2_error(session, LIBSSH2_ERROR_PROTO, + "Unsupported hash digest length"); + ret = -1; + } ret = mbedtls_rsa_pkcs1_sign(rsa, NULL, NULL, MBEDTLS_RSA_PRIVATE, - MBEDTLS_MD_SHA1, SHA_DIGEST_LENGTH, + md_type, hash_len, hash, sig); if(ret) { LIBSSH2_FREE(session, sig); @@ -509,6 +552,17 @@ _libssh2_mbedtls_rsa_sha1_sign(LIBSSH2_SESSION *session, return (ret == 0) ? 0 : -1; } +int +_libssh2_mbedtls_rsa_sha1_sign(LIBSSH2_SESSION * session, + libssh2_rsa_ctx * rsactx, + const unsigned char *hash, + size_t hash_len, + unsigned char **signature, size_t *signature_len) +{ + return _libssh2_mbedtls_rsa_sha2_sign(session, rsactx, hash, hash_len, + signature, signature_len); +} + void _libssh2_mbedtls_rsa_free(libssh2_rsa_ctx *ctx) { @@ -1260,8 +1314,13 @@ _libssh2_supported_key_sign_algorithms(LIBSSH2_SESSION *session, size_t key_method_len) { (void)session; - (void)key_method; - (void)key_method_len; + +#if LIBSSH2_RSA_SHA2 + if(key_method_len == 7 && + memcmp(key_method, "ssh-rsa", key_method_len) == 0) { + return "rsa-sha2-512,rsa-sha2-256,ssh-rsa"; + } +#endif return NULL; } diff --git a/src/mbedtls.h b/src/mbedtls.h index 0450113f..e86ebd26 100644 --- a/src/mbedtls.h +++ b/src/mbedtls.h @@ -71,7 +71,7 @@ #define LIBSSH2_3DES 1 #define LIBSSH2_RSA 1 -#define LIBSSH2_RSA_SHA2 0 +#define LIBSSH2_RSA_SHA2 1 #define LIBSSH2_DSA 0 #ifdef MBEDTLS_ECDSA_C # define LIBSSH2_ECDSA 1 @@ -243,9 +243,16 @@ #define _libssh2_rsa_sha1_sign(s, rsactx, hash, hash_len, sig, sig_len) \ _libssh2_mbedtls_rsa_sha1_sign(s, rsactx, hash, hash_len, sig, sig_len) +#define _libssh2_rsa_sha2_sign(s, rsactx, hash, hash_len, sig, sig_len) \ + _libssh2_mbedtls_rsa_sha2_sign(s, rsactx, hash, hash_len, sig, sig_len) + + #define _libssh2_rsa_sha1_verify(rsactx, sig, sig_len, m, m_len) \ _libssh2_mbedtls_rsa_sha1_verify(rsactx, sig, sig_len, m, m_len) +#define _libssh2_rsa_sha2_verify(rsactx, hash_len, sig, sig_len, m, m_len) \ + _libssh2_mbedtls_rsa_sha2_verify(rsactx, hash_len, sig, sig_len, m, m_len) + #define _libssh2_rsa_free(rsactx) \ _libssh2_mbedtls_rsa_free(rsactx)