diff --git a/quic/api/QuicBatchWriter.cpp b/quic/api/QuicBatchWriter.cpp index 189434ac4..abb2c4732 100644 --- a/quic/api/QuicBatchWriter.cpp +++ b/quic/api/QuicBatchWriter.cpp @@ -227,10 +227,15 @@ void GSOInplacePacketBatchWriter::reset() { lastPacketEnd_ = nullptr; prevSize_ = 0; numPackets_ = 0; + nextPacketSize_ = 0; } bool GSOInplacePacketBatchWriter::needsFlush(size_t size) { - return prevSize_ && size > prevSize_; + auto shouldFlush = prevSize_ && size > prevSize_; + if (shouldFlush) { + nextPacketSize_ = size; + } + return shouldFlush; } bool GSOInplacePacketBatchWriter::append( @@ -274,11 +279,13 @@ ssize_t GSOInplacePacketBatchWriter::write( CHECK(lastPacketEnd_ >= buf->data() && lastPacketEnd_ <= buf->tail()) << "lastPacketEnd_=" << (long)lastPacketEnd_ << " data=" << (long)buf->data() << " tail=" << (long)buf->tail(); - auto diffToEnd = buf->tail() - lastPacketEnd_; + uint64_t diffToEnd = buf->tail() - lastPacketEnd_; CHECK( - diffToEnd >= 0 && - static_cast(diffToEnd) <= conn_.udpSendPacketLen); - auto diffToStart = lastPacketEnd_ - buf->data(); + diffToEnd <= conn_.udpSendPacketLen || + (nextPacketSize_ && diffToEnd == nextPacketSize_)) + << "diffToEnd=" << diffToEnd << ", pktLimit=" << conn_.udpSendPacketLen + << ", nextPacketSize_=" << nextPacketSize_; + uint64_t diffToStart = lastPacketEnd_ - buf->data(); buf->trimEnd(diffToEnd); auto bytesWritten = (numPackets_ > 1) ? sock.writeGSO(address, buf, static_cast(prevSize_)) @@ -296,8 +303,15 @@ ssize_t GSOInplacePacketBatchWriter::write( buf->trimStart(diffToStart); buf->append(diffToEnd); buf->retreat(diffToStart); - CHECK(buf->length() <= conn_.udpSendPacketLen); - CHECK(0 == buf->headroom()); + auto bufLength = buf->length(); + CHECK_EQ(diffToEnd, bufLength) + << "diffToEnd=" << diffToEnd << ", bufLength=" << bufLength; + CHECK( + bufLength <= conn_.udpSendPacketLen || + (nextPacketSize_ && bufLength == nextPacketSize_)) + << "bufLength=" << bufLength << ", pktLimit=" << conn_.udpSendPacketLen + << ", nextPacketSize_=" << nextPacketSize_; + CHECK(0 == buf->headroom()) << "headroom=" << buf->headroom(); } else { buf->clear(); } diff --git a/quic/api/QuicBatchWriter.h b/quic/api/QuicBatchWriter.h index 47f54ec7e..bdc53addc 100644 --- a/quic/api/QuicBatchWriter.h +++ b/quic/api/QuicBatchWriter.h @@ -175,6 +175,15 @@ class GSOInplacePacketBatchWriter : public BatchWriter { const uint8_t* lastPacketEnd_{nullptr}; size_t prevSize_{0}; size_t numPackets_{0}; + + /** + * If we flush the batch due to the next packet being larger than current GSO + * size, we use the following value to keep track of that next packet, and + * checks against buffer residue after writes. The reason we cannot just check + * the buffer residue against the Quic packet limit is that there may be some + * retranmission packets slightly larger than the limit. + */ + size_t nextPacketSize_{0}; }; class SendmmsgPacketBatchWriter : public BatchWriter { diff --git a/quic/api/test/QuicBatchWriterTest.cpp b/quic/api/test/QuicBatchWriterTest.cpp index ad2fe11c9..7acee2660 100644 --- a/quic/api/test/QuicBatchWriterTest.cpp +++ b/quic/api/test/QuicBatchWriterTest.cpp @@ -583,6 +583,50 @@ TEST_P(QuicBatchWriterTest, InplaceWriterLastOneTooBig) { EXPECT_EQ(0, buf->headroom()); } +TEST_P(QuicBatchWriterTest, InplaceWriterBufResidueCheck) { + bool useThreadLocal = GetParam(); + folly::EventBase evb; + folly::test::MockAsyncUDPSocket sock(&evb); + EXPECT_CALL(sock, getGSO()).WillRepeatedly(Return(1)); + + uint32_t batchSize = 20; + auto bufAccessor = + std::make_unique(conn_.udpSendPacketLen * batchSize); + conn_.bufAccessor = bufAccessor.get(); + conn_.udpSendPacketLen = 1000; + auto batchWriter = quic::BatchWriterFactory::makeBatchWriter( + sock, + quic::QuicBatchingMode::BATCHING_MODE_GSO, + batchSize, + useThreadLocal, + quic::kDefaultThreadLocalDelay, + DataPathType::ContinuousMemory, + conn_); + auto buf = bufAccessor->obtain(); + folly::IOBuf* rawBuf = buf.get(); + bufAccessor->release(std::move(buf)); + rawBuf->append(700); + ASSERT_FALSE( + batchWriter->append(nullptr, 700, folly::SocketAddress(), nullptr)); + + size_t packetSizeTooBig = 1200; + rawBuf->append(packetSizeTooBig); + EXPECT_TRUE(batchWriter->needsFlush(packetSizeTooBig)); + + EXPECT_CALL(sock, write(_, _)) + .Times(1) + .WillOnce(Invoke([&](const auto& /* addr */, + const std::unique_ptr& buf) { + EXPECT_EQ(700, buf->length()); + return 700; + })); + // No crash: + EXPECT_EQ(700, batchWriter->write(sock, folly::SocketAddress())); + + EXPECT_EQ(1200, rawBuf->length()); + EXPECT_EQ(0, rawBuf->headroom()); +} + INSTANTIATE_TEST_CASE_P( QuicBatchWriterTest, QuicBatchWriterTest,