diff --git a/httplib.h b/httplib.h index 2df93c7..bf2f64c 100644 --- a/httplib.h +++ b/httplib.h @@ -299,7 +299,7 @@ private: using ContentProvider = std::function; -using ChunkedContentProvider = +using ContentProviderWithoutLength = std::function; using ContentReceiver = @@ -404,8 +404,12 @@ struct Response { size_t length, const char *content_type, ContentProvider provider, std::function resource_releaser = [] {}); + void set_content_provider( + const char *content_type, ContentProviderWithoutLength provider, + std::function resource_releaser = [] {}); + void set_chunked_content_provider( - const char *content_type, ChunkedContentProvider provider, + const char *content_type, ContentProviderWithoutLength provider, std::function resource_releaser = [] {}); Response() = default; @@ -423,6 +427,7 @@ struct Response { size_t content_length_ = 0; ContentProvider content_provider_; std::function content_provider_resource_releaser_; + bool is_chunked_content_provider = false; }; class Stream { @@ -2664,19 +2669,19 @@ inline ssize_t write_content(Stream &strm, ContentProvider content_provider, size_t offset, size_t length, T is_shutting_down) { size_t begin_offset = offset; size_t end_offset = offset + length; - auto ok = true; - DataSink data_sink; + data_sink.write = [&](const char *d, size_t l) { if (ok) { offset += l; if (!write_data(strm, d, l)) { ok = false; } } }; + data_sink.is_writable = [&](void) { return ok && strm.is_writable(); }; - while (ok && offset < end_offset && !is_shutting_down()) { + while (offset < end_offset && !is_shutting_down()) { if (!content_provider(offset, end_offset - offset, data_sink)) { return -1; } @@ -2686,6 +2691,34 @@ inline ssize_t write_content(Stream &strm, ContentProvider content_provider, return static_cast(offset - begin_offset); } +template +inline ssize_t write_content_without_length(Stream &strm, + ContentProvider content_provider, + T is_shutting_down) { + size_t offset = 0; + auto data_available = true; + auto ok = true; + DataSink data_sink; + + data_sink.write = [&](const char *d, size_t l) { + if (ok) { + offset += l; + if (!write_data(strm, d, l)) { ok = false; } + } + }; + + data_sink.done = [&](void) { data_available = false; }; + + data_sink.is_writable = [&](void) { return ok && strm.is_writable(); }; + + while (data_available && !is_shutting_down()) { + if (!content_provider(offset, 0, data_sink)) { return -1; } + if (!ok) { return -1; } + } + + return static_cast(offset); +} + template inline ssize_t write_content_chunked(Stream &strm, ContentProvider content_provider, @@ -2693,7 +2726,6 @@ inline ssize_t write_content_chunked(Stream &strm, size_t offset = 0; auto data_available = true; ssize_t total_written_length = 0; - auto ok = true; DataSink data_sink; @@ -3544,10 +3576,11 @@ Response::set_content_provider(size_t in_length, const char *content_type, return provider(offset, length, sink); }; content_provider_resource_releaser_ = resource_releaser; + is_chunked_content_provider = false; } -inline void Response::set_chunked_content_provider( - const char *content_type, ChunkedContentProvider provider, +inline void Response::set_content_provider( + const char *content_type, ContentProviderWithoutLength provider, std::function resource_releaser) { set_header("Content-Type", content_type); content_length_ = 0; @@ -3555,6 +3588,19 @@ inline void Response::set_chunked_content_provider( return provider(offset, sink); }; content_provider_resource_releaser_ = resource_releaser; + is_chunked_content_provider = false; +} + +inline void Response::set_chunked_content_provider( + const char *content_type, ContentProviderWithoutLength provider, + std::function resource_releaser) { + set_header("Content-Type", content_type); + content_length_ = 0; + content_provider_ = [provider](size_t offset, size_t, DataSink &sink) { + return provider(offset, sink); + }; + content_provider_resource_releaser_ = resource_releaser; + is_chunked_content_provider = true; } // Rstream implementation @@ -3893,7 +3939,7 @@ inline bool Server::write_response(Stream &strm, bool close_connection, } if (!res.has_header("Content-Type") && - (!res.body.empty() || res.content_length_ > 0)) { + (!res.body.empty() || res.content_length_ > 0 || res.content_provider_)) { res.set_header("Content-Type", "text/plain"); } @@ -3939,11 +3985,13 @@ inline bool Server::write_response(Stream &strm, bool close_connection, res.set_header("Content-Length", std::to_string(length)); } else { if (res.content_provider_) { - res.set_header("Transfer-Encoding", "chunked"); - if (type == detail::EncodingType::Gzip) { - res.set_header("Content-Encoding", "gzip"); - } else if (type == detail::EncodingType::Brotli) { - res.set_header("Content-Encoding", "br"); + if (res.is_chunked_content_provider) { + res.set_header("Transfer-Encoding", "chunked"); + if (type == detail::EncodingType::Gzip) { + res.set_header("Content-Encoding", "gzip"); + } else if (type == detail::EncodingType::Brotli) { + res.set_header("Content-Encoding", "br"); + } } } else { res.set_header("Content-Length", "0"); @@ -4033,7 +4081,7 @@ Server::write_content_with_provider(Stream &strm, const Request &req, return this->svr_sock_ == INVALID_SOCKET; }; - if (res.content_length_) { + if (res.content_length_ > 0) { if (req.ranges.empty()) { if (detail::write_content(strm, res.content_provider_, 0, res.content_length_, is_shutting_down) < 0) { @@ -4055,25 +4103,32 @@ Server::write_content_with_provider(Stream &strm, const Request &req, } } } else { - auto type = detail::encoding_type(req, res); + if (res.is_chunked_content_provider) { + auto type = detail::encoding_type(req, res); - std::shared_ptr compressor; - if (type == detail::EncodingType::Gzip) { + std::shared_ptr compressor; + if (type == detail::EncodingType::Gzip) { #ifdef CPPHTTPLIB_ZLIB_SUPPORT - compressor = std::make_shared(); + compressor = std::make_shared(); #endif - } else if (type == detail::EncodingType::Brotli) { + } else if (type == detail::EncodingType::Brotli) { #ifdef CPPHTTPLIB_BROTLI_SUPPORT - compressor = std::make_shared(); + compressor = std::make_shared(); #endif - } else { - compressor = std::make_shared(); - } - assert(compressor != nullptr); + } else { + compressor = std::make_shared(); + } + assert(compressor != nullptr); - if (detail::write_content_chunked(strm, res.content_provider_, - is_shutting_down, *compressor) < 0) { - return false; + if (detail::write_content_chunked(strm, res.content_provider_, + is_shutting_down, *compressor) < 0) { + return false; + } + } else { + if (detail::write_content_without_length(strm, res.content_provider_, + is_shutting_down) < 0) { + return false; + } } } return true; diff --git a/test/test.cc b/test/test.cc index b603acd..c04bb0c 100644 --- a/test/test.cc +++ b/test/test.cc @@ -1895,8 +1895,7 @@ TEST_F(ServerTest, ClientStop) { auto res = cli_.Get("/streamed-cancel", [&](const char *, uint64_t) { return true; }); ASSERT_TRUE(!res); - EXPECT_TRUE(res.error() == Error::Canceled || - res.error() == Error::Read); + EXPECT_TRUE(res.error() == Error::Canceled || res.error() == Error::Read); })); } @@ -2730,6 +2729,46 @@ TEST(ServerStopTest, StopServerWithChunkedTransmission) { ASSERT_FALSE(svr.is_running()); } +TEST(StreamingTest, NoContentLengthStreaming) { + Server svr; + + svr.Get("/stream", [](const Request & /*req*/, Response &res) { + res.set_content_provider( + "text/plain", [](size_t offset, DataSink &sink) { + if (offset < 6) { + sink.os << (offset < 3 ? "a" : "b"); + } else { + sink.done(); + } + return true; + }); + }); + + auto listen_thread = std::thread([&svr]() { svr.listen("localhost", PORT); }); + while (!svr.is_running()) { + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + } + + Client client(HOST, PORT); + + auto get_thread = std::thread([&client]() { + auto res = client.Get("/stream", [](const char *data, size_t len) -> bool { + EXPECT_EQ("aaabbb", std::string(data, len)); + return true; + }); + }); + + // Give GET time to get a few messages. + std::this_thread::sleep_for(std::chrono::milliseconds(500)); + + svr.stop(); + + listen_thread.join(); + get_thread.join(); + + ASSERT_FALSE(svr.is_running()); +} + TEST(MountTest, Unmount) { Server svr;