diff --git a/README.md b/README.md index 5cc674c..8560bcf 100644 --- a/README.md +++ b/README.md @@ -94,11 +94,20 @@ int main(void) res.set_content("Hello World!", "text/plain"); }); + // Match the request path against a regular expression + // and extract its captures svr.Get(R"(/numbers/(\d+))", [&](const Request& req, Response& res) { auto numbers = req.matches[1]; res.set_content(numbers, "text/plain"); }); + // Capture the second segment of the request path as "id" path param + svr.Get("/users/:id", [&](const Request& req, Response& res) { + auto user_id = req.path_params.at("id"); + res.set_content(user_id, "text/plain"); + }); + + // Extract values from HTTP headers and URL query params svr.Get("/body-header-param", [](const Request& req, Response& res) { if (req.has_header("Content-Length")) { auto val = req.get_header_value("Content-Length"); diff --git a/httplib.h b/httplib.h index 5b42992..24d4f07 100644 --- a/httplib.h +++ b/httplib.h @@ -229,6 +229,8 @@ using socket_t = int; #include #include #include +#include +#include #include #ifdef CPPHTTPLIB_OPENSSL_SUPPORT @@ -472,6 +474,7 @@ struct Request { MultipartFormDataMap files; Ranges ranges; Match matches; + std::unordered_map path_params; // for client ResponseHandler response_handler; @@ -665,6 +668,76 @@ using SocketOptions = std::function; void default_socket_options(socket_t sock); +namespace detail { + +class MatcherBase { +public: + virtual ~MatcherBase() = default; + + // Match request path and populate its matches and + virtual bool match(Request &request) const = 0; +}; + +/** + * Captures parameters in request path and stores them in Request::path_params + * + * Capture name is a substring of a pattern from : to /. + * The rest of the pattern is matched agains the request path directly + * Parameters are captured starting from the next character after + * the end of the last matched static pattern fragment until the next /. + * + * Example pattern: + * "/path/fragments/:capture/more/fragments/:second_capture" + * Static fragments: + * "/path/fragments/", "more/fragments/" + * + * Given the following request path: + * "/path/fragments/:1/more/fragments/:2" + * the resulting capture will be + * {{"capture", "1"}, {"second_capture", "2"}} + */ +class PathParamsMatcher : public MatcherBase { +public: + PathParamsMatcher(const std::string &pattern); + + bool match(Request &request) const override; + +private: + static constexpr char marker = ':'; + // Treat segment separators as the end of path parameter capture + // Does not need to handle query parameters as they are parsed before path + // matching + static constexpr char separator = '/'; + + // Contains static path fragments to match against, excluding the '/' after + // path params + // Fragments are separated by path params + std::vector static_fragments_; + // Stores the names of the path parameters to be used as keys in the + // Request::path_params map + std::vector param_names_; +}; + +/** + * Performs std::regex_match on request path + * and stores the result in Request::matches + * + * Note that regex match is performed directly on the whole request. + * This means that wildcard patterns may match multiple path segments with /: + * "/begin/(.*)/end" will match both "/begin/middle/end" and "/begin/1/2/end". + */ +class RegexMatcher : public MatcherBase { +public: + RegexMatcher(const std::string &pattern) : regex_(pattern) {} + + bool match(Request &request) const override; + +private: + std::regex regex_; +}; + +} // namespace detail + class Server { public: using Handler = std::function; @@ -772,9 +845,14 @@ protected: size_t payload_max_length_ = CPPHTTPLIB_PAYLOAD_MAX_LENGTH; private: - using Handlers = std::vector>; + using Handlers = + std::vector, Handler>>; using HandlersForContentReader = - std::vector>; + std::vector, + HandlerWithContentReader>>; + + static std::unique_ptr + make_matcher(const std::string &pattern); socket_t create_server_socket(const std::string &host, int port, int socket_flags, @@ -5147,6 +5225,99 @@ inline socket_t BufferStream::socket() const { return 0; } inline const std::string &BufferStream::get_buffer() const { return buffer; } +inline PathParamsMatcher::PathParamsMatcher(const std::string &pattern) { + // One past the last ending position of a path param substring + std::size_t last_param_end = 0; + +#ifndef CPPHTTPLIB_NO_EXCEPTIONS + // Needed to ensure that parameter names are unique during matcher + // construction + // If exceptions are disabled, only last duplicate path + // parameter will be set + std::unordered_set param_name_set; +#endif + + while (true) { + const auto marker_pos = pattern.find(marker, last_param_end); + if (marker_pos == std::string::npos) { break; } + + static_fragments_.push_back( + pattern.substr(last_param_end, marker_pos - last_param_end)); + + const auto param_name_start = marker_pos + 1; + + auto sep_pos = pattern.find(separator, param_name_start); + if (sep_pos == std::string::npos) { sep_pos = pattern.length(); } + + auto param_name = + pattern.substr(param_name_start, sep_pos - param_name_start); + +#ifndef CPPHTTPLIB_NO_EXCEPTIONS + if (param_name_set.find(param_name) != param_name_set.cend()) { + std::string msg = "Encountered path parameter '" + param_name + + "' multiple times in route pattern '" + pattern + "'."; + throw std::invalid_argument(msg); + } +#endif + + param_names_.push_back(std::move(param_name)); + + last_param_end = sep_pos + 1; + } + + if (last_param_end < pattern.length()) { + static_fragments_.push_back(pattern.substr(last_param_end)); + } +} + +inline bool PathParamsMatcher::match(Request &request) const { + request.matches = {}; + request.path_params.clear(); + request.path_params.reserve(param_names_.size()); + + // One past the position at which the path matched the pattern last time + std::size_t starting_pos = 0; + for (size_t i = 0; i < static_fragments_.size(); ++i) { + const auto &fragment = static_fragments_[i]; + + if (starting_pos + fragment.length() > request.path.length()) { + return false; + } + + // Avoid unnecessary allocation by using strncmp instead of substr + + // comparison + if (std::strncmp(request.path.c_str() + starting_pos, fragment.c_str(), + fragment.length()) != 0) { + return false; + } + + starting_pos += fragment.length(); + + // Should only happen when we have a static fragment after a param + // Example: '/users/:id/subscriptions' + // The 'subscriptions' fragment here does not have a corresponding param + if (i >= param_names_.size()) { continue; } + + auto sep_pos = request.path.find(separator, starting_pos); + if (sep_pos == std::string::npos) { sep_pos = request.path.length(); } + + const auto ¶m_name = param_names_[i]; + + request.path_params.emplace( + param_name, request.path.substr(starting_pos, sep_pos - starting_pos)); + + // Mark everythin up to '/' as matched + starting_pos = sep_pos + 1; + } + // Returns false if the path is longer than the pattern + return starting_pos >= request.path.length(); +} + +inline bool RegexMatcher::match(Request &request) const { + request.path_params.clear(); + return std::regex_match(request.path, request.matches, regex_); +} + } // namespace detail // HTTP server implementation @@ -5160,67 +5331,76 @@ inline Server::Server() inline Server::~Server() {} +inline std::unique_ptr +Server::make_matcher(const std::string &pattern) { + if (pattern.find("/:") != std::string::npos) { + return detail::make_unique(pattern); + } else { + return detail::make_unique(pattern); + } +} + inline Server &Server::Get(const std::string &pattern, Handler handler) { get_handlers_.push_back( - std::make_pair(std::regex(pattern), std::move(handler))); + std::make_pair(make_matcher(pattern), std::move(handler))); return *this; } inline Server &Server::Post(const std::string &pattern, Handler handler) { post_handlers_.push_back( - std::make_pair(std::regex(pattern), std::move(handler))); + std::make_pair(make_matcher(pattern), std::move(handler))); return *this; } inline Server &Server::Post(const std::string &pattern, HandlerWithContentReader handler) { post_handlers_for_content_reader_.push_back( - std::make_pair(std::regex(pattern), std::move(handler))); + std::make_pair(make_matcher(pattern), std::move(handler))); return *this; } inline Server &Server::Put(const std::string &pattern, Handler handler) { put_handlers_.push_back( - std::make_pair(std::regex(pattern), std::move(handler))); + std::make_pair(make_matcher(pattern), std::move(handler))); return *this; } inline Server &Server::Put(const std::string &pattern, HandlerWithContentReader handler) { put_handlers_for_content_reader_.push_back( - std::make_pair(std::regex(pattern), std::move(handler))); + std::make_pair(make_matcher(pattern), std::move(handler))); return *this; } inline Server &Server::Patch(const std::string &pattern, Handler handler) { patch_handlers_.push_back( - std::make_pair(std::regex(pattern), std::move(handler))); + std::make_pair(make_matcher(pattern), std::move(handler))); return *this; } inline Server &Server::Patch(const std::string &pattern, HandlerWithContentReader handler) { patch_handlers_for_content_reader_.push_back( - std::make_pair(std::regex(pattern), std::move(handler))); + std::make_pair(make_matcher(pattern), std::move(handler))); return *this; } inline Server &Server::Delete(const std::string &pattern, Handler handler) { delete_handlers_.push_back( - std::make_pair(std::regex(pattern), std::move(handler))); + std::make_pair(make_matcher(pattern), std::move(handler))); return *this; } inline Server &Server::Delete(const std::string &pattern, HandlerWithContentReader handler) { delete_handlers_for_content_reader_.push_back( - std::make_pair(std::regex(pattern), std::move(handler))); + std::make_pair(make_matcher(pattern), std::move(handler))); return *this; } inline Server &Server::Options(const std::string &pattern, Handler handler) { options_handlers_.push_back( - std::make_pair(std::regex(pattern), std::move(handler))); + std::make_pair(make_matcher(pattern), std::move(handler))); return *this; } @@ -5930,10 +6110,10 @@ inline bool Server::routing(Request &req, Response &res, Stream &strm) { inline bool Server::dispatch_request(Request &req, Response &res, const Handlers &handlers) { for (const auto &x : handlers) { - const auto &pattern = x.first; + const auto &matcher = x.first; const auto &handler = x.second; - if (std::regex_match(req.path, req.matches, pattern)) { + if (matcher->match(req)) { handler(req, res); return true; } @@ -6055,10 +6235,10 @@ inline bool Server::dispatch_request_for_content_reader( Request &req, Response &res, ContentReader content_reader, const HandlersForContentReader &handlers) { for (const auto &x : handlers) { - const auto &pattern = x.first; + const auto &matcher = x.first; const auto &handler = x.second; - if (std::regex_match(req.path, req.matches, pattern)) { + if (matcher->match(req)) { handler(req, res, content_reader); return true; } diff --git a/test/test.cc b/test/test.cc index 3f4a6f7..7750a51 100644 --- a/test/test.cc +++ b/test/test.cc @@ -4379,6 +4379,12 @@ TEST(GetWithParametersTest, GetWithParameters) { EXPECT_EQ("bar", req.get_param_value("param2")); }); + svr.Get("/users/:id", [&](const Request &req, Response &) { + EXPECT_EQ("user-id", req.path_params.at("id")); + EXPECT_EQ("foo", req.get_param_value("param1")); + EXPECT_EQ("bar", req.get_param_value("param2")); + }); + auto listen_thread = std::thread([&svr]() { svr.listen(HOST, PORT); }); auto se = detail::scope_exit([&] { svr.stop(); @@ -4419,6 +4425,15 @@ TEST(GetWithParametersTest, GetWithParameters) { ASSERT_TRUE(res); EXPECT_EQ(200, res->status); } + + { + Client cli(HOST, PORT); + + auto res = cli.Get("/users/user-id?param1=foo¶m2=bar"); + + ASSERT_TRUE(res); + EXPECT_EQ(200, res->status); + } } TEST(GetWithParametersTest, GetWithParameters2) { @@ -6290,3 +6305,140 @@ TEST(VulnerabilityTest, CRLFInjection) { cli.Patch("/test4", "content", "text/plain\r\nevil: hello4"); } } + +TEST(PathParamsTest, StaticMatch) { + const auto pattern = "/users/all"; + detail::PathParamsMatcher matcher(pattern); + + Request request; + request.path = "/users/all"; + ASSERT_TRUE(matcher.match(request)); + + std::unordered_map expected_params = {}; + + EXPECT_EQ(request.path_params, expected_params); +} + +TEST(PathParamsTest, StaticMismatch) { + const auto pattern = "/users/all"; + detail::PathParamsMatcher matcher(pattern); + + Request request; + request.path = "/users/1"; + ASSERT_FALSE(matcher.match(request)); +} + +TEST(PathParamsTest, SingleParamInTheMiddle) { + const auto pattern = "/users/:id/subscriptions"; + detail::PathParamsMatcher matcher(pattern); + + Request request; + request.path = "/users/42/subscriptions"; + ASSERT_TRUE(matcher.match(request)); + + std::unordered_map expected_params = {{"id", "42"}}; + + EXPECT_EQ(request.path_params, expected_params); +} + +TEST(PathParamsTest, SingleParamInTheEnd) { + const auto pattern = "/users/:id"; + detail::PathParamsMatcher matcher(pattern); + + Request request; + request.path = "/users/24"; + ASSERT_TRUE(matcher.match(request)); + + std::unordered_map expected_params = {{"id", "24"}}; + + EXPECT_EQ(request.path_params, expected_params); +} + +TEST(PathParamsTest, SingleParamInTheEndTrailingSlash) { + const auto pattern = "/users/:id/"; + detail::PathParamsMatcher matcher(pattern); + + Request request; + request.path = "/users/42/"; + ASSERT_TRUE(matcher.match(request)); + std::unordered_map expected_params = {{"id", "42"}}; + + EXPECT_EQ(request.path_params, expected_params); +} + +TEST(PathParamsTest, EmptyParam) { + const auto pattern = "/users/:id/"; + detail::PathParamsMatcher matcher(pattern); + + Request request; + request.path = "/users//"; + ASSERT_TRUE(matcher.match(request)); + + std::unordered_map expected_params = {{"id", ""}}; + + EXPECT_EQ(request.path_params, expected_params); +} + +TEST(PathParamsTest, FragmentMismatch) { + const auto pattern = "/users/:id/"; + detail::PathParamsMatcher matcher(pattern); + + Request request; + request.path = "/admins/24/"; + ASSERT_FALSE(matcher.match(request)); +} + +TEST(PathParamsTest, ExtraFragments) { + const auto pattern = "/users/:id"; + detail::PathParamsMatcher matcher(pattern); + + Request request; + request.path = "/users/42/subscriptions"; + ASSERT_FALSE(matcher.match(request)); +} + +TEST(PathParamsTest, MissingTrailingParam) { + const auto pattern = "/users/:id"; + detail::PathParamsMatcher matcher(pattern); + + Request request; + request.path = "/users"; + ASSERT_FALSE(matcher.match(request)); +} + +TEST(PathParamsTest, MissingParamInTheMiddle) { + const auto pattern = "/users/:id/subscriptions"; + detail::PathParamsMatcher matcher(pattern); + + Request request; + request.path = "/users/subscriptions"; + ASSERT_FALSE(matcher.match(request)); +} + +TEST(PathParamsTest, MultipleParams) { + const auto pattern = "/users/:userid/subscriptions/:subid"; + detail::PathParamsMatcher matcher(pattern); + + Request request; + request.path = "/users/42/subscriptions/2"; + ASSERT_TRUE(matcher.match(request)); + + std::unordered_map expected_params = { + {"userid", "42"}, {"subid", "2"}}; + + EXPECT_EQ(request.path_params, expected_params); +} + +TEST(PathParamsTest, SequenceOfParams) { + const auto pattern = "/values/:x/:y/:z"; + detail::PathParamsMatcher matcher(pattern); + + Request request; + request.path = "/values/1/2/3"; + ASSERT_TRUE(matcher.match(request)); + + std::unordered_map expected_params = { + {"x", "1"}, {"y", "2"}, {"z", "3"}}; + + EXPECT_EQ(request.path_params, expected_params); +}