1
0
mirror of synced 2025-04-28 09:25:05 +03:00

Added Endpoint structure in Client

This commit is contained in:
yhirose 2020-06-13 01:26:57 -04:00
parent 5af7222217
commit f80b6bd980
2 changed files with 281 additions and 209 deletions

352
httplib.h
View File

@ -194,7 +194,6 @@ using socket_t = int;
#include <mutex> #include <mutex>
#include <random> #include <random>
#include <regex> #include <regex>
#include <set>
#include <string> #include <string>
#include <sys/stat.h> #include <sys/stat.h>
#include <thread> #include <thread>
@ -801,7 +800,7 @@ public:
bool send(const std::vector<Request> &requests, bool send(const std::vector<Request> &requests,
std::vector<Response> &responses); std::vector<Response> &responses);
void stop(); virtual void stop();
CPPHTTPLIB_DEPRECATED void set_timeout_sec(time_t timeout_sec); CPPHTTPLIB_DEPRECATED void set_timeout_sec(time_t timeout_sec);
void set_connection_timeout(time_t sec, time_t usec = 0); void set_connection_timeout(time_t sec, time_t usec = 0);
@ -832,11 +831,21 @@ public:
void set_logger(Logger logger); void set_logger(Logger logger);
protected: protected:
struct Endpoint {
socket_t sock = INVALID_SOCKET;
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
SSL *ssl = nullptr;
#endif
};
virtual bool create_and_connect_socket(Endpoint &endpoint);
virtual void close_socket(Endpoint &endpoint, bool process_socket_ret);
bool process_request(Stream &strm, const Request &req, Response &res, bool process_request(Stream &strm, const Request &req, Response &res,
bool last_connection, bool &connection_close); bool last_connection, bool &connection_close);
std::set<socket_t> cli_socks_; std::vector<Endpoint> endpoints_;
std::mutex cli_socks_mutex_; std::mutex endpoints_mutex_;
const std::string host_; const std::string host_;
const int port_; const int port_;
@ -913,14 +922,13 @@ protected:
private: private:
socket_t create_client_socket() const; socket_t create_client_socket() const;
bool create_and_connect_socket(socket_t &sock);
bool read_response_line(Stream &strm, Response &res); bool read_response_line(Stream &strm, Response &res);
bool write_request(Stream &strm, const Request &req, bool last_connection); bool write_request(Stream &strm, const Request &req, bool last_connection);
bool redirect(const Request &req, Response &res); bool redirect(const Request &req, Response &res);
bool handle_request(Stream &strm, const Request &req, Response &res, bool handle_request(Stream &strm, const Request &req, Response &res,
bool last_connection, bool &connection_close); bool last_connection, bool &connection_close);
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT #ifdef CPPHTTPLIB_OPENSSL_SUPPORT
bool connect(socket_t sock, Response &res, bool &error); bool connect_with_proxy(socket_t sock, Response &res, bool &error);
#endif #endif
std::shared_ptr<Response> send_with_content_provider( std::shared_ptr<Response> send_with_content_provider(
@ -928,8 +936,8 @@ private:
const std::string &body, size_t content_length, const std::string &body, size_t content_length,
ContentProvider content_provider, const char *content_type); ContentProvider content_provider, const char *content_type);
virtual bool process_and_close_socket( virtual bool
socket_t sock, size_t request_count, process_socket(Endpoint &endpoint, size_t request_count,
std::function<bool(Stream &strm, bool last_connection, std::function<bool(Stream &strm, bool last_connection,
bool &connection_close)> bool &connection_close)>
callback); callback);
@ -1018,6 +1026,8 @@ public:
~SSLClient() override; ~SSLClient() override;
void stop() override;
bool is_valid() const override; bool is_valid() const override;
void set_ca_cert_path(const char *ca_cert_file_path, void set_ca_cert_path(const char *ca_cert_file_path,
@ -1032,13 +1042,17 @@ public:
SSL_CTX *ssl_context() const; SSL_CTX *ssl_context() const;
private: private:
bool process_and_close_socket( bool create_and_connect_socket(Endpoint &endpoint) override;
socket_t sock, size_t request_count, void close_socket(Endpoint &endpoint, bool process_socket_ret) override;
bool process_socket(Endpoint &endpoint, size_t request_count,
std::function<bool(Stream &strm, bool last_connection, std::function<bool(Stream &strm, bool last_connection,
bool &connection_close)> bool &connection_close)>
callback) override; callback) override;
bool is_ssl() const override; bool is_ssl() const override;
bool initialize_ssl(Endpoint &endpoint);
bool verify_host(X509 *server_cert) const; bool verify_host(X509 *server_cert) const;
bool verify_host_with_subject_alt_name(X509 *server_cert) const; bool verify_host_with_subject_alt_name(X509 *server_cert) const;
bool verify_host_with_common_name(X509 *server_cert) const; bool verify_host_with_common_name(X509 *server_cert) const;
@ -1845,10 +1859,8 @@ private:
}; };
template <typename T> template <typename T>
inline bool process_socket(bool is_client_request, socket_t sock, inline bool process_socket_core(bool is_client_request, socket_t sock,
size_t keep_alive_max_count, time_t read_timeout_sec, size_t keep_alive_max_count, T callback) {
time_t read_timeout_usec, time_t write_timeout_sec,
time_t write_timeout_usec, T callback) {
assert(keep_alive_max_count > 0); assert(keep_alive_max_count > 0);
auto ret = false; auto ret = false;
@ -1859,37 +1871,34 @@ inline bool process_socket(bool is_client_request, socket_t sock,
(is_client_request || (is_client_request ||
select_read(sock, CPPHTTPLIB_KEEPALIVE_TIMEOUT_SECOND, select_read(sock, CPPHTTPLIB_KEEPALIVE_TIMEOUT_SECOND,
CPPHTTPLIB_KEEPALIVE_TIMEOUT_USECOND) > 0)) { CPPHTTPLIB_KEEPALIVE_TIMEOUT_USECOND) > 0)) {
SocketStream strm(sock, read_timeout_sec, read_timeout_usec,
write_timeout_sec, write_timeout_usec);
auto last_connection = count == 1; auto last_connection = count == 1;
auto connection_close = false; auto connection_close = false;
ret = callback(strm, last_connection, connection_close); ret = callback(last_connection, connection_close);
if (!ret || connection_close) { break; } if (!ret || connection_close) { break; }
count--; count--;
} }
} else { // keep_alive_max_count is 0 or 1 } else { // keep_alive_max_count is 0 or 1
SocketStream strm(sock, read_timeout_sec, read_timeout_usec,
write_timeout_sec, write_timeout_usec);
auto dummy_connection_close = false; auto dummy_connection_close = false;
ret = callback(strm, true, dummy_connection_close); ret = callback(true, dummy_connection_close);
} }
return ret; return ret;
} }
template <typename T> template <typename T>
inline bool inline bool process_socket(bool is_client_request, socket_t sock,
process_and_close_socket(bool is_client_request, socket_t sock,
size_t keep_alive_max_count, time_t read_timeout_sec, size_t keep_alive_max_count, time_t read_timeout_sec,
time_t read_timeout_usec, time_t write_timeout_sec, time_t read_timeout_usec, time_t write_timeout_sec,
time_t write_timeout_usec, T callback) { time_t write_timeout_usec, T callback) {
auto ret = process_socket(is_client_request, sock, keep_alive_max_count, return process_socket_core(
read_timeout_sec, read_timeout_usec, is_client_request, sock, keep_alive_max_count,
write_timeout_sec, write_timeout_usec, callback); [&](bool last_connection, bool connection_close) {
close_socket(sock); SocketStream strm(sock, read_timeout_sec, read_timeout_usec,
return ret; write_timeout_sec, write_timeout_usec);
return callback(strm, last_connection, connection_close);
});
} }
inline int shutdown_socket(socket_t sock) { inline int shutdown_socket(socket_t sock) {
@ -4295,13 +4304,16 @@ Server::process_request(Stream &strm, bool last_connection,
inline bool Server::is_valid() const { return true; } inline bool Server::is_valid() const { return true; }
inline bool Server::process_and_close_socket(socket_t sock) { inline bool Server::process_and_close_socket(socket_t sock) {
return detail::process_and_close_socket( auto ret = detail::process_socket(
false, sock, keep_alive_max_count_, read_timeout_sec_, read_timeout_usec_, false, sock, keep_alive_max_count_, read_timeout_sec_, read_timeout_usec_,
write_timeout_sec_, write_timeout_usec_, write_timeout_sec_, write_timeout_usec_,
[this](Stream &strm, bool last_connection, bool &connection_close) { [this](Stream &strm, bool last_connection, bool &connection_close) {
return process_request(strm, last_connection, connection_close, return process_request(strm, last_connection, connection_close,
nullptr); nullptr);
}); });
detail::close_socket(sock);
return ret;
} }
// HTTP client implementation // HTTP client implementation
@ -4333,20 +4345,26 @@ inline socket_t Client::create_client_socket() const {
connection_timeout_usec_, interface_); connection_timeout_usec_, interface_);
} }
inline bool Client::create_and_connect_socket(socket_t &sock) { inline bool Client::create_and_connect_socket(Endpoint &endpoint) {
sock = create_client_socket(); auto sock = create_client_socket();
if (sock == INVALID_SOCKET) { return false; } if (sock == INVALID_SOCKET) { return false; }
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT #ifdef CPPHTTPLIB_OPENSSL_SUPPORT
if (is_ssl() && !proxy_host_.empty()) { if (is_ssl() && !proxy_host_.empty()) {
Response res; Response res;
bool error; bool error;
if (!connect(sock, res, error)) { return error; } if (!connect_with_proxy(sock, res, error)) { return error; }
} }
#endif #endif
endpoint.sock = sock;
return true; return true;
} }
inline void Client::close_socket(Endpoint &endpoint,
bool /*process_socket_ret*/) {
detail::close_socket(endpoint.sock);
}
inline bool Client::read_response_line(Stream &strm, Response &res) { inline bool Client::read_response_line(Stream &strm, Response &res) {
std::array<char, 2048> buf; std::array<char, 2048> buf;
@ -4366,23 +4384,32 @@ inline bool Client::read_response_line(Stream &strm, Response &res) {
} }
inline bool Client::send(const Request &req, Response &res) { inline bool Client::send(const Request &req, Response &res) {
socket_t sock = INVALID_SOCKET; Endpoint endpoint;
if (!create_and_connect_socket(sock)) { return false; } if (!create_and_connect_socket(endpoint)) { return false; }
{ {
std::lock_guard<std::mutex> guard(cli_socks_mutex_); std::lock_guard<std::mutex> guard(endpoints_mutex_);
cli_socks_.insert(sock); endpoints_.push_back(endpoint);
} }
auto ret = process_and_close_socket( auto ret = process_socket(
sock, 1, [&](Stream &strm, bool last_connection, bool &connection_close) { endpoint, 1,
[&](Stream &strm, bool last_connection, bool &connection_close) {
return handle_request(strm, req, res, last_connection, return handle_request(strm, req, res, last_connection,
connection_close); connection_close);
}); });
{ {
std::lock_guard<std::mutex> guard(cli_socks_mutex_); std::lock_guard<std::mutex> guard(endpoints_mutex_);
cli_socks_.erase(sock);
auto it = std::find_if(
endpoints_.begin(), endpoints_.end(),
[&](Endpoint &endpoint2) { return endpoint.sock == endpoint2.sock; });
if (it != endpoints_.end()) {
close_socket(endpoint, ret);
endpoints_.erase(it);
}
} }
return ret; return ret;
@ -4392,29 +4419,41 @@ inline bool Client::send(const std::vector<Request> &requests,
std::vector<Response> &responses) { std::vector<Response> &responses) {
size_t i = 0; size_t i = 0;
while (i < requests.size()) { while (i < requests.size()) {
socket_t sock = INVALID_SOCKET; Endpoint endpoint;
if (!create_and_connect_socket(sock)) { return false; } if (!create_and_connect_socket(endpoint)) { return false; }
{ {
std::lock_guard<std::mutex> guard(cli_socks_mutex_); std::lock_guard<std::mutex> guard(endpoints_mutex_);
cli_socks_.insert(sock); endpoints_.push_back(endpoint);
} }
auto ret = process_and_close_socket( auto request_count = (std::min)(requests.size() - i, keep_alive_max_count_);
sock, requests.size() - i,
auto ret = process_socket(endpoint, request_count,
[&](Stream &strm, bool last_connection, [&](Stream &strm, bool last_connection,
bool &connection_close) -> bool { bool &connection_close) -> bool {
auto &req = requests[i++]; auto &req = requests[i++];
auto res = Response(); auto res = Response();
auto ret = auto ret = handle_request(strm, req, res,
handle_request(strm, req, res, last_connection, connection_close); last_connection,
if (ret) { responses.emplace_back(std::move(res)); } connection_close);
if (ret) {
responses.emplace_back(std::move(res));
}
return ret; return ret;
}); });
{ {
std::lock_guard<std::mutex> guard(cli_socks_mutex_); std::lock_guard<std::mutex> guard(endpoints_mutex_);
cli_socks_.erase(sock);
auto it = std::find_if(
endpoints_.begin(), endpoints_.end(),
[&](Endpoint &endpoint2) { return endpoint.sock == endpoint2.sock; });
if (it != endpoints_.end()) {
close_socket(endpoint, ret);
endpoints_.erase(it);
}
} }
if (!ret) { return false; } if (!ret) { return false; }
@ -4477,14 +4516,16 @@ inline bool Client::handle_request(Stream &strm, const Request &req,
} }
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT #ifdef CPPHTTPLIB_OPENSSL_SUPPORT
inline bool Client::connect(socket_t sock, Response &res, bool &error) { inline bool Client::connect_with_proxy(socket_t sock, Response &res,
bool &error) {
error = true; error = true;
Response res2; Response res2;
if (!detail::process_socket( if (!detail::process_socket_core(
true, sock, 1, read_timeout_sec_, read_timeout_usec_, true, sock, 1, [&](bool /*last_connection*/, bool &connection_close) {
write_timeout_sec_, write_timeout_usec_, detail::SocketStream strm(sock, read_timeout_sec_,
[&](Stream &strm, bool /*last_connection*/, bool &connection_close) { read_timeout_usec_, write_timeout_sec_,
write_timeout_usec_);
Request req2; Request req2;
req2.method = "CONNECT"; req2.method = "CONNECT";
req2.path = host_and_port_; req2.path = host_and_port_;
@ -4501,11 +4542,12 @@ inline bool Client::connect(socket_t sock, Response &res, bool &error) {
std::map<std::string, std::string> auth; std::map<std::string, std::string> auth;
if (parse_www_authenticate(res2, auth, true)) { if (parse_www_authenticate(res2, auth, true)) {
Response res3; Response res3;
if (!detail::process_socket( if (!detail::process_socket_core(
true, sock, 1, read_timeout_sec_, read_timeout_usec_, true, sock, 1,
write_timeout_sec_, write_timeout_usec_, [&](bool /*last_connection*/, bool &connection_close) {
[&](Stream &strm, bool /*last_connection*/, detail::SocketStream strm(
bool &connection_close) { sock, read_timeout_sec_, read_timeout_usec_,
write_timeout_sec_, write_timeout_usec_);
Request req3; Request req3;
req3.method = "CONNECT"; req3.method = "CONNECT";
req3.path = host_and_port_; req3.path = host_and_port_;
@ -4781,14 +4823,13 @@ inline bool Client::process_request(Stream &strm, const Request &req,
return true; return true;
} }
inline bool Client::process_and_close_socket( inline bool
socket_t sock, size_t request_count, Client::process_socket(Endpoint &endpoint, size_t request_count,
std::function<bool(Stream &strm, bool last_connection, std::function<bool(Stream &strm, bool last_connection,
bool &connection_close)> bool &connection_close)>
callback) { callback) {
request_count = (std::min)(request_count, keep_alive_max_count_); return detail::process_socket(
return detail::process_and_close_socket( true, endpoint.sock, request_count, read_timeout_sec_, read_timeout_usec_,
true, sock, request_count, read_timeout_sec_, read_timeout_usec_,
write_timeout_sec_, write_timeout_usec_, callback); write_timeout_sec_, write_timeout_usec_, callback);
} }
@ -5085,12 +5126,12 @@ inline std::shared_ptr<Response> Client::Options(const char *path,
} }
inline void Client::stop() { inline void Client::stop() {
std::lock_guard<std::mutex> guard(cli_socks_mutex_); std::lock_guard<std::mutex> guard(endpoints_mutex_);
for (auto &sock : cli_socks_) { for (auto &endpoint : endpoints_) {
detail::shutdown_socket(sock); detail::shutdown_socket(endpoint.sock);
detail::close_socket(sock); detail::close_socket(endpoint.sock);
} }
cli_socks_.clear(); endpoints_.clear();
} }
inline void Client::set_timeout_sec(time_t timeout_sec) { inline void Client::set_timeout_sec(time_t timeout_sec) {
@ -5164,77 +5205,55 @@ inline void Client::set_logger(Logger logger) { logger_ = std::move(logger); }
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT #ifdef CPPHTTPLIB_OPENSSL_SUPPORT
namespace detail { namespace detail {
template <typename U, typename V, typename T> template <typename U, typename V>
inline bool process_and_close_socket_ssl( inline SSL *ssl_new(socket_t sock, SSL_CTX *ctx, std::mutex &ctx_mutex,
bool is_client_request, socket_t sock, size_t keep_alive_max_count, U SSL_connect_or_accept, V setup) {
time_t read_timeout_sec, time_t read_timeout_usec, time_t write_timeout_sec,
time_t write_timeout_usec, SSL_CTX *ctx, std::mutex &ctx_mutex,
U SSL_connect_or_accept, V setup, T callback) {
assert(keep_alive_max_count > 0);
SSL *ssl = nullptr; SSL *ssl = nullptr;
{ {
std::lock_guard<std::mutex> guard(ctx_mutex); std::lock_guard<std::mutex> guard(ctx_mutex);
ssl = SSL_new(ctx); ssl = SSL_new(ctx);
} }
if (!ssl) { if (ssl) {
close_socket(sock);
return false;
}
auto bio = BIO_new_socket(static_cast<int>(sock), BIO_NOCLOSE); auto bio = BIO_new_socket(static_cast<int>(sock), BIO_NOCLOSE);
SSL_set_bio(ssl, bio, bio); SSL_set_bio(ssl, bio, bio);
if (!setup(ssl)) { if (!setup(ssl) || SSL_connect_or_accept(ssl) != 1) {
SSL_shutdown(ssl); SSL_shutdown(ssl);
{ {
std::lock_guard<std::mutex> guard(ctx_mutex); std::lock_guard<std::mutex> guard(ctx_mutex);
SSL_free(ssl); SSL_free(ssl);
} }
return nullptr;
close_socket(sock);
return false;
}
auto ret = false;
if (SSL_connect_or_accept(ssl) == 1) {
if (keep_alive_max_count > 1) {
auto count = keep_alive_max_count;
while (count > 0 &&
(is_client_request ||
select_read(sock, CPPHTTPLIB_KEEPALIVE_TIMEOUT_SECOND,
CPPHTTPLIB_KEEPALIVE_TIMEOUT_USECOND) > 0)) {
SSLSocketStream strm(sock, ssl, read_timeout_sec, read_timeout_usec,
write_timeout_sec, write_timeout_usec);
auto last_connection = count == 1;
auto connection_close = false;
ret = callback(ssl, strm, last_connection, connection_close);
if (!ret || connection_close) { break; }
count--;
}
} else {
SSLSocketStream strm(sock, ssl, read_timeout_sec, read_timeout_usec,
write_timeout_sec, write_timeout_usec);
auto dummy_connection_close = false;
ret = callback(ssl, strm, true, dummy_connection_close);
} }
} }
if (ret) { return ssl;
}
inline void ssl_delete(std::mutex &ctx_mutex, SSL *ssl,
bool process_socket_ret) {
if (process_socket_ret) {
SSL_shutdown(ssl); // shutdown only if not already closed by remote SSL_shutdown(ssl); // shutdown only if not already closed by remote
} }
{
std::lock_guard<std::mutex> guard(ctx_mutex); std::lock_guard<std::mutex> guard(ctx_mutex);
SSL_free(ssl); SSL_free(ssl);
} }
close_socket(sock); template <typename T>
inline bool
return ret; process_socket_ssl(SSL *ssl, bool is_client_request, socket_t sock,
size_t keep_alive_max_count, time_t read_timeout_sec,
time_t read_timeout_usec, time_t write_timeout_sec,
time_t write_timeout_usec, T callback) {
return process_socket_core(
is_client_request, sock, keep_alive_max_count,
[&](bool last_connection, bool connection_close) {
SSLSocketStream strm(sock, ssl, read_timeout_sec, read_timeout_usec,
write_timeout_sec, write_timeout_usec);
return callback(strm, last_connection, connection_close);
});
} }
#if OPENSSL_VERSION_NUMBER < 0x10100000L #if OPENSSL_VERSION_NUMBER < 0x10100000L
@ -5311,8 +5330,7 @@ inline bool SSLSocketStream::is_writable() const {
} }
inline ssize_t SSLSocketStream::read(char *ptr, size_t size) { inline ssize_t SSLSocketStream::read(char *ptr, size_t size) {
if (SSL_pending(ssl_) > 0 || if (SSL_pending(ssl_) > 0 || is_readable()) {
select_read(sock_, read_timeout_sec_, read_timeout_usec_) > 0) {
return SSL_read(ssl_, ptr, static_cast<int>(size)); return SSL_read(ssl_, ptr, static_cast<int>(size));
} }
return -1; return -1;
@ -5405,15 +5423,25 @@ inline SSLServer::~SSLServer() {
inline bool SSLServer::is_valid() const { return ctx_; } inline bool SSLServer::is_valid() const { return ctx_; }
inline bool SSLServer::process_and_close_socket(socket_t sock) { inline bool SSLServer::process_and_close_socket(socket_t sock) {
return detail::process_and_close_socket_ssl( auto ssl = detail::ssl_new(sock, ctx_, ctx_mutex_, SSL_accept,
false, sock, keep_alive_max_count_, read_timeout_sec_, read_timeout_usec_, [](SSL * /*ssl*/) { return true; });
write_timeout_sec_, write_timeout_usec_, ctx_, ctx_mutex_, SSL_accept,
[](SSL * /*ssl*/) { return true; }, if (ssl) {
[this](SSL *ssl, Stream &strm, bool last_connection, auto ret = detail::process_socket_ssl(
ssl, false, sock, keep_alive_max_count_, read_timeout_sec_,
read_timeout_usec_, write_timeout_sec_, write_timeout_usec_,
[this, ssl](Stream &strm, bool last_connection,
bool &connection_close) { bool &connection_close) {
return process_request(strm, last_connection, connection_close, return process_request(strm, last_connection, connection_close,
[&](Request &req) { req.ssl = ssl; }); [&](Request &req) { req.ssl = ssl; });
}); });
detail::ssl_delete(ctx_mutex_, ssl, ret);
return ret;
}
detail::close_socket(sock);
return false;
} }
// SSL HTTP client implementation // SSL HTTP client implementation
@ -5466,6 +5494,25 @@ inline SSLClient::~SSLClient() {
if (ctx_) { SSL_CTX_free(ctx_); } if (ctx_) { SSL_CTX_free(ctx_); }
} }
inline void SSLClient::stop() {
auto endpoints = endpoints_;
{
std::lock_guard<std::mutex> guard(endpoints_mutex_);
for (auto &endpoint : endpoints_) {
detail::shutdown_socket(endpoint.sock);
detail::close_socket(endpoint.sock);
}
endpoints_.clear();
}
std::this_thread::sleep_for(std::chrono::milliseconds(100));
for (auto &endpoint : endpoints) {
SSL_shutdown(endpoint.ssl);
SSL_free(endpoint.ssl);
}
}
inline bool SSLClient::is_valid() const { return ctx_; } inline bool SSLClient::is_valid() const { return ctx_; }
inline void SSLClient::set_ca_cert_path(const char *ca_cert_file_path, inline void SSLClient::set_ca_cert_path(const char *ca_cert_file_path,
@ -5488,24 +5535,20 @@ inline long SSLClient::get_openssl_verify_result() const {
inline SSL_CTX *SSLClient::ssl_context() const { return ctx_; } inline SSL_CTX *SSLClient::ssl_context() const { return ctx_; }
inline bool SSLClient::process_and_close_socket( inline bool SSLClient::create_and_connect_socket(Endpoint &endpoint) {
socket_t sock, size_t request_count, return is_valid() && Client::create_and_connect_socket(endpoint) &&
std::function<bool(Stream &strm, bool last_connection, initialize_ssl(endpoint);
bool &connection_close)> }
callback) {
request_count = std::min(request_count, keep_alive_max_count_); inline bool SSLClient::initialize_ssl(Endpoint &endpoint) {
auto ssl = detail::ssl_new(
return is_valid() && endpoint.sock, ctx_, ctx_mutex_,
detail::process_and_close_socket_ssl(
true, sock, request_count, read_timeout_sec_, read_timeout_usec_,
write_timeout_sec_, write_timeout_usec_, ctx_, ctx_mutex_,
[&](SSL *ssl) { [&](SSL *ssl) {
if (ca_cert_file_path_.empty() && ca_cert_store_ == nullptr) { if (ca_cert_file_path_.empty() && ca_cert_store_ == nullptr) {
SSL_CTX_set_verify(ctx_, SSL_VERIFY_NONE, nullptr); SSL_CTX_set_verify(ctx_, SSL_VERIFY_NONE, nullptr);
} else if (!ca_cert_file_path_.empty()) { } else if (!ca_cert_file_path_.empty()) {
if (!SSL_CTX_load_verify_locations( if (!SSL_CTX_load_verify_locations(ctx_, ca_cert_file_path_.c_str(),
ctx_, ca_cert_file_path_.c_str(), nullptr)) { nullptr)) {
return false; return false;
} }
SSL_CTX_set_verify(ctx_, SSL_VERIFY_PEER, nullptr); SSL_CTX_set_verify(ctx_, SSL_VERIFY_PEER, nullptr);
@ -5539,9 +5582,34 @@ inline bool SSLClient::process_and_close_socket(
[&](SSL *ssl) { [&](SSL *ssl) {
SSL_set_tlsext_host_name(ssl, host_.c_str()); SSL_set_tlsext_host_name(ssl, host_.c_str());
return true; return true;
}, });
[&](SSL * /*ssl*/, Stream &strm, bool last_connection,
bool &connection_close) { if (ssl) {
endpoint.ssl = ssl;
return true;
}
detail::close_socket(endpoint.sock);
return false;
}
inline void SSLClient::close_socket(Endpoint &endpoint,
bool process_socket_ret) {
assert(endpoint.ssl);
detail::ssl_delete(ctx_mutex_, endpoint.ssl, process_socket_ret);
detail::close_socket(endpoint.sock);
}
inline bool
SSLClient::process_socket(Endpoint &endpoint, size_t request_count,
std::function<bool(Stream &strm, bool last_connection,
bool &connection_close)>
callback) {
assert(endpoint.ssl);
return detail::process_socket_ssl(
endpoint.ssl, true, endpoint.sock, request_count, read_timeout_sec_,
read_timeout_usec_, write_timeout_sec_, write_timeout_usec_,
[&](Stream &strm, bool last_connection, bool &connection_close) {
return callback(strm, last_connection, connection_close); return callback(strm, last_connection, connection_close);
}); });
} }

View File

@ -1767,16 +1767,16 @@ TEST_F(ServerTest, GetStreamedEndless) {
TEST_F(ServerTest, ClientStop) { TEST_F(ServerTest, ClientStop) {
std::vector<std::thread> threads; std::vector<std::thread> threads;
for (auto i = 0; i < 10; i++) { for (auto i = 0; i < 8; i++) {
threads.emplace_back(thread([&]() { threads.emplace_back(thread([&]() {
auto res = cli_.Get("/streamed-cancel", auto res = cli_.Get("/streamed-cancel",
[&](const char *, uint64_t) { return true; }); [&](const char *, uint64_t) { return true; });
ASSERT_TRUE(res == nullptr); ASSERT_TRUE(res == nullptr);
})); }));
} }
std::this_thread::sleep_for(std::chrono::seconds(1)); std::this_thread::sleep_for(std::chrono::seconds(3));
cli_.stop(); cli_.stop();
for (auto& t: threads) { for (auto &t : threads) {
t.join(); t.join();
} }
} }
@ -2299,13 +2299,13 @@ TEST_F(ServerTest, MultipartFormDataGzip) {
// Sends a raw request to a server listening at HOST:PORT. // Sends a raw request to a server listening at HOST:PORT.
static bool send_request(time_t read_timeout_sec, const std::string &req, static bool send_request(time_t read_timeout_sec, const std::string &req,
std::string *resp = nullptr) { std::string *resp = nullptr) {
auto client_sock = detail::create_client_socket( auto client_sock =
HOST, PORT, nullptr, detail::create_client_socket(HOST, PORT, nullptr,
/*timeout_sec=*/5, 0, std::string()); /*timeout_sec=*/5, 0, std::string());
if (client_sock == INVALID_SOCKET) { return false; } if (client_sock == INVALID_SOCKET) { return false; }
return detail::process_and_close_socket( auto ret = detail::process_socket(
true, client_sock, 1, read_timeout_sec, 0, 0, 0, true, client_sock, 1, read_timeout_sec, 0, 0, 0,
[&](Stream &strm, bool /*last_connection*/, bool & [&](Stream &strm, bool /*last_connection*/, bool &
/*connection_close*/) -> bool { /*connection_close*/) -> bool {
@ -2322,6 +2322,10 @@ static bool send_request(time_t read_timeout_sec, const std::string &req,
} }
return true; return true;
}); });
detail::close_socket(client_sock);
return ret;
} }
TEST(ServerRequestParsingTest, TrimWhitespaceFromHeaderValues) { TEST(ServerRequestParsingTest, TrimWhitespaceFromHeaderValues) {