diff --git a/README.md b/README.md index ec7c543..a63b53f 100644 --- a/README.md +++ b/README.md @@ -324,16 +324,21 @@ std::shared_ptr res = This feature was contributed by [underscorediscovery](https://github.com/yhirose/cpp-httplib/pull/23). -### Basic Authentication +### Authentication ```cpp httplib::Client cli("httplib.org"); +cli.set_auth("user", "pass"); -auto res = cli.Get("/basic-auth/hello/world", { - httplib::make_basic_authentication_header("hello", "world") -}); +// Basic +auto res = cli.Get("/basic-auth/user/pass"); // res->status should be 200 -// res->body should be "{\n \"authenticated\": true, \n \"user\": \"hello\"\n}\n". +// res->body should be "{\n \"authenticated\": true, \n \"user\": \"user\"\n}\n". + +// Digest +res = cli.Get("/digest-auth/auth/user/pass/SHA-256"); +// res->status should be 200 +// res->body should be "{\n \"authenticated\": true, \n \"user\": \"user\"\n}\n". ``` ### Range diff --git a/example/Makefile b/example/Makefile index 28cea04..0add1e4 100644 --- a/example/Makefile +++ b/example/Makefile @@ -33,4 +33,4 @@ pem: openssl req -new -key key.pem | openssl x509 -days 3650 -req -signkey key.pem > cert.pem clean: - rm server client hello simplesvr upload redirect *.pem + rm server client hello simplesvr upload redirect benchmark *.pem diff --git a/httplib.h b/httplib.h index dc04f88..de929e4 100644 --- a/httplib.h +++ b/httplib.h @@ -149,9 +149,13 @@ using socket_t = int; #ifdef CPPHTTPLIB_OPENSSL_SUPPORT #include +#include #include #include +#include +#include + // #if OPENSSL_VERSION_NUMBER < 0x1010100fL // #error Sorry, OpenSSL versions prior to 1.1.1 are not supported // #endif @@ -756,10 +760,13 @@ public: std::vector &responses); void set_keep_alive_max_count(size_t count); + void set_read_timeout(time_t sec, time_t usec); void follow_location(bool on); + void set_auth(const char *username, const char *password); + protected: bool process_request(Stream &strm, const Request &req, Response &res, bool last_connection, bool &connection_close); @@ -772,6 +779,8 @@ protected: time_t read_timeout_sec_; time_t read_timeout_usec_; size_t follow_location_; + std::string username_; + std::string password_; private: socket_t create_client_socket() const; @@ -1439,6 +1448,7 @@ inline const char *status_message(int status) { case 303: return "See Other"; case 304: return "Not Modified"; case 400: return "Bad Request"; + case 401: return "Unauthorized"; case 403: return "Forbidden"; case 404: return "Not Found"; case 413: return "Payload Too Large"; @@ -2287,6 +2297,43 @@ inline bool expect_content(const Request &req) { return false; } +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +template +inline std::string message_digest(const std::string &s, Init init, + Update update, Final final, + size_t digest_length) { + using namespace std; + + unsigned char md[digest_length]; + CTX ctx; + init(&ctx); + update(&ctx, s.data(), s.size()); + final(md, &ctx); + + stringstream ss; + for (auto c : md) { + ss << setfill('0') << setw(2) << hex << (unsigned int)c; + } + return ss.str(); +} + +inline std::string MD5(const std::string &s) { + using namespace detail; + return message_digest(s, MD5_Init, MD5_Update, MD5_Final, + MD5_DIGEST_LENGTH); +} + +inline std::string SHA_256(const std::string &s) { + return message_digest(s, SHA256_Init, SHA256_Update, SHA256_Final, + SHA256_DIGEST_LENGTH); +} + +inline std::string SHA_512(const std::string &s) { + return message_digest(s, SHA512_Init, SHA512_Update, SHA512_Final, + SHA512_DIGEST_LENGTH); +} +#endif + #ifdef _WIN32 class WSInit { public: @@ -2324,6 +2371,98 @@ make_basic_authentication_header(const std::string &username, return std::make_pair("Authorization", field); } +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +inline std::pair make_digest_authentication_header( + const Request &req, + const std::map &auth, + size_t cnonce_count, const std::string &cnonce, + const std::string &username, const std::string &password) { + using namespace std; + + string nc; + { + stringstream ss; + ss << setfill('0') << setw(8) << hex << cnonce_count; + nc = ss.str(); + } + + auto qop = auth.at("qop"); + if (qop.find("auth-int") != std::string::npos) { + qop = "auth-int"; + } else { + qop = "auth"; + } + + string response; + { + auto algo = auth.at("algorithm"); + + auto H = algo == "SHA-256" + ? detail::SHA_256 + : algo == "SHA-512" ? detail::SHA_512 : detail::MD5; + + auto A1 = username + ":" + auth.at("realm") + ":" + password; + + auto A2 = req.method + ":" + req.path; + if (qop == "auth-int") { + A2 += ":" + H(req.body); + } + + response = H(H(A1) + ":" + auth.at("nonce") + ":" + nc + ":" + cnonce + + ":" + qop + ":" + H(A2)); + } + + auto field = "Digest username=\"hello\", realm=\"" + auth.at("realm") + + "\", nonce=\"" + auth.at("nonce") + "\", uri=\"" + req.path + + "\", algorithm=" + auth.at("algorithm") + ", qop=" + qop + ", nc=\"" + + nc + "\", cnonce=\"" + cnonce + "\", response=\"" + response + + "\""; + + return make_pair("Authorization", field); +} +#endif + +inline int parse_www_authenticate(const httplib::Response &res, + std::map &digest_auth) { + if (res.has_header("WWW-Authenticate")) { + static auto re = std::regex(R"~((?:(?:,\s*)?(.+?)=(?:"(.*?)"|([^,]*)))))~"); + auto s = res.get_header_value("WWW-Authenticate"); + auto pos = s.find(' '); + if (pos != std::string::npos) { + auto type = s.substr(0, pos); + if (type == "Basic") { + return 1; + } else if (type == "Digest") { + s = s.substr(pos + 1); + auto beg = std::sregex_iterator(s.begin(), s.end(), re); + for (auto i = beg; i != std::sregex_iterator(); ++i) { + auto m = *i; + auto key = s.substr(m.position(1), m.length(1)); + auto val = m.length(2) > 0 ? s.substr(m.position(2), m.length(2)) + : s.substr(m.position(3), m.length(3)); + digest_auth[key] = val; + } + return 2; + } + } + } + return 0; +} + +// https://stackoverflow.com/questions/440133/how-do-i-create-a-random-alpha-numeric-string-in-c/440240#answer-440240 +inline std::string random_string(size_t length) { + auto randchar = []() -> char { + const char charset[] = "0123456789" + "ABCDEFGHIJKLMNOPQRSTUVWXYZ" + "abcdefghijklmnopqrstuvwxyz"; + const size_t max_index = (sizeof(charset) - 1); + return charset[rand() % max_index]; + }; + std::string str(length, 0); + std::generate_n(str.begin(), length, randchar); + return str; +} + // Request implementation inline bool Request::has_header(const char *key) const { return detail::has_header(headers, key); @@ -3244,6 +3383,43 @@ inline bool Client::send(const Request &req, Response &res) { ret = redirect(req, res); } + if (ret && !username_.empty() && !password_.empty() && res.status == 401) { + int type; + std::map digest_auth; + + if ((type = parse_www_authenticate(res, digest_auth)) > 0) { + std::pair header; + + if (type == 1) { + header = make_basic_authentication_header(username_, password_); + } else if (type == 2) { +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + size_t cnonce_count = 1; + auto cnonce = random_string(10); + + header = make_digest_authentication_header( + req, digest_auth, cnonce_count, cnonce, username_, password_); +#endif + } + + Request new_req; + new_req.method = req.method; + new_req.path = req.path; + new_req.headers = req.headers; + new_req.body = req.body; + new_req.response_handler = req.response_handler; + new_req.content_receiver = req.content_receiver; + new_req.progress = req.progress; + + new_req.headers.insert(header); + + Response new_res; + auto ret = send(new_req, new_res); + if (ret) { res = new_res; } + return ret; + } + } + return ret; } @@ -3810,6 +3986,11 @@ inline void Client::set_read_timeout(time_t sec, time_t usec) { inline void Client::follow_location(bool on) { follow_location_ = on; } +inline void Client::set_auth(const char *username, const char *password) { + username_ = username; + password_ = password; +} + /* * SSL Implementation */ diff --git a/test/test.cc b/test/test.cc index bdc02d4..8fecdec 100644 --- a/test/test.cc +++ b/test/test.cc @@ -469,8 +469,50 @@ TEST(BaseAuthTest, FromHTTPWatch) { "{\n \"authenticated\": true, \n \"user\": \"hello\"\n}\n"); EXPECT_EQ(200, res->status); } + + { + cli.set_auth("hello", "world"); + auto res = cli.Get("/basic-auth/hello/world"); + ASSERT_TRUE(res != nullptr); + EXPECT_EQ(res->body, + "{\n \"authenticated\": true, \n \"user\": \"hello\"\n}\n"); + EXPECT_EQ(200, res->status); + } } +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +TEST(DigestAuthTest, FromHTTPWatch) { + auto host = "httpbin.org"; + auto port = 443; + httplib::SSLClient cli(host, port); + + { + auto res = cli.Get("/digest-auth/auth/hello/world"); + ASSERT_TRUE(res != nullptr); + EXPECT_EQ(401, res->status); + } + + { + std::vector paths = { + "/digest-auth/auth/hello/world/MD5", + "/digest-auth/auth/hello/world/SHA-256", + "/digest-auth/auth/hello/world/SHA-512", + "/digest-auth/auth-init/hello/world/MD5", + "/digest-auth/auth-int/hello/world/MD5", + }; + + cli.set_auth("hello", "world"); + for (auto path: paths) { + auto res = cli.Get(path.c_str()); + ASSERT_TRUE(res != nullptr); + EXPECT_EQ(res->body, + "{\n \"authenticated\": true, \n \"user\": \"hello\"\n}\n"); + EXPECT_EQ(200, res->status); + } + } +} +#endif + TEST(AbsoluteRedirectTest, Redirect) { auto host = "httpbin.org";