diff --git a/httplib.h b/httplib.h index 5a4b64a..8944945 100644 --- a/httplib.h +++ b/httplib.h @@ -737,6 +737,8 @@ private: std::regex regex_; }; +ssize_t write_headers(Stream &strm, const Headers &headers); + } // namespace detail class Server { @@ -800,6 +802,8 @@ public: Server &set_socket_options(SocketOptions socket_options); Server &set_default_headers(Headers headers); + Server & + set_header_writer(std::function const &writer); Server &set_keep_alive_max_count(size_t count); Server &set_keep_alive_timeout(time_t sec); @@ -934,6 +938,8 @@ private: SocketOptions socket_options_ = default_socket_options; Headers default_headers_; + std::function header_writer_ = + detail::write_headers; }; enum class Error { @@ -1164,6 +1170,9 @@ public: void set_default_headers(Headers headers); + void + set_header_writer(std::function const &writer); + void set_address_family(int family); void set_tcp_nodelay(bool on); void set_socket_options(SocketOptions socket_options); @@ -1273,6 +1282,10 @@ protected: // Default headers Headers default_headers_; + // Header writer + std::function header_writer_ = + detail::write_headers; + // Settings std::string client_cert_path_; std::string client_key_path_; @@ -1539,6 +1552,9 @@ public: void set_default_headers(Headers headers); + void + set_header_writer(std::function const &writer); + void set_address_family(int family); void set_tcp_nodelay(bool on); void set_socket_options(SocketOptions socket_options); @@ -5672,6 +5688,12 @@ inline Server &Server::set_default_headers(Headers headers) { return *this; } +inline Server &Server::set_header_writer( + std::function const &writer) { + header_writer_ = writer; + return *this; +} + inline Server &Server::set_keep_alive_max_count(size_t count) { keep_alive_max_count_ = count; return *this; @@ -5866,7 +5888,7 @@ inline bool Server::write_response_core(Stream &strm, bool close_connection, return false; } - if (!detail::write_headers(bstrm, res.headers)) { return false; } + if (!header_writer_(bstrm, res.headers)) { return false; } // Flush buffer auto &data = bstrm.get_buffer(); @@ -7105,7 +7127,7 @@ inline bool ClientImpl::write_request(Stream &strm, Request &req, const auto &path = url_encode_ ? detail::encode_url(req.path) : req.path; bstrm.write_format("%s %s HTTP/1.1\r\n", req.method.c_str(), path.c_str()); - detail::write_headers(bstrm, req.headers); + header_writer_(bstrm, req.headers); // Flush buffer auto &data = bstrm.get_buffer(); @@ -7916,6 +7938,11 @@ inline void ClientImpl::set_default_headers(Headers headers) { default_headers_ = std::move(headers); } +inline void ClientImpl::set_header_writer( + std::function const &writer) { + header_writer_ = writer; +} + inline void ClientImpl::set_address_family(int family) { address_family_ = family; } @@ -9110,6 +9137,11 @@ inline void Client::set_default_headers(Headers headers) { cli_->set_default_headers(std::move(headers)); } +inline void Client::set_header_writer( + std::function const &writer) { + cli_->set_header_writer(writer); +} + inline void Client::set_address_family(int family) { cli_->set_address_family(family); } diff --git a/test/test.cc b/test/test.cc index c8cf9e0..6d3f586 100644 --- a/test/test.cc +++ b/test/test.cc @@ -1592,6 +1592,46 @@ TEST(URLFragmentTest, WithFragment) { } } +TEST(HeaderWriter, SetHeaderWriter) { + Server svr; + + svr.set_header_writer([](Stream &strm, Headers &hdrs) { + hdrs.emplace("CustomServerHeader", "CustomServerValue"); + return detail::write_headers(strm, hdrs); + }); + svr.Get("/hi", [](const Request &req, Response &res) { + auto it = req.headers.find("CustomClientHeader"); + EXPECT_TRUE(it != req.headers.end()); + EXPECT_EQ(it->second, "CustomClientValue"); + res.set_content("Hello World!\n", "text/plain"); + }); + + auto thread = std::thread([&]() { svr.listen(HOST, PORT); }); + auto se = detail::scope_exit([&] { + svr.stop(); + thread.join(); + ASSERT_FALSE(svr.is_running()); + }); + + std::this_thread::sleep_for(std::chrono::seconds(1)); + + { + Client cli(HOST, PORT); + cli.set_header_writer([](Stream &strm, Headers &hdrs) { + hdrs.emplace("CustomClientHeader", "CustomClientValue"); + return detail::write_headers(strm, hdrs); + }); + + auto res = cli.Get("/hi"); + EXPECT_TRUE(res); + EXPECT_EQ(200, res->status); + + auto it = res->headers.find("CustomServerHeader"); + EXPECT_TRUE(it != res->headers.end()); + EXPECT_EQ(it->second, "CustomServerValue"); + } +} + class ServerTest : public ::testing::Test { protected: ServerTest()