diff --git a/quic/api/QuicTransportBase.cpp b/quic/api/QuicTransportBase.cpp index bc3e1a6ac..a44019ab6 100644 --- a/quic/api/QuicTransportBase.cpp +++ b/quic/api/QuicTransportBase.cpp @@ -1579,12 +1579,13 @@ void QuicTransportBase::onNetworkData( updateWriteLooper(true); }; try { - if (networkData.data) { - conn_->lossState.totalBytesRecvd += - networkData.data->computeChainDataLength(); - } + conn_->lossState.totalBytesRecvd += networkData.totalData; auto originalAckVersion = currentAckStateVersion(*conn_); - onReadData(peer, std::move(networkData)); + for (auto& packet : networkData.packets) { + onReadData( + peer, + NetworkDataSingle(std::move(packet), networkData.receiveTimePoint)); + } processCallbacksAfterNetworkData(); if (closeState_ != CloseState::CLOSED) { if (currentAckStateVersion(*conn_) != originalAckVersion) { diff --git a/quic/api/QuicTransportBase.h b/quic/api/QuicTransportBase.h index f75518dbc..ea4eb77cd 100644 --- a/quic/api/QuicTransportBase.h +++ b/quic/api/QuicTransportBase.h @@ -278,7 +278,7 @@ class QuicTransportBase : public QuicSocket { */ virtual void onReadData( const folly::SocketAddress& peer, - NetworkData&& networkData) = 0; + NetworkDataSingle&& networkData) = 0; /** * Invoked when we have to write some data to the wire. diff --git a/quic/api/test/Mocks.h b/quic/api/test/Mocks.h index 0b759cd0f..edcbe88a1 100644 --- a/quic/api/test/Mocks.h +++ b/quic/api/test/Mocks.h @@ -201,7 +201,7 @@ class MockQuicTransport : public QuicServerTransport { void onNetworkData( const folly::SocketAddress& peer, NetworkData&& networkData) noexcept override { - onNetworkData(peer, networkData.data.get()); + onNetworkData(peer, networkData); } GMOCK_METHOD2_( @@ -209,7 +209,7 @@ class MockQuicTransport : public QuicServerTransport { noexcept, , onNetworkData, - void(const folly::SocketAddress&, const folly::IOBuf*)); + void(const folly::SocketAddress&, const NetworkData&)); GMOCK_METHOD1_( , diff --git a/quic/api/test/QuicTransportBaseTest.cpp b/quic/api/test/QuicTransportBaseTest.cpp index f1007b5ce..2da504965 100644 --- a/quic/api/test/QuicTransportBaseTest.cpp +++ b/quic/api/test/QuicTransportBaseTest.cpp @@ -156,7 +156,8 @@ class TestQuicTransport return lossTimeout_.getTimeRemaining(); } - void onReadData(const folly::SocketAddress&, NetworkData&& data) override { + void onReadData(const folly::SocketAddress&, NetworkDataSingle&& data) + override { if (!data.data) { return; } diff --git a/quic/api/test/QuicTransportTest.cpp b/quic/api/test/QuicTransportTest.cpp index d2b0e9934..460d05bcd 100644 --- a/quic/api/test/QuicTransportTest.cpp +++ b/quic/api/test/QuicTransportTest.cpp @@ -78,7 +78,7 @@ class TestQuicTransport void onReadData( const folly::SocketAddress& /*peer*/, - NetworkData&& /*networkData*/) noexcept override {} + NetworkDataSingle&& /*networkData*/) noexcept override {} void writeData() override { if (closed) { diff --git a/quic/client/QuicClientTransport.cpp b/quic/client/QuicClientTransport.cpp index d3f39997a..3b8232e1e 100644 --- a/quic/client/QuicClientTransport.cpp +++ b/quic/client/QuicClientTransport.cpp @@ -95,7 +95,7 @@ QuicClientTransport::~QuicClientTransport() { void QuicClientTransport::processUDPData( const folly::SocketAddress& peer, - NetworkData&& networkData) { + NetworkDataSingle&& networkData) { folly::IOBufQueue udpData{folly::IOBufQueue::cacheChainLength()}; udpData.append(std::move(networkData.data)); @@ -652,7 +652,7 @@ void QuicClientTransport::processPacketData( void QuicClientTransport::onReadData( const folly::SocketAddress& peer, - NetworkData&& networkData) { + NetworkDataSingle&& networkData) { if (closeState_ == CloseState::CLOSED) { // If we are closed, then we shoudn't process new network data. // TODO: we might want to process network data if we decide that we should @@ -1029,49 +1029,74 @@ bool QuicClientTransport::shouldOnlyNotify() { void QuicClientTransport::onNotifyDataAvailable() noexcept { DCHECK(conn_) << "trying to receive packets without a connection"; auto readBufferSize = conn_->transportSettings.maxRecvPacketSize; - auto readBuffer = folly::IOBuf::create(readBufferSize); + const size_t numPackets = conn_->transportSettings.maxRecvBatchSize; - struct iovec vec; - vec.iov_base = readBuffer->writableData(); - vec.iov_len = readBufferSize; + NetworkData networkData; + networkData.packets.reserve(numPackets); + size_t totalData = 0; + folly::Optional server; + for (size_t packetNum = 0; packetNum < numPackets; ++packetNum) { + // We create 1 buffer per packet so that it is not shared, this enables + // us to decrypt in place. If the fizz decrypt api could decrypt in-place + // even if shared, then we could allocate one giant IOBuf here. + Buf readBuffer = folly::IOBuf::createCombined(readBufferSize); + struct iovec vec {}; + vec.iov_base = readBuffer->writableData(); + vec.iov_len = readBufferSize; - struct sockaddr_storage addrStorage; - socklen_t addrLen = sizeof(addrStorage); - memset(&addrStorage, 0, size_t(addrLen)); - auto rawAddr = reinterpret_cast(&addrStorage); - rawAddr->sa_family = socket_->address().getFamily(); - - struct msghdr msg; - memset(&msg, 0, sizeof(msg)); - msg.msg_name = rawAddr; - msg.msg_namelen = addrLen; - msg.msg_iov = &vec; - msg.msg_iovlen = 1; - - ssize_t ret = socket_->recvmsg(&msg, 0); - if (ret < 0) { - if (errno == EAGAIN || errno == EWOULDBLOCK) { - return; + sockaddr* rawAddr{nullptr}; + struct sockaddr_storage addrStorage {}; + socklen_t addrLen{sizeof(addrStorage)}; + if (!server) { + memset(&addrStorage, 0, size_t(addrLen)); + rawAddr = reinterpret_cast(&addrStorage); + rawAddr->sa_family = socket_->address().getFamily(); + } + + struct msghdr msg {}; + msg.msg_name = rawAddr; + msg.msg_namelen = size_t(addrLen); + msg.msg_iov = &vec; + msg.msg_iovlen = 1; + + ssize_t ret = socket_->recvmsg(&msg, 0); + if (ret < 0) { + if (errno == EAGAIN || errno == EWOULDBLOCK || packetNum != 0) { + // If we got a retriable error, or we had previously processed + // a packet successfully, let's use that packet. + break; + } + return onReadError(folly::AsyncSocketException( + folly::AsyncSocketException::INTERNAL_ERROR, + "::recvmsg() failed", + errno)); + } else if (ret == 0) { + break; + } + size_t bytesRead = size_t(ret); + totalData += bytesRead; + if (!server) { + server = folly::SocketAddress(); + server->setFromSockaddr(rawAddr, addrLen); + } + VLOG(10) << "Got data from socket peer=" << *server << " len=" << bytesRead; + readBuffer->append(bytesRead); + networkData.packets.emplace_back(std::move(readBuffer)); + QUIC_TRACE(udp_recvd, *conn_, bytesRead); + if (conn_->qLogger) { + conn_->qLogger->addDatagramReceived(bytesRead); } - return onReadError(folly::AsyncSocketException( - folly::AsyncSocketException::INTERNAL_ERROR, - "::recvmsg() failed", - errno)); } - size_t bytesRead = size_t(ret); - folly::SocketAddress server; - server.setFromSockaddr(rawAddr, addrLen); - VLOG(10) << "Got data from socket peer=" << server << " len=" << bytesRead; + if (networkData.packets.empty()) { + return; + } + DCHECK(server.hasValue()); // TODO: we can get better receive time accuracy than this, with // SO_TIMESTAMP or SIOCGSTAMP. auto packetReceiveTime = Clock::now(); - readBuffer->append(bytesRead); - QUIC_TRACE(udp_recvd, *conn_, bytesRead); - if (conn_->qLogger) { - conn_->qLogger->addDatagramReceived(bytesRead); - } - NetworkData networkData(std::move(readBuffer), packetReceiveTime); - onNetworkData(server, std::move(networkData)); + networkData.receiveTimePoint = packetReceiveTime; + networkData.totalData = totalData; + onNetworkData(*server, std::move(networkData)); } void QuicClientTransport:: diff --git a/quic/client/QuicClientTransport.h b/quic/client/QuicClientTransport.h index f051ca5df..1d25897e6 100644 --- a/quic/client/QuicClientTransport.h +++ b/quic/client/QuicClientTransport.h @@ -89,8 +89,9 @@ class QuicClientTransport bool isTLSResumed() const; // From QuicTransportBase - void onReadData(const folly::SocketAddress& peer, NetworkData&& networkData) - override; + void onReadData( + const folly::SocketAddress& peer, + NetworkDataSingle&& networkData) override; void writeData() override; void closeTransport() override; void unbindConnection() override; @@ -153,7 +154,7 @@ class QuicClientTransport void processUDPData( const folly::SocketAddress& peer, - NetworkData&& networkData); + NetworkDataSingle&& networkData); void processPacketData( const folly::SocketAddress& peer, diff --git a/quic/client/test/QuicClientTransportTest.cpp b/quic/client/test/QuicClientTransportTest.cpp index 5ea9daec6..299a7e024 100644 --- a/quic/client/test/QuicClientTransportTest.cpp +++ b/quic/client/test/QuicClientTransportTest.cpp @@ -1287,6 +1287,17 @@ class FakeOneRttHandshakeLayer : public ClientHandshake { class QuicClientTransportTest : public Test { public: + struct TestReadData { + std::unique_ptr data; + folly::SocketAddress addr; + folly::Optional err; + + TestReadData(folly::ByteRange dataIn, const folly::SocketAddress& addrIn) + : data(folly::IOBuf::copyBuffer(dataIn)), addr(addrIn) {} + + explicit TestReadData(int errIn) : err(errIn) {} + }; + QuicClientTransportTest() : eventbase_(std::make_unique()) { auto socket = std::make_unique(eventbase_.get()); @@ -1306,6 +1317,30 @@ class QuicClientTransportTest : public Test { ON_CALL(*sock, resumeRead(_)) .WillByDefault(SaveArg<0>(&networkReadCallback)); ON_CALL(*sock, address()).WillByDefault(ReturnRef(serverAddr)); + ON_CALL(*sock, recvmsg(_, _)) + .WillByDefault(Invoke([&](struct msghdr* msg, int) -> ssize_t { + DCHECK_GT(msg->msg_iovlen, 0); + if (socketReads.empty()) { + errno = EAGAIN; + return -1; + } + if (socketReads[0].err) { + errno = *socketReads[0].err; + return -1; + } + auto testData = std::move(socketReads[0].data); + testData->coalesce(); + size_t testDataLen = testData->length(); + memcpy( + msg->msg_iov[0].iov_base, testData->data(), testData->length()); + if (msg->msg_name) { + socklen_t msg_len = socketReads[0].addr.getAddress( + static_cast(msg->msg_name)); + msg->msg_namelen = msg_len; + } + socketReads.pop_front(); + return testDataLen; + })); EXPECT_EQ(client->getConn().selfConnectionIds.size(), 1); EXPECT_EQ( client->getConn().selfConnectionIds[0].connId, @@ -1483,23 +1518,19 @@ class QuicClientTransportTest : public Test { folly::ByteRange data, bool writes = true) { ASSERT_TRUE(networkReadCallback); - EXPECT_CALL(*sock, recvmsg(_, _)) - .WillOnce(Invoke([&](struct msghdr* msg, int) { - DCHECK_GT(msg->msg_iovlen, 0); - memcpy(msg->msg_iov[0].iov_base, data.data(), data.size()); - if (msg->msg_name) { - socklen_t msg_len = - addr.getAddress(static_cast(msg->msg_name)); - msg->msg_namelen = msg_len; - } - return data.size(); - })); + socketReads.emplace_back(TestReadData(data, addr)); networkReadCallback->onNotifyDataAvailable(); if (writes) { loopForWrites(); } } + void deliverNetworkError(int err) { + ASSERT_TRUE(networkReadCallback); + socketReads.emplace_back(TestReadData(err)); + networkReadCallback->onNotifyDataAvailable(); + } + void deliverDataWithoutErrorCheck(folly::ByteRange data, bool writes = true) { deliverDataWithoutErrorCheck(serverAddr, std::move(data), writes); } @@ -1650,6 +1681,7 @@ class QuicClientTransportTest : public Test { protected: std::vector> socketWrites; + std::deque socketReads; MockDeliveryCallback deliveryCallback; MockReadCallback readCb; MockConnectionCallback clientConnCallback; @@ -2734,6 +2766,140 @@ TEST_F(QuicClientTransportAfterStartTest, ReadStream) { client->close(folly::none); } +TEST_F(QuicClientTransportAfterStartTest, ReadStreamMultiplePackets) { + StreamId streamId = client->createBidirectionalStream().value(); + + client->setReadCallback(streamId, &readCb); + bool dataDelivered = false; + auto data = IOBuf::copyBuffer("hello"); + + auto expected = data->clone(); + expected->prependChain(data->clone()); + EXPECT_CALL(readCb, readAvailable(streamId)).WillOnce(Invoke([&](auto) { + auto readData = client->read(streamId, 1000); + auto copy = readData->first->clone(); + LOG(INFO) << "Client received data=" + << copy->clone()->moveToFbString().toStdString() + << " on stream=" << streamId; + EXPECT_EQ( + copy->moveToFbString().toStdString(), + expected->clone()->moveToFbString().toStdString()); + dataDelivered = true; + eventbase_->terminateLoopSoon(); + })); + auto packet1 = packetToBuf(createStreamPacket( + *serverChosenConnId /* src */, + *originalConnId /* dest */, + appDataPacketNum++, + streamId, + *data, + 0 /* cipherOverhead */, + 0 /* largestAcked */, + folly::none /* longHeaderOverride */, + false /* eof */)); + auto packet2 = packetToBuf(createStreamPacket( + *serverChosenConnId /* src */, + *originalConnId /* dest */, + appDataPacketNum++, + streamId, + *data, + 0 /* cipherOverhead */, + 0 /* largestAcked */, + folly::none /* longHeaderOverride */, + true /* eof */, + folly::none /* shortHeaderOverride */, + data->length() /* offset */)); + + socketReads.emplace_back(TestReadData(packet1->coalesce(), serverAddr)); + deliverData(packet2->coalesce()); + if (!dataDelivered) { + eventbase_->loopForever(); + } + EXPECT_TRUE(dataDelivered); + client->close(folly::none); +} + +TEST_F(QuicClientTransportAfterStartTest, ReadStreamWithRetriableError) { + StreamId streamId = client->createBidirectionalStream().value(); + client->setReadCallback(streamId, &readCb); + EXPECT_CALL(readCb, readAvailable(_)).Times(0); + EXPECT_CALL(readCb, readError(_, _)).Times(0); + deliverNetworkError(EAGAIN); + client->setReadCallback(streamId, nullptr); + client->close(folly::none); +} + +TEST_F(QuicClientTransportAfterStartTest, ReadStreamWithNonRetriableError) { + StreamId streamId = client->createBidirectionalStream().value(); + client->setReadCallback(streamId, &readCb); + EXPECT_CALL(readCb, readAvailable(_)).Times(0); + // TODO: we currently do not close the socket, but maybe we can in the future. + EXPECT_CALL(readCb, readError(_, _)).Times(0); + deliverNetworkError(EBADF); + client->setReadCallback(streamId, nullptr); + client->close(folly::none); +} + +TEST_F( + QuicClientTransportAfterStartTest, + ReadStreamMultiplePacketsWithRetriableError) { + StreamId streamId = client->createBidirectionalStream().value(); + + client->setReadCallback(streamId, &readCb); + bool dataDelivered = false; + auto expected = IOBuf::copyBuffer("hello"); + EXPECT_CALL(readCb, readAvailable(streamId)).WillOnce(Invoke([&](auto) { + auto readData = client->read(streamId, 1000); + auto copy = readData->first->clone(); + LOG(INFO) << "Client received data=" << copy->moveToFbString().toStdString() + << " on stream=" << streamId; + EXPECT_TRUE(folly::IOBufEqualTo()((*readData).first, expected)); + dataDelivered = true; + eventbase_->terminateLoopSoon(); + })); + auto packet = packetToBuf(createStreamPacket( + *serverChosenConnId /* src */, + *originalConnId /* dest */, + appDataPacketNum++, + streamId, + *expected, + 0 /* cipherOverhead */, + 0 /* largestAcked */)); + + socketReads.emplace_back(TestReadData(packet->coalesce(), serverAddr)); + deliverNetworkError(EAGAIN); + if (!dataDelivered) { + eventbase_->loopForever(); + } + EXPECT_TRUE(dataDelivered); + client->close(folly::none); +} + +TEST_F( + QuicClientTransportAfterStartTest, + ReadStreamMultiplePacketsWithNonRetriableError) { + StreamId streamId = client->createBidirectionalStream().value(); + + client->setReadCallback(streamId, &readCb); + auto expected = IOBuf::copyBuffer("hello"); + EXPECT_CALL(readCb, readAvailable(streamId)).Times(0); + + // TODO: we currently do not close the socket, but maybe we can in the future. + EXPECT_CALL(readCb, readError(_, _)).Times(0); + auto packet = packetToBuf(createStreamPacket( + *serverChosenConnId /* src */, + *originalConnId /* dest */, + appDataPacketNum++, + streamId, + *expected, + 0 /* cipherOverhead */, + 0 /* largestAcked */)); + + socketReads.emplace_back(TestReadData(packet->coalesce(), serverAddr)); + deliverNetworkError(EBADF); + client->setReadCallback(streamId, nullptr); +} + TEST_F(QuicClientTransportAfterStartTest, RecvNewConnectionIdValid) { auto& conn = client->getNonConstConn(); conn.transportSettings.selfActiveConnectionIdLimit = 1; diff --git a/quic/server/QuicServerTransport.cpp b/quic/server/QuicServerTransport.cpp index df484f3fb..73f5e587f 100644 --- a/quic/server/QuicServerTransport.cpp +++ b/quic/server/QuicServerTransport.cpp @@ -97,7 +97,7 @@ void QuicServerTransport::setCongestionControllerFactory( void QuicServerTransport::onReadData( const folly::SocketAddress& peer, - NetworkData&& networkData) { + NetworkDataSingle&& networkData) { ServerEvents::ReadData readData; readData.peer = peer; readData.networkData = std::move(networkData); @@ -341,7 +341,10 @@ void QuicServerTransport::processPendingData(bool async) { auto serverPtr = static_cast(self.get()); for (auto& pendingPacket : *pendingData) { serverPtr->onNetworkData( - pendingPacket.peer, std::move(pendingPacket.networkData)); + pendingPacket.peer, + NetworkData( + std::move(pendingPacket.networkData.data), + pendingPacket.networkData.receiveTimePoint)); if (serverPtr->closeState_ == CloseState::CLOSED) { // The pending data could potentially contain a connection close, or // the app could have triggered a connection close with an error. It diff --git a/quic/server/QuicServerTransport.h b/quic/server/QuicServerTransport.h index 67d9c64fa..38d2ddcb2 100644 --- a/quic/server/QuicServerTransport.h +++ b/quic/server/QuicServerTransport.h @@ -91,8 +91,9 @@ class QuicServerTransport virtual void setClientConnectionId(const ConnectionId& clientConnectionId); // From QuicTransportBase - void onReadData(const folly::SocketAddress& peer, NetworkData&& networkData) - override; + void onReadData( + const folly::SocketAddress& peer, + NetworkDataSingle&& networkData) override; void writeData() override; void closeTransport() override; void unbindConnection() override; diff --git a/quic/server/QuicServerWorker.cpp b/quic/server/QuicServerWorker.cpp index 52ab314c4..7834778d9 100644 --- a/quic/server/QuicServerWorker.cpp +++ b/quic/server/QuicServerWorker.cpp @@ -298,8 +298,9 @@ void QuicServerWorker::forwardNetworkData( if (packetForwardingEnabled_ && !isForwardedData) { VLOG(3) << "Forwarding packet with unknown connId version from client=" << client << " to another process"; + auto recvTime = networkData.receiveTimePoint; takeoverPktHandler_.forwardPacketToAnotherServer( - client, std::move(networkData.data), networkData.receiveTimePoint); + client, std::move(networkData).moveAllData(), recvTime); QUIC_STATS(infoCallback_, onPacketForwarded); return; } else { @@ -370,8 +371,7 @@ void QuicServerWorker::dispatchPacketData( // This could be a new connection, add it in the map // verify that the initial packet is at least min initial bytes // to avoid amplification attacks. - if (networkData.data->computeChainDataLength() < - kMinInitialPacketSize) { + if (networkData.totalData < kMinInitialPacketSize) { // Don't even attempt to forward the packet, just drop it. VLOG(3) << "Dropping small initial packet from client=" << client; QUIC_STATS( @@ -476,8 +476,9 @@ void QuicServerWorker::dispatchPacketData( VLOG(4) << "Forwarding packet from client=" << client << " to another process, workerId=" << (uint32_t)workerId_ << ", processId_=" << (uint32_t) static_cast(processId_); + auto recvTime = networkData.receiveTimePoint; takeoverPktHandler_.forwardPacketToAnotherServer( - client, std::move(networkData.data), networkData.receiveTimePoint); + client, std::move(networkData).moveAllData(), recvTime); QUIC_STATS(infoCallback_, onPacketForwarded); } @@ -490,7 +491,7 @@ void QuicServerWorker::sendResetPacket( // Only send resets in response to short header packets. return; } - uint16_t packetSize = networkData.data->computeChainDataLength(); + uint16_t packetSize = networkData.totalData; uint16_t maxResetPacketSize = std::min( std::max(kMinStatelessPacketSize, packetSize), kDefaultUDPSendPacketLen); diff --git a/quic/server/state/ServerStateMachine.cpp b/quic/server/state/ServerStateMachine.cpp index b8f7acec1..e4fff926d 100644 --- a/quic/server/state/ServerStateMachine.cpp +++ b/quic/server/state/ServerStateMachine.cpp @@ -481,7 +481,7 @@ void handleCipherUnavailable( } ServerEvents::ReadData pendingReadData; pendingReadData.peer = readData.peer; - pendingReadData.networkData = NetworkData( + pendingReadData.networkData = NetworkDataSingle( std::move(originalData->packet), readData.networkData.receiveTimePoint); pendingData->emplace_back(std::move(pendingReadData)); VLOG(10) << "Adding pending data to " diff --git a/quic/server/state/ServerStateMachine.h b/quic/server/state/ServerStateMachine.h index 9dde7f981..05ed48621 100644 --- a/quic/server/state/ServerStateMachine.h +++ b/quic/server/state/ServerStateMachine.h @@ -43,7 +43,7 @@ enum ServerState { struct ServerEvents { struct ReadData { folly::SocketAddress peer; - NetworkData networkData; + NetworkDataSingle networkData; }; struct Close {}; diff --git a/quic/server/test/QuicServerTest.cpp b/quic/server/test/QuicServerTest.cpp index 1648d8cf0..9e9542c5e 100644 --- a/quic/server/test/QuicServerTest.cpp +++ b/quic/server/test/QuicServerTest.cpp @@ -32,9 +32,15 @@ using PacketDropReason = QuicTransportStatsCallback::PacketDropReason; } // namespace namespace test { -MATCHER_P(BufMatches, buf, "") { - folly::IOBufEqualTo eq; - return eq(*arg, buf); +MATCHER_P(NetworkDataMatches, networkData, "") { + for (size_t i = 0; i < arg.packets.size(); ++i) { + folly::IOBufEqualTo eq; + bool equals = eq(*arg.packets[i], networkData); + if (equals) { + return true; + } + } + return false; } class TestingEventBaseObserver : public folly::EventBaseObserver { @@ -215,7 +221,7 @@ void QuicServerWorkerTest::createQuicConnection( transport = transportOverride; } expectConnectionCreation(addr, connId, transport); - EXPECT_CALL(*transport, onNetworkData(addr, BufMatches(*data))); + EXPECT_CALL(*transport, onNetworkData(addr, NetworkDataMatches(*data))); worker_->dispatchPacketData( addr, std::move(routingData), NetworkData(data->clone(), Clock::now())); @@ -364,7 +370,8 @@ TEST_F(QuicServerWorkerTest, QuicServerMultipleConnIdsRouting) { EXPECT_EQ(addrMap.count(std::make_pair(kClientAddr, connId)), 0); // routing by connid after connid available. - EXPECT_CALL(*transport_, onNetworkData(kClientAddr, BufMatches(*data))) + EXPECT_CALL( + *transport_, onNetworkData(kClientAddr, NetworkDataMatches(*data))) .Times(1); RoutingData routingData2( HeaderForm::Short, false, false, connId, folly::none); @@ -380,7 +387,8 @@ TEST_F(QuicServerWorkerTest, QuicServerMultipleConnIdsRouting) { EXPECT_EQ(connIdMap.size(), 2); - EXPECT_CALL(*transport_, onNetworkData(kClientAddr, BufMatches(*data))) + EXPECT_CALL( + *transport_, onNetworkData(kClientAddr, NetworkDataMatches(*data))) .Times(1); RoutingData routingData3( HeaderForm::Short, false, false, connId2, folly::none); @@ -449,7 +457,8 @@ TEST_F(QuicServerWorkerTest, QuicServerNewConnection) { 0); // routing by connid after connid available. - EXPECT_CALL(*transport_, onNetworkData(kClientAddr, BufMatches(*data))); + EXPECT_CALL( + *transport_, onNetworkData(kClientAddr, NetworkDataMatches(*data))); RoutingData routingData2( HeaderForm::Short, false, @@ -1012,7 +1021,8 @@ TEST_F(QuicServerWorkerTakeoverTest, QuicServerTakeoverProcessForwardedPkt) { EXPECT_EQ(addr.getPort(), clientAddr.getPort()); // the original data should be extracted after processing takeover // protocol related information - EXPECT_TRUE(eq(*data, *(networkData->data))); + EXPECT_EQ(networkData->packets.size(), 1); + EXPECT_TRUE(eq(*data, *(networkData->packets[0]))); EXPECT_TRUE(isForwardedData); }; EXPECT_CALL(*takeoverWorkerCb_, routeDataToWorkerLong(_, _, _, _)) @@ -1182,12 +1192,15 @@ class QuicServerTest : public Test { transportSettings.advertisedInitialConnectionWindowSize); })); ON_CALL(*transport, onNetworkData(_, _)) - .WillByDefault(Invoke([&, expected = data.get()](auto, auto buf) { - EXPECT_TRUE(folly::IOBufEqualTo()(*buf, *expected)); - std::unique_lock lg(m); - calledOnNetworkData = true; - cv.notify_one(); - })); + .WillByDefault( + Invoke([&, expected = data.get()](auto, const auto& networkData) { + EXPECT_GT(networkData.packets.size(), 0); + EXPECT_TRUE( + folly::IOBufEqualTo()(*networkData.packets[0], *expected)); + std::unique_lock lg(m); + calledOnNetworkData = true; + cv.notify_one(); + })); return transport; }; EXPECT_CALL(*factory_, _make(_, _, _, _)).WillOnce(Invoke(makeTransport)); @@ -1361,10 +1374,13 @@ class QuicServerTakeoverTest : public Test { EXPECT_EQ(params.workerId, 0); })); EXPECT_CALL(*transport, onNetworkData(_, _)) - .WillOnce(Invoke([&, expected = data.get()](auto, auto buf) { - EXPECT_TRUE(folly::IOBufEqualTo()(*buf, *expected)); - baton.post(); - })); + .WillOnce( + Invoke([&, expected = data.get()](auto, const auto& networkData) { + EXPECT_GT(networkData.packets.size(), 0); + EXPECT_TRUE( + folly::IOBufEqualTo()(*networkData.packets[0], *expected)); + baton.post(); + })); return transport; }; EXPECT_CALL(*factory, _make(_, _, _, _)).WillOnce(Invoke(makeTransport)); @@ -1465,10 +1481,13 @@ class QuicServerTakeoverTest : public Test { // onNetworkData(_, _) shouldn't be called on the newServer_ transport, // but should be routed to oldServer_ EXPECT_CALL(*transportCbForOldServer, onNetworkData(_, _)) - .WillOnce(Invoke([&, expected = data.get()](auto, auto buf) { - EXPECT_TRUE(folly::IOBufEqualTo()(*buf, *expected)); - b1.post(); - })); + .WillOnce( + Invoke([&, expected = data.get()](auto, const auto& networkData) { + EXPECT_GT(networkData.packets.size(), 0); + EXPECT_TRUE( + folly::IOBufEqualTo()(*networkData.packets[0], *expected)); + b1.post(); + })); // new quic server receives the packet and forwards it EXPECT_CALL(*newTransInfoCb_, onPacketReceived()); EXPECT_CALL(*newTransInfoCb_, onRead(_)); @@ -1885,10 +1904,13 @@ TEST_F(QuicServerTest, ZeroRttPacketRoute) { EXPECT_CALL(*transport, accept()); // post baton upon receiving the data EXPECT_CALL(*transport, onNetworkData(_, _)) - .WillOnce(Invoke([&, expected = data.get()](auto, auto buf) { - EXPECT_TRUE(folly::IOBufEqualTo()(*buf, *expected)); - b.post(); - })); + .WillOnce( + Invoke([&, expected = data.get()](auto, const auto& networkData) { + EXPECT_GT(networkData.packets.size(), 0); + EXPECT_TRUE( + folly::IOBufEqualTo()(*networkData.packets[0], *expected)); + b.post(); + })); return transport; }; EXPECT_CALL(*factory_, _make(_, _, _, _)).WillOnce(Invoke(makeTransport)); @@ -1926,9 +1948,11 @@ TEST_F(QuicServerTest, ZeroRttPacketRoute) { data = std::move(packet); folly::Baton<> b1; auto verifyZeroRtt = [&]( - const folly::SocketAddress& peer, const folly::IOBuf* rcvdPkt) noexcept { + const folly::SocketAddress& peer, + const NetworkData& networkData) noexcept { + EXPECT_GT(networkData.packets.size(), 0); EXPECT_EQ(peer, reader->getSocket().address()); - EXPECT_TRUE(folly::IOBufEqualTo()(*rcvdPkt, *data)); + EXPECT_TRUE(folly::IOBufEqualTo()(*data, *networkData.packets[0])); b1.post(); }; EXPECT_CALL(*transport, onNetworkData(_, _)).WillOnce(Invoke(verifyZeroRtt)); diff --git a/quic/state/StateData.h b/quic/state/StateData.h index 3d10afcfd..38d53f6e6 100644 --- a/quic/state/StateData.h +++ b/quic/state/StateData.h @@ -34,12 +34,47 @@ namespace quic { struct NetworkData { - Buf data; TimePoint receiveTimePoint; + std::vector> packets; + size_t totalData{0}; NetworkData() = default; NetworkData(Buf&& buf, const TimePoint& receiveTime) - : data(std::move(buf)), receiveTimePoint(receiveTime) {} + : receiveTimePoint(receiveTime) { + if (buf) { + totalData = buf->computeChainDataLength(); + packets.emplace_back(std::move(buf)); + } + } + + std::unique_ptr moveAllData() && { + std::unique_ptr buf; + for (size_t i = 0; i < packets.size(); ++i) { + if (buf) { + buf->prependChain(std::move(packets[i])); + } else { + buf = std::move(packets[i]); + } + } + return buf; + } +}; + +struct NetworkDataSingle { + std::unique_ptr data; + TimePoint receiveTimePoint; + size_t totalData{0}; + + NetworkDataSingle() = default; + + NetworkDataSingle( + std::unique_ptr buf, + const TimePoint& receiveTime) + : data(std::move(buf)), receiveTimePoint(receiveTime) { + if (data) { + totalData += data->computeChainDataLength(); + } + } }; /** diff --git a/quic/state/TransportSettings.h b/quic/state/TransportSettings.h index 749f104a4..5d3e2abb1 100644 --- a/quic/state/TransportSettings.h +++ b/quic/state/TransportSettings.h @@ -115,6 +115,10 @@ struct TransportSettings { // The active_connection_id_limit that is sent to the peer. uint64_t selfActiveConnectionIdLimit{0}; + + // Maximum size of the batch that should be used when receiving packets from + // the kernel in one event loop. + size_t maxRecvBatchSize{5}; }; } // namespace quic