From c6d19083df89e690c04e8588b7f53f070b3f2ca6 Mon Sep 17 00:00:00 2001 From: Aman Sharma Date: Mon, 10 Feb 2025 16:28:00 -0800 Subject: [PATCH] Make a SendmmsgInplacePacketBatchWriter Summary: See title. Reviewed By: mjoras Differential Revision: D69073543 fbshipit-source-id: 11f8a376909df9ae2bc7560cac09a1d2f431acaa --- quic/api/QuicBatchWriter.cpp | 81 +++++++++++++++++++++++++++ quic/api/QuicBatchWriter.h | 33 +++++++++++ quic/api/QuicBatchWriterFactory.h | 8 ++- quic/api/test/QuicBatchWriterTest.cpp | 77 +++++++++++++++++++++++++ 4 files changed, 198 insertions(+), 1 deletion(-) diff --git a/quic/api/QuicBatchWriter.cpp b/quic/api/QuicBatchWriter.cpp index 4daa0876b..def0042c6 100644 --- a/quic/api/QuicBatchWriter.cpp +++ b/quic/api/QuicBatchWriter.cpp @@ -222,4 +222,85 @@ bool useSinglePacketInplaceBatchWriter( dataPathType == quic::DataPathType::ContinuousMemory; } +SendmmsgInplacePacketBatchWriter::SendmmsgInplacePacketBatchWriter( + QuicConnectionStateBase& conn, + size_t maxBufs) + : conn_(conn), maxBufs_(maxBufs) { + CHECK_LT(maxBufs, kMaxIovecs) << "maxBufs must be less than " << kMaxIovecs; +} + +bool SendmmsgInplacePacketBatchWriter::empty() const { + return currSize_ == 0; +} + +size_t SendmmsgInplacePacketBatchWriter::size() const { + return currSize_; +} + +void SendmmsgInplacePacketBatchWriter::reset() { + currSize_ = 0; + numPacketsBuffered_ = 0; +} + +bool SendmmsgInplacePacketBatchWriter::append( + std::unique_ptr&& /* buf */, + size_t size, + const folly::SocketAddress& /*unused*/, + QuicAsyncUDPSocket* /*unused*/) { + CHECK_LT(numPacketsBuffered_, maxBufs_); + + auto& buf = conn_.bufAccessor->buf(); + CHECK(!buf->isChained() && buf->length() >= size); + iovecs_[numPacketsBuffered_].iov_base = (void*)(buf->tail() - size); + iovecs_[numPacketsBuffered_].iov_len = size; + + ++numPacketsBuffered_; + currSize_ += size; + + // reached max buffers + if (FOLLY_UNLIKELY(numPacketsBuffered_ == maxBufs_)) { + return true; + } + + // does not need to be flushed yet + return false; +} + +ssize_t SendmmsgInplacePacketBatchWriter::write( + QuicAsyncUDPSocket& sock, + const folly::SocketAddress& address) { + CHECK_GT(numPacketsBuffered_, 0); + + auto& buf = conn_.bufAccessor->buf(); + buf->clear(); + + if (numPacketsBuffered_ == 1) { + return sock.write(address, &iovecs_[0], 1); + } + + int ret = 0; + std::array messageSizes{}; + + for (size_t i = 0; i < numPacketsBuffered_; i++) { + messageSizes[i] = iovecs_[i].iov_len; + } + + sock.writem( + folly::range(&address, &address + 1), + &iovecs_[0], + &messageSizes[0], + numPacketsBuffered_); + if (ret <= 0) { + return ret; + } + + if (static_cast(ret) == numPacketsBuffered_) { + return currSize_; + } + + // this is a partial write - we just need to + // return a different number than currSize_ + return 0; +} + } // namespace quic diff --git a/quic/api/QuicBatchWriter.h b/quic/api/QuicBatchWriter.h index cfc0f57e4..3ef06a47b 100644 --- a/quic/api/QuicBatchWriter.h +++ b/quic/api/QuicBatchWriter.h @@ -158,6 +158,39 @@ class SendmmsgPacketBatchWriter : public BatchWriter { std::vector> bufs_; }; +class SendmmsgInplacePacketBatchWriter : public BatchWriter { + public: + explicit SendmmsgInplacePacketBatchWriter( + QuicConnectionStateBase& conn, + size_t maxBufs); + ~SendmmsgInplacePacketBatchWriter() override = default; + + [[nodiscard]] bool empty() const override; + + [[nodiscard]] size_t size() const override; + + void reset() override; + bool append( + std::unique_ptr&& /* buf */, + size_t size, + const folly::SocketAddress& /*unused*/, + QuicAsyncUDPSocket* /*unused*/) override; + ssize_t write(QuicAsyncUDPSocket& sock, const folly::SocketAddress& address) + override; + + private: + static const size_t kMaxIovecs = 64; + + QuicConnectionStateBase& conn_; + // Max number of packets we can accumulate before we need to flush + size_t maxBufs_{1}; + // size of data in all the buffers + size_t currSize_{0}; + // Number of packets that have been written to iovec_ + size_t numPacketsBuffered_{0}; + std::array iovecs_{}; +}; + struct BatchWriterDeleter { void operator()(BatchWriter* batchWriter); }; diff --git a/quic/api/QuicBatchWriterFactory.h b/quic/api/QuicBatchWriterFactory.h index f7c3e4579..99ca2974d 100644 --- a/quic/api/QuicBatchWriterFactory.h +++ b/quic/api/QuicBatchWriterFactory.h @@ -56,7 +56,13 @@ class BatchWriterFactory { } [[fallthrough]]; case quic::QuicBatchingMode::BATCHING_MODE_SENDMMSG: - return BatchWriterPtr(new SendmmsgPacketBatchWriter(batchSize)); + switch (dataPathType) { + case DataPathType::ChainedMemory: + return BatchWriterPtr(new SendmmsgPacketBatchWriter(batchSize)); + case DataPathType::ContinuousMemory: + return BatchWriterPtr( + new SendmmsgInplacePacketBatchWriter(conn, batchSize)); + } case quic::QuicBatchingMode::BATCHING_MODE_SENDMMSG_GSO: { if (gsoSupported) { return makeSendmmsgGsoBatchWriter(batchSize); diff --git a/quic/api/test/QuicBatchWriterTest.cpp b/quic/api/test/QuicBatchWriterTest.cpp index 8d2dfc195..9f019188f 100644 --- a/quic/api/test/QuicBatchWriterTest.cpp +++ b/quic/api/test/QuicBatchWriterTest.cpp @@ -390,6 +390,83 @@ TEST_F(QuicBatchWriterTest, TestBatchingSendmmsgNewlyAllocatedIovecMatches) { batchWriter->write(sock, folly::SocketAddress()); } +TEST_F(QuicBatchWriterTest, TestBatchingSendmmsgInplace) { + auto bufAccessor = + std::make_unique(conn_.udpSendPacketLen * kBatchNum); + conn_.bufAccessor = bufAccessor.get(); + + folly::EventBase evb; + std::shared_ptr qEvb = + std::make_shared(&evb); + quic::test::MockAsyncUDPSocket sock(qEvb); + + auto batchWriter = quic::BatchWriterFactory::makeBatchWriter( + quic::QuicBatchingMode::BATCHING_MODE_SENDMMSG, + kBatchNum, + false, /* enable backpressure */ + DataPathType::ContinuousMemory, + conn_, + gsoSupported_); + CHECK(batchWriter); + + // run multiple loops + for (size_t i = 0; i < kNumLoops; i++) { + std::vector expectedIovecs; + + // try to batch up to kBatchNum + CHECK(batchWriter->empty()); + CHECK_EQ(batchWriter->size(), 0); + size_t size = 0; + for (auto j = 0; j < kBatchNum - 1; j++) { + iovec vec{}; + vec.iov_base = (void*)bufAccessor->buf()->tail(); + vec.iov_len = kStrLen; + bufAccessor->buf()->append(kStrLen); + expectedIovecs.push_back(vec); + + EXPECT_FALSE(batchWriter->append( + nullptr, kStrLen, folly::SocketAddress(), nullptr)); + size += kStrLen; + CHECK_EQ(batchWriter->size(), size); + } + + // add the kBatchNum buf + iovec vec{}; + vec.iov_base = (void*)bufAccessor->buf()->tail(); + vec.iov_len = kStrLen; + bufAccessor->buf()->append(kStrLen); + expectedIovecs.push_back(vec); + + CHECK( + batchWriter->append(nullptr, kStrLen, folly::SocketAddress(), nullptr)); + size += kStrLen; + CHECK_EQ(batchWriter->size(), size); + + EXPECT_CALL(sock, writem(_, _, _, _)) + .Times(1) + .WillOnce(Invoke([&](folly::Range addrs, + iovec* iovecs, + size_t* messageSizes, + size_t count) { + EXPECT_EQ(addrs.size(), 1); + EXPECT_EQ(count, kBatchNum); + + for (size_t k = 0; k < count; k++) { + EXPECT_EQ(messageSizes[k], expectedIovecs[k].iov_len); + EXPECT_EQ(expectedIovecs[k].iov_base, iovecs[k].iov_base); + EXPECT_EQ(expectedIovecs[k].iov_len, iovecs[k].iov_len); + } + + return 0; + })); + batchWriter->write(sock, folly::SocketAddress()); + expectedIovecs.clear(); + EXPECT_TRUE(bufAccessor->buf()->empty()); + + batchWriter->reset(); + } +} + TEST_F(QuicBatchWriterTest, TestBatchingSendmmsgGSOBatchNum) { folly::EventBase evb; std::shared_ptr qEvb =