From 8a7c536ad5b2d304bd467209a13627a8db510bc9 Mon Sep 17 00:00:00 2001 From: yhirose Date: Mon, 10 Feb 2025 06:51:07 -0500 Subject: [PATCH] Fix #2034 (#2048) * Fix #2034 * Fix build error * Adjust threshold * Add temporary debug prints * Adjust threshhold * Another threshold adjustment for macOS on GitHub Actions CI... * Performance improvement by avoiding unnecessary chrono access * More performance improvement to avoid unnecessary chrono access --- httplib.h | 408 ++++++++++++++++++++++------------ test/fuzzing/server_fuzzer.cc | 2 + test/test.cc | 261 +++++++++++++++++++++- 3 files changed, 532 insertions(+), 139 deletions(-) diff --git a/httplib.h b/httplib.h index eb592d2..ed49a54 100644 --- a/httplib.h +++ b/httplib.h @@ -66,6 +66,10 @@ #define CPPHTTPLIB_CLIENT_WRITE_TIMEOUT_USECOND 0 #endif +#ifndef CPPHTTPLIB_CLIENT_GLOBAL_TIMEOUT_MSECOND +#define CPPHTTPLIB_CLIENT_GLOBAL_TIMEOUT_MSECOND 0 +#endif + #ifndef CPPHTTPLIB_IDLE_INTERVAL_SECOND #define CPPHTTPLIB_IDLE_INTERVAL_SECOND 0 #endif @@ -664,6 +668,8 @@ struct Request { ContentProvider content_provider_; bool is_chunked_content_provider_ = false; size_t authorization_count_ = 0; + std::chrono::time_point start_time_ = + std::chrono::steady_clock::time_point::min(); }; struct Response { @@ -737,6 +743,8 @@ public: virtual void get_local_ip_and_port(std::string &ip, int &port) const = 0; virtual socket_t socket() const = 0; + virtual time_t duration() const = 0; + ssize_t write(const char *ptr); ssize_t write(const std::string &s); }; @@ -1423,6 +1431,10 @@ public: template void set_write_timeout(const std::chrono::duration &duration); + void set_global_timeout(time_t msec); + template + void set_global_timeout(const std::chrono::duration &duration); + void set_basic_auth(const std::string &username, const std::string &password); void set_bearer_token_auth(const std::string &token); #ifdef CPPHTTPLIB_OPENSSL_SUPPORT @@ -1531,6 +1543,7 @@ protected: time_t read_timeout_usec_ = CPPHTTPLIB_CLIENT_READ_TIMEOUT_USECOND; time_t write_timeout_sec_ = CPPHTTPLIB_CLIENT_WRITE_TIMEOUT_SECOND; time_t write_timeout_usec_ = CPPHTTPLIB_CLIENT_WRITE_TIMEOUT_USECOND; + time_t global_timeout_msec_ = CPPHTTPLIB_CLIENT_GLOBAL_TIMEOUT_MSECOND; std::string basic_auth_username_; std::string basic_auth_password_; @@ -1585,9 +1598,6 @@ private: bool send_(Request &req, Response &res, Error &error); Result send_(Request &&req); -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT - bool is_ssl_peer_could_be_closed(SSL *ssl) const; -#endif socket_t create_client_socket(Error &error) const; bool read_response_line(Stream &strm, const Request &req, Response &res) const; @@ -1613,8 +1623,10 @@ private: std::string adjust_host_string(const std::string &host) const; - virtual bool process_socket(const Socket &socket, - std::function callback); + virtual bool + process_socket(const Socket &socket, + std::chrono::time_point start_time, + std::function callback); virtual bool is_ssl() const; }; @@ -1856,6 +1868,10 @@ public: template void set_write_timeout(const std::chrono::duration &duration); + void set_global_timeout(time_t msec); + template + void set_global_timeout(const std::chrono::duration &duration); + void set_basic_auth(const std::string &username, const std::string &password); void set_bearer_token_auth(const std::string &token); #ifdef CPPHTTPLIB_OPENSSL_SUPPORT @@ -1973,12 +1989,16 @@ private: void shutdown_ssl(Socket &socket, bool shutdown_gracefully) override; void shutdown_ssl_impl(Socket &socket, bool shutdown_gracefully); - bool process_socket(const Socket &socket, - std::function callback) override; + bool + process_socket(const Socket &socket, + std::chrono::time_point start_time, + std::function callback) override; bool is_ssl() const override; - bool connect_with_proxy(Socket &sock, Response &res, bool &success, - Error &error); + bool connect_with_proxy( + Socket &sock, + std::chrono::time_point start_time, + Response &res, bool &success, Error &error); bool initialize_ssl(Socket &socket, Error &error); bool load_certs(); @@ -2240,6 +2260,14 @@ inline void ClientImpl::set_write_timeout( duration, [&](time_t sec, time_t usec) { set_write_timeout(sec, usec); }); } +template +inline void ClientImpl::set_global_timeout( + const std::chrono::duration &duration) { + auto msec = + std::chrono::duration_cast(duration).count(); + set_global_timeout(msec); +} + template inline void Client::set_connection_timeout( const std::chrono::duration &duration) { @@ -2258,6 +2286,12 @@ Client::set_write_timeout(const std::chrono::duration &duration) { cli_->set_write_timeout(duration); } +template +inline void +Client::set_global_timeout(const std::chrono::duration &duration) { + cli_->set_global_timeout(duration); +} + /* * Forward declarations and types that will be part of the .h file if split into * .h + .cc. @@ -2332,10 +2366,12 @@ void split(const char *b, const char *e, char d, void split(const char *b, const char *e, char d, size_t m, std::function fn); -bool process_client_socket(socket_t sock, time_t read_timeout_sec, - time_t read_timeout_usec, time_t write_timeout_sec, - time_t write_timeout_usec, - std::function callback); +bool process_client_socket( + socket_t sock, time_t read_timeout_sec, time_t read_timeout_usec, + time_t write_timeout_sec, time_t write_timeout_usec, + time_t global_timeout_msec, + std::chrono::time_point start_time, + std::function callback); socket_t create_client_socket(const std::string &host, const std::string &ip, int port, int address_family, bool tcp_nodelay, @@ -2383,6 +2419,7 @@ public: void get_remote_ip_and_port(std::string &ip, int &port) const override; void get_local_ip_and_port(std::string &ip, int &port) const override; socket_t socket() const override; + time_t duration() const override; const std::string &get_buffer() const; @@ -3300,7 +3337,10 @@ inline bool is_socket_alive(socket_t sock) { class SocketStream final : public Stream { public: SocketStream(socket_t sock, time_t read_timeout_sec, time_t read_timeout_usec, - time_t write_timeout_sec, time_t write_timeout_usec); + time_t write_timeout_sec, time_t write_timeout_usec, + time_t global_timeout_msec = 0, + std::chrono::time_point start_time = + std::chrono::steady_clock::time_point::min()); ~SocketStream() override; bool is_readable() const override; @@ -3310,6 +3350,7 @@ public: void get_remote_ip_and_port(std::string &ip, int &port) const override; void get_local_ip_and_port(std::string &ip, int &port) const override; socket_t socket() const override; + time_t duration() const override; private: socket_t sock_; @@ -3317,6 +3358,8 @@ private: time_t read_timeout_usec_; time_t write_timeout_sec_; time_t write_timeout_usec_; + time_t global_timeout_msec_; + const std::chrono::time_point start_time; std::vector read_buff_; size_t read_buff_off_ = 0; @@ -3328,9 +3371,12 @@ private: #ifdef CPPHTTPLIB_OPENSSL_SUPPORT class SSLSocketStream final : public Stream { public: - SSLSocketStream(socket_t sock, SSL *ssl, time_t read_timeout_sec, - time_t read_timeout_usec, time_t write_timeout_sec, - time_t write_timeout_usec); + SSLSocketStream( + socket_t sock, SSL *ssl, time_t read_timeout_sec, + time_t read_timeout_usec, time_t write_timeout_sec, + time_t write_timeout_usec, time_t global_timeout_msec = 0, + std::chrono::time_point start_time = + std::chrono::steady_clock::time_point::min()); ~SSLSocketStream() override; bool is_readable() const override; @@ -3340,6 +3386,7 @@ public: void get_remote_ip_and_port(std::string &ip, int &port) const override; void get_local_ip_and_port(std::string &ip, int &port) const override; socket_t socket() const override; + time_t duration() const override; private: socket_t sock_; @@ -3348,6 +3395,8 @@ private: time_t read_timeout_usec_; time_t write_timeout_sec_; time_t write_timeout_usec_; + time_t global_timeout_msec_; + const std::chrono::time_point start_time; }; #endif @@ -3418,13 +3467,15 @@ process_server_socket(const std::atomic &svr_sock, socket_t sock, }); } -inline bool process_client_socket(socket_t sock, time_t read_timeout_sec, - time_t read_timeout_usec, - time_t write_timeout_sec, - time_t write_timeout_usec, - std::function callback) { +inline bool process_client_socket( + socket_t sock, time_t read_timeout_sec, time_t read_timeout_usec, + time_t write_timeout_sec, time_t write_timeout_usec, + time_t global_timeout_msec, + std::chrono::time_point start_time, + std::function callback) { SocketStream strm(sock, read_timeout_sec, read_timeout_usec, - write_timeout_sec, write_timeout_usec); + write_timeout_sec, write_timeout_usec, global_timeout_msec, + start_time); return callback(strm); } @@ -4313,7 +4364,7 @@ inline bool read_content_without_length(Stream &strm, uint64_t r = 0; for (;;) { auto n = strm.read(buf, CPPHTTPLIB_RECV_BUFSIZ); - if (n <= 0) { return true; } + if (n <= 0) { return false; } if (!out(buf, static_cast(n), r, 0)) { return false; } r += static_cast(n); @@ -5433,9 +5484,76 @@ inline std::string SHA_256(const std::string &s) { inline std::string SHA_512(const std::string &s) { return message_digest(s, EVP_sha512()); } -#endif -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +inline std::pair make_digest_authentication_header( + const Request &req, const std::map &auth, + size_t cnonce_count, const std::string &cnonce, const std::string &username, + const std::string &password, bool is_proxy = false) { + std::string nc; + { + std::stringstream ss; + ss << std::setfill('0') << std::setw(8) << std::hex << cnonce_count; + nc = ss.str(); + } + + std::string qop; + if (auth.find("qop") != auth.end()) { + qop = auth.at("qop"); + if (qop.find("auth-int") != std::string::npos) { + qop = "auth-int"; + } else if (qop.find("auth") != std::string::npos) { + qop = "auth"; + } else { + qop.clear(); + } + } + + std::string algo = "MD5"; + if (auth.find("algorithm") != auth.end()) { algo = auth.at("algorithm"); } + + std::string response; + { + auto H = algo == "SHA-256" ? detail::SHA_256 + : algo == "SHA-512" ? detail::SHA_512 + : detail::MD5; + + auto A1 = username + ":" + auth.at("realm") + ":" + password; + + auto A2 = req.method + ":" + req.path; + if (qop == "auth-int") { A2 += ":" + H(req.body); } + + if (qop.empty()) { + response = H(H(A1) + ":" + auth.at("nonce") + ":" + H(A2)); + } else { + response = H(H(A1) + ":" + auth.at("nonce") + ":" + nc + ":" + cnonce + + ":" + qop + ":" + H(A2)); + } + } + + auto opaque = (auth.find("opaque") != auth.end()) ? auth.at("opaque") : ""; + + auto field = "Digest username=\"" + username + "\", realm=\"" + + auth.at("realm") + "\", nonce=\"" + auth.at("nonce") + + "\", uri=\"" + req.path + "\", algorithm=" + algo + + (qop.empty() ? ", response=\"" + : ", qop=" + qop + ", nc=" + nc + ", cnonce=\"" + + cnonce + "\", response=\"") + + response + "\"" + + (opaque.empty() ? "" : ", opaque=\"" + opaque + "\""); + + auto key = is_proxy ? "Proxy-Authorization" : "Authorization"; + return std::make_pair(key, field); +} + +inline bool is_ssl_peer_could_be_closed(SSL *ssl, socket_t sock) { + detail::set_nonblocking(sock, true); + auto se = detail::scope_exit([&]() { detail::set_nonblocking(sock, false); }); + + char buf[1]; + return !SSL_peek(ssl, buf, 1) && + SSL_get_error(ssl, 0) == SSL_ERROR_ZERO_RETURN; +} + #ifdef _WIN32 // NOTE: This code came up with the following stackoverflow post: // https://stackoverflow.com/questions/9507184/can-openssl-on-windows-use-the-system-certificate-store @@ -5574,68 +5692,6 @@ public: static WSInit wsinit_; #endif -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT -inline std::pair make_digest_authentication_header( - const Request &req, const std::map &auth, - size_t cnonce_count, const std::string &cnonce, const std::string &username, - const std::string &password, bool is_proxy = false) { - std::string nc; - { - std::stringstream ss; - ss << std::setfill('0') << std::setw(8) << std::hex << cnonce_count; - nc = ss.str(); - } - - std::string qop; - if (auth.find("qop") != auth.end()) { - qop = auth.at("qop"); - if (qop.find("auth-int") != std::string::npos) { - qop = "auth-int"; - } else if (qop.find("auth") != std::string::npos) { - qop = "auth"; - } else { - qop.clear(); - } - } - - std::string algo = "MD5"; - if (auth.find("algorithm") != auth.end()) { algo = auth.at("algorithm"); } - - std::string response; - { - auto H = algo == "SHA-256" ? detail::SHA_256 - : algo == "SHA-512" ? detail::SHA_512 - : detail::MD5; - - auto A1 = username + ":" + auth.at("realm") + ":" + password; - - auto A2 = req.method + ":" + req.path; - if (qop == "auth-int") { A2 += ":" + H(req.body); } - - if (qop.empty()) { - response = H(H(A1) + ":" + auth.at("nonce") + ":" + H(A2)); - } else { - response = H(H(A1) + ":" + auth.at("nonce") + ":" + nc + ":" + cnonce + - ":" + qop + ":" + H(A2)); - } - } - - auto opaque = (auth.find("opaque") != auth.end()) ? auth.at("opaque") : ""; - - auto field = "Digest username=\"" + username + "\", realm=\"" + - auth.at("realm") + "\", nonce=\"" + auth.at("nonce") + - "\", uri=\"" + req.path + "\", algorithm=" + algo + - (qop.empty() ? ", response=\"" - : ", qop=" + qop + ", nc=" + nc + ", cnonce=\"" + - cnonce + "\", response=\"") + - response + "\"" + - (opaque.empty() ? "" : ", opaque=\"" + opaque + "\""); - - auto key = is_proxy ? "Proxy-Authorization" : "Authorization"; - return std::make_pair(key, field); -} -#endif - inline bool parse_www_authenticate(const Response &res, std::map &auth, bool is_proxy) { @@ -5954,20 +6010,45 @@ inline ssize_t Stream::write(const std::string &s) { namespace detail { +inline void calc_actual_timeout(time_t global_timeout_msec, + time_t duration_msec, time_t timeout_sec, + time_t timeout_usec, time_t &actual_timeout_sec, + time_t &actual_timeout_usec) { + auto timeout_msec = (timeout_sec * 1000) + (timeout_usec / 1000); + + auto actual_timeout_msec = + std::min(global_timeout_msec - duration_msec, timeout_msec); + + actual_timeout_sec = actual_timeout_msec / 1000; + actual_timeout_usec = (actual_timeout_msec % 1000) * 1000; +} + // Socket stream implementation -inline SocketStream::SocketStream(socket_t sock, time_t read_timeout_sec, - time_t read_timeout_usec, - time_t write_timeout_sec, - time_t write_timeout_usec) +inline SocketStream::SocketStream( + socket_t sock, time_t read_timeout_sec, time_t read_timeout_usec, + time_t write_timeout_sec, time_t write_timeout_usec, + time_t global_timeout_msec, + std::chrono::time_point start_time) : sock_(sock), read_timeout_sec_(read_timeout_sec), read_timeout_usec_(read_timeout_usec), write_timeout_sec_(write_timeout_sec), - write_timeout_usec_(write_timeout_usec), read_buff_(read_buff_size_, 0) {} + write_timeout_usec_(write_timeout_usec), + global_timeout_msec_(global_timeout_msec), start_time(start_time), + read_buff_(read_buff_size_, 0) {} inline SocketStream::~SocketStream() = default; inline bool SocketStream::is_readable() const { - return select_read(sock_, read_timeout_sec_, read_timeout_usec_) > 0; + if (global_timeout_msec_ <= 0) { + return select_read(sock_, read_timeout_sec_, read_timeout_usec_) > 0; + } + + time_t read_timeout_sec; + time_t read_timeout_usec; + calc_actual_timeout(global_timeout_msec_, duration(), read_timeout_sec_, + read_timeout_usec_, read_timeout_sec, read_timeout_usec); + + return select_read(sock_, read_timeout_sec, read_timeout_usec) > 0; } inline bool SocketStream::is_writable() const { @@ -6044,6 +6125,12 @@ inline void SocketStream::get_local_ip_and_port(std::string &ip, inline socket_t SocketStream::socket() const { return sock_; } +inline time_t SocketStream::duration() const { + return std::chrono::duration_cast( + std::chrono::steady_clock::now() - start_time) + .count(); +} + // Buffer stream implementation inline bool BufferStream::is_readable() const { return true; } @@ -6072,6 +6159,8 @@ inline void BufferStream::get_local_ip_and_port(std::string & /*ip*/, inline socket_t BufferStream::socket() const { return 0; } +inline time_t BufferStream::duration() const { return 0; } + inline const std::string &BufferStream::get_buffer() const { return buffer; } inline PathParamsMatcher::PathParamsMatcher(const std::string &pattern) { @@ -7368,6 +7457,7 @@ inline void ClientImpl::copy_settings(const ClientImpl &rhs) { read_timeout_usec_ = rhs.read_timeout_usec_; write_timeout_sec_ = rhs.write_timeout_sec_; write_timeout_usec_ = rhs.write_timeout_usec_; + global_timeout_msec_ = rhs.global_timeout_msec_; basic_auth_username_ = rhs.basic_auth_username_; basic_auth_password_ = rhs.basic_auth_password_; bearer_token_auth_token_ = rhs.bearer_token_auth_token_; @@ -7514,18 +7604,6 @@ inline bool ClientImpl::send(Request &req, Response &res, Error &error) { return ret; } -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT -inline bool ClientImpl::is_ssl_peer_could_be_closed(SSL *ssl) const { - detail::set_nonblocking(socket_.sock, true); - auto se = detail::scope_exit( - [&]() { detail::set_nonblocking(socket_.sock, false); }); - - char buf[1]; - return !SSL_peek(ssl, buf, 1) && - SSL_get_error(ssl, 0) == SSL_ERROR_ZERO_RETURN; -} -#endif - inline bool ClientImpl::send_(Request &req, Response &res, Error &error) { { std::lock_guard guard(socket_mutex_); @@ -7540,7 +7618,9 @@ inline bool ClientImpl::send_(Request &req, Response &res, Error &error) { #ifdef CPPHTTPLIB_OPENSSL_SUPPORT if (is_alive && is_ssl()) { - if (is_ssl_peer_could_be_closed(socket_.ssl)) { is_alive = false; } + if (detail::is_ssl_peer_could_be_closed(socket_.ssl, socket_.sock)) { + is_alive = false; + } } #endif @@ -7565,7 +7645,8 @@ inline bool ClientImpl::send_(Request &req, Response &res, Error &error) { auto &scli = static_cast(*this); if (!proxy_host_.empty() && proxy_port_ != -1) { auto success = false; - if (!scli.connect_with_proxy(socket_, res, success, error)) { + if (!scli.connect_with_proxy(socket_, req.start_time_, res, success, + error)) { return success; } } @@ -7611,7 +7692,7 @@ inline bool ClientImpl::send_(Request &req, Response &res, Error &error) { } }); - ret = process_socket(socket_, [&](Stream &strm) { + ret = process_socket(socket_, req.start_time_, [&](Stream &strm) { return handle_request(strm, req, res, close_connection, error); }); @@ -8020,6 +8101,9 @@ inline Result ClientImpl::send_with_content_provider( req.headers = headers; req.path = path; req.progress = progress; + if (global_timeout_msec_ > 0) { + req.start_time_ = std::chrono::steady_clock::now(); + } auto error = Error::Success; @@ -8046,7 +8130,7 @@ inline bool ClientImpl::process_request(Stream &strm, Request &req, if (is_ssl()) { auto is_proxy_enabled = !proxy_host_.empty() && proxy_port_ != -1; if (!is_proxy_enabled) { - if (is_ssl_peer_could_be_closed(socket_.ssl)) { + if (detail::is_ssl_peer_could_be_closed(socket_.ssl, socket_.sock)) { error = Error::SSLPeerCouldBeClosed_; return false; } @@ -8171,12 +8255,14 @@ inline ContentProviderWithoutLength ClientImpl::get_multipart_content_provider( }; } -inline bool -ClientImpl::process_socket(const Socket &socket, - std::function callback) { +inline bool ClientImpl::process_socket( + const Socket &socket, + std::chrono::time_point start_time, + std::function callback) { return detail::process_client_socket( socket.sock, read_timeout_sec_, read_timeout_usec_, write_timeout_sec_, - write_timeout_usec_, std::move(callback)); + write_timeout_usec_, global_timeout_msec_, start_time, + std::move(callback)); } inline bool ClientImpl::is_ssl() const { return false; } @@ -8200,6 +8286,9 @@ inline Result ClientImpl::Get(const std::string &path, const Headers &headers, req.path = path; req.headers = headers; req.progress = std::move(progress); + if (global_timeout_msec_ > 0) { + req.start_time_ = std::chrono::steady_clock::now(); + } return send_(std::move(req)); } @@ -8265,6 +8354,9 @@ inline Result ClientImpl::Get(const std::string &path, const Headers &headers, return content_receiver(data, data_length); }; req.progress = std::move(progress); + if (global_timeout_msec_ > 0) { + req.start_time_ = std::chrono::steady_clock::now(); + } return send_(std::move(req)); } @@ -8310,6 +8402,9 @@ inline Result ClientImpl::Head(const std::string &path, req.method = "HEAD"; req.headers = headers; req.path = path; + if (global_timeout_msec_ > 0) { + req.start_time_ = std::chrono::steady_clock::now(); + } return send_(std::move(req)); } @@ -8727,6 +8822,9 @@ inline Result ClientImpl::Delete(const std::string &path, req.headers = headers; req.path = path; req.progress = progress; + if (global_timeout_msec_ > 0) { + req.start_time_ = std::chrono::steady_clock::now(); + } if (!content_type.empty()) { req.set_header("Content-Type", content_type); } req.body.assign(body, content_length); @@ -8774,6 +8872,9 @@ inline Result ClientImpl::Options(const std::string &path, req.method = "OPTIONS"; req.headers = headers; req.path = path; + if (global_timeout_msec_ > 0) { + req.start_time_ = std::chrono::steady_clock::now(); + } return send_(std::move(req)); } @@ -8827,6 +8928,10 @@ inline void ClientImpl::set_write_timeout(time_t sec, time_t usec) { write_timeout_usec_ = usec; } +inline void ClientImpl::set_global_timeout(time_t msec) { + global_timeout_msec_ = msec; +} + inline void ClientImpl::set_basic_auth(const std::string &username, const std::string &password) { basic_auth_username_ = username; @@ -9065,12 +9170,14 @@ inline bool process_server_socket_ssl( } template -inline bool -process_client_socket_ssl(SSL *ssl, socket_t sock, time_t read_timeout_sec, - time_t read_timeout_usec, time_t write_timeout_sec, - time_t write_timeout_usec, T callback) { +inline bool process_client_socket_ssl( + SSL *ssl, socket_t sock, time_t read_timeout_sec, time_t read_timeout_usec, + time_t write_timeout_sec, time_t write_timeout_usec, + time_t global_timeout_msec, + std::chrono::time_point start_time, T callback) { SSLSocketStream strm(sock, ssl, read_timeout_sec, read_timeout_usec, - write_timeout_sec, write_timeout_usec); + write_timeout_sec, write_timeout_usec, + global_timeout_msec, start_time); return callback(strm); } @@ -9083,27 +9190,37 @@ public: }; // SSL socket stream implementation -inline SSLSocketStream::SSLSocketStream(socket_t sock, SSL *ssl, - time_t read_timeout_sec, - time_t read_timeout_usec, - time_t write_timeout_sec, - time_t write_timeout_usec) +inline SSLSocketStream::SSLSocketStream( + socket_t sock, SSL *ssl, time_t read_timeout_sec, time_t read_timeout_usec, + time_t write_timeout_sec, time_t write_timeout_usec, + time_t global_timeout_msec, + std::chrono::time_point start_time) : sock_(sock), ssl_(ssl), read_timeout_sec_(read_timeout_sec), read_timeout_usec_(read_timeout_usec), write_timeout_sec_(write_timeout_sec), - write_timeout_usec_(write_timeout_usec) { + write_timeout_usec_(write_timeout_usec), + global_timeout_msec_(global_timeout_msec), start_time(start_time) { SSL_clear_mode(ssl, SSL_MODE_AUTO_RETRY); } inline SSLSocketStream::~SSLSocketStream() = default; inline bool SSLSocketStream::is_readable() const { - return detail::select_read(sock_, read_timeout_sec_, read_timeout_usec_) > 0; + if (global_timeout_msec_ <= 0) { + return select_read(sock_, read_timeout_sec_, read_timeout_usec_) > 0; + } + + time_t read_timeout_sec; + time_t read_timeout_usec; + calc_actual_timeout(global_timeout_msec_, duration(), read_timeout_sec_, + read_timeout_usec_, read_timeout_sec, read_timeout_usec); + + return select_read(sock_, read_timeout_sec, read_timeout_usec) > 0; } inline bool SSLSocketStream::is_writable() const { return select_write(sock_, write_timeout_sec_, write_timeout_usec_) > 0 && - is_socket_alive(sock_); + is_socket_alive(sock_) && !is_ssl_peer_could_be_closed(ssl_, sock_); } inline ssize_t SSLSocketStream::read(char *ptr, size_t size) { @@ -9134,8 +9251,9 @@ inline ssize_t SSLSocketStream::read(char *ptr, size_t size) { } } return ret; + } else { + return -1; } - return -1; } inline ssize_t SSLSocketStream::write(const char *ptr, size_t size) { @@ -9181,6 +9299,12 @@ inline void SSLSocketStream::get_local_ip_and_port(std::string &ip, inline socket_t SSLSocketStream::socket() const { return sock_; } +inline time_t SSLSocketStream::duration() const { + return std::chrono::duration_cast( + std::chrono::steady_clock::now() - start_time) + .count(); +} + static SSLInit sslinit_; } // namespace detail @@ -9421,16 +9545,22 @@ inline bool SSLClient::create_and_connect_socket(Socket &socket, Error &error) { } // Assumes that socket_mutex_ is locked and that there are no requests in flight -inline bool SSLClient::connect_with_proxy(Socket &socket, Response &res, - bool &success, Error &error) { +inline bool SSLClient::connect_with_proxy( + Socket &socket, + std::chrono::time_point start_time, + Response &res, bool &success, Error &error) { success = true; Response proxy_res; if (!detail::process_client_socket( socket.sock, read_timeout_sec_, read_timeout_usec_, - write_timeout_sec_, write_timeout_usec_, [&](Stream &strm) { + write_timeout_sec_, write_timeout_usec_, global_timeout_msec_, + start_time, [&](Stream &strm) { Request req2; req2.method = "CONNECT"; req2.path = host_and_port_; + if (global_timeout_msec_ > 0) { + req2.start_time_ = std::chrono::steady_clock::now(); + } return process_request(strm, req2, proxy_res, false, error); })) { // Thread-safe to close everything because we are assuming there are no @@ -9450,7 +9580,8 @@ inline bool SSLClient::connect_with_proxy(Socket &socket, Response &res, proxy_res = Response(); if (!detail::process_client_socket( socket.sock, read_timeout_sec_, read_timeout_usec_, - write_timeout_sec_, write_timeout_usec_, [&](Stream &strm) { + write_timeout_sec_, write_timeout_usec_, global_timeout_msec_, + start_time, [&](Stream &strm) { Request req3; req3.method = "CONNECT"; req3.path = host_and_port_; @@ -9458,6 +9589,9 @@ inline bool SSLClient::connect_with_proxy(Socket &socket, Response &res, req3, auth, 1, detail::random_string(10), proxy_digest_auth_username_, proxy_digest_auth_password_, true)); + if (global_timeout_msec_ > 0) { + req3.start_time_ = std::chrono::steady_clock::now(); + } return process_request(strm, req3, proxy_res, false, error); })) { // Thread-safe to close everything because we are assuming there are @@ -9613,13 +9747,15 @@ inline void SSLClient::shutdown_ssl_impl(Socket &socket, assert(socket.ssl == nullptr); } -inline bool -SSLClient::process_socket(const Socket &socket, - std::function callback) { +inline bool SSLClient::process_socket( + const Socket &socket, + std::chrono::time_point start_time, + std::function callback) { assert(socket.ssl); return detail::process_client_socket_ssl( socket.ssl, socket.sock, read_timeout_sec_, read_timeout_usec_, - write_timeout_sec_, write_timeout_usec_, std::move(callback)); + write_timeout_sec_, write_timeout_usec_, global_timeout_msec_, start_time, + std::move(callback)); } inline bool SSLClient::is_ssl() const { return true; } diff --git a/test/fuzzing/server_fuzzer.cc b/test/fuzzing/server_fuzzer.cc index 3cffbae..b1ba3db 100644 --- a/test/fuzzing/server_fuzzer.cc +++ b/test/fuzzing/server_fuzzer.cc @@ -39,6 +39,8 @@ public: socket_t socket() const override { return 0; } + time_t duration() const override { return 0; }; + private: const uint8_t *data_; size_t size_; diff --git a/test/test.cc b/test/test.cc index b69be5c..f2b1089 100644 --- a/test/test.cc +++ b/test/test.cc @@ -162,7 +162,8 @@ TEST(SocketStream, is_writable_UNIX) { const auto asSocketStream = [&](socket_t fd, std::function func) { - return detail::process_client_socket(fd, 0, 0, 0, 0, func); + return detail::process_client_socket( + fd, 0, 0, 0, 0, 0, std::chrono::steady_clock::time_point::min(), func); }; asSocketStream(fds[0], [&](Stream &s0) { EXPECT_EQ(s0.socket(), fds[0]); @@ -206,7 +207,8 @@ TEST(SocketStream, is_writable_INET) { const auto asSocketStream = [&](socket_t fd, std::function func) { - return detail::process_client_socket(fd, 0, 0, 0, 0, func); + return detail::process_client_socket( + fd, 0, 0, 0, 0, 0, std::chrono::steady_clock::time_point::min(), func); }; asSocketStream(disconnected_svr_sock, [&](Stream &ss) { EXPECT_EQ(ss.socket(), disconnected_svr_sock); @@ -4904,7 +4906,8 @@ static bool send_request(time_t read_timeout_sec, const std::string &req, if (client_sock == INVALID_SOCKET) { return false; } auto ret = detail::process_client_socket( - client_sock, read_timeout_sec, 0, 0, 0, [&](Stream &strm) { + client_sock, read_timeout_sec, 0, 0, 0, 0, + std::chrono::steady_clock::time_point::min(), [&](Stream &strm) { if (req.size() != static_cast(strm.write(req.data(), req.size()))) { return false; @@ -8190,3 +8193,255 @@ TEST(Expect100ContinueTest, ServerClosesConnection) { } } #endif + +TEST(GlobalTimeoutTest, ContentStream) { + Server svr; + + svr.Get("/stream", [&](const Request &, Response &res) { + auto data = new std::string("01234567890123456789"); + + res.set_content_provider( + data->size(), "text/plain", + [&, data](size_t offset, size_t length, DataSink &sink) { + const size_t DATA_CHUNK_SIZE = 4; + const auto &d = *data; + std::this_thread::sleep_for(std::chrono::seconds(1)); + sink.write(&d[offset], std::min(length, DATA_CHUNK_SIZE)); + return true; + }, + [data](bool success) { + EXPECT_FALSE(success); + delete data; + }); + }); + + svr.Get("/stream_without_length", [&](const Request &, Response &res) { + auto i = new size_t(0); + + res.set_content_provider( + "text/plain", + [i](size_t, DataSink &sink) { + if (*i < 5) { + std::this_thread::sleep_for(std::chrono::seconds(1)); + sink.write("abcd", 4); + (*i)++; + } else { + sink.done(); + } + return true; + }, + [i](bool success) { + EXPECT_FALSE(success); + delete i; + }); + }); + + svr.Get("/chunked", [&](const Request &, Response &res) { + auto i = new size_t(0); + + res.set_chunked_content_provider( + "text/plain", + [i](size_t, DataSink &sink) { + if (*i < 5) { + std::this_thread::sleep_for(std::chrono::seconds(1)); + sink.os << "abcd"; + (*i)++; + } else { + sink.done(); + } + return true; + }, + [i](bool success) { + EXPECT_FALSE(success); + delete i; + }); + }); + + auto listen_thread = std::thread([&svr]() { svr.listen("localhost", PORT); }); + auto se = detail::scope_exit([&] { + svr.stop(); + listen_thread.join(); + ASSERT_FALSE(svr.is_running()); + }); + + svr.wait_until_ready(); + + const time_t timeout = 2000; + const time_t threshold = 200; + + Client cli("localhost", PORT); + cli.set_global_timeout(std::chrono::milliseconds(timeout)); + + + { + auto start = std::chrono::steady_clock::now(); + + auto res = cli.Get("/stream"); + + auto elapsed = std::chrono::duration_cast( + std::chrono::steady_clock::now() - start) + .count(); + + ASSERT_FALSE(res); + EXPECT_EQ(Error::Read, res.error()); + EXPECT_TRUE(timeout <= elapsed && elapsed < timeout + threshold); + } + + { + auto start = std::chrono::steady_clock::now(); + + auto res = cli.Get("/stream_without_length"); + + auto elapsed = std::chrono::duration_cast( + std::chrono::steady_clock::now() - start) + .count(); + + ASSERT_FALSE(res); + EXPECT_EQ(Error::Read, res.error()); + EXPECT_TRUE(timeout <= elapsed && elapsed < timeout + threshold); + } + + { + auto start = std::chrono::steady_clock::now(); + + auto res = cli.Get("/chunked", [&](const char *data, size_t data_length) { + EXPECT_EQ("abcd", string(data, data_length)); + return true; + }); + + auto elapsed = std::chrono::duration_cast( + std::chrono::steady_clock::now() - start) + .count(); + + ASSERT_FALSE(res); + EXPECT_EQ(Error::Read, res.error()); + EXPECT_TRUE(timeout <= elapsed && elapsed < timeout + threshold); + } +} + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +TEST(GlobalTimeoutTest, ContentStreamSSL) { + SSLServer svr(SERVER_CERT_FILE, SERVER_PRIVATE_KEY_FILE); + + svr.Get("/stream", [&](const Request &, Response &res) { + auto data = new std::string("01234567890123456789"); + + res.set_content_provider( + data->size(), "text/plain", + [&, data](size_t offset, size_t length, DataSink &sink) { + const size_t DATA_CHUNK_SIZE = 4; + const auto &d = *data; + std::this_thread::sleep_for(std::chrono::seconds(1)); + sink.write(&d[offset], std::min(length, DATA_CHUNK_SIZE)); + return true; + }, + [data](bool success) { + EXPECT_FALSE(success); + delete data; + }); + }); + + svr.Get("/stream_without_length", [&](const Request &, Response &res) { + auto i = new size_t(0); + + res.set_content_provider( + "text/plain", + [i](size_t, DataSink &sink) { + if (*i < 5) { + std::this_thread::sleep_for(std::chrono::seconds(1)); + sink.write("abcd", 4); + (*i)++; + } else { + sink.done(); + } + return true; + }, + [i](bool success) { + EXPECT_FALSE(success); + delete i; + }); + }); + + svr.Get("/chunked", [&](const Request &, Response &res) { + auto i = new size_t(0); + + res.set_chunked_content_provider( + "text/plain", + [i](size_t, DataSink &sink) { + if (*i < 5) { + std::this_thread::sleep_for(std::chrono::seconds(1)); + sink.os << "abcd"; + (*i)++; + } else { + sink.done(); + } + return true; + }, + [i](bool success) { + EXPECT_FALSE(success); + delete i; + }); + }); + + auto listen_thread = std::thread([&svr]() { svr.listen("localhost", PORT); }); + auto se = detail::scope_exit([&] { + svr.stop(); + listen_thread.join(); + ASSERT_FALSE(svr.is_running()); + }); + + svr.wait_until_ready(); + + const time_t timeout = 2000; + const time_t threshold = 1000; // SSL_shutdown is slow... + + SSLClient cli("localhost", PORT); + cli.enable_server_certificate_verification(false); + cli.set_global_timeout(std::chrono::milliseconds(timeout)); + + { + auto start = std::chrono::steady_clock::now(); + + auto res = cli.Get("/stream"); + + auto elapsed = std::chrono::duration_cast( + std::chrono::steady_clock::now() - start) + .count(); + + ASSERT_FALSE(res); + EXPECT_EQ(Error::Read, res.error()); + EXPECT_TRUE(timeout <= elapsed && elapsed < timeout + threshold); + } + + { + auto start = std::chrono::steady_clock::now(); + + auto res = cli.Get("/stream_without_length"); + + auto elapsed = std::chrono::duration_cast( + std::chrono::steady_clock::now() - start) + .count(); + + ASSERT_FALSE(res); + EXPECT_EQ(Error::Read, res.error()); + EXPECT_TRUE(timeout <= elapsed && elapsed < timeout + threshold); + } + + { + auto start = std::chrono::steady_clock::now(); + + auto res = cli.Get("/chunked", [&](const char *data, size_t data_length) { + EXPECT_EQ("abcd", string(data, data_length)); + return true; + }); + + auto elapsed = std::chrono::duration_cast( + std::chrono::steady_clock::now() - start) + .count(); + + ASSERT_FALSE(res); + EXPECT_EQ(Error::Read, res.error()); + EXPECT_TRUE(timeout <= elapsed && elapsed < timeout + threshold); + } +} +#endif