diff --git a/httplib.h b/httplib.h index 2bdae5d..7b7b3f2 100644 --- a/httplib.h +++ b/httplib.h @@ -932,6 +932,7 @@ public: bool is_running() const; void wait_until_ready() const; void stop(); + void decommission(); std::function new_task_queue; @@ -1006,7 +1007,7 @@ private: virtual bool process_and_close_socket(socket_t sock); std::atomic is_running_{false}; - std::atomic done_{false}; + std::atomic is_decommisioned{false}; struct MountPointEntry { std::string mount_point; @@ -6111,27 +6112,27 @@ inline Server &Server::set_payload_max_length(size_t length) { inline bool Server::bind_to_port(const std::string &host, int port, int socket_flags) { - return bind_internal(host, port, socket_flags) >= 0; + auto ret = bind_internal(host, port, socket_flags); + if (ret == -1) { is_decommisioned = true; } + return ret >= 0; } inline int Server::bind_to_any_port(const std::string &host, int socket_flags) { - return bind_internal(host, 0, socket_flags); + auto ret = bind_internal(host, 0, socket_flags); + if (ret == -1) { is_decommisioned = true; } + return ret; } -inline bool Server::listen_after_bind() { - auto se = detail::scope_exit([&]() { done_ = true; }); - return listen_internal(); -} +inline bool Server::listen_after_bind() { return listen_internal(); } inline bool Server::listen(const std::string &host, int port, int socket_flags) { - auto se = detail::scope_exit([&]() { done_ = true; }); return bind_to_port(host, port, socket_flags) && listen_internal(); } inline bool Server::is_running() const { return is_running_; } inline void Server::wait_until_ready() const { - while (!is_running() && !done_) { + while (!is_running_ && !is_decommisioned) { std::this_thread::sleep_for(std::chrono::milliseconds{1}); } } @@ -6143,8 +6144,11 @@ inline void Server::stop() { detail::shutdown_socket(sock); detail::close_socket(sock); } + is_decommisioned = false; } +inline void Server::decommission() { is_decommisioned = true; } + inline bool Server::parse_request_line(const char *s, Request &req) const { auto len = strlen(s); if (len < 2 || s[len - 2] != '\r' || s[len - 1] != '\n') { return false; } @@ -6499,6 +6503,8 @@ Server::create_server_socket(const std::string &host, int port, inline int Server::bind_internal(const std::string &host, int port, int socket_flags) { + if (is_decommisioned) { return -1; } + if (!is_valid()) { return -1; } svr_sock_ = create_server_socket(host, port, socket_flags, socket_options_); @@ -6524,6 +6530,8 @@ inline int Server::bind_internal(const std::string &host, int port, } inline bool Server::listen_internal() { + if (is_decommisioned) { return false; } + auto ret = true; is_running_ = true; auto se = detail::scope_exit([&]() { is_running_ = false; }); @@ -6613,6 +6621,7 @@ inline bool Server::listen_internal() { task_queue->shutdown(); } + is_decommisioned = !ret; return ret; } diff --git a/test/test.cc b/test/test.cc index e13d3b9..8c25ec0 100644 --- a/test/test.cc +++ b/test/test.cc @@ -4926,6 +4926,52 @@ TEST(ServerStopTest, ListenFailure) { t.join(); } +TEST(ServerStopTest, Decommision) { + Server svr; + + svr.Get("/hi", [&](const Request &, Response &res) { res.body = "hi..."; }); + + for (int i = 0; i < 4; i++) { + auto is_even = !(i % 2); + + std::thread t{[&] { + try { + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + + if (is_even) { + throw std::runtime_error("Some thing that happens to go wrong."); + } + + svr.listen(HOST, PORT); + } catch (...) { svr.decommission(); } + }}; + + svr.wait_until_ready(); + + // Server is up + { + Client cli(HOST, PORT); + auto res = cli.Get("/hi"); + if (is_even) { + EXPECT_FALSE(res); + } else { + EXPECT_TRUE(res); + EXPECT_EQ("hi...", res->body); + } + } + + svr.stop(); + t.join(); + + // Server is down... + { + Client cli(HOST, PORT); + auto res = cli.Get("/hi"); + EXPECT_FALSE(res); + } + } +} + TEST(StreamingTest, NoContentLengthStreaming) { Server svr;