diff --git a/httplib.h b/httplib.h index c878fea..724ecc6 100644 --- a/httplib.h +++ b/httplib.h @@ -800,7 +800,9 @@ public: bool send(const std::vector &requests, std::vector &responses); - virtual void stop(); + size_t is_socket_open() const; + + void stop(); CPPHTTPLIB_DEPRECATED void set_timeout_sec(time_t timeout_sec); void set_connection_timeout(time_t sec, time_t usec = 0); @@ -831,26 +833,31 @@ public: void set_logger(Logger logger); protected: - struct Endpoint { + struct Socket { socket_t sock = INVALID_SOCKET; #ifdef CPPHTTPLIB_OPENSSL_SUPPORT SSL *ssl = nullptr; #endif + + bool is_open() const { return sock != INVALID_SOCKET; } }; - virtual bool create_and_connect_socket(Endpoint &endpoint); - virtual void close_socket(Endpoint &endpoint, bool process_socket_ret); + virtual bool create_and_connect_socket(Socket &socket); + virtual void close_socket(Socket &socket, bool process_socket_ret); bool process_request(Stream &strm, const Request &req, Response &res, - bool last_connection, bool &connection_close); - - std::vector endpoints_; - std::mutex endpoints_mutex_; + bool &connection_close); + // Socket endoint information const std::string host_; const int port_; const std::string host_and_port_; + // Current open socket + Socket socket_; + mutable std::mutex socket_mutex_; + std::recursive_mutex request_mutex_; + // Settings std::string client_cert_path_; std::string client_key_path_; @@ -923,13 +930,10 @@ protected: private: socket_t create_client_socket() const; bool read_response_line(Stream &strm, Response &res); - bool write_request(Stream &strm, const Request &req, bool last_connection); + bool write_request(Stream &strm, const Request &req); bool redirect(const Request &req, Response &res); bool handle_request(Stream &strm, const Request &req, Response &res, - bool last_connection, bool &connection_close); -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT - bool connect_with_proxy(socket_t sock, Response &res, bool &error); -#endif + bool &connection_close); std::shared_ptr send_with_content_provider( const char *method, const char *path, const Headers &headers, @@ -937,7 +941,7 @@ private: ContentProvider content_provider, const char *content_type); virtual bool - process_socket(Endpoint &endpoint, size_t request_count, + process_socket(Socket &socket, size_t request_count, std::function callback); @@ -1026,8 +1030,6 @@ public: ~SSLClient() override; - void stop() override; - bool is_valid() const override; void set_ca_cert_path(const char *ca_cert_file_path, @@ -1042,16 +1044,17 @@ public: SSL_CTX *ssl_context() const; private: - bool create_and_connect_socket(Endpoint &endpoint) override; - void close_socket(Endpoint &endpoint, bool process_socket_ret) override; + bool create_and_connect_socket(Socket &socket) override; + bool connect_with_proxy(Socket &sock, bool &error); + void close_socket(Socket &socket, bool process_socket_ret) override; - bool process_socket(Endpoint &endpoint, size_t request_count, + bool process_socket(Socket &socket, size_t request_count, std::function callback) override; bool is_ssl() const override; - bool initialize_ssl(Endpoint &endpoint); + bool initialize_ssl(Socket &socket); bool verify_host(X509 *server_cert) const; bool verify_host_with_subject_alt_name(X509 *server_cert) const; @@ -1303,6 +1306,8 @@ public: return cli_->send(requests, responses); } + bool is_socket_open() { return cli_->is_socket_open(); } + void stop() { cli_->stop(); } Client2 &set_connection_timeout(time_t sec, time_t usec) { @@ -4330,7 +4335,12 @@ inline Client::Client(const std::string &host, int port, host_and_port_(host_ + ":" + std::to_string(port_)), client_cert_path_(client_cert_path), client_key_path_(client_key_path) {} -inline Client::~Client() {} +inline Client::~Client() { + assert(socket_.sock == INVALID_SOCKET); +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + assert(socket_.ssl == nullptr); +#endif +} inline bool Client::is_valid() const { return true; } @@ -4345,24 +4355,19 @@ inline socket_t Client::create_client_socket() const { connection_timeout_usec_, interface_); } -inline bool Client::create_and_connect_socket(Endpoint &endpoint) { +inline bool Client::create_and_connect_socket(Socket &socket) { auto sock = create_client_socket(); if (sock == INVALID_SOCKET) { return false; } - -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT - if (is_ssl() && !proxy_host_.empty()) { - Response res; - bool error; - if (!connect_with_proxy(sock, res, error)) { return error; } - } -#endif - endpoint.sock = sock; + socket.sock = sock; return true; } -inline void Client::close_socket(Endpoint &endpoint, - bool /*process_socket_ret*/) { - detail::close_socket(endpoint.sock); +inline void Client::close_socket(Socket &socket, bool /*process_socket_ret*/) { + detail::close_socket(socket.sock); + socket_.sock = INVALID_SOCKET; +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + socket_.ssl = nullptr; +#endif } inline bool Client::read_response_line(Stream &strm, Response &res) { @@ -4384,32 +4389,23 @@ inline bool Client::read_response_line(Stream &strm, Response &res) { } inline bool Client::send(const Request &req, Response &res) { - Endpoint endpoint; - if (!create_and_connect_socket(endpoint)) { return false; } + std::lock_guard guard(request_mutex_); + auto need_new_socket = !is_socket_open(); - { - std::lock_guard guard(endpoints_mutex_); - endpoints_.push_back(endpoint); + if (need_new_socket) { + std::lock_guard guard(socket_mutex_); + if (!create_and_connect_socket(socket_)) { return false; } } auto ret = process_socket( - endpoint, 1, - [&](Stream &strm, bool last_connection, bool &connection_close) { - return handle_request(strm, req, res, last_connection, - connection_close); + socket_, 1, + [&](Stream &strm, bool /*last_connection*/, bool &connection_close) { + return handle_request(strm, req, res, connection_close); }); - { - std::lock_guard guard(endpoints_mutex_); - - auto it = std::find_if( - endpoints_.begin(), endpoints_.end(), - [&](Endpoint &endpoint2) { return endpoint.sock == endpoint2.sock; }); - - if (it != endpoints_.end()) { - close_socket(endpoint, ret); - endpoints_.erase(it); - } + if (need_new_socket) { + std::lock_guard guard(socket_mutex_); + if (socket_.is_open()) { close_socket(socket_, ret); } } return ret; @@ -4417,43 +4413,30 @@ inline bool Client::send(const Request &req, Response &res) { inline bool Client::send(const std::vector &requests, std::vector &responses) { + std::lock_guard guard(request_mutex_); + size_t i = 0; while (i < requests.size()) { - Endpoint endpoint; - if (!create_and_connect_socket(endpoint)) { return false; } - { - std::lock_guard guard(endpoints_mutex_); - endpoints_.push_back(endpoint); + std::lock_guard guard(socket_mutex_); + if (!create_and_connect_socket(socket_)) { return false; } } auto request_count = (std::min)(requests.size() - i, keep_alive_max_count_); - auto ret = process_socket(endpoint, request_count, - [&](Stream &strm, bool last_connection, - bool &connection_close) -> bool { - auto &req = requests[i++]; - auto res = Response(); - auto ret = handle_request(strm, req, res, - last_connection, - connection_close); - if (ret) { - responses.emplace_back(std::move(res)); - } - return ret; - }); + auto ret = process_socket( + socket_, request_count, + [&](Stream &strm, bool /*last_connection*/, bool &connection_close) { + auto &req = requests[i++]; + auto res = Response(); + auto ret = handle_request(strm, req, res, connection_close); + if (ret) { responses.emplace_back(std::move(res)); } + return ret; + }); { - std::lock_guard guard(endpoints_mutex_); - - auto it = std::find_if( - endpoints_.begin(), endpoints_.end(), - [&](Endpoint &endpoint2) { return endpoint.sock == endpoint2.sock; }); - - if (it != endpoints_.end()) { - close_socket(endpoint, ret); - endpoints_.erase(it); - } + std::lock_guard guard(socket_mutex_); + if (socket_.is_open()) { close_socket(socket_, ret); } } if (!ret) { return false; } @@ -4463,8 +4446,7 @@ inline bool Client::send(const std::vector &requests, } inline bool Client::handle_request(Stream &strm, const Request &req, - Response &res, bool last_connection, - bool &connection_close) { + Response &res, bool &connection_close) { if (req.path.empty()) { return false; } bool ret; @@ -4472,9 +4454,9 @@ inline bool Client::handle_request(Stream &strm, const Request &req, if (!is_ssl() && !proxy_host_.empty()) { auto req2 = req; req2.path = "http://" + host_and_port_ + req.path; - ret = process_request(strm, req2, res, last_connection, connection_close); + ret = process_request(strm, req2, res, connection_close); } else { - ret = process_request(strm, req, res, last_connection, connection_close); + ret = process_request(strm, req, res, connection_close); } if (!ret) { return false; } @@ -4515,64 +4497,6 @@ inline bool Client::handle_request(Stream &strm, const Request &req, return ret; } -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT -inline bool Client::connect_with_proxy(socket_t sock, Response &res, - bool &error) { - error = true; - Response res2; - - if (!detail::process_socket_core( - true, sock, 1, [&](bool /*last_connection*/, bool &connection_close) { - detail::SocketStream strm(sock, read_timeout_sec_, - read_timeout_usec_, write_timeout_sec_, - write_timeout_usec_); - Request req2; - req2.method = "CONNECT"; - req2.path = host_and_port_; - return process_request(strm, req2, res2, false, connection_close); - })) { - detail::close_socket(sock); - error = false; - return false; - } - - if (res2.status == 407) { - if (!proxy_digest_auth_username_.empty() && - !proxy_digest_auth_password_.empty()) { - std::map auth; - if (parse_www_authenticate(res2, auth, true)) { - Response res3; - if (!detail::process_socket_core( - true, sock, 1, - [&](bool /*last_connection*/, bool &connection_close) { - detail::SocketStream strm( - sock, read_timeout_sec_, read_timeout_usec_, - write_timeout_sec_, write_timeout_usec_); - Request req3; - req3.method = "CONNECT"; - req3.path = host_and_port_; - req3.headers.insert(make_digest_authentication_header( - req3, auth, 1, random_string(10), - proxy_digest_auth_username_, proxy_digest_auth_password_, - true)); - return process_request(strm, req3, res3, false, - connection_close); - })) { - detail::close_socket(sock); - error = false; - return false; - } - } - } else { - res = res2; - return false; - } - } - - return true; -} -#endif - inline bool Client::redirect(const Request &req, Response &res) { if (req.redirect_count == 0) { return false; } @@ -4622,8 +4546,7 @@ inline bool Client::redirect(const Request &req, Response &res) { } } -inline bool Client::write_request(Stream &strm, const Request &req, - bool last_connection) { +inline bool Client::write_request(Stream &strm, const Request &req) { detail::BufferStream bstrm; // Request line @@ -4633,8 +4556,6 @@ inline bool Client::write_request(Stream &strm, const Request &req, // Additonal headers Headers headers; - if (last_connection) { headers.emplace("Connection", "close"); } - if (!req.has_header("Host")) { if (is_ssl()) { if (port_ == 443) { @@ -4777,10 +4698,9 @@ inline std::shared_ptr Client::send_with_content_provider( } inline bool Client::process_request(Stream &strm, const Request &req, - Response &res, bool last_connection, - bool &connection_close) { + Response &res, bool &connection_close) { // Send request - if (!write_request(strm, req, last_connection)) { return false; } + if (!write_request(strm, req)) { return false; } // Receive response and headers if (!read_response_line(strm, res) || @@ -4824,12 +4744,12 @@ inline bool Client::process_request(Stream &strm, const Request &req, } inline bool -Client::process_socket(Endpoint &endpoint, size_t request_count, +Client::process_socket(Socket &socket, size_t request_count, std::function callback) { return detail::process_socket( - true, endpoint.sock, request_count, read_timeout_sec_, read_timeout_usec_, + true, socket.sock, request_count, read_timeout_sec_, read_timeout_usec_, write_timeout_sec_, write_timeout_usec_, callback); } @@ -5125,13 +5045,17 @@ inline std::shared_ptr Client::Options(const char *path, return send(req, *res) ? res : nullptr; } +inline size_t Client::is_socket_open() const { + std::lock_guard guard(socket_mutex_); + return socket_.is_open(); +} + inline void Client::stop() { - std::lock_guard guard(endpoints_mutex_); - for (auto &endpoint : endpoints_) { - detail::shutdown_socket(endpoint.sock); - detail::close_socket(endpoint.sock); + std::lock_guard guard(socket_mutex_); + if (socket_.is_open()) { + detail::shutdown_socket(socket_.sock); + close_socket(socket_, true); } - endpoints_.clear(); } inline void Client::set_timeout_sec(time_t timeout_sec) { @@ -5494,25 +5418,6 @@ inline SSLClient::~SSLClient() { if (ctx_) { SSL_CTX_free(ctx_); } } -inline void SSLClient::stop() { - auto endpoints = endpoints_; - { - std::lock_guard guard(endpoints_mutex_); - for (auto &endpoint : endpoints_) { - detail::shutdown_socket(endpoint.sock); - detail::close_socket(endpoint.sock); - } - endpoints_.clear(); - } - - std::this_thread::sleep_for(std::chrono::milliseconds(100)); - - for (auto &endpoint : endpoints) { - SSL_shutdown(endpoint.ssl); - SSL_free(endpoint.ssl); - } -} - inline bool SSLClient::is_valid() const { return ctx_; } inline void SSLClient::set_ca_cert_path(const char *ca_cert_file_path, @@ -5535,14 +5440,75 @@ inline long SSLClient::get_openssl_verify_result() const { inline SSL_CTX *SSLClient::ssl_context() const { return ctx_; } -inline bool SSLClient::create_and_connect_socket(Endpoint &endpoint) { - return is_valid() && Client::create_and_connect_socket(endpoint) && - initialize_ssl(endpoint); +inline bool SSLClient::create_and_connect_socket(Socket &socket) { + if (is_valid() && Client::create_and_connect_socket(socket) && + initialize_ssl(socket)) { + if (!proxy_host_.empty()) { + bool error; + if (!connect_with_proxy(socket, error)) { return error; } + } + return true; + } + return false; } -inline bool SSLClient::initialize_ssl(Endpoint &endpoint) { +inline bool SSLClient::connect_with_proxy(Socket &socket, bool &error) { + error = true; + Response res; + + if (!detail::process_socket_core( + true, socket.sock, 1, + [&](bool /*last_connection*/, bool &connection_close) { + detail::SocketStream strm(socket.sock, read_timeout_sec_, + read_timeout_usec_, write_timeout_sec_, + write_timeout_usec_); + Request req2; + req2.method = "CONNECT"; + req2.path = host_and_port_; + return process_request(strm, req2, res, connection_close); + })) { + close_socket(socket, true); + error = false; + return false; + } + + if (res.status == 407) { + if (!proxy_digest_auth_username_.empty() && + !proxy_digest_auth_password_.empty()) { + std::map auth; + if (parse_www_authenticate(res, auth, true)) { + Response res3; + if (!detail::process_socket_core( + true, socket.sock, 1, + [&](bool /*last_connection*/, bool &connection_close) { + detail::SocketStream strm( + socket.sock, read_timeout_sec_, read_timeout_usec_, + write_timeout_sec_, write_timeout_usec_); + Request req3; + req3.method = "CONNECT"; + req3.path = host_and_port_; + req3.headers.insert(make_digest_authentication_header( + req3, auth, 1, random_string(10), + proxy_digest_auth_username_, proxy_digest_auth_password_, + true)); + return process_request(strm, req3, res3, connection_close); + })) { + close_socket(socket, true); + error = false; + return false; + } + } + } else { + return false; + } + } + + return true; +} + +inline bool SSLClient::initialize_ssl(Socket &socket) { auto ssl = detail::ssl_new( - endpoint.sock, ctx_, ctx_mutex_, + socket.sock, ctx_, ctx_mutex_, [&](SSL *ssl) { if (ca_cert_file_path_.empty() && ca_cert_store_ == nullptr) { SSL_CTX_set_verify(ctx_, SSL_VERIFY_NONE, nullptr); @@ -5585,29 +5551,32 @@ inline bool SSLClient::initialize_ssl(Endpoint &endpoint) { }); if (ssl) { - endpoint.ssl = ssl; + socket.ssl = ssl; return true; } - detail::close_socket(endpoint.sock); + close_socket(socket, false); return false; } -inline void SSLClient::close_socket(Endpoint &endpoint, - bool process_socket_ret) { - assert(endpoint.ssl); - detail::ssl_delete(ctx_mutex_, endpoint.ssl, process_socket_ret); - detail::close_socket(endpoint.sock); +inline void SSLClient::close_socket(Socket &socket, bool process_socket_ret) { + detail::close_socket(socket.sock); + socket_.sock = INVALID_SOCKET; + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + if (socket.ssl) { + detail::ssl_delete(ctx_mutex_, socket.ssl, process_socket_ret); + socket_.ssl = nullptr; + } } inline bool -SSLClient::process_socket(Endpoint &endpoint, size_t request_count, +SSLClient::process_socket(Socket &socket, size_t request_count, std::function callback) { - assert(endpoint.ssl); + assert(socket.ssl); return detail::process_socket_ssl( - endpoint.ssl, true, endpoint.sock, request_count, read_timeout_sec_, + socket.ssl, true, socket.sock, request_count, read_timeout_sec_, read_timeout_usec_, write_timeout_sec_, write_timeout_usec_, [&](Stream &strm, bool last_connection, bool &connection_close) { return callback(strm, last_connection, connection_close); diff --git a/test/test.cc b/test/test.cc index bb81036..1f79494 100644 --- a/test/test.cc +++ b/test/test.cc @@ -1767,15 +1767,20 @@ TEST_F(ServerTest, GetStreamedEndless) { TEST_F(ServerTest, ClientStop) { std::vector threads; - for (auto i = 0; i < 3; i++) { + for (auto i = 0; i < 100; i++) { threads.emplace_back(thread([&]() { auto res = cli_.Get("/streamed-cancel", [&](const char *, uint64_t) { return true; }); ASSERT_TRUE(res == nullptr); })); } - std::this_thread::sleep_for(std::chrono::seconds(3)); - cli_.stop(); + + std::this_thread::sleep_for(std::chrono::seconds(1)); + + while (cli_.is_socket_open()) { + cli_.stop(); + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + } for (auto &t : threads) { t.join(); }