diff --git a/quic/api/QuicTransportFunctions.cpp b/quic/api/QuicTransportFunctions.cpp index 7d1e7a3af..d1f364a57 100644 --- a/quic/api/QuicTransportFunctions.cpp +++ b/quic/api/QuicTransportFunctions.cpp @@ -181,6 +181,98 @@ uint64_t writeQuicDataToSocketImpl( return written; } +DataPathResult continuousMemoryBuildScheduleEncrypt( + QuicConnectionStateBase& connection, + PacketHeader header, + PacketNumberSpace pnSpace, + PacketNum packetNum, + uint64_t cipherOverhead, + QuicPacketScheduler& scheduler, + uint64_t writableBytes, + IOBufQuicBatch& ioBufBatch, + const Aead& aead, + const PacketNumberCipher& headerCipher) { + auto buf = connection.bufAccessor->obtain(); + auto prevSize = buf->length(); + connection.bufAccessor->release(std::move(buf)); + + auto rollbackBuf = [&]() { + auto buf = connection.bufAccessor->obtain(); + buf->trimEnd(buf->length() - prevSize); + connection.bufAccessor->release(std::move(buf)); + }; + + // It's the scheduler's job to invoke encode header + InplaceQuicPacketBuilder pktBuilder( + *connection.bufAccessor, + connection.udpSendPacketLen, + std::move(header), + getAckState(connection, pnSpace).largestAckedByPeer); + pktBuilder.setCipherOverhead(cipherOverhead); + CHECK(scheduler.hasData()); + auto result = + scheduler.scheduleFramesForPacket(std::move(pktBuilder), writableBytes); + CHECK(connection.bufAccessor->ownsBuffer()); + auto& packet = result.packet; + if (!packet || packet->packet.frames.empty()) { + rollbackBuf(); + ioBufBatch.flush(); + if (connection.loopDetectorCallback) { + connection.writeDebugState.noWriteReason = NoWriteReason::NO_FRAME; + } + return DataPathResult::makeBuildFailure(); + } + if (!packet->body) { + // No more space remaining. + rollbackBuf(); + ioBufBatch.flush(); + if (connection.loopDetectorCallback) { + connection.writeDebugState.noWriteReason = NoWriteReason::NO_BODY; + } + return DataPathResult::makeBuildFailure(); + } + CHECK(!packet->header->isChained()); + auto headerLen = packet->header->length(); + buf = connection.bufAccessor->obtain(); + CHECK( + packet->body->data() > buf->data() && + packet->body->tail() <= buf->tail()); + CHECK( + packet->header->data() >= buf->data() && + packet->header->tail() < buf->tail()); + // Trim off everything before the current packet, and the header length, so + // buf's data starts from the body part of buf. + buf->trimStart(prevSize + headerLen); + // buf and packetBuf is actually the same. + auto packetBuf = + aead.inplaceEncrypt(std::move(buf), packet->header.get(), packetNum); + CHECK(packetBuf->headroom() == headerLen + prevSize); + // Include header back. + packetBuf->prepend(headerLen); + + HeaderForm headerForm = packet->packet.header.getHeaderForm(); + encryptPacketHeader( + headerForm, + packetBuf->writableData(), + headerLen, + packetBuf->data() + headerLen, + packetBuf->length() - headerLen, + headerCipher); + CHECK(!packetBuf->isChained()); + auto encodedSize = packetBuf->length(); + // Include previous packets back. + packetBuf->prepend(prevSize); + connection.bufAccessor->release(std::move(packetBuf)); + // TODO: I think we should add an API that doesn't need a buffer. + bool ret = ioBufBatch.write(nullptr /* no need to pass buf */, encodedSize); + // update stats and connection + if (ret) { + QUIC_STATS(connection.statsCallback, onWrite, encodedSize); + QUIC_STATS(connection.statsCallback, onPacketSent); + } + return DataPathResult::makeWriteResult(ret, std::move(result), encodedSize); +} + DataPathResult iobufChainBasedBuildScheduleEncrypt( QuicConnectionStateBase& connection, PacketHeader header, @@ -1046,7 +1138,10 @@ uint64_t writeConnectionDataToSocket( } // TODO: Select a different DataPathFunc based on TransportSettings - const auto& dataPlainFunc = iobufChainBasedBuildScheduleEncrypt; + const auto& dataPlainFunc = + connection.transportSettings.dataPathType == DataPathType::ChainedMemory + ? iobufChainBasedBuildScheduleEncrypt + : continuousMemoryBuildScheduleEncrypt; auto ret = dataPlainFunc( connection, std::move(header), @@ -1088,6 +1183,13 @@ uint64_t writeConnectionDataToSocket( } ioBufBatch.flush(); + if (connection.transportSettings.dataPathType == + DataPathType::ContinuousMemory) { + CHECK(connection.bufAccessor->ownsBuffer()); + auto buf = connection.bufAccessor->obtain(); + CHECK(buf->length() == 0 && buf->headroom() == 0); + connection.bufAccessor->release(std::move(buf)); + } return ioBufBatch.getPktSent(); } diff --git a/quic/api/test/QuicTransportFunctionsTest.cpp b/quic/api/test/QuicTransportFunctionsTest.cpp index d63e460ef..4387f8c4a 100644 --- a/quic/api/test/QuicTransportFunctionsTest.cpp +++ b/quic/api/test/QuicTransportFunctionsTest.cpp @@ -2221,5 +2221,258 @@ TEST_F(QuicTransportFunctionsTest, ProbeWriteNewFunctionalFrames) { conn->outstandingPackets[1].packet.frames[0].type()); } +TEST_F(QuicTransportFunctionsTest, WriteWithInplaceBuilder) { + auto conn = createConn(); + conn->transportSettings.dataPathType = DataPathType::ContinuousMemory; + auto simpleBufAccessor = + std::make_unique(conn->udpSendPacketLen * 16); + auto outputBuf = simpleBufAccessor->obtain(); + auto bufPtr = outputBuf.get(); + simpleBufAccessor->release(std::move(outputBuf)); + conn->bufAccessor = simpleBufAccessor.get(); + conn->transportSettings.batchingMode = QuicBatchingMode::BATCHING_MODE_GSO; + EventBase evb; + folly::test::MockAsyncUDPSocket mockSock(&evb); + EXPECT_CALL(mockSock, getGSO()).WillRepeatedly(Return(true)); + auto stream = conn->streamManager->createNextBidirectionalStream().value(); + auto buf = folly::IOBuf::copyBuffer("Andante in C minor"); + writeDataToQuicStream(*stream, buf->clone(), true); + EXPECT_CALL(mockSock, write(_, _)) + .Times(1) + .WillOnce(Invoke([&](const SocketAddress&, + const std::unique_ptr& sockBuf) { + EXPECT_GT(bufPtr->length(), 0); + EXPECT_GE(sockBuf->length(), buf->length()); + EXPECT_EQ(sockBuf.get(), bufPtr); + EXPECT_TRUE(folly::IOBufEqualTo()(*sockBuf, *bufPtr)); + EXPECT_FALSE(sockBuf->isChained()); + return sockBuf->computeChainDataLength(); + })); + writeQuicDataToSocket( + mockSock, + *conn, + *conn->clientConnectionId, + *conn->serverConnectionId, + *aead, + *headerCipher, + getVersion(*conn), + conn->transportSettings.writeConnectionDataPacketsLimit); + EXPECT_EQ(0, bufPtr->length()); + EXPECT_EQ(0, bufPtr->headroom()); +} + +TEST_F(QuicTransportFunctionsTest, WriteWithInplaceBuilderRollbackBuf) { + auto conn = createConn(); + conn->transportSettings.dataPathType = DataPathType::ContinuousMemory; + auto simpleBufAccessor = + std::make_unique(conn->udpSendPacketLen * 16); + auto outputBuf = simpleBufAccessor->obtain(); + auto bufPtr = outputBuf.get(); + simpleBufAccessor->release(std::move(outputBuf)); + conn->bufAccessor = simpleBufAccessor.get(); + conn->transportSettings.batchingMode = QuicBatchingMode::BATCHING_MODE_GSO; + EventBase evb; + folly::test::MockAsyncUDPSocket mockSock(&evb); + EXPECT_CALL(mockSock, getGSO()).WillRepeatedly(Return(true)); + EXPECT_CALL(mockSock, write(_, _)).Times(0); + writeQuicDataToSocket( + mockSock, + *conn, + *conn->clientConnectionId, + *conn->serverConnectionId, + *aead, + *headerCipher, + getVersion(*conn), + conn->transportSettings.writeConnectionDataPacketsLimit); + EXPECT_EQ(0, bufPtr->length()); + EXPECT_EQ(0, bufPtr->headroom()); +} + +TEST_F(QuicTransportFunctionsTest, WriteWithInplaceBuilderGSOMultiplePackets) { + auto conn = createConn(); + conn->transportSettings.dataPathType = DataPathType::ContinuousMemory; + auto simpleBufAccessor = + std::make_unique(conn->udpSendPacketLen * 16); + auto outputBuf = simpleBufAccessor->obtain(); + auto bufPtr = outputBuf.get(); + simpleBufAccessor->release(std::move(outputBuf)); + conn->bufAccessor = simpleBufAccessor.get(); + conn->transportSettings.batchingMode = QuicBatchingMode::BATCHING_MODE_GSO; + EventBase evb; + folly::test::MockAsyncUDPSocket mockSock(&evb); + EXPECT_CALL(mockSock, getGSO()).WillRepeatedly(Return(true)); + auto stream = conn->streamManager->createNextBidirectionalStream().value(); + auto buf = buildRandomInputData(conn->udpSendPacketLen * 10); + writeDataToQuicStream(*stream, buf->clone(), true); + EXPECT_CALL(mockSock, writeGSO(_, _, _)) + .Times(1) + .WillOnce(Invoke([&](const folly::SocketAddress&, + const std::unique_ptr& sockBuf, + int gso) { + EXPECT_LE(gso, conn->udpSendPacketLen); + EXPECT_GT(bufPtr->length(), 0); + EXPECT_EQ(sockBuf.get(), bufPtr); + EXPECT_TRUE(folly::IOBufEqualTo()(*sockBuf, *bufPtr)); + EXPECT_FALSE(sockBuf->isChained()); + return sockBuf->length(); + })); + writeQuicDataToSocket( + mockSock, + *conn, + *conn->clientConnectionId, + *conn->serverConnectionId, + *aead, + *headerCipher, + getVersion(*conn), + conn->transportSettings.writeConnectionDataPacketsLimit); + EXPECT_EQ(0, bufPtr->length()); + EXPECT_EQ(0, bufPtr->headroom()); +} + +TEST_F(QuicTransportFunctionsTest, WriteProbingWithInplaceBuilder) { + auto conn = createConn(); + conn->transportSettings.dataPathType = DataPathType::ContinuousMemory; + conn->transportSettings.batchingMode = QuicBatchingMode::BATCHING_MODE_GSO; + EventBase evb; + folly::test::MockAsyncUDPSocket mockSock(&evb); + EXPECT_CALL(mockSock, getGSO()).WillRepeatedly(Return(true)); + + SimpleBufAccessor bufAccessor( + conn->udpSendPacketLen * conn->transportSettings.maxBatchSize); + conn->bufAccessor = &bufAccessor; + auto buf = bufAccessor.obtain(); + auto bufPtr = buf.get(); + bufAccessor.release(std::move(buf)); + + auto stream = conn->streamManager->createNextBidirectionalStream().value(); + auto inputBuf = buildRandomInputData( + conn->udpSendPacketLen * + conn->transportSettings.writeConnectionDataPacketsLimit); + writeDataToQuicStream(*stream, inputBuf->clone(), true); + EXPECT_CALL(mockSock, writeGSO(_, _, _)) + .Times(1) + .WillOnce(Invoke([&](const folly::SocketAddress&, + const std::unique_ptr& sockBuf, + int gso) { + EXPECT_LE(gso, conn->udpSendPacketLen); + EXPECT_GE( + bufPtr->length(), + conn->udpSendPacketLen * + conn->transportSettings.writeConnectionDataPacketsLimit); + EXPECT_EQ(sockBuf.get(), bufPtr); + EXPECT_TRUE(folly::IOBufEqualTo()(*sockBuf, *bufPtr)); + EXPECT_FALSE(sockBuf->isChained()); + return sockBuf->length(); + })); + writeQuicDataToSocket( + mockSock, + *conn, + *conn->clientConnectionId, + *conn->serverConnectionId, + *aead, + *headerCipher, + getVersion(*conn), + conn->transportSettings.writeConnectionDataPacketsLimit + 1); + ASSERT_EQ(0, bufPtr->length()); + ASSERT_EQ(0, bufPtr->headroom()); + EXPECT_GE(conn->outstandingPackets.size(), 5); + // Make sure there no more new data to write: + StreamFrameScheduler streamScheduler(*conn); + ASSERT_FALSE(streamScheduler.hasPendingData()); + + // The last packet may not be a full packet + auto lastPacketSize = conn->outstandingPackets.back().encodedSize; + size_t expectedOutstandingPacketsCount = 5; + if (lastPacketSize < conn->udpSendPacketLen) { + expectedOutstandingPacketsCount++; + } + EXPECT_CALL(mockSock, write(_, _)) + .Times(1) + .WillOnce(Invoke([&](const folly::SocketAddress&, + const std::unique_ptr& buf) { + EXPECT_FALSE(buf->isChained()); + // If the last packet isn't full, it may have the stream length field + // but the clone won't have it. + EXPECT_LE(buf->length(), lastPacketSize); + return buf->length(); + })); + writeProbingDataToSocketForTest( + mockSock, + *conn, + 1 /* probesToSend */, + *aead, + *headerCipher, + getVersion(*conn)); + EXPECT_EQ( + conn->outstandingPackets.size(), expectedOutstandingPacketsCount + 1); + EXPECT_EQ(0, bufPtr->length()); + EXPECT_EQ(0, bufPtr->headroom()); + + // Clone again, this time 2 pacckets. + if (lastPacketSize < conn->udpSendPacketLen) { + EXPECT_CALL(mockSock, writeGSO(_, _, _)) + .Times(1) + .WillOnce(Invoke([&](const folly::SocketAddress&, + const std::unique_ptr& buf, + int gso) { + EXPECT_FALSE(buf->isChained()); + EXPECT_LE(gso, lastPacketSize); + EXPECT_LE(buf->length(), lastPacketSize * 2); + return buf->length(); + })); + } else { + EXPECT_CALL(mockSock, writeGSO(_, _, _)) + .Times(1) + .WillOnce(Invoke([&](const folly::SocketAddress&, + const std::unique_ptr& buf, + int gso) { + EXPECT_FALSE(buf->isChained()); + EXPECT_EQ(conn->udpSendPacketLen, gso); + EXPECT_EQ(buf->length(), conn->udpSendPacketLen * 4); + return buf->length(); + })); + } + writeProbingDataToSocketForTest( + mockSock, + *conn, + 2 /* probesToSend */, + *aead, + *headerCipher, + getVersion(*conn)); + EXPECT_EQ(0, bufPtr->length()); + EXPECT_EQ(0, bufPtr->headroom()); + EXPECT_EQ( + conn->outstandingPackets.size(), expectedOutstandingPacketsCount + 3); + + // Clear out all the small packets: + while (conn->outstandingPackets.back().encodedSize < conn->udpSendPacketLen) { + conn->outstandingPackets.pop_back(); + } + ASSERT_FALSE(conn->outstandingPackets.empty()); + auto currentOutstandingPackets = conn->outstandingPackets.size(); + + // Clone 2 full size packets + EXPECT_CALL(mockSock, writeGSO(_, _, _)) + .Times(1) + .WillOnce(Invoke([&](const folly::SocketAddress&, + const std::unique_ptr& buf, + int gso) { + EXPECT_FALSE(buf->isChained()); + EXPECT_EQ(conn->udpSendPacketLen, gso); + EXPECT_EQ(buf->length(), conn->udpSendPacketLen * 2); + return buf->length(); + })); + writeProbingDataToSocketForTest( + mockSock, + *conn, + 2 /* probesToSend */, + *aead, + *headerCipher, + getVersion(*conn)); + EXPECT_EQ(conn->outstandingPackets.size(), currentOutstandingPackets + 2); + EXPECT_EQ(0, bufPtr->length()); + EXPECT_EQ(0, bufPtr->headroom()); +} + } // namespace test } // namespace quic