diff --git a/quic/api/QuicBatchWriter.cpp b/quic/api/QuicBatchWriter.cpp index 5a079c4cb..1d22cd4ce 100644 --- a/quic/api/QuicBatchWriter.cpp +++ b/quic/api/QuicBatchWriter.cpp @@ -53,6 +53,39 @@ ssize_t SinglePacketBatchWriter::write( return sock.write(address, buf_); } +// SinglePacketInplaceBatchWriter +void SinglePacketInplaceBatchWriter::reset() { + ScopedBufAccessor scopedBufAccessor(conn_.bufAccessor); + auto& buf = scopedBufAccessor.buf(); + buf->clear(); +} + +bool SinglePacketInplaceBatchWriter::append( + std::unique_ptr&& /* buf */, + size_t /*unused*/, + const folly::SocketAddress& /*unused*/, + QuicAsyncUDPSocketType* /*unused*/) { + // Always flush. This should trigger a write afterwards. + return true; +} + +ssize_t SinglePacketInplaceBatchWriter::write( + QuicAsyncUDPSocketType& sock, + const folly::SocketAddress& address) { + ScopedBufAccessor scopedBufAccessor(conn_.bufAccessor); + auto& buf = scopedBufAccessor.buf(); + CHECK(!buf->isChained()); + auto ret = sock.write(address, buf); + buf->clear(); + return ret; +} + +bool SinglePacketInplaceBatchWriter::empty() const { + ScopedBufAccessor scopedBufAccessor(conn_.bufAccessor); + auto& buf = scopedBufAccessor.buf(); + return buf->length() == 0; +} + // SendmmsgPacketBatchWriter SendmmsgPacketBatchWriter::SendmmsgPacketBatchWriter(size_t maxBufs) : maxBufs_(maxBufs) { diff --git a/quic/api/QuicBatchWriter.h b/quic/api/QuicBatchWriter.h index ad70ae6da..3ed04cc84 100644 --- a/quic/api/QuicBatchWriter.h +++ b/quic/api/QuicBatchWriter.h @@ -92,6 +92,32 @@ class SinglePacketBatchWriter : public IOBufBatchWriter { const folly::SocketAddress& address) override; }; +/** + * This writer allows for single buf inplace writes. + * The buffer is owned by the conn/accessor, and every append will trigger a + * flush/write. + */ +class SinglePacketInplaceBatchWriter : public IOBufBatchWriter { + public: + explicit SinglePacketInplaceBatchWriter(QuicConnectionStateBase& conn) + : conn_(conn) {} + ~SinglePacketInplaceBatchWriter() override = default; + + void reset() override; + bool append( + std::unique_ptr&& /* buf */, + size_t /*unused*/, + const folly::SocketAddress& /*unused*/, + QuicAsyncUDPSocketType* /*unused*/) override; + ssize_t write( + QuicAsyncUDPSocketType& sock, + const folly::SocketAddress& address) override; + [[nodiscard]] bool empty() const override; + + private: + QuicConnectionStateBase& conn_; +}; + class SendmmsgPacketBatchWriter : public BatchWriter { public: explicit SendmmsgPacketBatchWriter(size_t maxBufs); diff --git a/quic/api/QuicBatchWriterFactory.cpp b/quic/api/QuicBatchWriterFactory.cpp index dc86556a3..165ae9453 100644 --- a/quic/api/QuicBatchWriterFactory.cpp +++ b/quic/api/QuicBatchWriterFactory.cpp @@ -135,6 +135,13 @@ class ThreadLocalBatchWriterCache : public folly::AsyncTimeout { namespace quic { +bool useSinglePacketInplaceBatchWriter( + uint32_t maxBatchSize, + quic::DataPathType dataPathType) { + return maxBatchSize == 1 && + dataPathType == quic::DataPathType::ContinuousMemory; +} + // BatchWriterDeleter void BatchWriterDeleter::operator()(BatchWriter* batchWriter) { #if USE_THREAD_LOCAL_BATCH_WRITER diff --git a/quic/api/QuicBatchWriterFactory.h b/quic/api/QuicBatchWriterFactory.h index 347078d24..6feef7cb9 100644 --- a/quic/api/QuicBatchWriterFactory.h +++ b/quic/api/QuicBatchWriterFactory.h @@ -12,6 +12,10 @@ namespace quic { +bool useSinglePacketInplaceBatchWriter( + uint32_t maxBatchSize, + quic::DataPathType dataPathType); + BatchWriterPtr makeGsoBatchWriter(uint32_t batchSize); BatchWriterPtr makeGsoInPlaceBatchWriter( uint32_t batchSize, @@ -38,6 +42,9 @@ class BatchWriterFactory { bool gsoSupported) { switch (batchingMode) { case quic::QuicBatchingMode::BATCHING_MODE_NONE: + if (useSinglePacketInplaceBatchWriter(batchSize, dataPathType)) { + return BatchWriterPtr(new SinglePacketInplaceBatchWriter(conn)); + } return BatchWriterPtr(new SinglePacketBatchWriter()); case quic::QuicBatchingMode::BATCHING_MODE_GSO: { if (gsoSupported) { diff --git a/quic/api/test/QuicBatchWriterTest.cpp b/quic/api/test/QuicBatchWriterTest.cpp index 20ae680c6..988eb16ad 100644 --- a/quic/api/test/QuicBatchWriterTest.cpp +++ b/quic/api/test/QuicBatchWriterTest.cpp @@ -624,5 +624,155 @@ INSTANTIATE_TEST_SUITE_P( QuicBatchWriterTest, ::testing::Values(false, true)); +class SinglePacketInplaceBatchWriterTest : public ::testing::Test { + public: + SinglePacketInplaceBatchWriterTest() + : conn_(FizzServerQuicHandshakeContext::Builder().build()) {} + + void SetUp() override { + bufAccessor_ = + std::make_unique(conn_.udpSendPacketLen); + conn_.bufAccessor = bufAccessor_.get(); + } + + quic::BatchWriterPtr makeBatchWriter( + quic::QuicBatchingMode batchingMode = + quic::QuicBatchingMode::BATCHING_MODE_NONE) { + return quic::BatchWriterFactory::makeBatchWriter( + batchingMode, + conn_.transportSettings.maxBatchSize, + false /* useThreadLocal */, + quic::kDefaultThreadLocalDelay, + conn_.transportSettings.dataPathType, + conn_, + false /* gsoSupported_ */); + } + + void enableSinglePacketInplaceBatchWriter() { + conn_.transportSettings.maxBatchSize = 1; + conn_.transportSettings.dataPathType = DataPathType::ContinuousMemory; + } + + protected: + std::unique_ptr bufAccessor_; + QuicServerConnectionState conn_; +}; + +TEST_F(SinglePacketInplaceBatchWriterTest, TestFactorySuccess) { + enableSinglePacketInplaceBatchWriter(); + + auto batchWriter = makeBatchWriter(); + CHECK(batchWriter); + CHECK(dynamic_cast(batchWriter.get())); +} + +TEST_F(SinglePacketInplaceBatchWriterTest, TestFactoryNoTransportSetting) { + conn_.transportSettings.maxBatchSize = 1; + conn_.transportSettings.dataPathType = DataPathType::ChainedMemory; + auto batchWriter = makeBatchWriter(); + CHECK(batchWriter); + EXPECT_EQ( + dynamic_cast(batchWriter.get()), + nullptr); +} + +TEST_F(SinglePacketInplaceBatchWriterTest, TestFactoryNoTransportSetting2) { + conn_.transportSettings.maxBatchSize = 16; + conn_.transportSettings.dataPathType = DataPathType::ContinuousMemory; + auto batchWriter = makeBatchWriter(); + CHECK(batchWriter); + EXPECT_EQ( + dynamic_cast(batchWriter.get()), + nullptr); +} + +TEST_F(SinglePacketInplaceBatchWriterTest, TestFactoryWrongBatchingMode) { + enableSinglePacketInplaceBatchWriter(); + + auto batchWriter = makeBatchWriter(quic::QuicBatchingMode::BATCHING_MODE_GSO); + CHECK(batchWriter); + EXPECT_EQ( + dynamic_cast(batchWriter.get()), + nullptr); +} + +TEST_F(SinglePacketInplaceBatchWriterTest, TestReset) { + enableSinglePacketInplaceBatchWriter(); + + auto batchWriter = makeBatchWriter(); + CHECK(batchWriter); + CHECK(dynamic_cast(batchWriter.get())); + + auto buf = bufAccessor_->obtain(); + folly::IOBuf* rawBuf = buf.get(); + bufAccessor_->release(std::move(buf)); + rawBuf->append(700); + + EXPECT_EQ(rawBuf->computeChainDataLength(), 700); + batchWriter->reset(); + EXPECT_EQ(rawBuf->computeChainDataLength(), 0); +} + +TEST_F(SinglePacketInplaceBatchWriterTest, TestAppend) { + enableSinglePacketInplaceBatchWriter(); + + auto batchWriter = makeBatchWriter(); + CHECK(batchWriter); + CHECK(dynamic_cast(batchWriter.get())); + + EXPECT_EQ( + true, batchWriter->append(nullptr, 0, folly::SocketAddress(), nullptr)); +} + +TEST_F(SinglePacketInplaceBatchWriterTest, TestEmpty) { + enableSinglePacketInplaceBatchWriter(); + + auto batchWriter = makeBatchWriter(); + CHECK(batchWriter); + CHECK(dynamic_cast(batchWriter.get())); + EXPECT_TRUE(batchWriter->empty()); + + auto buf = bufAccessor_->obtain(); + folly::IOBuf* rawBuf = buf.get(); + bufAccessor_->release(std::move(buf)); + rawBuf->append(700); + + EXPECT_EQ(rawBuf->computeChainDataLength(), 700); + EXPECT_FALSE(batchWriter->empty()); + + batchWriter->reset(); + EXPECT_TRUE(batchWriter->empty()); +} + +TEST_F(SinglePacketInplaceBatchWriterTest, TestWrite) { + enableSinglePacketInplaceBatchWriter(); + + auto batchWriter = makeBatchWriter(); + CHECK(batchWriter); + CHECK(dynamic_cast(batchWriter.get())); + EXPECT_TRUE(batchWriter->empty()); + + auto buf = bufAccessor_->obtain(); + folly::IOBuf* rawBuf = buf.get(); + bufAccessor_->release(std::move(buf)); + const auto appendSize = conn_.udpSendPacketLen - 200; + rawBuf->append(appendSize); + + EXPECT_EQ(rawBuf->computeChainDataLength(), appendSize); + EXPECT_FALSE(batchWriter->empty()); + + folly::EventBase evb; + folly::test::MockAsyncUDPSocket sock(&evb); + EXPECT_CALL(sock, write(_, _)) + .Times(1) + .WillOnce(Invoke([&](const auto& /* addr */, + const std::unique_ptr& buf) { + EXPECT_EQ(appendSize, buf->length()); + return appendSize; + })); + EXPECT_EQ(appendSize, batchWriter->write(sock, folly::SocketAddress())); + EXPECT_TRUE(batchWriter->empty()); +} + } // namespace testing } // namespace quic