diff --git a/README.md b/README.md index aed19ca..e85f9bc 100644 --- a/README.md +++ b/README.md @@ -285,6 +285,22 @@ svr.set_post_routing_handler([](const auto& req, auto& res) { }); ``` +### Pre request handler + +```cpp +svr.set_pre_request_handler([](const auto& req, auto& res) { + if (req.matched_route == "/user/:user") { + auto user = req.path_params.at("user"); + if (user != "john") { + res.status = StatusCode::Forbidden_403; + res.set_content("error", "text/html"); + return Server::HandlerResponse::Handled; + } + } + return Server::HandlerResponse::Unhandled; +}); +``` + ### 'multipart/form-data' POST data ```cpp diff --git a/httplib.h b/httplib.h index e1eb470..54b80d7 100644 --- a/httplib.h +++ b/httplib.h @@ -636,6 +636,7 @@ using Ranges = std::vector; struct Request { std::string method; std::string path; + std::string matched_route; Params params; Headers headers; std::string body; @@ -887,10 +888,16 @@ namespace detail { class MatcherBase { public: + MatcherBase(std::string pattern) : pattern_(pattern) {} virtual ~MatcherBase() = default; + const std::string &pattern() const { return pattern_; } + // Match request path and populate its matches and virtual bool match(Request &request) const = 0; + +private: + std::string pattern_; }; /** @@ -942,7 +949,8 @@ private: */ class RegexMatcher final : public MatcherBase { public: - RegexMatcher(const std::string &pattern) : regex_(pattern) {} + RegexMatcher(const std::string &pattern) + : MatcherBase(pattern), regex_(pattern) {} bool match(Request &request) const override; @@ -1009,9 +1017,12 @@ public: } Server &set_exception_handler(ExceptionHandler handler); + Server &set_pre_routing_handler(HandlerWithResponse handler); Server &set_post_routing_handler(Handler handler); + Server &set_pre_request_handler(HandlerWithResponse handler); + Server &set_expect_100_continue_handler(Expect100ContinueHandler handler); Server &set_logger(Logger logger); @@ -1153,6 +1164,7 @@ private: ExceptionHandler exception_handler_; HandlerWithResponse pre_routing_handler_; Handler post_routing_handler_; + HandlerWithResponse pre_request_handler_; Expect100ContinueHandler expect_100_continue_handler_; Logger logger_; @@ -6224,7 +6236,8 @@ inline time_t BufferStream::duration() const { return 0; } inline const std::string &BufferStream::get_buffer() const { return buffer; } -inline PathParamsMatcher::PathParamsMatcher(const std::string &pattern) { +inline PathParamsMatcher::PathParamsMatcher(const std::string &pattern) + : MatcherBase(pattern) { constexpr const char marker[] = "/:"; // One past the last ending position of a path param substring @@ -6475,6 +6488,11 @@ inline Server &Server::set_post_routing_handler(Handler handler) { return *this; } +inline Server &Server::set_pre_request_handler(HandlerWithResponse handler) { + pre_request_handler_ = std::move(handler); + return *this; +} + inline Server &Server::set_logger(Logger logger) { logger_ = std::move(logger); return *this; @@ -7129,7 +7147,11 @@ inline bool Server::dispatch_request(Request &req, Response &res, const auto &handler = x.second; if (matcher->match(req)) { - handler(req, res); + req.matched_route = matcher->pattern(); + if (!pre_request_handler_ || + pre_request_handler_(req, res) != HandlerResponse::Handled) { + handler(req, res); + } return true; } } @@ -7256,7 +7278,11 @@ inline bool Server::dispatch_request_for_content_reader( const auto &handler = x.second; if (matcher->match(req)) { - handler(req, res, content_reader); + req.matched_route = matcher->pattern(); + if (!pre_request_handler_ || + pre_request_handler_(req, res) != HandlerResponse::Handled) { + handler(req, res, content_reader); + } return true; } } diff --git a/test/test.cc b/test/test.cc index 0f012bd..de3aae5 100644 --- a/test/test.cc +++ b/test/test.cc @@ -2263,7 +2263,7 @@ TEST(NoContentTest, ContentLength) { } } -TEST(RoutingHandlerTest, PreRoutingHandler) { +TEST(RoutingHandlerTest, PreAndPostRoutingHandlers) { #ifdef CPPHTTPLIB_OPENSSL_SUPPORT SSLServer svr(SERVER_CERT_FILE, SERVER_PRIVATE_KEY_FILE); ASSERT_TRUE(svr.is_valid()); @@ -2354,6 +2354,63 @@ TEST(RoutingHandlerTest, PreRoutingHandler) { } } +TEST(RequestHandlerTest, PreRequestHandler) { + auto route_path = "/user/:user"; + + Server svr; + + svr.Get("/hi", [](const Request &, Response &res) { + res.set_content("hi", "text/plain"); + }); + + svr.Get(route_path, [](const Request &req, Response &res) { + res.set_content(req.path_params.at("user"), "text/plain"); + }); + + svr.set_pre_request_handler([&](const Request &req, Response &res) { + if (req.matched_route == route_path) { + auto user = req.path_params.at("user"); + if (user != "john") { + res.status = StatusCode::Forbidden_403; + res.set_content("error", "text/html"); + return Server::HandlerResponse::Handled; + } + } + return Server::HandlerResponse::Unhandled; + }); + + auto thread = std::thread([&]() { svr.listen(HOST, PORT); }); + auto se = detail::scope_exit([&] { + svr.stop(); + thread.join(); + ASSERT_FALSE(svr.is_running()); + }); + + svr.wait_until_ready(); + + Client cli(HOST, PORT); + { + auto res = cli.Get("/hi"); + ASSERT_TRUE(res); + EXPECT_EQ(StatusCode::OK_200, res->status); + EXPECT_EQ("hi", res->body); + } + + { + auto res = cli.Get("/user/john"); + ASSERT_TRUE(res); + EXPECT_EQ(StatusCode::OK_200, res->status); + EXPECT_EQ("john", res->body); + } + + { + auto res = cli.Get("/user/invalid-user"); + ASSERT_TRUE(res); + EXPECT_EQ(StatusCode::Forbidden_403, res->status); + EXPECT_EQ("error", res->body); + } +} + TEST(InvalidFormatTest, StatusCode) { Server svr;