From 9ca1fa8b18e57727c7f6e2b0db82c3841132b0e8 Mon Sep 17 00:00:00 2001 From: yhirose Date: Sat, 25 Jul 2020 09:37:57 -0400 Subject: [PATCH] Fix #576 --- httplib.h | 171 +++++++++++++++++++++++++++++---------------------- test/test.cc | 16 ++++- 2 files changed, 111 insertions(+), 76 deletions(-) diff --git a/httplib.h b/httplib.h index 44a96f0..6cc04cc 100644 --- a/httplib.h +++ b/httplib.h @@ -349,6 +349,8 @@ struct Request { bool has_header(const char *key) const; std::string get_header_value(const char *key, size_t id = 0) const; + template + T get_header_value(const char *key, size_t id = 0) const; size_t get_header_value_count(const char *key) const; void set_header(const char *key, const char *val); void set_header(const char *key, const std::string &val); @@ -374,6 +376,8 @@ struct Response { bool has_header(const char *key) const; std::string get_header_value(const char *key, size_t id = 0) const; + template + T get_header_value(const char *key, size_t id = 0) const; size_t get_header_value_count(const char *key) const; void set_header(const char *key, const char *val); void set_header(const char *key, const std::string &val); @@ -1580,6 +1584,74 @@ inline bool is_valid_path(const std::string &path) { return true; } +inline std::string encode_url(const std::string &s) { + std::string result; + + for (size_t i = 0; s[i]; i++) { + switch (s[i]) { + case ' ': result += "%20"; break; + case '+': result += "%2B"; break; + case '\r': result += "%0D"; break; + case '\n': result += "%0A"; break; + case '\'': result += "%27"; break; + case ',': result += "%2C"; break; + // case ':': result += "%3A"; break; // ok? probably... + case ';': result += "%3B"; break; + default: + auto c = static_cast(s[i]); + if (c >= 0x80) { + result += '%'; + char hex[4]; + auto len = snprintf(hex, sizeof(hex) - 1, "%02X", c); + assert(len == 2); + result.append(hex, static_cast(len)); + } else { + result += s[i]; + } + break; + } + } + + return result; +} + +inline std::string decode_url(const std::string &s, + bool convert_plus_to_space) { + std::string result; + + for (size_t i = 0; i < s.size(); i++) { + if (s[i] == '%' && i + 1 < s.size()) { + if (s[i + 1] == 'u') { + int val = 0; + if (from_hex_to_i(s, i + 2, 4, val)) { + // 4 digits Unicode codes + char buff[4]; + size_t len = to_utf8(val, buff); + if (len > 0) { result.append(buff, len); } + i += 5; // 'u0000' + } else { + result += s[i]; + } + } else { + int val = 0; + if (from_hex_to_i(s, i + 1, 2, val)) { + // 2 digits hex codes + result += static_cast(val); + i += 2; // '00' + } else { + result += s[i]; + } + } + } else if (convert_plus_to_space && s[i] == '+') { + result += ' '; + } else { + result += s[i]; + } + } + + return result; +} + inline void read_file(const std::string &path, std::string &out) { std::ifstream fs(path, std::ios_base::binary); fs.seekg(0, std::ios_base::end); @@ -2379,10 +2451,18 @@ inline const char *get_header_value(const Headers &headers, const char *key, return def; } -inline uint64_t get_header_value_uint64(const Headers &headers, const char *key, - uint64_t def = 0) { - auto it = headers.find(key); - if (it != headers.end()) { +template +inline T get_header_value(const Headers & /*headers*/, const char * /*key*/, + size_t /*id*/ = 0, uint64_t /*def*/ = 0) {} + +template <> +inline uint64_t get_header_value(const Headers &headers, + const char *key, size_t id, + uint64_t def) { + auto rng = headers.equal_range(key); + auto it = rng.first; + std::advance(it, static_cast(id)); + if (it != rng.second) { return std::strtoull(it->second.data(), nullptr, 10); } return def; @@ -2404,7 +2484,8 @@ inline void parse_header(const char *beg, const char *end, Headers &headers) { while (p < end) { p++; } - headers.emplace(std::string(beg, key_end), std::string(val_begin, end)); + headers.emplace(std::string(beg, key_end), + decode_url(std::string(val_begin, end), true)); } } } @@ -2574,7 +2655,7 @@ bool read_content(Stream &strm, T &x, size_t payload_max_length, int &status, } else if (!has_header(x.headers, "Content-Length")) { ret = read_content_without_length(strm, out); } else { - auto len = get_header_value_uint64(x.headers, "Content-Length", 0); + auto len = get_header_value(x.headers, "Content-Length"); if (len > payload_max_length) { exceed_payload_max_length = true; skip_content_with_length(strm, len); @@ -2765,74 +2846,6 @@ inline bool redirect(T &cli, const Request &req, Response &res, return ret; } -inline std::string encode_url(const std::string &s) { - std::string result; - - for (size_t i = 0; s[i]; i++) { - switch (s[i]) { - case ' ': result += "%20"; break; - case '+': result += "%2B"; break; - case '\r': result += "%0D"; break; - case '\n': result += "%0A"; break; - case '\'': result += "%27"; break; - case ',': result += "%2C"; break; - // case ':': result += "%3A"; break; // ok? probably... - case ';': result += "%3B"; break; - default: - auto c = static_cast(s[i]); - if (c >= 0x80) { - result += '%'; - char hex[4]; - auto len = snprintf(hex, sizeof(hex) - 1, "%02X", c); - assert(len == 2); - result.append(hex, static_cast(len)); - } else { - result += s[i]; - } - break; - } - } - - return result; -} - -inline std::string decode_url(const std::string &s, - bool convert_plus_to_space) { - std::string result; - - for (size_t i = 0; i < s.size(); i++) { - if (s[i] == '%' && i + 1 < s.size()) { - if (s[i + 1] == 'u') { - int val = 0; - if (from_hex_to_i(s, i + 2, 4, val)) { - // 4 digits Unicode codes - char buff[4]; - size_t len = to_utf8(val, buff); - if (len > 0) { result.append(buff, len); } - i += 5; // 'u0000' - } else { - result += s[i]; - } - } else { - int val = 0; - if (from_hex_to_i(s, i + 1, 2, val)) { - // 2 digits hex codes - result += static_cast(val); - i += 2; // '00' - } else { - result += s[i]; - } - } - } else if (convert_plus_to_space && s[i] == '+') { - result += ' '; - } else { - result += s[i]; - } - } - - return result; -} - inline std::string params_to_query_str(const Params ¶ms) { std::string query; @@ -3458,6 +3471,11 @@ inline std::string Request::get_header_value(const char *key, size_t id) const { return detail::get_header_value(headers, key, id, ""); } +template +inline T Request::get_header_value(const char *key, size_t id) const { + return detail::get_header_value(headers, key, id, 0); +} + inline size_t Request::get_header_value_count(const char *key) const { auto r = headers.equal_range(key); return static_cast(std::distance(r.first, r.second)); @@ -3517,6 +3535,11 @@ inline std::string Response::get_header_value(const char *key, return detail::get_header_value(headers, key, id, ""); } +template +inline T Response::get_header_value(const char *key, size_t id) const { + return detail::get_header_value(headers, key, id, 0); +} + inline size_t Response::get_header_value_count(const char *key) const { auto r = headers.equal_range(key); return static_cast(std::distance(r.first, r.second)); diff --git a/test/test.cc b/test/test.cc index 7dda645..616b72a 100644 --- a/test/test.cc +++ b/test/test.cc @@ -100,7 +100,8 @@ TEST(GetHeaderValueTest, DefaultValue) { TEST(GetHeaderValueTest, DefaultValueInt) { Headers headers = {{"Dummy", "Dummy"}}; - auto val = detail::get_header_value_uint64(headers, "Content-Length", 100); + auto val = + detail::get_header_value(headers, "Content-Length", 0, 100); EXPECT_EQ(100ull, val); } @@ -112,7 +113,8 @@ TEST(GetHeaderValueTest, RegularValue) { TEST(GetHeaderValueTest, RegularValueInt) { Headers headers = {{"Content-Length", "100"}, {"Dummy", "Dummy"}}; - auto val = detail::get_header_value_uint64(headers, "Content-Length", 0); + auto val = + detail::get_header_value(headers, "Content-Length", 0, 0); EXPECT_EQ(100ull, val); } @@ -716,6 +718,16 @@ TEST(RedirectToDifferentPort, Redirect) { ASSERT_FALSE(svr8080.is_running()); ASSERT_FALSE(svr8081.is_running()); } + +TEST(UrlWithSpace, Redirect) { + httplib::SSLClient cli("edge.forgecdn.net"); + cli.set_follow_location(true); + + auto res = cli.Get("/files/2595/310/Neat 1.4-17.jar"); + ASSERT_TRUE(res != nullptr); + EXPECT_EQ(200, res->status); + EXPECT_EQ(18527, res->get_header_value("Content-Length")); +} #endif TEST(Server, BindDualStack) {