diff --git a/example/Makefile b/example/Makefile index ba68359..3082b88 100644 --- a/example/Makefile +++ b/example/Makefile @@ -18,7 +18,7 @@ ZLIB_SUPPORT = -DCPPHTTPLIB_ZLIB_SUPPORT -lz BROTLI_DIR = $(PREFIX)/opt/brotli BROTLI_SUPPORT = -DCPPHTTPLIB_BROTLI_SUPPORT -I$(BROTLI_DIR)/include -L$(BROTLI_DIR)/lib -lbrotlicommon -lbrotlienc -lbrotlidec -all: server client hello simplecli simplesvr upload redirect ssesvr ssecli benchmark one_time_request server_and_client +all: server client hello simplecli simplesvr upload redirect ssesvr ssecli benchmark one_time_request server_and_client accept_header server : server.cc ../httplib.h Makefile $(CXX) -o server $(CXXFLAGS) server.cc $(OPENSSL_SUPPORT) $(ZLIB_SUPPORT) $(BROTLI_SUPPORT) @@ -56,9 +56,12 @@ one_time_request : one_time_request.cc ../httplib.h Makefile server_and_client : server_and_client.cc ../httplib.h Makefile $(CXX) -o server_and_client $(CXXFLAGS) server_and_client.cc $(OPENSSL_SUPPORT) $(ZLIB_SUPPORT) $(BROTLI_SUPPORT) +accept_header : accept_header.cc ../httplib.h Makefile + $(CXX) -o accept_header $(CXXFLAGS) accept_header.cc $(OPENSSL_SUPPORT) $(ZLIB_SUPPORT) $(BROTLI_SUPPORT) + pem: openssl genrsa 2048 > key.pem openssl req -new -key key.pem | openssl x509 -days 3650 -req -signkey key.pem > cert.pem clean: - rm server client hello simplecli simplesvr upload redirect ssesvr ssecli benchmark one_time_request server_and_client *.pem + rm server client hello simplecli simplesvr upload redirect ssesvr ssecli benchmark one_time_request server_and_client accept_header *.pem diff --git a/example/accept_header.cc b/example/accept_header.cc new file mode 100644 index 0000000..33798d9 --- /dev/null +++ b/example/accept_header.cc @@ -0,0 +1,134 @@ +#include "httplib.h" +#include + +int main() { + using namespace httplib; + + // Example usage of parse_accept_header function + std::cout << "=== Accept Header Parser Example ===" << std::endl; + + // Example 1: Simple Accept header + std::string accept1 = "text/html,application/json,text/plain"; + std::vector result1; + if (detail::parse_accept_header(accept1, result1)) { + std::cout << "\nExample 1: " << accept1 << std::endl; + std::cout << "Parsed order:" << std::endl; + for (size_t i = 0; i < result1.size(); ++i) { + std::cout << " " << (i + 1) << ". " << result1[i] << std::endl; + } + } else { + std::cout << "\nExample 1: Failed to parse Accept header" << std::endl; + } + + // Example 2: Accept header with quality values + std::string accept2 = "text/html;q=0.9,application/json;q=1.0,text/plain;q=0.8"; + std::vector result2; + if (detail::parse_accept_header(accept2, result2)) { + std::cout << "\nExample 2: " << accept2 << std::endl; + std::cout << "Parsed order (sorted by priority):" << std::endl; + for (size_t i = 0; i < result2.size(); ++i) { + std::cout << " " << (i + 1) << ". " << result2[i] << std::endl; + } + } else { + std::cout << "\nExample 2: Failed to parse Accept header" << std::endl; + } + + // Example 3: Browser-like Accept header + std::string accept3 = "text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,*/*;q=0.8"; + std::vector result3; + if (detail::parse_accept_header(accept3, result3)) { + std::cout << "\nExample 3: " << accept3 << std::endl; + std::cout << "Parsed order:" << std::endl; + for (size_t i = 0; i < result3.size(); ++i) { + std::cout << " " << (i + 1) << ". " << result3[i] << std::endl; + } + } else { + std::cout << "\nExample 3: Failed to parse Accept header" << std::endl; + } + + // Example 4: Invalid Accept header examples + std::cout << "\n=== Invalid Accept Header Examples ===" << std::endl; + + std::vector invalid_examples = { + "text/html;q=1.5,application/json", // q > 1.0 + "text/html;q=-0.1,application/json", // q < 0.0 + "text/html;q=invalid,application/json", // invalid q value + "invalidtype,application/json", // invalid media type + ",application/json" // empty entry + }; + + for (const auto& invalid_accept : invalid_examples) { + std::vector temp_result; + std::cout << "\nTesting invalid: " << invalid_accept << std::endl; + if (detail::parse_accept_header(invalid_accept, temp_result)) { + std::cout << " Unexpectedly succeeded!" << std::endl; + } else { + std::cout << " Correctly rejected as invalid" << std::endl; + } + } + + // Example 4: Server usage example + std::cout << "\n=== Server Usage Example ===" << std::endl; + Server svr; + + svr.Get("/api/data", [](const Request& req, Response& res) { + // Get Accept header + auto accept_header = req.get_header_value("Accept"); + if (accept_header.empty()) { + accept_header = "*/*"; // Default if no Accept header + } + + // Parse accept header to get preferred content types + std::vector preferred_types; + if (!detail::parse_accept_header(accept_header, preferred_types)) { + // Invalid Accept header + res.status = 400; // Bad Request + res.set_content("Invalid Accept header", "text/plain"); + return; + } + + std::cout << "Client Accept header: " << accept_header << std::endl; + std::cout << "Preferred types in order:" << std::endl; + for (size_t i = 0; i < preferred_types.size(); ++i) { + std::cout << " " << (i + 1) << ". " << preferred_types[i] << std::endl; + } + + // Choose response format based on client preference + std::string response_content; + std::string content_type; + + for (const auto& type : preferred_types) { + if (type == "application/json" || type == "application/*" || type == "*/*") { + response_content = "{\"message\": \"Hello, World!\", \"data\": [1, 2, 3]}"; + content_type = "application/json"; + break; + } else if (type == "text/html" || type == "text/*") { + response_content = "

Hello, World!

Data: 1, 2, 3

"; + content_type = "text/html"; + break; + } else if (type == "text/plain") { + response_content = "Hello, World!\nData: 1, 2, 3"; + content_type = "text/plain"; + break; + } + } + + if (response_content.empty()) { + // No supported content type found + res.status = 406; // Not Acceptable + res.set_content("No acceptable content type found", "text/plain"); + return; + } + + res.set_content(response_content, content_type); + std::cout << "Responding with: " << content_type << std::endl; + }); + + std::cout << "Server configured. You can test it with:" << std::endl; + std::cout << " curl -H \"Accept: application/json\" http://localhost:8080/api/data" << std::endl; + std::cout << " curl -H \"Accept: text/html\" http://localhost:8080/api/data" << std::endl; + std::cout << " curl -H \"Accept: text/plain\" http://localhost:8080/api/data" << std::endl; + std::cout << " curl -H \"Accept: text/html;q=0.9,application/json;q=1.0\" http://localhost:8080/api/data" << std::endl; + + return 0; +} diff --git a/httplib.h b/httplib.h index ec6faf0..63718f5 100644 --- a/httplib.h +++ b/httplib.h @@ -670,6 +670,7 @@ struct Request { std::function is_connection_closed = []() { return true; }; // for client + std::vector accept_content_types; ResponseHandler response_handler; ContentReceiverWithProgress content_receiver; Progress progress; @@ -2491,6 +2492,9 @@ bool parse_multipart_boundary(const std::string &content_type, bool parse_range_header(const std::string &s, Ranges &ranges); +bool parse_accept_header(const std::string &s, + std::vector &content_types); + int close_socket(socket_t sock); ssize_t send_socket(socket_t sock, const void *ptr, size_t size, int flags); @@ -5026,6 +5030,123 @@ inline bool parse_range_header(const std::string &s, Ranges &ranges) try { } catch (...) { return false; } #endif +inline bool parse_accept_header(const std::string &s, + std::vector &content_types) { + content_types.clear(); + + // Empty string is considered valid (no preference) + if (s.empty()) { return true; } + + // Check for invalid patterns: leading/trailing commas or consecutive commas + if (s.front() == ',' || s.back() == ',' || + s.find(",,") != std::string::npos) { + return false; + } + + struct AcceptEntry { + std::string media_type; + double quality; + int order; // Original order in header + }; + + std::vector entries; + int order = 0; + bool has_invalid_entry = false; + + // Split by comma and parse each entry + split(s.data(), s.data() + s.size(), ',', [&](const char *b, const char *e) { + std::string entry(b, e); + entry = trim_copy(entry); + + if (entry.empty()) { + has_invalid_entry = true; + return; + } + + AcceptEntry accept_entry; + accept_entry.quality = 1.0; // Default quality + accept_entry.order = order++; + + // Find q= parameter + auto q_pos = entry.find(";q="); + if (q_pos == std::string::npos) { q_pos = entry.find("; q="); } + + if (q_pos != std::string::npos) { + // Extract media type (before q parameter) + accept_entry.media_type = trim_copy(entry.substr(0, q_pos)); + + // Extract quality value + auto q_start = entry.find('=', q_pos) + 1; + auto q_end = entry.find(';', q_start); + if (q_end == std::string::npos) { q_end = entry.length(); } + + std::string quality_str = + trim_copy(entry.substr(q_start, q_end - q_start)); + if (quality_str.empty()) { + has_invalid_entry = true; + return; + } + + try { + accept_entry.quality = std::stod(quality_str); + // Check if quality is in valid range [0.0, 1.0] + if (accept_entry.quality < 0.0 || accept_entry.quality > 1.0) { + has_invalid_entry = true; + return; + } + } catch (...) { + has_invalid_entry = true; + return; + } + } else { + // No quality parameter, use entire entry as media type + accept_entry.media_type = entry; + } + + // Remove additional parameters from media type + auto param_pos = accept_entry.media_type.find(';'); + if (param_pos != std::string::npos) { + accept_entry.media_type = + trim_copy(accept_entry.media_type.substr(0, param_pos)); + } + + // Basic validation of media type format + if (accept_entry.media_type.empty()) { + has_invalid_entry = true; + return; + } + + // Check for basic media type format (should contain '/' or be '*') + if (accept_entry.media_type != "*" && + accept_entry.media_type.find('/') == std::string::npos) { + has_invalid_entry = true; + return; + } + + entries.push_back(accept_entry); + }); + + // Return false if any invalid entry was found + if (has_invalid_entry) { return false; } + + // Sort by quality (descending), then by original order (ascending) + std::sort(entries.begin(), entries.end(), + [](const AcceptEntry &a, const AcceptEntry &b) { + if (a.quality != b.quality) { + return a.quality > b.quality; // Higher quality first + } + return a.order < b.order; // Earlier order first for same quality + }); + + // Extract sorted media types + content_types.reserve(entries.size()); + for (const auto &entry : entries) { + content_types.push_back(entry.media_type); + } + + return true; +} + class MultipartFormDataParser { public: MultipartFormDataParser() = default; @@ -7446,6 +7567,14 @@ Server::process_request(Stream &strm, const std::string &remote_addr, req.set_header("LOCAL_ADDR", req.local_addr); req.set_header("LOCAL_PORT", std::to_string(req.local_port)); + if (req.has_header("Accept")) { + const auto &accept_header = req.get_header_value("Accept"); + if (!detail::parse_accept_header(accept_header, req.accept_content_types)) { + res.status = StatusCode::BadRequest_400; + return write_response(strm, close_connection, req, res); + } + } + if (req.has_header("Range")) { const auto &range_header_value = req.get_header_value("Range"); if (!detail::parse_range_header(range_header_value, req.ranges)) { diff --git a/test/test.cc b/test/test.cc index 297a694..347b90e 100644 --- a/test/test.cc +++ b/test/test.cc @@ -308,6 +308,201 @@ TEST(TrimTests, TrimStringTests) { EXPECT_TRUE(detail::trim_copy("").empty()); } +TEST(ParseAcceptHeaderTest, BasicAcceptParsing) { + // Simple case without quality values + std::vector result1; + EXPECT_TRUE(detail::parse_accept_header( + "text/html,application/json,text/plain", result1)); + EXPECT_EQ(result1.size(), 3); + EXPECT_EQ(result1[0], "text/html"); + EXPECT_EQ(result1[1], "application/json"); + EXPECT_EQ(result1[2], "text/plain"); + + // With quality values + std::vector result2; + EXPECT_TRUE(detail::parse_accept_header( + "text/html;q=0.9,application/json;q=1.0,text/plain;q=0.8", result2)); + EXPECT_EQ(result2.size(), 3); + EXPECT_EQ(result2[0], "application/json"); // highest q value + EXPECT_EQ(result2[1], "text/html"); + EXPECT_EQ(result2[2], "text/plain"); // lowest q value +} + +TEST(ParseAcceptHeaderTest, MixedQualityValues) { + // Mixed with and without quality values + std::vector result; + EXPECT_TRUE(detail::parse_accept_header( + "text/html,application/json;q=0.5,text/plain;q=0.8", result)); + EXPECT_EQ(result.size(), 3); + EXPECT_EQ(result[0], "text/html"); // no q value means 1.0 + EXPECT_EQ(result[1], "text/plain"); // q=0.8 + EXPECT_EQ(result[2], "application/json"); // q=0.5 +} + +TEST(ParseAcceptHeaderTest, EdgeCases) { + // Empty header + std::vector empty_result; + EXPECT_TRUE(detail::parse_accept_header("", empty_result)); + EXPECT_TRUE(empty_result.empty()); + + // Single type + std::vector single_result; + EXPECT_TRUE(detail::parse_accept_header("application/json", single_result)); + EXPECT_EQ(single_result.size(), 1); + EXPECT_EQ(single_result[0], "application/json"); + + // Wildcard types + std::vector wildcard_result; + EXPECT_TRUE(detail::parse_accept_header( + "text/*;q=0.5,*/*;q=0.1,application/json", wildcard_result)); + EXPECT_EQ(wildcard_result.size(), 3); + EXPECT_EQ(wildcard_result[0], "application/json"); + EXPECT_EQ(wildcard_result[1], "text/*"); + EXPECT_EQ(wildcard_result[2], "*/*"); +} + +TEST(ParseAcceptHeaderTest, RealWorldExamples) { + // Common browser Accept header + std::vector browser_result; + EXPECT_TRUE( + detail::parse_accept_header("text/html,application/xhtml+xml,application/" + "xml;q=0.9,image/webp,image/apng,*/*;q=0.8", + browser_result)); + EXPECT_EQ(browser_result.size(), 6); + EXPECT_EQ(browser_result[0], "text/html"); // q=1.0 (default) + EXPECT_EQ(browser_result[1], "application/xhtml+xml"); // q=1.0 (default) + EXPECT_EQ(browser_result[2], "image/webp"); // q=1.0 (default) + EXPECT_EQ(browser_result[3], "image/apng"); // q=1.0 (default) + EXPECT_EQ(browser_result[4], "application/xml"); // q=0.9 + EXPECT_EQ(browser_result[5], "*/*"); // q=0.8 + + // API client header + std::vector api_result; + EXPECT_TRUE(detail::parse_accept_header( + "application/json;q=0.9,application/xml;q=0.8,text/plain;q=0.1", + api_result)); + EXPECT_EQ(api_result.size(), 3); + EXPECT_EQ(api_result[0], "application/json"); + EXPECT_EQ(api_result[1], "application/xml"); + EXPECT_EQ(api_result[2], "text/plain"); +} + +TEST(ParseAcceptHeaderTest, SpecialCases) { + // Quality value with 3 decimal places + std::vector decimal_result; + EXPECT_TRUE(detail::parse_accept_header( + "text/html;q=0.123,application/json;q=0.456", decimal_result)); + EXPECT_EQ(decimal_result.size(), 2); + EXPECT_EQ(decimal_result[0], "application/json"); // Higher q value + EXPECT_EQ(decimal_result[1], "text/html"); + + // Zero quality (should still be included but with lowest priority) + std::vector zero_q_result; + EXPECT_TRUE(detail::parse_accept_header("text/html;q=0,application/json;q=1", + zero_q_result)); + EXPECT_EQ(zero_q_result.size(), 2); + EXPECT_EQ(zero_q_result[0], "application/json"); // q=1 + EXPECT_EQ(zero_q_result[1], "text/html"); // q=0 + + // No spaces around commas + std::vector no_space_result; + EXPECT_TRUE(detail::parse_accept_header( + "text/html;q=0.9,application/json;q=0.8,text/plain;q=0.7", + no_space_result)); + EXPECT_EQ(no_space_result.size(), 3); + EXPECT_EQ(no_space_result[0], "text/html"); + EXPECT_EQ(no_space_result[1], "application/json"); + EXPECT_EQ(no_space_result[2], "text/plain"); +} + +TEST(ParseAcceptHeaderTest, InvalidCases) { + std::vector result; + + // Invalid quality value (> 1.0) + EXPECT_FALSE( + detail::parse_accept_header("text/html;q=1.5,application/json", result)); + + // Invalid quality value (< 0.0) + EXPECT_FALSE( + detail::parse_accept_header("text/html;q=-0.1,application/json", result)); + + // Invalid quality value (not a number) + EXPECT_FALSE(detail::parse_accept_header( + "text/html;q=invalid,application/json", result)); + + // Empty quality value + EXPECT_FALSE( + detail::parse_accept_header("text/html;q=,application/json", result)); + + // Invalid media type format (no slash and not wildcard) + EXPECT_FALSE( + detail::parse_accept_header("invalidtype,application/json", result)); + + // Empty media type + result.clear(); + EXPECT_FALSE(detail::parse_accept_header(",application/json", result)); + + // Only commas + result.clear(); + EXPECT_FALSE(detail::parse_accept_header(",,,", result)); + + // Valid cases should still work + EXPECT_TRUE(detail::parse_accept_header("*/*", result)); + EXPECT_EQ(result.size(), 1); + EXPECT_EQ(result[0], "*/*"); + + EXPECT_TRUE(detail::parse_accept_header("*", result)); + EXPECT_EQ(result.size(), 1); + EXPECT_EQ(result[0], "*"); + + EXPECT_TRUE(detail::parse_accept_header("text/*", result)); + EXPECT_EQ(result.size(), 1); + EXPECT_EQ(result[0], "text/*"); +} + +TEST(ParseAcceptHeaderTest, ContentTypesPopulatedAndInvalidHeaderHandling) { + Server svr; + + svr.Get("/accept_ok", [&](const Request &req, Response &res) { + EXPECT_EQ(req.accept_content_types.size(), 3); + EXPECT_EQ(req.accept_content_types[0], "application/json"); + EXPECT_EQ(req.accept_content_types[1], "text/html"); + EXPECT_EQ(req.accept_content_types[2], "*/*"); + res.set_content("ok", "text/plain"); + }); + + svr.Get("/accept_bad_request", [&](const Request & /*req*/, Response &res) { + EXPECT_TRUE(false); + res.set_content("bad request", "text/plain"); + }); + + auto listen_thread = std::thread([&svr]() { svr.listen("localhost", PORT); }); + auto se = detail::scope_exit([&] { + svr.stop(); + listen_thread.join(); + ASSERT_FALSE(svr.is_running()); + }); + + svr.wait_until_ready(); + + Client cli("localhost", PORT); + + { + auto res = + cli.Get("/accept_ok", + {{"Accept", "application/json, text/html;q=0.8, */*;q=0.1"}}); + ASSERT_TRUE(res); + EXPECT_EQ(StatusCode::OK_200, res->status); + } + + { + auto res = cli.Get("/accept_bad_request", + {{"Accept", "text/html;q=abc,application/json"}}); + ASSERT_TRUE(res); + EXPECT_EQ(StatusCode::BadRequest_400, res->status); + } +} + TEST(DivideTest, DivideStringTests) { auto divide = [](const std::string &str, char d) { std::string lhs;