1
0
mirror of synced 2025-07-02 20:02:24 +03:00

Issue 2162 (#2163)

* Resolve #2162

* Update
This commit is contained in:
yhirose
2025-06-24 17:37:30 -04:00
committed by GitHub
parent aabd0634ae
commit 1729aa8c1f
4 changed files with 463 additions and 2 deletions

129
httplib.h
View File

@ -670,6 +670,7 @@ struct Request {
std::function<bool()> is_connection_closed = []() { return true; };
// for client
std::vector<std::string> accept_content_types;
ResponseHandler response_handler;
ContentReceiverWithProgress content_receiver;
Progress progress;
@ -2491,6 +2492,9 @@ bool parse_multipart_boundary(const std::string &content_type,
bool parse_range_header(const std::string &s, Ranges &ranges);
bool parse_accept_header(const std::string &s,
std::vector<std::string> &content_types);
int close_socket(socket_t sock);
ssize_t send_socket(socket_t sock, const void *ptr, size_t size, int flags);
@ -5026,6 +5030,123 @@ inline bool parse_range_header(const std::string &s, Ranges &ranges) try {
} catch (...) { return false; }
#endif
inline bool parse_accept_header(const std::string &s,
std::vector<std::string> &content_types) {
content_types.clear();
// Empty string is considered valid (no preference)
if (s.empty()) { return true; }
// Check for invalid patterns: leading/trailing commas or consecutive commas
if (s.front() == ',' || s.back() == ',' ||
s.find(",,") != std::string::npos) {
return false;
}
struct AcceptEntry {
std::string media_type;
double quality;
int order; // Original order in header
};
std::vector<AcceptEntry> entries;
int order = 0;
bool has_invalid_entry = false;
// Split by comma and parse each entry
split(s.data(), s.data() + s.size(), ',', [&](const char *b, const char *e) {
std::string entry(b, e);
entry = trim_copy(entry);
if (entry.empty()) {
has_invalid_entry = true;
return;
}
AcceptEntry accept_entry;
accept_entry.quality = 1.0; // Default quality
accept_entry.order = order++;
// Find q= parameter
auto q_pos = entry.find(";q=");
if (q_pos == std::string::npos) { q_pos = entry.find("; q="); }
if (q_pos != std::string::npos) {
// Extract media type (before q parameter)
accept_entry.media_type = trim_copy(entry.substr(0, q_pos));
// Extract quality value
auto q_start = entry.find('=', q_pos) + 1;
auto q_end = entry.find(';', q_start);
if (q_end == std::string::npos) { q_end = entry.length(); }
std::string quality_str =
trim_copy(entry.substr(q_start, q_end - q_start));
if (quality_str.empty()) {
has_invalid_entry = true;
return;
}
try {
accept_entry.quality = std::stod(quality_str);
// Check if quality is in valid range [0.0, 1.0]
if (accept_entry.quality < 0.0 || accept_entry.quality > 1.0) {
has_invalid_entry = true;
return;
}
} catch (...) {
has_invalid_entry = true;
return;
}
} else {
// No quality parameter, use entire entry as media type
accept_entry.media_type = entry;
}
// Remove additional parameters from media type
auto param_pos = accept_entry.media_type.find(';');
if (param_pos != std::string::npos) {
accept_entry.media_type =
trim_copy(accept_entry.media_type.substr(0, param_pos));
}
// Basic validation of media type format
if (accept_entry.media_type.empty()) {
has_invalid_entry = true;
return;
}
// Check for basic media type format (should contain '/' or be '*')
if (accept_entry.media_type != "*" &&
accept_entry.media_type.find('/') == std::string::npos) {
has_invalid_entry = true;
return;
}
entries.push_back(accept_entry);
});
// Return false if any invalid entry was found
if (has_invalid_entry) { return false; }
// Sort by quality (descending), then by original order (ascending)
std::sort(entries.begin(), entries.end(),
[](const AcceptEntry &a, const AcceptEntry &b) {
if (a.quality != b.quality) {
return a.quality > b.quality; // Higher quality first
}
return a.order < b.order; // Earlier order first for same quality
});
// Extract sorted media types
content_types.reserve(entries.size());
for (const auto &entry : entries) {
content_types.push_back(entry.media_type);
}
return true;
}
class MultipartFormDataParser {
public:
MultipartFormDataParser() = default;
@ -7446,6 +7567,14 @@ Server::process_request(Stream &strm, const std::string &remote_addr,
req.set_header("LOCAL_ADDR", req.local_addr);
req.set_header("LOCAL_PORT", std::to_string(req.local_port));
if (req.has_header("Accept")) {
const auto &accept_header = req.get_header_value("Accept");
if (!detail::parse_accept_header(accept_header, req.accept_content_types)) {
res.status = StatusCode::BadRequest_400;
return write_response(strm, close_connection, req, res);
}
}
if (req.has_header("Range")) {
const auto &range_header_value = req.get_header_value("Range");
if (!detail::parse_range_header(range_header_value, req.ranges)) {