From 548dfff0aef25e36e971af96b49ce7fbb72d840e Mon Sep 17 00:00:00 2001 From: yhirose Date: Sat, 9 Mar 2024 22:26:17 -0500 Subject: [PATCH] Fix #1793 --- httplib.h | 33 ++++++++++++++++------- test/Makefile | 4 ++- test/test.cc | 74 ++++++++++++++++++++++++++++++++++----------------- 3 files changed, 76 insertions(+), 35 deletions(-) diff --git a/httplib.h b/httplib.h index cabdb48..a77d6d9 100644 --- a/httplib.h +++ b/httplib.h @@ -145,11 +145,11 @@ using ssize_t = long; #endif // _MSC_VER #ifndef S_ISREG -#define S_ISREG(m) (((m) & S_IFREG) == S_IFREG) +#define S_ISREG(m) (((m)&S_IFREG) == S_IFREG) #endif // S_ISREG #ifndef S_ISDIR -#define S_ISDIR(m) (((m) & S_IFDIR) == S_IFDIR) +#define S_ISDIR(m) (((m)&S_IFDIR) == S_IFDIR) #endif // S_ISDIR #ifndef NOMINMAX @@ -1745,10 +1745,12 @@ public: explicit SSLClient(const std::string &host, int port, const std::string &client_cert_path, - const std::string &client_key_path); + const std::string &client_key_path, + const std::string &private_key_password = std::string()); explicit SSLClient(const std::string &host, int port, X509 *client_cert, - EVP_PKEY *client_key); + EVP_PKEY *client_key, + const std::string &private_key_password = std::string()); ~SSLClient() override; @@ -2700,8 +2702,8 @@ inline bool mmap::open(const char *path) { if (!::GetFileSizeEx(hFile_, &size)) { return false; } size_ = static_cast(size.QuadPart); - hMapping_ = ::CreateFileMappingFromApp(hFile_, NULL, PAGE_READONLY, size_, - NULL); + hMapping_ = + ::CreateFileMappingFromApp(hFile_, NULL, PAGE_READONLY, size_, NULL); if (hMapping_ == NULL) { close(); @@ -8438,7 +8440,6 @@ inline SSLServer::SSLServer(const char *cert_path, const char *private_key_path, SSL_CTX_set_min_proto_version(ctx_, TLS1_1_VERSION); - // add default password callback before opening encrypted private key if (private_key_password != nullptr && (private_key_password[0] != '\0')) { SSL_CTX_set_default_passwd_cb_userdata( ctx_, @@ -8544,7 +8545,8 @@ inline SSLClient::SSLClient(const std::string &host, int port) inline SSLClient::SSLClient(const std::string &host, int port, const std::string &client_cert_path, - const std::string &client_key_path) + const std::string &client_key_path, + const std::string &private_key_password) : ClientImpl(host, port, client_cert_path, client_key_path) { ctx_ = SSL_CTX_new(TLS_client_method()); @@ -8554,6 +8556,12 @@ inline SSLClient::SSLClient(const std::string &host, int port, }); if (!client_cert_path.empty() && !client_key_path.empty()) { + if (!private_key_password.empty()) { + SSL_CTX_set_default_passwd_cb_userdata( + ctx_, reinterpret_cast( + const_cast(private_key_password.c_str()))); + } + if (SSL_CTX_use_certificate_file(ctx_, client_cert_path.c_str(), SSL_FILETYPE_PEM) != 1 || SSL_CTX_use_PrivateKey_file(ctx_, client_key_path.c_str(), @@ -8565,7 +8573,8 @@ inline SSLClient::SSLClient(const std::string &host, int port, } inline SSLClient::SSLClient(const std::string &host, int port, - X509 *client_cert, EVP_PKEY *client_key) + X509 *client_cert, EVP_PKEY *client_key, + const std::string &private_key_password) : ClientImpl(host, port) { ctx_ = SSL_CTX_new(TLS_client_method()); @@ -8575,6 +8584,12 @@ inline SSLClient::SSLClient(const std::string &host, int port, }); if (client_cert != nullptr && client_key != nullptr) { + if (!private_key_password.empty()) { + SSL_CTX_set_default_passwd_cb_userdata( + ctx_, reinterpret_cast( + const_cast(private_key_password.c_str()))); + } + if (SSL_CTX_use_certificate(ctx_, client_cert) != 1 || SSL_CTX_use_PrivateKey(ctx_, client_key) != 1) { SSL_CTX_free(ctx_); diff --git a/test/Makefile b/test/Makefile index 3b72d20..6594af1 100644 --- a/test/Makefile +++ b/test/Makefile @@ -70,9 +70,11 @@ cert.pem: openssl genrsa 2048 > rootCA.key.pem openssl req -x509 -new -batch -config test.rootCA.conf -key rootCA.key.pem -days 1024 > rootCA.cert.pem openssl genrsa 2048 > client.key.pem - openssl req -new -batch -config test.conf -key client.key.pem | openssl x509 -days 370 -req -CA rootCA.cert.pem -CAkey rootCA.key.pem -CAcreateserial > client.cert.pem + openssl req -new -batch -config test.conf -key client.key.pem| openssl x509 -days 370 -req -CA rootCA.cert.pem -CAkey rootCA.key.pem -CAcreateserial > client.cert.pem openssl genrsa -passout pass:test123! 2048 > key_encrypted.pem openssl req -new -batch -config test.conf -key key_encrypted.pem | openssl x509 -days 3650 -req -signkey key_encrypted.pem > cert_encrypted.pem + openssl genrsa -aes256 -passout pass:test012! 2048 > client_encrypted.key.pem + openssl req -new -batch -config test.conf -key client_encrypted.key.pem -passin pass:test012! | openssl x509 -days 370 -req -CA rootCA.cert.pem -CAkey rootCA.key.pem -CAcreateserial > client_encrypted.cert.pem #c_rehash . clean: diff --git a/test/test.cc b/test/test.cc index 64fc8f2..801673a 100644 --- a/test/test.cc +++ b/test/test.cc @@ -20,6 +20,9 @@ #define CLIENT_CA_CERT_DIR "." #define CLIENT_CERT_FILE "./client.cert.pem" #define CLIENT_PRIVATE_KEY_FILE "./client.key.pem" +#define CLIENT_ENCRYPTED_CERT_FILE "./client_encrypted.cert.pem" +#define CLIENT_ENCRYPTED_PRIVATE_KEY_FILE "./client_encrypted.key.pem" +#define CLIENT_ENCRYPTED_PRIVATE_KEY_PASS "test012!" #define SERVER_ENCRYPTED_CERT_FILE "./cert_encrypted.pem" #define SERVER_ENCRYPTED_PRIVATE_KEY_FILE "./key_encrypted.pem" #define SERVER_ENCRYPTED_PRIVATE_KEY_PASS "test123!" @@ -5109,15 +5112,16 @@ TEST(SSLClientTest, SetInterfaceWithINET6) { } #endif -TEST(SSLClientServerTest, ClientCertPresent) { +void ClientCertPresent( + const std::string &client_cert_file, + const std::string &client_private_key_file, + const std::string &client_encrypted_private_key_pass = std::string()) { SSLServer svr(SERVER_CERT_FILE, SERVER_PRIVATE_KEY_FILE, CLIENT_CA_CERT_FILE, CLIENT_CA_CERT_DIR); ASSERT_TRUE(svr.is_valid()); svr.Get("/test", [&](const Request &req, Response &res) { res.set_content("test", "text/plain"); - svr.stop(); - ASSERT_TRUE(true); auto peer_cert = SSL_get_peer_certificate(req.ssl); ASSERT_TRUE(peer_cert != nullptr); @@ -5140,13 +5144,15 @@ TEST(SSLClientServerTest, ClientCertPresent) { thread t = thread([&]() { ASSERT_TRUE(svr.listen(HOST, PORT)); }); auto se = detail::scope_exit([&] { + svr.stop(); t.join(); ASSERT_FALSE(svr.is_running()); }); svr.wait_until_ready(); - SSLClient cli(HOST, PORT, CLIENT_CERT_FILE, CLIENT_PRIVATE_KEY_FILE); + SSLClient cli(HOST, PORT, client_cert_file, client_private_key_file, + client_encrypted_private_key_pass); cli.enable_server_certificate_verification(false); cli.set_connection_timeout(30); @@ -5155,35 +5161,43 @@ TEST(SSLClientServerTest, ClientCertPresent) { ASSERT_EQ(StatusCode::OK_200, res->status); } -#if !defined(_WIN32) || defined(OPENSSL_USE_APPLINK) -TEST(SSLClientServerTest, MemoryClientCertPresent) { - X509 *server_cert; - EVP_PKEY *server_private_key; - X509_STORE *client_ca_cert_store; - X509 *client_cert; - EVP_PKEY *client_private_key; +TEST(SSLClientServerTest, ClientCertPresent) { + ClientCertPresent(CLIENT_CERT_FILE, CLIENT_PRIVATE_KEY_FILE); +} - FILE *f = fopen(SERVER_CERT_FILE, "r+"); - server_cert = PEM_read_X509(f, nullptr, nullptr, nullptr); +TEST(SSLClientServerTest, ClientEncryptedCertPresent) { + ClientCertPresent(CLIENT_ENCRYPTED_CERT_FILE, + CLIENT_ENCRYPTED_PRIVATE_KEY_FILE, + CLIENT_ENCRYPTED_PRIVATE_KEY_PASS); +} + +#if !defined(_WIN32) || defined(OPENSSL_USE_APPLINK) +void MemoryClientCertPresent( + const std::string &client_cert_file, + const std::string &client_private_key_file, + const std::string &client_encrypted_private_key_pass = std::string()) { + auto f = fopen(SERVER_CERT_FILE, "r+"); + auto server_cert = PEM_read_X509(f, nullptr, nullptr, nullptr); fclose(f); f = fopen(SERVER_PRIVATE_KEY_FILE, "r+"); - server_private_key = PEM_read_PrivateKey(f, nullptr, nullptr, nullptr); + auto server_private_key = PEM_read_PrivateKey(f, nullptr, nullptr, nullptr); fclose(f); f = fopen(CLIENT_CA_CERT_FILE, "r+"); - client_cert = PEM_read_X509(f, nullptr, nullptr, nullptr); - client_ca_cert_store = X509_STORE_new(); + auto client_cert = PEM_read_X509(f, nullptr, nullptr, nullptr); + auto client_ca_cert_store = X509_STORE_new(); X509_STORE_add_cert(client_ca_cert_store, client_cert); X509_free(client_cert); fclose(f); - f = fopen(CLIENT_CERT_FILE, "r+"); + f = fopen(client_cert_file.c_str(), "r+"); client_cert = PEM_read_X509(f, nullptr, nullptr, nullptr); fclose(f); - f = fopen(CLIENT_PRIVATE_KEY_FILE, "r+"); - client_private_key = PEM_read_PrivateKey(f, nullptr, nullptr, nullptr); + f = fopen(client_private_key_file.c_str(), "r+"); + auto client_private_key = PEM_read_PrivateKey( + f, nullptr, nullptr, (void *)client_encrypted_private_key_pass.c_str()); fclose(f); SSLServer svr(server_cert, server_private_key, client_ca_cert_store); @@ -5191,8 +5205,6 @@ TEST(SSLClientServerTest, MemoryClientCertPresent) { svr.Get("/test", [&](const Request &req, Response &res) { res.set_content("test", "text/plain"); - svr.stop(); - ASSERT_TRUE(true); auto peer_cert = SSL_get_peer_certificate(req.ssl); ASSERT_TRUE(peer_cert != nullptr); @@ -5215,13 +5227,15 @@ TEST(SSLClientServerTest, MemoryClientCertPresent) { thread t = thread([&]() { ASSERT_TRUE(svr.listen(HOST, PORT)); }); auto se = detail::scope_exit([&] { + svr.stop(); t.join(); ASSERT_FALSE(svr.is_running()); }); svr.wait_until_ready(); - SSLClient cli(HOST, PORT, client_cert, client_private_key); + SSLClient cli(HOST, PORT, client_cert, client_private_key, + client_encrypted_private_key_pass); cli.enable_server_certificate_verification(false); cli.set_connection_timeout(30); @@ -5234,6 +5248,16 @@ TEST(SSLClientServerTest, MemoryClientCertPresent) { X509_free(client_cert); EVP_PKEY_free(client_private_key); } + +TEST(SSLClientServerTest, MemoryClientCertPresent) { + MemoryClientCertPresent(CLIENT_CERT_FILE, CLIENT_PRIVATE_KEY_FILE); +} + +TEST(SSLClientServerTest, MemoryClientEncryptedCertPresent) { + MemoryClientCertPresent(CLIENT_ENCRYPTED_CERT_FILE, + CLIENT_ENCRYPTED_PRIVATE_KEY_FILE, + CLIENT_ENCRYPTED_PRIVATE_KEY_PASS); +} #endif TEST(SSLClientServerTest, ClientCertMissing) { @@ -5265,11 +5289,11 @@ TEST(SSLClientServerTest, TrustDirOptional) { svr.Get("/test", [&](const Request &, Response &res) { res.set_content("test", "text/plain"); - svr.stop(); }); thread t = thread([&]() { ASSERT_TRUE(svr.listen(HOST, PORT)); }); auto se = detail::scope_exit([&] { + svr.stop(); t.join(); ASSERT_FALSE(svr.is_running()); }); @@ -5361,13 +5385,12 @@ TEST(SSLClientServerTest, CustomizeServerSSLCtx) { nullptr); return true; }; + SSLServer svr(setup_ssl_ctx_callback); ASSERT_TRUE(svr.is_valid()); svr.Get("/test", [&](const Request &req, Response &res) { res.set_content("test", "text/plain"); - svr.stop(); - ASSERT_TRUE(true); auto peer_cert = SSL_get_peer_certificate(req.ssl); ASSERT_TRUE(peer_cert != nullptr); @@ -5390,6 +5413,7 @@ TEST(SSLClientServerTest, CustomizeServerSSLCtx) { thread t = thread([&]() { ASSERT_TRUE(svr.listen(HOST, PORT)); }); auto se = detail::scope_exit([&] { + svr.stop(); t.join(); ASSERT_FALSE(svr.is_running()); });