From 22f124f871bf4c45538b8947c69a1742465a598d Mon Sep 17 00:00:00 2001 From: yhirose Date: Fri, 21 Apr 2017 23:00:00 -0400 Subject: [PATCH] Added OpenSSL support. #5 --- README.md | 15 +- example/Makefile | 25 +-- example/client.cc | 4 + example/server.cc | 9 +- example/simplesvr.cc | 9 +- httplib.h | 374 ++++++++++++++++++++++++++++++++++--------- test/Makefile | 19 +-- test/test.cc | 18 ++- 8 files changed, 373 insertions(+), 100 deletions(-) diff --git a/README.md b/README.md index 6fbd506..b1d9097 100644 --- a/README.md +++ b/README.md @@ -52,4 +52,17 @@ int main(void) } ``` -Copyright (c) 2014 Yuji Hirose. All rights reserved. +OpenSSL Support +--------------- + +SSL support is available with `CPPHTTPLIB_OPENSSL_SUPPORT`. `libssl` and `libcrypto` should be linked. + +```c++ +#define CPPHTTPLIB_OPENSSL_SUPPORT + +SSLServer svr("./key.pem", "./cert.pem"); + +SSLClient cli("localhost", 8080); +``` + +Copyright (c) 2017 Yuji Hirose. All rights reserved. diff --git a/example/Makefile b/example/Makefile index ceed0ed..becc4c9 100644 --- a/example/Makefile +++ b/example/Makefile @@ -1,24 +1,25 @@ -USE_CLANG = 1 - -ifdef USE_CLANG CC = clang++ -CFLAGS = -std=c++1y -stdlib=libc++ -g -else -CC = g++-4.9 -CFLAGS = -std=c++1y -g -endif +CFLAGS = -std=c++14 -I.. +#OPENSSL_SUPPORT = -DCPPHTTPLIB_OPENSSL_SUPPORT -I/usr/local/opt/openssl/include -L/usr/local/opt/openssl/lib -lssl -lcrypto all: server client hello simplesvr server : server.cc ../httplib.h - $(CC) -o server $(CFLAGS) -I.. server.cc + $(CC) -o server $(CFLAGS) server.cc $(OPENSSL_SUPPORT) client : client.cc ../httplib.h - $(CC) -o client $(CFLAGS) -I.. client.cc + $(CC) -o client $(CFLAGS) client.cc $(OPENSSL_SUPPORT) hello : hello.cc ../httplib.h - $(CC) -o hello $(CFLAGS) -I.. hello.cc + $(CC) -o hello $(CFLAGS) hello.cc $(OPENSSL_SUPPORT) simplesvr : simplesvr.cc ../httplib.h - $(CC) -o simplesvr $(CFLAGS) -I.. simplesvr.cc + $(CC) -o simplesvr $(CFLAGS) simplesvr.cc $(OPENSSL_SUPPORT) + +pem: + openssl genrsa 2048 > key.pem + openssl req -new -key key.pem | openssl x509 -days 3650 -req -signkey key.pem > cert.pem + +clean: + rm server client hello simplesvr *.pem diff --git a/example/client.cc b/example/client.cc index f08837b..3bd1641 100644 --- a/example/client.cc +++ b/example/client.cc @@ -12,7 +12,11 @@ using namespace std; int main(void) { +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + httplib::SSLClient cli("localhost", 8080); +#else httplib::Client cli("localhost", 8080); +#endif auto res = cli.get("/hi"); if (res) { diff --git a/example/server.cc b/example/server.cc index 0c521b6..b5e9aad 100644 --- a/example/server.cc +++ b/example/server.cc @@ -8,6 +8,9 @@ #include #include +#define SERVER_CERT_FILE "./cert.pem" +#define SERVER_PRIVATE_KEY_FILE "./key.pem" + using namespace httplib; std::string dump_headers(const MultiMap& headers) @@ -51,7 +54,7 @@ std::string log(const Request& req, const Response& res) snprintf(buf, sizeof(buf), "%d\n", res.status); s += buf; s += dump_headers(res.headers); - + if (!res.body.empty()) { s += res.body; } @@ -63,7 +66,11 @@ std::string log(const Request& req, const Response& res) int main(void) { +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + SSLServer svr(SERVER_CERT_FILE, SERVER_PRIVATE_KEY_FILE); +#else Server svr; +#endif svr.get("/", [=](const auto& req, auto& res) { res.set_redirect("/hi"); diff --git a/example/simplesvr.cc b/example/simplesvr.cc index fc24f56..c881bc2 100644 --- a/example/simplesvr.cc +++ b/example/simplesvr.cc @@ -9,6 +9,9 @@ #include #include +#define SERVER_CERT_FILE "./cert.pem" +#define SERVER_PRIVATE_KEY_FILE "./key.pem" + using namespace httplib; using namespace std; @@ -52,7 +55,7 @@ string log(const Request& req, const Response& res) snprintf(buf, sizeof(buf), "%d\n", res.status); s += buf; s += dump_headers(res.headers); - + return s; } @@ -63,7 +66,11 @@ int main(int argc, const char** argv) return 1; } +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + SSLServer svr(SERVER_CERT_FILE, SERVER_PRIVATE_KEY_FILE); +#else Server svr; +#endif svr.set_error_handler([](const auto& req, auto& res) { const char* fmt = "

Error Status: %d

"; diff --git a/httplib.h b/httplib.h index 4dc9976..4ec9a0f 100644 --- a/httplib.h +++ b/httplib.h @@ -1,12 +1,12 @@ // // httplib.h // -// Copyright (c) 2012 Yuji Hirose. All rights reserved. +// Copyright (c) 2017 Yuji Hirose. All rights reserved. // The Boost Software License 1.0 // -#ifndef _CPPHTTPLIB_HTTPSLIB_H_ -#define _CPPHTTPLIB_HTTPSLIB_H_ +#ifndef _CPPHTTPLIB_HTTPLIB_H_ +#define _CPPHTTPLIB_HTTPLIB_H_ #ifdef _MSC_VER #define _CRT_SECURE_NO_WARNINGS @@ -55,6 +55,10 @@ typedef int socket_t; #include #include +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +#include +#endif + namespace httplib { @@ -93,12 +97,34 @@ struct Response { Response() : status(-1) {} }; +class Stream { +public: + virtual ~Stream() {} + virtual int read(char* ptr, size_t size) = 0; + virtual int write(const char* ptr, size_t size1) = 0; + virtual int write(const char* ptr) = 0; +}; + +class SocketStream : public Stream { +public: + SocketStream(socket_t sock); + virtual ~SocketStream(); + + virtual int read(char* ptr, size_t size); + virtual int write(const char* ptr, size_t size); + virtual int write(const char* ptr); + +private: + socket_t sock_; +}; + class Server { public: typedef std::function Handler; typedef std::function Logger; Server(); + virtual ~Server(); void get(const char* pattern, Handler handler); void post(const char* pattern, Handler handler); @@ -111,15 +137,20 @@ public: bool listen(const char* host, int port); void stop(); +protected: + void process_request(Stream& strm); + private: typedef std::vector> Handlers; - void process_request(socket_t sock); - bool read_request_line(socket_t sock, Request& req); bool routing(Request& req, Response& res); bool handle_file_request(Request& req, Response& res); bool dispatch_request(Request& req, Response& res, Handlers& handlers); + bool read_request_line(Stream& strm, Request& req); + + virtual bool read_and_close_socket(socket_t sock); + socket_t svr_sock_; std::string base_dir_; Handlers get_handlers_; @@ -131,6 +162,7 @@ private: class Client { public: Client(const char* host, int port); + virtual ~Client(); std::shared_ptr get(const char* url); std::shared_ptr head(const char* url); @@ -139,14 +171,58 @@ public: bool send(const Request& req, Response& res); +protected: + bool process_request(Stream& strm, const Request& req, Response& res); + private: - bool read_response_line(socket_t sock, Response& res); + bool read_response_line(Stream& strm, Response& res); + + virtual bool read_and_close_socket(socket_t sock, const Request& req, Response& res); const std::string host_; const int port_; }; -// Implementation +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +class SSLSocketStream : public Stream { +public: + SSLSocketStream(SSL* ssl); + virtual ~SSLSocketStream(); + + virtual int read(char* ptr, size_t size); + virtual int write(const char* ptr, size_t size); + virtual int write(const char* ptr); + +private: + SSL* ssl_; +}; + +class SSLServer : public Server { +public: + SSLServer(const char* cert_path, const char* private_key_path); + virtual ~SSLServer(); + +private: + virtual bool read_and_close_socket(socket_t sock); + + SSL_CTX* ctx_; +}; + +class SSLClient : public Client { +public: + SSLClient(const char* host, int port); + virtual ~SSLClient(); + +private: + virtual bool read_and_close_socket(socket_t sock, const Request& req, Response& res); + + SSL_CTX* ctx_; +}; +#endif + +/* + * Implementation + */ namespace detail { template @@ -168,30 +244,14 @@ void split(const char* b, const char* e, char d, Fn fn) } } -inline int socket_read(socket_t sock, char* ptr, size_t size) -{ - return recv(sock, ptr, size, 0); -} - -inline int socket_write(socket_t sock, const char* ptr, size_t size) -{ - return send(sock, ptr, size, 0); -} - -inline int socket_write(socket_t sock, const char* ptr) -{ - size_t size = strlen(ptr); - return socket_write(sock, ptr, size); -} - -inline bool socket_gets(socket_t sock, char* buf, int bufsiz) +inline bool socket_gets(Stream& strm, char* buf, int bufsiz) { // TODO: buffering for better performance size_t i = 0; for (;;) { char byte; - auto n = socket_read(sock, &byte, 1); + auto n = strm.read(&byte, 1); if (n < 1) { if (i == 0) { @@ -213,7 +273,7 @@ inline bool socket_gets(socket_t sock, char* buf, int bufsiz) } template -inline void socket_printf(socket_t sock, const char* fmt, const Args& ...args) +inline void socket_printf(Stream& strm, const char* fmt, const Args& ...args) { char buf[BUFSIZ]; auto n = snprintf(buf, BUFSIZ, fmt, args...); @@ -221,7 +281,7 @@ inline void socket_printf(socket_t sock, const char* fmt, const Args& ...args) if (n >= BUFSIZ) { // TODO: buffer size is not large enough... } else { - socket_write(sock, buf, n); + strm.write(buf, n); } } } @@ -238,7 +298,8 @@ inline int close_socket(socket_t sock) template inline bool read_and_close_socket(socket_t sock, T callback) { - auto ret = callback(sock); + SocketStream strm(sock); + auto ret = callback(strm); close_socket(sock); return ret; } @@ -394,7 +455,7 @@ inline int get_header_value_int(const MultiMap& map, const char* key, int def) return def; } -inline bool read_headers(socket_t sock, MultiMap& headers) +inline bool read_headers(Stream& strm, MultiMap& headers) { static std::regex re("(.+?): (.+?)\r\n"); @@ -402,7 +463,7 @@ inline bool read_headers(socket_t sock, MultiMap& headers) char buf[BUFSIZ_HEADER]; for (;;) { - if (!socket_gets(sock, buf, BUFSIZ_HEADER)) { + if (!socket_gets(strm, buf, BUFSIZ_HEADER)) { return false; } if (!strcmp(buf, "\r\n")) { @@ -420,12 +481,12 @@ inline bool read_headers(socket_t sock, MultiMap& headers) } template -bool read_content(socket_t sock, T& x) +bool read_content(Stream& strm, T& x) { auto len = get_header_value_int(x.headers, "Content-Length", 0); if (len) { x.body.assign(len, 0); - if (!socket_read(sock, &x.body[0], x.body.size())) { + if (!strm.read(&x.body[0], x.body.size())) { return false; } } @@ -433,30 +494,30 @@ bool read_content(socket_t sock, T& x) } template -inline void write_headers(socket_t sock, const T& res) +inline void write_headers(Stream& strm, const T& res) { - socket_write(sock, "Connection: close\r\n"); + strm.write("Connection: close\r\n"); for (const auto& x: res.headers) { if (x.first != "Content-Type" && x.first != "Content-Length") { - socket_printf(sock, "%s: %s\r\n", x.first.c_str(), x.second.c_str()); + socket_printf(strm, "%s: %s\r\n", x.first.c_str(), x.second.c_str()); } } auto t = get_header_value(res.headers, "Content-Type", "text/plain"); - socket_printf(sock, "Content-Type: %s\r\n", t); - socket_printf(sock, "Content-Length: %ld\r\n", res.body.size()); - socket_write(sock, "\r\n"); + socket_printf(strm, "Content-Type: %s\r\n", t); + socket_printf(strm, "Content-Length: %ld\r\n", res.body.size()); + strm.write("\r\n"); } -inline void write_response(socket_t sock, const Request& req, const Response& res) +inline void write_response(Stream& strm, const Request& req, const Response& res) { - socket_printf(sock, "HTTP/1.0 %d %s\r\n", res.status, status_message(res.status)); + socket_printf(strm, "HTTP/1.0 %d %s\r\n", res.status, status_message(res.status)); - write_headers(sock, res); + write_headers(strm, res); if (!res.body.empty() && req.method != "HEAD") { - socket_write(sock, res.body.c_str(), res.body.size()); + strm.write(res.body.c_str(), res.body.size()); } } @@ -591,19 +652,19 @@ inline std::string decode_url(const std::string& s) return result; } -inline void write_request(socket_t sock, const Request& req) +inline void write_request(Stream& strm, const Request& req) { auto url = encode_url(req.url); - socket_printf(sock, "%s %s HTTP/1.0\r\n", req.method.c_str(), url.c_str()); + socket_printf(strm, "%s %s HTTP/1.0\r\n", req.method.c_str(), url.c_str()); - write_headers(sock, req); + write_headers(strm, req); if (!req.body.empty()) { if (req.has_header("application/x-www-form-urlencoded")) { auto str = encode_url(req.body); - socket_write(sock, str.c_str(), str.size()); + strm.write(str.c_str(), str.size()); } else { - socket_write(sock, req.body.c_str(), req.body.size()); + strm.write(req.body.c_str(), req.body.size()); } } } @@ -697,12 +758,40 @@ inline void Response::set_content(const std::string& s, const char* content_type set_header("Content-Type", content_type); } +// Socket stream implementation +inline SocketStream::SocketStream(socket_t sock): sock_(sock) +{ +} + +inline SocketStream::~SocketStream() +{ +} + +inline int SocketStream::read(char* ptr, size_t size) +{ + return recv(sock_, ptr, size, 0); +} + +inline int SocketStream::write(const char* ptr, size_t size) +{ + return send(sock_, ptr, size, 0); +} + +inline int SocketStream::write(const char* ptr) +{ + return write(ptr, strlen(ptr)); +} + // HTTP server implementation inline Server::Server() : svr_sock_(-1) { } +inline Server::~Server() +{ +} + inline void Server::get(const char* pattern, Handler handler) { get_handlers_.push_back(std::make_pair(std::regex(pattern), handler)); @@ -755,10 +844,7 @@ inline bool Server::listen(const char* host, int port) } // TODO: should be async - detail::read_and_close_socket(sock, [this](socket_t sock) { - process_request(sock); - return true; - }); + read_and_close_socket(sock); } return ret; @@ -771,11 +857,11 @@ inline void Server::stop() svr_sock_ = -1; } -inline bool Server::read_request_line(socket_t sock, Request& req) +inline bool Server::read_request_line(Stream& strm, Request& req) { const auto BUFSIZ_REQUESTLINE = 2048; char buf[BUFSIZ_REQUESTLINE]; - if (!detail::socket_gets(sock, buf, BUFSIZ_REQUESTLINE)) { + if (!detail::socket_gets(strm, buf, BUFSIZ_REQUESTLINE)) { return false; } @@ -809,7 +895,7 @@ inline bool Server::handle_file_request(Request& req, Response& res) if (detail::is_file(path)) { detail::read_file(path, res.body); - res.set_header("Content-Type", + res.set_header("Content-Type", detail::get_content_type_from_file_extention( detail::get_file_extention(path))); res.status = 200; @@ -848,18 +934,20 @@ inline bool Server::dispatch_request(Request& req, Response& res, Handlers& hand return false; } -inline void Server::process_request(socket_t sock) +inline void Server::process_request(Stream& strm) { Request req; Response res; - if (!read_request_line(sock, req) || - !detail::read_headers(sock, req.headers)) { + if (!read_request_line(strm, req) || + !detail::read_headers(strm, req.headers)) { + // TODO: return; } if (req.method == "POST") { - if (!detail::read_content(sock, req)) { + if (!detail::read_content(strm, req)) { + // TODO: return; } static std::string type = "application/x-www-form-urlencoded"; @@ -881,13 +969,21 @@ inline void Server::process_request(socket_t sock) error_handler_(req, res); } - detail::write_response(sock, req, res); + detail::write_response(strm, req, res); if (logger_) { logger_(req, res); } } +inline bool Server::read_and_close_socket(socket_t sock) +{ + return detail::read_and_close_socket(sock, [this](Stream& strm) { + process_request(strm); + return true; + }); +} + // HTTP client implementation inline Client::Client(const char* host, int port) : host_(host) @@ -895,11 +991,15 @@ inline Client::Client(const char* host, int port) { } -inline bool Client::read_response_line(socket_t sock, Response& res) +inline Client::~Client() +{ +} + +inline bool Client::read_response_line(Stream& strm, Response& res) { const auto BUFSIZ_RESPONSELINE = 2048; char buf[BUFSIZ_RESPONSELINE]; - if (!detail::socket_gets(sock, buf, BUFSIZ_RESPONSELINE)) { + if (!detail::socket_gets(strm, buf, BUFSIZ_RESPONSELINE)) { return false; } @@ -920,22 +1020,32 @@ inline bool Client::send(const Request& req, Response& res) return false; } - return detail::read_and_close_socket(sock, [&](socket_t sock) { - // Send request - detail::write_request(sock, req); + return read_and_close_socket(sock, req, res); +} - // Receive response - if (!read_response_line(sock, res) || - !detail::read_headers(sock, res.headers)) { +inline bool Client::process_request(Stream& strm, const Request& req, Response& res) +{ + // Send request + detail::write_request(strm, req); + + // Receive response + if (!read_response_line(strm, res) || + !detail::read_headers(strm, res.headers)) { + return false; + } + if (req.method != "HEAD") { + if (!detail::read_content(strm, res)) { return false; } - if (req.method != "HEAD") { - if (!detail::read_content(sock, res)) { - return false; - } - } + } - return true; + return true; +} + +inline bool Client::read_and_close_socket(socket_t sock, const Request& req, Response& res) +{ + return detail::read_and_close_socket(sock, [&](Stream& strm) { + return process_request(strm, req, res); }); } @@ -991,6 +1101,124 @@ inline std::shared_ptr Client::post( return post(url, query, "application/x-www-form-urlencoded"); } +/* + * SSL Implementation + */ +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +namespace detail { + +template +inline bool read_and_close_socket_ssl(socket_t sock, SSL_CTX* ctx, U SSL_connect_or_accept, T callback) +{ + auto ssl = SSL_new(ctx); + SSL_set_fd(ssl, sock); + SSL_connect_or_accept(ssl); + + SSLSocketStream strm(ssl); + auto ret = callback(strm); + + SSL_shutdown(ssl); + SSL_free(ssl); + close_socket(sock); + return ret; +} + +class SSLInit { +public: + SSLInit() { + SSL_load_error_strings(); + SSL_library_init(); + } +}; + +static SSLInit sslinit_; + +} // namespace detail + +// SSL socket stream implementation +inline SSLSocketStream::SSLSocketStream(SSL* ssl): ssl_(ssl) +{ +} + +inline SSLSocketStream::~SSLSocketStream() +{ +} + +inline int SSLSocketStream::read(char* ptr, size_t size) +{ + return SSL_read(ssl_, ptr, size); +} + +inline int SSLSocketStream::write(const char* ptr, size_t size) +{ + return SSL_write(ssl_, ptr, size); +} + +inline int SSLSocketStream::write(const char* ptr) +{ + return write(ptr, strlen(ptr)); +} + +// SSL HTTP server implementation +inline SSLServer::SSLServer(const char* cert_path, const char* private_key_path) +{ + ctx_ = SSL_CTX_new(SSLv23_server_method()); + + if (ctx_) { + SSL_CTX_set_options(ctx_, + SSL_OP_ALL | SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3 | + SSL_OP_NO_COMPRESSION | + SSL_OP_NO_SESSION_RESUMPTION_ON_RENEGOTIATION); + + // auto ecdh = EC_KEY_new_by_curve_name(NID_X9_62_prime256v1); + // SSL_CTX_set_tmp_ecdh(ctx_, ecdh); + // EC_KEY_free(ecdh); + + if (SSL_CTX_use_certificate_file(ctx_, cert_path, SSL_FILETYPE_PEM) != 1 || + SSL_CTX_use_PrivateKey_file(ctx_, private_key_path, SSL_FILETYPE_PEM) != 1) { + SSL_CTX_free(ctx_); + ctx_ = nullptr; + } + } +} + +inline SSLServer::~SSLServer() +{ + if (ctx_) { + SSL_CTX_free(ctx_); + } +} + +inline bool SSLServer::read_and_close_socket(socket_t sock) +{ + return detail::read_and_close_socket_ssl(sock, ctx_, SSL_accept, [this](Stream& strm) { + process_request(strm); + return true; + }); +} + +// SSL HTTP client implementation +inline SSLClient::SSLClient(const char* host, int port) + : Client(host, port) +{ + ctx_ = SSL_CTX_new(SSLv23_client_method()); +} + +inline SSLClient::~SSLClient() +{ + if (ctx_) { + SSL_CTX_free(ctx_); + } +} + +inline bool SSLClient::read_and_close_socket(socket_t sock, const Request& req, Response& res) +{ + return detail::read_and_close_socket_ssl(sock, ctx_, SSL_connect, [&](Stream& strm) { + return process_request(strm, req, res); + }); +} +#endif + } // namespace httplib #endif diff --git a/test/Makefile b/test/Makefile index 11ce67d..3672ccb 100644 --- a/test/Makefile +++ b/test/Makefile @@ -1,16 +1,17 @@ -USE_CLANG = 1 - -ifdef USE_CLANG CC = clang++ -CCFLAGS = -std=c++1y -stdlib=libc++ -g -DGTEST_USE_OWN_TR1_TUPLE -else -CC = g++-4.9 -CCFLAGS = -std=c++1y -g -endif +CFLAGS = -std=c++14 -DGTEST_USE_OWN_TR1_TUPLE -I.. -I. +#OPENSSL_SUPPORT = -DCPPHTTPLIB_OPENSSL_SUPPORT -I/usr/local/opt/openssl/include -L/usr/local/opt/openssl/lib -lssl -lcrypto all : test ./test test : test.cc ../httplib.h - $(CC) -o test $(CCFLAGS) -I.. -I. test.cc gtest/gtest-all.cc gtest/gtest_main.cc + $(CC) -o test $(CFLAGS) test.cc gtest/gtest-all.cc gtest/gtest_main.cc $(OPENSSL_SUPPORT) + +pem: + openssl genrsa 2048 > key.pem + openssl req -new -key key.pem | openssl x509 -days 3650 -req -signkey key.pem > cert.pem + +clean: + rm test *.pem diff --git a/test/test.cc b/test/test.cc index 3bf1fa2..fea9fbe 100644 --- a/test/test.cc +++ b/test/test.cc @@ -4,6 +4,9 @@ #include #include +#define SERVER_CERT_FILE "./cert.pem" +#define SERVER_PRIVATE_KEY_FILE "./key.pem" + #ifdef _WIN32 #include #define msleep(n) ::Sleep(n) @@ -99,8 +102,12 @@ TEST(GetHeaderValueTest, RegularValueInt) class ServerTest : public ::testing::Test { protected: - ServerTest() : cli_(HOST, PORT), up_(false) { - } + ServerTest() + : cli_(HOST, PORT) +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + , svr_(SERVER_CERT_FILE, SERVER_PRIVATE_KEY_FILE) +#endif + , up_(false) {} virtual void SetUp() { svr_.set_base_dir("./www"); @@ -155,8 +162,13 @@ protected: } map persons_; - Server svr_; +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + SSLClient cli_; + SSLServer svr_; +#else Client cli_; + Server svr_; +#endif future f_; bool up_; };