From 307b729549a5243fde63b46e592d04793f1ec73f Mon Sep 17 00:00:00 2001 From: Yoshiki Matsuda <59041398+yosh-matsuda@users.noreply.github.com> Date: Thu, 28 Apr 2022 10:08:39 +0900 Subject: [PATCH] Accept large data transfer over SSL (#1261) * Add large data transfer test * Replace `SSL_read` and `SSL_write` with `ex` functions * Reflect review comment * Fix return value of `SSLSocketStream::read/write` * Fix return value in the case of `SSL_ERROR_ZERO_RETURN` * Disable `LargeDataTransfer` test due to OoM in CI --- httplib.h | 93 ++++++++++++++++++++++++++-------------------------- test/test.cc | 44 +++++++++++++++++++++++++ 2 files changed, 91 insertions(+), 46 deletions(-) diff --git a/httplib.h b/httplib.h index c7ecf74..ce23b15 100644 --- a/httplib.h +++ b/httplib.h @@ -7221,62 +7221,63 @@ inline bool SSLSocketStream::is_writable() const { } inline ssize_t SSLSocketStream::read(char *ptr, size_t size) { + size_t readbytes = 0; if (SSL_pending(ssl_) > 0) { - return SSL_read(ssl_, ptr, static_cast(size)); - } else if (is_readable()) { - auto ret = SSL_read(ssl_, ptr, static_cast(size)); - if (ret < 0) { - auto err = SSL_get_error(ssl_, ret); - int n = 1000; -#ifdef _WIN32 - while (--n >= 0 && (err == SSL_ERROR_WANT_READ || - (err == SSL_ERROR_SYSCALL && - WSAGetLastError() == WSAETIMEDOUT))) { -#else - while (--n >= 0 && err == SSL_ERROR_WANT_READ) { -#endif - if (SSL_pending(ssl_) > 0) { - return SSL_read(ssl_, ptr, static_cast(size)); - } else if (is_readable()) { - std::this_thread::sleep_for(std::chrono::milliseconds(1)); - ret = SSL_read(ssl_, ptr, static_cast(size)); - if (ret >= 0) { return ret; } - err = SSL_get_error(ssl_, ret); - } else { - return -1; - } - } - } - return ret; + auto ret = SSL_read_ex(ssl_, ptr, size, &readbytes); + if (ret == 1) { return static_cast(readbytes); } + if (SSL_get_error(ssl_, ret) == SSL_ERROR_ZERO_RETURN) { return 0; } + return -1; } + if (!is_readable()) { return -1; } + + auto ret = SSL_read_ex(ssl_, ptr, size, &readbytes); + if (ret == 1) { return static_cast(readbytes); } + auto err = SSL_get_error(ssl_, ret); + int n = 1000; +#ifdef _WIN32 + while (--n >= 0 && + (err == SSL_ERROR_WANT_READ || + (err == SSL_ERROR_SYSCALL && WSAGetLastError() == WSAETIMEDOUT))) { +#else + while (--n >= 0 && err == SSL_ERROR_WANT_READ) { +#endif + if (SSL_pending(ssl_) > 0) { + ret = SSL_read_ex(ssl_, ptr, size, &readbytes); + if (ret == 1) { return static_cast(readbytes); } + if (SSL_get_error(ssl_, ret) == SSL_ERROR_ZERO_RETURN) { return 0; } + return -1; + } + if (!is_readable()) { return -1; } + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + ret = SSL_read_ex(ssl_, ptr, size, &readbytes); + if (ret == 1) { return static_cast(readbytes); } + err = SSL_get_error(ssl_, ret); + } + if (err == SSL_ERROR_ZERO_RETURN) { return 0; } return -1; } inline ssize_t SSLSocketStream::write(const char *ptr, size_t size) { - if (is_writable()) { - auto ret = SSL_write(ssl_, ptr, static_cast(size)); - if (ret < 0) { - auto err = SSL_get_error(ssl_, ret); - int n = 1000; + if (!is_writable()) { return -1; } + size_t written = 0; + auto ret = SSL_write_ex(ssl_, ptr, size, &written); + if (ret == 1) { return static_cast(written); } + auto err = SSL_get_error(ssl_, ret); + int n = 1000; #ifdef _WIN32 - while (--n >= 0 && (err == SSL_ERROR_WANT_WRITE || - (err == SSL_ERROR_SYSCALL && - WSAGetLastError() == WSAETIMEDOUT))) { + while (--n >= 0 && + (err == SSL_ERROR_WANT_WRITE || + (err == SSL_ERROR_SYSCALL && WSAGetLastError() == WSAETIMEDOUT))) { #else - while (--n >= 0 && err == SSL_ERROR_WANT_WRITE) { + while (--n >= 0 && err == SSL_ERROR_WANT_WRITE) { #endif - if (is_writable()) { - std::this_thread::sleep_for(std::chrono::milliseconds(1)); - ret = SSL_write(ssl_, ptr, static_cast(size)); - if (ret >= 0) { return ret; } - err = SSL_get_error(ssl_, ret); - } else { - return -1; - } - } - } - return ret; + if (!is_writable()) { return -1; } + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + ret = SSL_write_ex(ssl_, ptr, size, &written); + if (ret == 1) { return static_cast(written); } + err = SSL_get_error(ssl_, ret); } + if (err == SSL_ERROR_ZERO_RETURN) { return 0; } return -1; } diff --git a/test/test.cc b/test/test.cc index b2ac051..d7257ee 100644 --- a/test/test.cc +++ b/test/test.cc @@ -4660,6 +4660,50 @@ TEST(SSLClientServerTest, CustomizeServerSSLCtx) { t.join(); } + +// Disabled due to the out-of-memory problem on GitHub Actions Workflows +TEST(SSLClientServerTest, DISABLED_LargeDataTransfer) { + + // prepare large data + std::random_device seed_gen; + std::mt19937 random(seed_gen()); + constexpr auto large_size_byte = 2147483648UL + 1048576UL; // 2GiB + 1MiB + std::vector binary(large_size_byte / sizeof(std::uint32_t)); + std::generate(binary.begin(), binary.end(), [&random]() { return random(); }); + + // server + SSLServer svr(SERVER_CERT_FILE, SERVER_PRIVATE_KEY_FILE); + ASSERT_TRUE(svr.is_valid()); + + svr.Post("/binary", [&](const Request &req, Response &res) { + EXPECT_EQ(large_size_byte, req.body.size()); + EXPECT_EQ(0, std::memcmp(binary.data(), req.body.data(), large_size_byte)); + res.set_content(req.body, "application/octet-stream"); + }); + + auto listen_thread = std::thread([&svr]() { svr.listen("localhost", PORT); }); + while (!svr.is_running()) { + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + } + + // client POST + SSLClient cli("localhost", PORT); + cli.enable_server_certificate_verification(false); + cli.set_read_timeout(std::chrono::seconds(100)); + cli.set_write_timeout(std::chrono::seconds(100)); + auto res = cli.Post("/binary", reinterpret_cast(binary.data()), + large_size_byte, "application/octet-stream"); + + // compare + EXPECT_EQ(200, res->status); + EXPECT_EQ(large_size_byte, res->body.size()); + EXPECT_EQ(0, std::memcmp(binary.data(), res->body.data(), large_size_byte)); + + // cleanup + svr.stop(); + listen_thread.join(); + ASSERT_FALSE(svr.is_running()); +} #endif #ifdef _WIN32