diff --git a/httplib.h b/httplib.h index 2aefc4b..d6d6541 100644 --- a/httplib.h +++ b/httplib.h @@ -8541,13 +8541,29 @@ inline SSL *ssl_new(socket_t sock, SSL_CTX *ctx, std::mutex &ctx_mutex, return ssl; } -inline void ssl_delete(std::mutex &ctx_mutex, SSL *ssl, +inline void ssl_delete(std::mutex &ctx_mutex, SSL *ssl, socket_t sock, bool shutdown_gracefully) { // sometimes we may want to skip this to try to avoid SIGPIPE if we know // the remote has closed the network connection // Note that it is not always possible to avoid SIGPIPE, this is merely a // best-efforts. - if (shutdown_gracefully) { SSL_shutdown(ssl); } + if (shutdown_gracefully) { +#ifdef _WIN32 + SSL_shutdown(ssl); +#else + timeval tv; + tv.tv_sec = 1; + tv.tv_usec = 0; + setsockopt(sock, SOL_SOCKET, SO_RCVTIMEO, + reinterpret_cast(&tv), sizeof(tv)); + + auto ret = SSL_shutdown(ssl); + while (ret == 0) { + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + ret = SSL_shutdown(ssl); + } +#endif + } std::lock_guard guard(ctx_mutex); SSL_free(ssl); @@ -8826,7 +8842,7 @@ inline bool SSLServer::process_and_close_socket(socket_t sock) { // Shutdown gracefully if the result seemed successful, non-gracefully if // the connection appeared to be closed. const bool shutdown_gracefully = ret; - detail::ssl_delete(ctx_mutex_, ssl, shutdown_gracefully); + detail::ssl_delete(ctx_mutex_, ssl, sock, shutdown_gracefully); } detail::shutdown_socket(sock); @@ -9109,7 +9125,8 @@ inline void SSLClient::shutdown_ssl_impl(Socket &socket, return; } if (socket.ssl) { - detail::ssl_delete(ctx_mutex_, socket.ssl, shutdown_gracefully); + detail::ssl_delete(ctx_mutex_, socket.ssl, socket.sock, + shutdown_gracefully); socket.ssl = nullptr; } assert(socket.ssl == nullptr); diff --git a/test/test.cc b/test/test.cc index df69d4a..d6610a7 100644 --- a/test/test.cc +++ b/test/test.cc @@ -54,6 +54,166 @@ MultipartFormData &get_file_value(MultipartFormDataItems &files, #endif } +#ifndef _WIN32 +class UnixSocketTest : public ::testing::Test { +protected: + void TearDown() override { std::remove(pathname_.c_str()); } + + void client_GET(const std::string &addr) { + httplib::Client cli{addr}; + cli.set_address_family(AF_UNIX); + ASSERT_TRUE(cli.is_valid()); + + const auto &result = cli.Get(pattern_); + ASSERT_TRUE(result) << "error: " << result.error(); + + const auto &resp = result.value(); + EXPECT_EQ(resp.status, StatusCode::OK_200); + EXPECT_EQ(resp.body, content_); + } + + const std::string pathname_{"./httplib-server.sock"}; + const std::string pattern_{"/hi"}; + const std::string content_{"Hello World!"}; +}; + +TEST_F(UnixSocketTest, pathname) { + httplib::Server svr; + svr.Get(pattern_, [&](const httplib::Request &, httplib::Response &res) { + res.set_content(content_, "text/plain"); + }); + + std::thread t{[&] { + ASSERT_TRUE(svr.set_address_family(AF_UNIX).listen(pathname_, 80)); + }}; + auto se = detail::scope_exit([&] { + svr.stop(); + t.join(); + ASSERT_FALSE(svr.is_running()); + }); + + svr.wait_until_ready(); + ASSERT_TRUE(svr.is_running()); + + client_GET(pathname_); +} + +#if defined(__linux__) || \ + /* __APPLE__ */ (defined(SOL_LOCAL) && defined(SO_PEERPID)) +TEST_F(UnixSocketTest, PeerPid) { + httplib::Server svr; + std::string remote_port_val; + svr.Get(pattern_, [&](const httplib::Request &req, httplib::Response &res) { + res.set_content(content_, "text/plain"); + remote_port_val = req.get_header_value("REMOTE_PORT"); + }); + + std::thread t{[&] { + ASSERT_TRUE(svr.set_address_family(AF_UNIX).listen(pathname_, 80)); + }}; + auto se = detail::scope_exit([&] { + svr.stop(); + t.join(); + ASSERT_FALSE(svr.is_running()); + }); + + svr.wait_until_ready(); + ASSERT_TRUE(svr.is_running()); + + client_GET(pathname_); + EXPECT_EQ(std::to_string(getpid()), remote_port_val); +} +#endif + +#ifdef __linux__ +TEST_F(UnixSocketTest, abstract) { + constexpr char svr_path[]{"\x00httplib-server.sock"}; + const std::string abstract_addr{svr_path, sizeof(svr_path) - 1}; + + httplib::Server svr; + svr.Get(pattern_, [&](const httplib::Request &, httplib::Response &res) { + res.set_content(content_, "text/plain"); + }); + + std::thread t{[&] { + ASSERT_TRUE(svr.set_address_family(AF_UNIX).listen(abstract_addr, 80)); + }}; + auto se = detail::scope_exit([&] { + svr.stop(); + t.join(); + ASSERT_FALSE(svr.is_running()); + }); + + svr.wait_until_ready(); + ASSERT_TRUE(svr.is_running()); + + client_GET(abstract_addr); +} +#endif + +TEST(SocketStream, is_writable_UNIX) { + int fds[2]; + ASSERT_EQ(0, socketpair(AF_UNIX, SOCK_STREAM, 0, fds)); + + const auto asSocketStream = [&](socket_t fd, + std::function func) { + return detail::process_client_socket(fd, 0, 0, 0, 0, func); + }; + asSocketStream(fds[0], [&](Stream &s0) { + EXPECT_EQ(s0.socket(), fds[0]); + EXPECT_TRUE(s0.is_writable()); + + EXPECT_EQ(0, close(fds[1])); + EXPECT_FALSE(s0.is_writable()); + + return true; + }); + EXPECT_EQ(0, close(fds[0])); +} + +TEST(SocketStream, is_writable_INET) { + sockaddr_in addr; + memset(&addr, 0, sizeof(addr)); + addr.sin_family = AF_INET; + addr.sin_port = htons(PORT + 1); + addr.sin_addr.s_addr = htonl(INADDR_LOOPBACK); + + int disconnected_svr_sock = -1; + std::thread svr{[&] { + const int s = socket(AF_INET, SOCK_STREAM, 0); + ASSERT_LE(0, s); + ASSERT_EQ(0, ::bind(s, reinterpret_cast(&addr), sizeof(addr))); + ASSERT_EQ(0, listen(s, 1)); + ASSERT_LE(0, disconnected_svr_sock = accept(s, nullptr, nullptr)); + ASSERT_EQ(0, close(s)); + }}; + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + + std::thread cli{[&] { + const int s = socket(AF_INET, SOCK_STREAM, 0); + ASSERT_LE(0, s); + ASSERT_EQ(0, connect(s, reinterpret_cast(&addr), sizeof(addr))); + ASSERT_EQ(0, close(s)); + }}; + cli.join(); + svr.join(); + ASSERT_NE(disconnected_svr_sock, -1); + + const auto asSocketStream = [&](socket_t fd, + std::function func) { + return detail::process_client_socket(fd, 0, 0, 0, 0, func); + }; + asSocketStream(disconnected_svr_sock, [&](Stream &ss) { + EXPECT_EQ(ss.socket(), disconnected_svr_sock); + EXPECT_FALSE(ss.is_writable()); + + return true; + }); + + ASSERT_EQ(0, close(disconnected_svr_sock)); +} +#endif // #ifndef _WIN32 + TEST(ClientTest, MoveConstructible) { EXPECT_FALSE(std::is_copy_constructible::value); EXPECT_TRUE(std::is_nothrow_move_constructible::value); @@ -4996,6 +5156,60 @@ TEST(KeepAliveTest, SSLClientReconnection) { ASSERT_TRUE(result); EXPECT_EQ(StatusCode::OK_200, result->status); } + +TEST(KeepAliveTest, SSLClientReconnectionPost) { + SSLServer svr(SERVER_CERT_FILE, SERVER_PRIVATE_KEY_FILE); + ASSERT_TRUE(svr.is_valid()); + svr.set_keep_alive_timeout(1); + std::string content = "reconnect"; + + svr.Post("/hi", [](const httplib::Request &, httplib::Response &res) { + res.set_content("Hello World!", "text/plain"); + }); + + auto f = std::async(std::launch::async, [&svr] { svr.listen(HOST, PORT); }); + std::this_thread::sleep_for(std::chrono::milliseconds(200)); + + SSLClient cli(HOST, PORT); + cli.enable_server_certificate_verification(false); + cli.set_keep_alive(true); + + auto result = cli.Post( + "/hi", content.size(), + [&content](size_t offset, size_t length, DataSink &sink) { + sink.write(content.c_str(), content.size()); + return true; + }, + "text/plain"); + ASSERT_TRUE(result); + EXPECT_EQ(200, result->status); + + std::this_thread::sleep_for(std::chrono::seconds(2)); + + // Recoonect + result = cli.Post( + "/hi", content.size(), + [&content](size_t offset, size_t length, DataSink &sink) { + sink.write(content.c_str(), content.size()); + return true; + }, + "text/plain"); + ASSERT_TRUE(result); + EXPECT_EQ(200, result->status); + + result = cli.Post( + "/hi", content.size(), + [&content](size_t offset, size_t length, DataSink &sink) { + sink.write(content.c_str(), content.size()); + return true; + }, + "text/plain"); + ASSERT_TRUE(result); + EXPECT_EQ(200, result->status); + + svr.stop(); + f.wait(); +} #endif TEST(ClientProblemDetectionTest, ContentProvider) { @@ -6970,166 +7184,6 @@ TEST(MultipartFormDataTest, ContentLength) { #endif -#ifndef _WIN32 -class UnixSocketTest : public ::testing::Test { -protected: - void TearDown() override { std::remove(pathname_.c_str()); } - - void client_GET(const std::string &addr) { - httplib::Client cli{addr}; - cli.set_address_family(AF_UNIX); - ASSERT_TRUE(cli.is_valid()); - - const auto &result = cli.Get(pattern_); - ASSERT_TRUE(result) << "error: " << result.error(); - - const auto &resp = result.value(); - EXPECT_EQ(resp.status, StatusCode::OK_200); - EXPECT_EQ(resp.body, content_); - } - - const std::string pathname_{"./httplib-server.sock"}; - const std::string pattern_{"/hi"}; - const std::string content_{"Hello World!"}; -}; - -TEST_F(UnixSocketTest, pathname) { - httplib::Server svr; - svr.Get(pattern_, [&](const httplib::Request &, httplib::Response &res) { - res.set_content(content_, "text/plain"); - }); - - std::thread t{[&] { - ASSERT_TRUE(svr.set_address_family(AF_UNIX).listen(pathname_, 80)); - }}; - auto se = detail::scope_exit([&] { - svr.stop(); - t.join(); - ASSERT_FALSE(svr.is_running()); - }); - - svr.wait_until_ready(); - ASSERT_TRUE(svr.is_running()); - - client_GET(pathname_); -} - -#if defined(__linux__) || \ - /* __APPLE__ */ (defined(SOL_LOCAL) && defined(SO_PEERPID)) -TEST_F(UnixSocketTest, PeerPid) { - httplib::Server svr; - std::string remote_port_val; - svr.Get(pattern_, [&](const httplib::Request &req, httplib::Response &res) { - res.set_content(content_, "text/plain"); - remote_port_val = req.get_header_value("REMOTE_PORT"); - }); - - std::thread t{[&] { - ASSERT_TRUE(svr.set_address_family(AF_UNIX).listen(pathname_, 80)); - }}; - auto se = detail::scope_exit([&] { - svr.stop(); - t.join(); - ASSERT_FALSE(svr.is_running()); - }); - - svr.wait_until_ready(); - ASSERT_TRUE(svr.is_running()); - - client_GET(pathname_); - EXPECT_EQ(std::to_string(getpid()), remote_port_val); -} -#endif - -#ifdef __linux__ -TEST_F(UnixSocketTest, abstract) { - constexpr char svr_path[]{"\x00httplib-server.sock"}; - const std::string abstract_addr{svr_path, sizeof(svr_path) - 1}; - - httplib::Server svr; - svr.Get(pattern_, [&](const httplib::Request &, httplib::Response &res) { - res.set_content(content_, "text/plain"); - }); - - std::thread t{[&] { - ASSERT_TRUE(svr.set_address_family(AF_UNIX).listen(abstract_addr, 80)); - }}; - auto se = detail::scope_exit([&] { - svr.stop(); - t.join(); - ASSERT_FALSE(svr.is_running()); - }); - - svr.wait_until_ready(); - ASSERT_TRUE(svr.is_running()); - - client_GET(abstract_addr); -} -#endif - -TEST(SocketStream, is_writable_UNIX) { - int fds[2]; - ASSERT_EQ(0, socketpair(AF_UNIX, SOCK_STREAM, 0, fds)); - - const auto asSocketStream = [&](socket_t fd, - std::function func) { - return detail::process_client_socket(fd, 0, 0, 0, 0, func); - }; - asSocketStream(fds[0], [&](Stream &s0) { - EXPECT_EQ(s0.socket(), fds[0]); - EXPECT_TRUE(s0.is_writable()); - - EXPECT_EQ(0, close(fds[1])); - EXPECT_FALSE(s0.is_writable()); - - return true; - }); - EXPECT_EQ(0, close(fds[0])); -} - -TEST(SocketStream, is_writable_INET) { - sockaddr_in addr; - memset(&addr, 0, sizeof(addr)); - addr.sin_family = AF_INET; - addr.sin_port = htons(PORT + 1); - addr.sin_addr.s_addr = htonl(INADDR_LOOPBACK); - - int disconnected_svr_sock = -1; - std::thread svr{[&] { - const int s = socket(AF_INET, SOCK_STREAM, 0); - ASSERT_LE(0, s); - ASSERT_EQ(0, ::bind(s, reinterpret_cast(&addr), sizeof(addr))); - ASSERT_EQ(0, listen(s, 1)); - ASSERT_LE(0, disconnected_svr_sock = accept(s, nullptr, nullptr)); - ASSERT_EQ(0, close(s)); - }}; - std::this_thread::sleep_for(std::chrono::milliseconds(100)); - - std::thread cli{[&] { - const int s = socket(AF_INET, SOCK_STREAM, 0); - ASSERT_LE(0, s); - ASSERT_EQ(0, connect(s, reinterpret_cast(&addr), sizeof(addr))); - ASSERT_EQ(0, close(s)); - }}; - cli.join(); - svr.join(); - ASSERT_NE(disconnected_svr_sock, -1); - - const auto asSocketStream = [&](socket_t fd, - std::function func) { - return detail::process_client_socket(fd, 0, 0, 0, 0, func); - }; - asSocketStream(disconnected_svr_sock, [&](Stream &ss) { - EXPECT_EQ(ss.socket(), disconnected_svr_sock); - EXPECT_FALSE(ss.is_writable()); - - return true; - }); - - ASSERT_EQ(0, close(disconnected_svr_sock)); -} -#endif // #ifndef _WIN32 - TEST(TaskQueueTest, IncreaseAtomicInteger) { static constexpr unsigned int number_of_tasks{1000000}; std::atomic_uint count{0};