From eb4fcb5003a4bccb0fe8d39b23db484800089756 Mon Sep 17 00:00:00 2001 From: yhirose Date: Fri, 20 Dec 2019 06:59:59 -0500 Subject: [PATCH] CONNECT method support on client --- httplib.h | 379 ++++++++++++++++++++++++++++++++++++--------------- test/test.cc | 32 ++++- 2 files changed, 297 insertions(+), 114 deletions(-) diff --git a/httplib.h b/httplib.h index 38b4222..b6478ba 100644 --- a/httplib.h +++ b/httplib.h @@ -505,12 +505,13 @@ public: }; #endif +using Logger = std::function; + class Server { public: using Handler = std::function; using HandlerWithContentReader = std::function; - using Logger = std::function; Server(); @@ -614,7 +615,9 @@ private: class Client { public: - explicit Client(const char *host, int port = 80); + explicit Client(const std::string &host, int port = 80, + const std::string &client_cert_path = std::string(), + const std::string &client_key_path = std::string()); virtual ~Client(); @@ -736,11 +739,13 @@ public: void set_timeout_sec(time_t timeout_sec); - void set_keep_alive_max_count(size_t count); - void set_read_timeout(time_t sec, time_t usec); - void set_auth(const char *username, const char *password); + void set_keep_alive_max_count(size_t count); + + void set_basic_auth(const char *username, const char *password); + + void set_digest_auth(const char *username, const char *password); void set_follow_location(bool on); @@ -748,6 +753,14 @@ public: void set_interface(const char *intf); + void set_proxy(const char *host, int port); + + void set_proxy_basic_auth(const char *username, const char *password); + + void set_proxy_digest_auth(const char *username, const char *password); + + void set_logger(Logger logger); + protected: bool process_request(Stream &strm, const Request &req, Response &res, bool last_connection, bool &connection_close); @@ -756,17 +769,60 @@ protected: const int port_; const std::string host_and_port_; - // Options + // Settings + std::string client_cert_path_; + std::string client_key_path_; + time_t timeout_sec_ = 300; - size_t keep_alive_max_count_ = CPPHTTPLIB_KEEPALIVE_MAX_COUNT; time_t read_timeout_sec_ = CPPHTTPLIB_READ_TIMEOUT_SECOND; time_t read_timeout_usec_ = CPPHTTPLIB_READ_TIMEOUT_USECOND; - std::string username_; - std::string password_; + + size_t keep_alive_max_count_ = CPPHTTPLIB_KEEPALIVE_MAX_COUNT; + + std::string basic_auth_username_; + std::string basic_auth_password_; + std::string digest_auth_username_; + std::string digest_auth_password_; + bool follow_location_ = false; + bool compress_ = false; + std::string interface_; + std::string proxy_host_; + int proxy_port_; + + std::string proxy_basic_auth_username_; + std::string proxy_basic_auth_password_; + std::string proxy_digest_auth_username_; + std::string proxy_digest_auth_password_; + + Logger logger_; + + void copy_settings(const Client &rhs) { + client_cert_path_ = rhs.client_cert_path_; + client_key_path_ = rhs.client_key_path_; + timeout_sec_ = rhs.timeout_sec_; + read_timeout_sec_ = rhs.read_timeout_sec_; + read_timeout_usec_ = rhs.read_timeout_usec_; + keep_alive_max_count_ = rhs.keep_alive_max_count_; + basic_auth_username_ = rhs.basic_auth_username_; + basic_auth_password_ = rhs.basic_auth_password_; + digest_auth_username_ = rhs.digest_auth_username_; + digest_auth_password_ = rhs.digest_auth_password_; + follow_location_ = rhs.follow_location_; + compress_ = rhs.compress_; + interface_ = rhs.interface_; + proxy_host_ = rhs.proxy_host_; + proxy_port_ = rhs.proxy_port_; + proxy_basic_auth_username_ = rhs.proxy_basic_auth_username_; + proxy_basic_auth_password_ = rhs.proxy_basic_auth_password_; + proxy_digest_auth_username_ = rhs.proxy_digest_auth_username_; + proxy_digest_auth_password_ = rhs.proxy_digest_auth_password_; + logger_ = rhs.logger_; + } + private: socket_t create_client_socket() const; bool read_response_line(Stream &strm, Response &res); @@ -856,9 +912,9 @@ private: class SSLClient : public Client { public: - SSLClient(const char *host, int port = 443, - const char *client_cert_path = nullptr, - const char *client_key_path = nullptr); + SSLClient(const std::string &host, int port = 443, + const std::string &client_cert_path = std::string(), + const std::string &client_key_path = std::string()); virtual ~SSLClient(); @@ -866,6 +922,7 @@ public: void set_ca_cert_path(const char *ca_ceert_file_path, const char *ca_cert_dir_path = nullptr); + void enable_server_certificate_verification(bool enabled); long get_openssl_verify_result() const; @@ -889,7 +946,6 @@ private: std::mutex ctx_mutex_; std::vector host_components_; - // Options std::string ca_cert_file_path_; std::string ca_cert_dir_path_; bool server_certificate_verification_ = false; @@ -1234,10 +1290,9 @@ inline bool wait_until_socket_is_ready(socket_t sock, time_t sec, time_t usec) { } template -inline bool process_and_close_socket(bool is_client_request, socket_t sock, - size_t keep_alive_max_count, - time_t read_timeout_sec, - time_t read_timeout_usec, T callback) { +inline bool process_socket(bool is_client_request, socket_t sock, + size_t keep_alive_max_count, time_t read_timeout_sec, + time_t read_timeout_usec, T callback) { assert(keep_alive_max_count > 0); bool ret = false; @@ -1263,6 +1318,16 @@ inline bool process_and_close_socket(bool is_client_request, socket_t sock, ret = callback(strm, true, dummy_connection_close); } + return ret; +} + +template +inline bool process_and_close_socket(bool is_client_request, socket_t sock, + size_t keep_alive_max_count, + time_t read_timeout_sec, + time_t read_timeout_usec, T callback) { + auto ret = process_socket(is_client_request, sock, keep_alive_max_count, + read_timeout_sec, read_timeout_usec, callback); close_socket(sock); return ret; } @@ -1309,8 +1374,8 @@ socket_t create_socket(const char *host, int port, Fn fn, auto sock = WSASocketW(rp->ai_family, rp->ai_socktype, rp->ai_protocol, nullptr, 0, WSA_FLAG_NO_HANDLE_INHERIT); /** - * Since the WSA_FLAG_NO_HANDLE_INHERIT is only supported on Windows 7 SP1 and above - * the socket creation fails on older Windows Systems. + * Since the WSA_FLAG_NO_HANDLE_INHERIT is only supported on Windows 7 SP1 + * and above the socket creation fails on older Windows Systems. * * Let's try to create a socket the old way in this case. * @@ -1318,11 +1383,12 @@ socket_t create_socket(const char *host, int port, Fn fn, * https://docs.microsoft.com/en-us/windows/win32/api/winsock2/nf-winsock2-wsasocketa * * WSA_FLAG_NO_HANDLE_INHERIT: - * This flag is supported on Windows 7 with SP1, Windows Server 2008 R2 with SP1, and later + * This flag is supported on Windows 7 with SP1, Windows Server 2008 R2 with + * SP1, and later * */ if (sock == INVALID_SOCKET) { - sock = socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol); + sock = socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol); } #else auto sock = socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol); @@ -1880,17 +1946,12 @@ write_content_chunked(Stream &strm, template inline bool redirect(T &cli, const Request &req, Response &res, const std::string &path) { - Request new_req; - new_req.method = req.method; + Request new_req = req; new_req.path = path; - new_req.headers = req.headers; - new_req.body = req.body; - new_req.redirect_count = req.redirect_count - 1; - new_req.response_handler = req.response_handler; - new_req.content_receiver = req.content_receiver; - new_req.progress = req.progress; + new_req.redirect_count -= 1; Response new_res; + auto ret = cli.send(new_req, new_res); if (ret) { res = new_res; } return ret; @@ -2416,16 +2477,17 @@ inline std::pair make_range_header(Ranges ranges) { inline std::pair make_basic_authentication_header(const std::string &username, - const std::string &password) { + const std::string &password, bool proxy = false) { auto field = "Basic " + detail::base64_encode(username + ":" + password); - return std::make_pair("Authorization", field); + auto key = proxy ? "Proxy-Authorization" : "Authorization"; + return std::make_pair(key, field); } #ifdef CPPHTTPLIB_OPENSSL_SUPPORT inline std::pair make_digest_authentication_header( const Request &req, const std::map &auth, size_t cnonce_count, const std::string &cnonce, const std::string &username, - const std::string &password) { + const std::string &password, bool proxy = false) { using namespace std; string nc; @@ -2442,10 +2504,11 @@ inline std::pair make_digest_authentication_header( qop = "auth"; } + std::string algo = "MD5"; + if (auth.find("algorithm") != auth.end()) { algo = auth.at("algorithm"); } + string response; { - auto algo = auth.at("algorithm"); - auto H = algo == "SHA-256" ? detail::SHA_256 : algo == "SHA-512" ? detail::SHA_512 : detail::MD5; @@ -2461,25 +2524,26 @@ inline std::pair make_digest_authentication_header( auto field = "Digest username=\"hello\", realm=\"" + auth.at("realm") + "\", nonce=\"" + auth.at("nonce") + "\", uri=\"" + req.path + - "\", algorithm=" + auth.at("algorithm") + ", qop=" + qop + - ", nc=\"" + nc + "\", cnonce=\"" + cnonce + "\", response=\"" + - response + "\""; + "\", algorithm=" + algo + ", qop=" + qop + ", nc=\"" + nc + + "\", cnonce=\"" + cnonce + "\", response=\"" + response + "\""; - return make_pair("Authorization", field); + auto key = proxy ? "Proxy-Authorization" : "Authorization"; + return std::make_pair(key, field); } #endif -inline int -parse_www_authenticate(const httplib::Response &res, - std::map &digest_auth) { - if (res.has_header("WWW-Authenticate")) { +inline bool parse_www_authenticate(const httplib::Response &res, + std::map &auth, + bool proxy) { + auto key = proxy ? "Proxy-Authenticate" : "WWW-Authenticate"; + if (res.has_header(key)) { static auto re = std::regex(R"~((?:(?:,\s*)?(.+?)=(?:"(.*?)"|([^,]*))))~"); - auto s = res.get_header_value("WWW-Authenticate"); + auto s = res.get_header_value(key); auto pos = s.find(' '); if (pos != std::string::npos) { auto type = s.substr(0, pos); if (type == "Basic") { - return 1; + return false; } else if (type == "Digest") { s = s.substr(pos + 1); auto beg = std::sregex_iterator(s.begin(), s.end(), re); @@ -2488,13 +2552,13 @@ parse_www_authenticate(const httplib::Response &res, auto key = s.substr(m.position(1), m.length(1)); auto val = m.length(2) > 0 ? s.substr(m.position(2), m.length(2)) : s.substr(m.position(3), m.length(3)); - digest_auth[key] = val; + auth[key] = val; } - return 2; + return true; } } } - return 0; + return false; } // https://stackoverflow.com/questions/440133/how-do-i-create-a-random-alpha-numeric-string-in-c/440240#answer-440240 @@ -3377,15 +3441,22 @@ inline bool Server::process_and_close_socket(socket_t sock) { } // HTTP client implementation -inline Client::Client(const char *host, int port) +inline Client::Client(const std::string &host, int port, + const std::string &client_cert_path, + const std::string &client_key_path) : host_(host), port_(port), - host_and_port_(host_ + ":" + std::to_string(port_)) {} + host_and_port_(host_ + ":" + std::to_string(port_)), + client_cert_path_(client_cert_path), client_key_path_(client_key_path) {} inline Client::~Client() {} inline bool Client::is_valid() const { return true; } inline socket_t Client::create_client_socket() const { + if (!proxy_host_.empty()) { + return detail::create_client_socket(proxy_host_.c_str(), proxy_port_, + timeout_sec_, interface_); + } return detail::create_client_socket(host_.c_str(), port_, timeout_sec_, interface_); } @@ -3414,54 +3485,97 @@ inline bool Client::send(const Request &req, Response &res) { auto sock = create_client_socket(); if (sock == INVALID_SOCKET) { return false; } - auto ret = process_and_close_socket( - sock, 1, [&](Stream &strm, bool last_connection, bool &connection_close) { - return process_request(strm, req, res, last_connection, - connection_close); - }); - - if (ret && follow_location_ && (300 < res.status && res.status < 400)) { - ret = redirect(req, res); - } - - if (ret && !username_.empty() && !password_.empty() && res.status == 401) { - int type; - std::map digest_auth; - - if ((type = parse_www_authenticate(res, digest_auth)) > 0) { - std::pair header; - - if (type == 1) { - header = make_basic_authentication_header(username_, password_); - } else if (type == 2) { #ifdef CPPHTTPLIB_OPENSSL_SUPPORT - size_t cnonce_count = 1; - auto cnonce = random_string(10); + // CONNECT + if (is_ssl() && !proxy_host_.empty()) { + Response res2; + if (!detail::process_socket( + true, sock, 1, read_timeout_sec_, read_timeout_usec_, + [&](Stream &strm, bool /*last_connection*/, + bool &connection_close) { + Request req2; + req2.method = "CONNECT"; + req2.path = host_and_port_; + return process_request(strm, req2, res2, false, connection_close); + })) { + return false; + } - header = make_digest_authentication_header( - req, digest_auth, cnonce_count, cnonce, username_, password_); -#endif + if (res2.status == 407 && !proxy_digest_auth_username_.empty() && + !proxy_digest_auth_password_.empty()) { + std::map auth; + if (parse_www_authenticate(res2, auth, true)) { + detail::close_socket(sock); + sock = create_client_socket(); + if (sock == INVALID_SOCKET) { return false; } + + Response res2; + if (!detail::process_socket( + true, sock, 1, read_timeout_sec_, read_timeout_usec_, + [&](Stream &strm, bool /*last_connection*/, + bool &connection_close) { + Request req2; + req2.method = "CONNECT"; + req2.path = host_and_port_; + req2.headers.insert(make_digest_authentication_header( + req2, auth, 1, random_string(10), + proxy_digest_auth_username_, proxy_digest_auth_password_, + true)); + return process_request(strm, req2, res2, false, + connection_close); + })) { + return false; + } } - - Request new_req; - new_req.method = req.method; - new_req.path = req.path; - new_req.headers = req.headers; - new_req.body = req.body; - new_req.response_handler = req.response_handler; - new_req.content_receiver = req.content_receiver; - new_req.progress = req.progress; - - new_req.headers.insert(header); - - Response new_res; - auto ret = send(new_req, new_res); - if (ret) { res = new_res; } - return ret; } } +#endif - return ret; + if (!process_and_close_socket( + sock, 1, + [&](Stream &strm, bool last_connection, bool &connection_close) { + if (!is_ssl() && !proxy_host_.empty()) { + auto req2 = req; + req2.path = "http://" + host_and_port_ + req.path; + return process_request(strm, req2, res, last_connection, + connection_close); + } + return process_request(strm, req, res, last_connection, + connection_close); + })) { + return false; + } + + if (300 < res.status && res.status < 400 && follow_location_) { + return redirect(req, res); + } + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + if (res.status == 401 || res.status == 407) { + auto is_proxy = res.status == 407; + const auto &username = + is_proxy ? proxy_digest_auth_username_ : digest_auth_username_; + const auto &password = + is_proxy ? proxy_digest_auth_password_ : digest_auth_password_; + + if (!username.empty() && !password.empty()) { + std::map auth; + if (parse_www_authenticate(res, auth, is_proxy)) { + Request new_req = req; + new_req.headers.insert(make_digest_authentication_header( + req, auth, 1, random_string(10), username, password, is_proxy)); + + Response new_res; + + auto ret = send(new_req, new_res); + if (ret) { res = new_res; } + return ret; + } + } + } +#endif + + return true; } inline bool Client::send(const std::vector &requests, @@ -3511,28 +3625,30 @@ inline bool Client::redirect(const Request &req, Response &res) { std::smatch m; if (!regex_match(location, m, re)) { return false; } + auto scheme = is_ssl() ? "https" : "http"; + auto next_scheme = m[1].str(); auto next_host = m[2].str(); auto next_path = m[3].str(); + if (next_scheme.empty()) { next_scheme = scheme; } + if (next_scheme.empty()) { next_scheme = scheme; } if (next_host.empty()) { next_host = host_; } if (next_path.empty()) { next_path = "/"; } - auto scheme = is_ssl() ? "https" : "http"; - if (next_scheme == scheme && next_host == host_) { return detail::redirect(*this, req, res, next_path); } else { if (next_scheme == "https") { #ifdef CPPHTTPLIB_OPENSSL_SUPPORT SSLClient cli(next_host.c_str()); - cli.set_follow_location(true); + cli.copy_settings(*this); return detail::redirect(cli, req, res, next_path); #else return false; #endif } else { Client cli(next_host.c_str()); - cli.set_follow_location(true); + cli.copy_settings(*this); return detail::redirect(cli, req, res, next_path); } } @@ -3544,7 +3660,7 @@ inline bool Client::write_request(Stream &strm, const Request &req, // Request line const static std::regex re( - R"(^([^:/?#]+://[^/?#]*)?([^?#]*(?:\?[^#]*)?(?:#.*)?))"); + R"(^((?:[^:/?#]+://)?(?:[^/?#]*)?)?([^?#]*(?:\?[^#]*)?(?:#.*)?))"); std::smatch m; if (!regex_match(req.path, m, re)) { return false; } @@ -3597,6 +3713,17 @@ inline bool Client::write_request(Stream &strm, const Request &req, } } + if (!basic_auth_username_.empty() && !basic_auth_password_.empty()) { + headers.insert(make_basic_authentication_header( + basic_auth_username_, basic_auth_password_, false)); + } + + if (!proxy_basic_auth_username_.empty() && + !proxy_basic_auth_password_.empty()) { + headers.insert(make_basic_authentication_header( + proxy_basic_auth_username_, proxy_basic_auth_password_, true)); + } + detail::write_headers(bstrm, req, headers); // Flush buffer @@ -3689,7 +3816,7 @@ inline bool Client::process_request(Stream &strm, const Request &req, } // Body - if (req.method != "HEAD") { + if (req.method != "HEAD" && req.method != "CONNECT") { ContentReceiver out = [&](const char *buf, size_t n) { if (res.body.size() + n > res.body.max_size()) { return false; } res.body.append(buf, n); @@ -3709,6 +3836,9 @@ inline bool Client::process_request(Stream &strm, const Request &req, } } + // Log + if (logger_) { logger_(req, res); } + return true; } @@ -4010,18 +4140,24 @@ inline void Client::set_timeout_sec(time_t timeout_sec) { timeout_sec_ = timeout_sec; } -inline void Client::set_keep_alive_max_count(size_t count) { - keep_alive_max_count_ = count; -} - inline void Client::set_read_timeout(time_t sec, time_t usec) { read_timeout_sec_ = sec; read_timeout_usec_ = usec; } -inline void Client::set_auth(const char *username, const char *password) { - username_ = username; - password_ = password; +inline void Client::set_keep_alive_max_count(size_t count) { + keep_alive_max_count_ = count; +} + +inline void Client::set_basic_auth(const char *username, const char *password) { + basic_auth_username_ = username; + basic_auth_password_ = password; +} + +inline void Client::set_digest_auth(const char *username, + const char *password) { + digest_auth_username_ = username; + digest_auth_password_ = password; } inline void Client::set_follow_location(bool on) { follow_location_ = on; } @@ -4030,6 +4166,25 @@ inline void Client::set_compress(bool on) { compress_ = on; } inline void Client::set_interface(const char *intf) { interface_ = intf; } +inline void Client::set_proxy(const char *host, int port) { + proxy_host_ = host; + proxy_port_ = port; +} + +inline void Client::set_proxy_basic_auth(const char *username, + const char *password) { + proxy_basic_auth_username_ = username; + proxy_basic_auth_password_ = password; +} + +inline void Client::set_proxy_digest_auth(const char *username, + const char *password) { + proxy_digest_auth_username_ = username; + proxy_digest_auth_password_ = password; +} + +inline void Client::set_logger(Logger logger) { logger_ = std::move(logger); } + /* * SSL Implementation */ @@ -4249,21 +4404,21 @@ inline bool SSLServer::process_and_close_socket(socket_t sock) { } // SSL HTTP client implementation -inline SSLClient::SSLClient(const char *host, int port, - const char *client_cert_path, - const char *client_key_path) - : Client(host, port) { +inline SSLClient::SSLClient(const std::string &host, int port, + const std::string &client_cert_path, + const std::string &client_key_path) + : Client(host, port, client_cert_path, client_key_path) { ctx_ = SSL_CTX_new(SSLv23_client_method()); detail::split(&host_[0], &host_[host_.size()], '.', [&](const char *b, const char *e) { host_components_.emplace_back(std::string(b, e)); }); - if (client_cert_path && client_key_path) { - if (SSL_CTX_use_certificate_file(ctx_, client_cert_path, + if (!client_cert_path.empty() && !client_key_path.empty()) { + if (SSL_CTX_use_certificate_file(ctx_, client_cert_path.c_str(), SSL_FILETYPE_PEM) != 1 || - SSL_CTX_use_PrivateKey_file(ctx_, client_key_path, SSL_FILETYPE_PEM) != - 1) { + SSL_CTX_use_PrivateKey_file(ctx_, client_key_path.c_str(), + SSL_FILETYPE_PEM) != 1) { SSL_CTX_free(ctx_); ctx_ = nullptr; } diff --git a/test/test.cc b/test/test.cc index 998a927..2cca8f7 100644 --- a/test/test.cc +++ b/test/test.cc @@ -474,13 +474,27 @@ TEST(BaseAuthTest, FromHTTPWatch) { } { - cli.set_auth("hello", "world"); + cli.set_basic_auth("hello", "world"); auto res = cli.Get("/basic-auth/hello/world"); ASSERT_TRUE(res != nullptr); EXPECT_EQ(res->body, "{\n \"authenticated\": true, \n \"user\": \"hello\"\n}\n"); EXPECT_EQ(200, res->status); } + + { + cli.set_basic_auth("hello", "bad"); + auto res = cli.Get("/basic-auth/hello/world"); + ASSERT_TRUE(res != nullptr); + EXPECT_EQ(401, res->status); + } + + { + cli.set_basic_auth("bad", "world"); + auto res = cli.Get("/basic-auth/hello/world"); + ASSERT_TRUE(res != nullptr); + EXPECT_EQ(401, res->status); + } } #ifdef CPPHTTPLIB_OPENSSL_SUPPORT @@ -504,7 +518,7 @@ TEST(DigestAuthTest, FromHTTPWatch) { "/digest-auth/auth-int/hello/world/MD5", }; - cli.set_auth("hello", "world"); + cli.set_digest_auth("hello", "world"); for (auto path : paths) { auto res = cli.Get(path.c_str()); ASSERT_TRUE(res != nullptr); @@ -512,6 +526,20 @@ TEST(DigestAuthTest, FromHTTPWatch) { "{\n \"authenticated\": true, \n \"user\": \"hello\"\n}\n"); EXPECT_EQ(200, res->status); } + + cli.set_digest_auth("hello", "bad"); + for (auto path : paths) { + auto res = cli.Get(path.c_str()); + ASSERT_TRUE(res != nullptr); + EXPECT_EQ(400, res->status); + } + + cli.set_digest_auth("bad", "world"); + for (auto path : paths) { + auto res = cli.Get(path.c_str()); + ASSERT_TRUE(res != nullptr); + EXPECT_EQ(400, res->status); + } } } #endif