diff --git a/quic/api/QuicTransportBase.cpp b/quic/api/QuicTransportBase.cpp index 0d0ca150a..5c8bc3d68 100644 --- a/quic/api/QuicTransportBase.cpp +++ b/quic/api/QuicTransportBase.cpp @@ -1873,7 +1873,7 @@ void QuicTransportBase::onNetworkData( updateWriteLooper(true); }; try { - conn_->lossState.totalBytesRecvd += networkData.totalData; + conn_->lossState.totalBytesRecvd += networkData.getTotalData(); auto originalAckVersion = currentAckStateVersion(*conn_); // handle PacketsReceivedEvent if requested by observers @@ -1883,13 +1883,13 @@ void QuicTransportBase::onNetworkData( SocketObserverInterface::Events::packetsReceivedEvents>()) { auto builder = SocketObserverInterface::PacketsReceivedEvent::Builder() .setReceiveLoopTime(TimePoint::clock::now()) - .setNumPacketsReceived(networkData.packets.size()) - .setNumBytesReceived(networkData.totalData); - for (auto& packet : networkData.packets) { + .setNumPacketsReceived(networkData.getPackets().size()) + .setNumBytesReceived(networkData.getTotalData()); + for (auto& packet : networkData.getPackets()) { builder.addReceivedPacket( SocketObserverInterface::PacketsReceivedEvent::ReceivedPacket:: Builder() - .setPacketReceiveTime(networkData.receiveTimePoint) + .setPacketReceiveTime(networkData.getReceiveTimePoint()) .setPacketNumBytes(packet.buf->computeChainDataLength()) .build()); } @@ -1903,12 +1903,13 @@ void QuicTransportBase::onNetworkData( }); } - for (auto& packet : networkData.packets) { + const auto receiveTimePoint = networkData.getReceiveTimePoint(); + auto packets = std::move(networkData).movePackets(); + for (auto& packet : packets) { onReadData( peer, NetworkDataSingle( - ReceivedPacket(std::move(packet.buf)), - networkData.receiveTimePoint)); + ReceivedPacket(std::move(packet.buf)), receiveTimePoint)); if (conn_->peerConnectionError) { closeImpl(QuicError( QuicErrorCode(TransportErrorCode::NO_ERROR), "Peer closed")); diff --git a/quic/client/QuicClientTransport.cpp b/quic/client/QuicClientTransport.cpp index ab02b5ee0..e4fbe4257 100644 --- a/quic/client/QuicClientTransport.cpp +++ b/quic/client/QuicClientTransport.cpp @@ -1236,9 +1236,9 @@ void QuicClientTransport::recvMsg( size_t len = bytesRead; size_t remaining = len; size_t offset = 0; - size_t totalNumPackets = - networkData.packets.size() + ((len + params.gro - 1) / params.gro); - networkData.packets.reserve(totalNumPackets); + size_t totalNumPackets = networkData.getPackets().size() + + ((len + params.gro - 1) / params.gro); + networkData.reserve(totalNumPackets); while (remaining) { if (static_cast(remaining) > params.gro) { auto tmp = readBuffer->cloneOne(); @@ -1251,18 +1251,18 @@ void QuicClientTransport::recvMsg( offset += params.gro; remaining -= params.gro; - networkData.packets.emplace_back(std::move(tmp)); + networkData.addPacket(ReceivedPacket(std::move(tmp))); } else { // do not clone the last packet // start at offset, use all the remaining data readBuffer->trimStart(offset); DCHECK_EQ(readBuffer->length(), remaining); remaining = 0; - networkData.packets.emplace_back(std::move(readBuffer)); + networkData.addPacket(ReceivedPacket(std::move(readBuffer))); } } } else { - networkData.packets.emplace_back(std::move(readBuffer)); + networkData.addPacket(ReceivedPacket(std::move(readBuffer))); } trackDatagramReceived(bytesRead); } @@ -1382,9 +1382,9 @@ void QuicClientTransport::recvMmsg( size_t len = bytesRead; size_t remaining = len; size_t offset = 0; - size_t totalNumPackets = - networkData.packets.size() + ((len + params.gro - 1) / params.gro); - networkData.packets.reserve(totalNumPackets); + size_t totalNumPackets = networkData.getPackets().size() + + ((len + params.gro - 1) / params.gro); + networkData.reserve(totalNumPackets); while (remaining) { if (static_cast(remaining) > params.gro) { auto tmp = readBuffer->cloneOne(); @@ -1397,18 +1397,18 @@ void QuicClientTransport::recvMmsg( offset += params.gro; remaining -= params.gro; - networkData.packets.emplace_back(std::move(tmp)); + networkData.addPacket(ReceivedPacket(std::move(tmp))); } else { // do not clone the last packet // start at offset, use all the remaining data readBuffer->trimStart(offset); DCHECK_EQ(readBuffer->length(), remaining); remaining = 0; - networkData.packets.emplace_back(std::move(readBuffer)); + networkData.addPacket(ReceivedPacket(std::move(readBuffer))); } } } else { - networkData.packets.emplace_back(std::move(readBuffer)); + networkData.addPacket(ReceivedPacket(std::move(readBuffer))); } trackDatagramReceived(bytesRead); @@ -1423,7 +1423,7 @@ void QuicClientTransport::onNotifyDataAvailable( const uint16_t numPackets = conn_->transportSettings.maxRecvBatchSize; NetworkData networkData; - networkData.packets.reserve(numPackets); + networkData.reserve(numPackets); size_t totalData = 0; folly::Optional server; @@ -1432,7 +1432,7 @@ void QuicClientTransport::onNotifyDataAvailable( readBufferSize, numPackets, networkData, server, totalData); // track the received packets - for (const auto& packet : networkData.packets) { + for (const auto& packet : networkData.getPackets()) { if (!packet.buf) { continue; } @@ -1479,7 +1479,7 @@ void QuicClientTransport::onNotifyDataAvailable( recvMsg(sock, readBufferSize, numPackets, networkData, server, totalData); } - if (networkData.packets.empty()) { + if (networkData.getPackets().empty()) { // recvMmsg and recvMsg might have already set the reason and counter if (conn_->loopDetectorCallback) { if (conn_->readDebugState.noReadReason == NoReadReason::READ_OK) { @@ -1497,8 +1497,8 @@ void QuicClientTransport::onNotifyDataAvailable( // TODO: we can get better receive time accuracy than this, with // SO_TIMESTAMP or SIOCGSTAMP. auto packetReceiveTime = Clock::now(); - networkData.receiveTimePoint = packetReceiveTime; - networkData.totalData = totalData; + networkData.setReceiveTimePoint(packetReceiveTime); + networkData.setTotalData(totalData); onNetworkData(*server, std::move(networkData)); } diff --git a/quic/common/NetworkData.h b/quic/common/NetworkData.h index 1c0db1255..c226db5f0 100644 --- a/quic/common/NetworkData.h +++ b/quic/common/NetworkData.h @@ -24,24 +24,20 @@ struct ReceivedPacket { }; struct NetworkData { - TimePoint receiveTimePoint; - std::vector packets; - size_t totalData{0}; - NetworkData() = default; NetworkData(Buf&& buf, const TimePoint& receiveTime) - : receiveTimePoint(receiveTime) { + : receiveTimePoint_(receiveTime) { if (buf) { - totalData = buf->computeChainDataLength(); - packets.emplace_back(std::move(buf)); + totalData_ = buf->computeChainDataLength(); + packets_.emplace_back(std::move(buf)); } } NetworkData( std::vector&& packetBufs, const TimePoint& receiveTimePointIn) - : receiveTimePoint(receiveTimePointIn), - packets([&packetBufs]() { + : receiveTimePoint_(receiveTimePointIn), + packets_([&packetBufs]() { std::vector result; result.reserve(packetBufs.size()); for (auto& packetBuf : packetBufs) { @@ -49,17 +45,49 @@ struct NetworkData { } return result; }()), - totalData([this]() { + totalData_([this]() { size_t result = 0; - for (const auto& packet : packets) { + for (const auto& packet : packets_) { result += packet.buf->computeChainDataLength(); } return result; }()) {} + void reserve(size_t size) { + packets_.reserve(size); + } + + void addPacket(ReceivedPacket&& packetIn) { + packets_.emplace_back(std::move(packetIn)); + } + + [[nodiscard]] const std::vector& getPackets() const { + return packets_; + } + + std::vector movePackets() && { + return std::move(packets_); + } + + void setReceiveTimePoint(const TimePoint& receiveTimePointIn) { + receiveTimePoint_ = receiveTimePointIn; + } + + [[nodiscard]] TimePoint getReceiveTimePoint() const { + return receiveTimePoint_; + } + + void setTotalData(const size_t totalDataIn) { + totalData_ = totalDataIn; + } + + [[nodiscard]] size_t getTotalData() const { + return totalData_; + } + std::unique_ptr moveAllData() && { std::unique_ptr buf; - for (auto& packet : packets) { + for (auto& packet : packets_) { if (buf) { buf->prependChain(std::move(packet.buf)); } else { @@ -68,6 +96,11 @@ struct NetworkData { } return buf; } + + private: + TimePoint receiveTimePoint_; + std::vector packets_; + size_t totalData_{0}; }; struct NetworkDataSingle { diff --git a/quic/common/QuicAsyncUDPSocketWrapper.cpp b/quic/common/QuicAsyncUDPSocketWrapper.cpp index 066c68d41..d1fb95c1e 100644 --- a/quic/common/QuicAsyncUDPSocketWrapper.cpp +++ b/quic/common/QuicAsyncUDPSocketWrapper.cpp @@ -142,9 +142,9 @@ QuicAsyncUDPSocketWrapperImpl::recvMmsg( size_t len = bytesRead; size_t remaining = len; size_t offset = 0; - size_t totalNumPackets = - networkData.packets.size() + ((len + params.gro - 1) / params.gro); - networkData.packets.reserve(totalNumPackets); + size_t totalNumPackets = networkData.getPackets().size() + + ((len + params.gro - 1) / params.gro); + networkData.reserve(totalNumPackets); while (remaining) { if (static_cast(remaining) > params.gro) { auto tmp = readBuffer->cloneOne(); @@ -157,18 +157,18 @@ QuicAsyncUDPSocketWrapperImpl::recvMmsg( offset += params.gro; remaining -= params.gro; - networkData.packets.emplace_back(std::move(tmp)); + networkData.addPacket(ReceivedPacket(std::move(tmp))); } else { // do not clone the last packet // start at offset, use all the remaining data readBuffer->trimStart(offset); DCHECK_EQ(readBuffer->length(), remaining); remaining = 0; - networkData.packets.emplace_back(std::move(readBuffer)); + networkData.addPacket(ReceivedPacket(std::move(readBuffer))); } } } else { - networkData.packets.emplace_back(std::move(readBuffer)); + networkData.addPacket(ReceivedPacket(std::move(readBuffer))); } } diff --git a/quic/fizz/client/test/QuicClientTransportTestUtil.h b/quic/fizz/client/test/QuicClientTransportTestUtil.h index e66878e62..160951757 100644 --- a/quic/fizz/client/test/QuicClientTransportTestUtil.h +++ b/quic/fizz/client/test/QuicClientTransportTestUtil.h @@ -720,7 +720,7 @@ class QuicClientTransportTestBase : public virtual testing::Test { NetworkData&& data, bool writes = true, folly::SocketAddress* peer = nullptr) { - for (const auto& packet : data.packets) { + for (const auto& packet : data.getPackets()) { deliverDataWithoutErrorCheck( peer == nullptr ? serverAddr : *peer, packet.buf->coalesce(), writes); } @@ -756,7 +756,7 @@ class QuicClientTransportTestBase : public virtual testing::Test { NetworkData&& data, bool writes = true, folly::SocketAddress* peer = nullptr) { - for (const auto& packet : data.packets) { + for (const auto& packet : data.getPackets()) { deliverData( peer == nullptr ? serverAddr : *peer, packet.buf->coalesce(), writes); } diff --git a/quic/server/QuicServerWorker.cpp b/quic/server/QuicServerWorker.cpp index 0a66c6023..1279e0c49 100644 --- a/quic/server/QuicServerWorker.cpp +++ b/quic/server/QuicServerWorker.cpp @@ -590,7 +590,7 @@ void QuicServerWorker::forwardNetworkData( "Forwarding packet with unknown connId version from client={} to another process, routingInfo={}", client.describe(), logRoutingInfo(routingData.destinationConnId)); - auto recvTime = networkData.receiveTimePoint; + auto recvTime = networkData.getReceiveTimePoint(); takeoverPktHandler_.forwardPacketToAnotherServer( client, std::move(networkData).moveAllData(), recvTime); QUIC_STATS(statsCallback_, onPacketForwarded); @@ -795,7 +795,7 @@ void QuicServerWorker::dispatchPacketData( "Forwarding packet from client={} to another process, routingInfo={}", client.describe(), logRoutingInfo(dstConnId)); - auto recvTime = networkData.receiveTimePoint; + auto recvTime = networkData.getReceiveTimePoint(); takeoverPktHandler_.forwardPacketToAnotherServer( client, std::move(networkData).moveAllData(), recvTime); QUIC_STATS(statsCallback_, onPacketForwarded); @@ -879,7 +879,7 @@ void QuicServerWorker::dispatchPacketData( // This could be a new connection, add it in the map // verify that the initial packet is at least min initial bytes // to avoid amplification attacks. Also check CID sizes. - if (networkData.totalData < kMinInitialPacketSize || + if (networkData.getTotalData() < kMinInitialPacketSize || !isValidConnIdLength(dstConnId)) { // Don't even attempt to forward the packet, just drop it. VLOG(3) << "Dropping small initial packet from client=" << client; @@ -889,7 +889,7 @@ void QuicServerWorker::dispatchPacketData( // If there is a token present, decrypt it (could be either a retry // token or a new token) - folly::io::Cursor cursor(networkData.packets.front().buf.get()); + folly::io::Cursor cursor(networkData.getPackets().front().buf.get()); auto maybeEncryptedToken = maybeGetEncryptedToken(cursor); bool hasTokenSecret = transportSettings_.retryTokenSecret.hasValue(); @@ -914,7 +914,7 @@ void QuicServerWorker::dispatchPacketData( // send a retry packet back to the client if (!isValidRetryToken && ((newConnRateLimiter_ && - newConnRateLimiter_->check(networkData.receiveTimePoint)) || + newConnRateLimiter_->check(networkData.getReceiveTimePoint())) || (unfinishedHandshakeLimitFn_.has_value() && globalUnfinishedHandshakes >= (*unfinishedHandshakeLimitFn_)()))) { QUIC_STATS(statsCallback_, onConnectionRateLimited); @@ -948,7 +948,7 @@ void QuicServerWorker::sendResetPacket( // Only send resets in response to short header packets. return; } - auto packetSize = networkData.totalData; + auto packetSize = networkData.getTotalData(); auto resetSize = std::min(packetSize, kDefaultMaxUDPPayload); // Per the spec, less than 43 we should respond with packet size - 1. if (packetSize < 43) { diff --git a/quic/server/test/QuicServerTest.cpp b/quic/server/test/QuicServerTest.cpp index 3c5834d05..9cd680e98 100644 --- a/quic/server/test/QuicServerTest.cpp +++ b/quic/server/test/QuicServerTest.cpp @@ -44,9 +44,9 @@ namespace quic { namespace test { MATCHER_P(NetworkDataMatches, networkData, "") { - for (size_t i = 0; i < arg.packets.size(); ++i) { + for (size_t i = 0; i < arg.getPackets().size(); ++i) { folly::IOBufEqualTo eq; - bool equals = eq(*arg.packets[i].buf, networkData); + bool equals = eq(*arg.getPackets()[i].buf, networkData); if (equals) { return true; } @@ -1971,8 +1971,8 @@ TEST_F(QuicServerWorkerTakeoverTest, QuicServerTakeoverProcessForwardedPkt) { EXPECT_EQ(addr.getPort(), clientAddr.getPort()); // the original data should be extracted after processing takeover // protocol related information - EXPECT_EQ(networkData->packets.size(), 1); - EXPECT_TRUE(eq(*data, *(networkData->packets[0].buf))); + EXPECT_EQ(networkData->getPackets().size(), 1); + EXPECT_TRUE(eq(*data, *(networkData->getPackets()[0].buf))); EXPECT_TRUE(isForwardedData); }; EXPECT_CALL(*takeoverWorkerCb_, routeDataToWorkerLong(_, _, _, _, _)) @@ -2163,9 +2163,9 @@ class QuicServerTest : public Test { .WillByDefault(Invoke( [&, expected = std::shared_ptr(data->clone())]( auto, const auto& networkData) mutable { - EXPECT_GT(networkData.packets.size(), 0); + EXPECT_GT(networkData.getPackets().size(), 0); EXPECT_TRUE(folly::IOBufEqualTo()( - *networkData.packets[0].buf, *expected)); + *networkData.getPackets()[0].buf, *expected)); std::unique_lock lg(m); calledOnNetworkData = true; cv.notify_one(); @@ -2324,9 +2324,9 @@ TEST_F(QuicServerTest, RouteDataFromDifferentThread) { EXPECT_CALL(*transport, onNetworkData(_, _)) .WillOnce(Invoke([&](auto, const auto& networkData) { - EXPECT_GT(networkData.packets.size(), 0); - EXPECT_TRUE( - folly::IOBufEqualTo()(*networkData.packets[0].buf, *initialData)); + EXPECT_GT(networkData.getPackets().size(), 0); + EXPECT_TRUE(folly::IOBufEqualTo()( + *networkData.getPackets()[0].buf, *initialData)); })); server_->routeDataToWorker( @@ -2424,9 +2424,9 @@ class QuicServerTakeoverTest : public Test { EXPECT_CALL(*transport, onNetworkData(_, _)) .WillOnce(Invoke( [&, expected = data.get()](auto, const auto& networkData) { - EXPECT_GT(networkData.packets.size(), 0); + EXPECT_GT(networkData.getPackets().size(), 0); EXPECT_TRUE(folly::IOBufEqualTo()( - *networkData.packets[0].buf, *expected)); + *networkData.getPackets()[0].buf, *expected)); baton.post(); })); return transport; @@ -2532,9 +2532,9 @@ class QuicServerTakeoverTest : public Test { EXPECT_CALL(*transportCbForOldServer, onNetworkData(_, _)) .WillOnce( Invoke([&, expected = data.get()](auto, const auto& networkData) { - EXPECT_GT(networkData.packets.size(), 0); + EXPECT_GT(networkData.getPackets().size(), 0); EXPECT_TRUE(folly::IOBufEqualTo()( - *networkData.packets[0].buf, *expected)); + *networkData.getPackets()[0].buf, *expected)); b1.post(); })); // new quic server receives the packet and forwards it @@ -2995,9 +2995,9 @@ TEST_F(QuicServerTest, ZeroRttPacketRoute) { EXPECT_CALL(*transport, onNetworkData(_, _)) .WillOnce(Invoke( [&, expected = data.get()](auto, const auto& networkData) { - EXPECT_GT(networkData.packets.size(), 0); + EXPECT_GT(networkData.getPackets().size(), 0); EXPECT_TRUE(folly::IOBufEqualTo()( - *networkData.packets[0].buf, *expected)); + *networkData.getPackets()[0].buf, *expected)); b.post(); })); return transport; @@ -3038,9 +3038,9 @@ TEST_F(QuicServerTest, ZeroRttPacketRoute) { folly::Baton<> b1; auto verifyZeroRtt = [&](const folly::SocketAddress& peer, const NetworkData& networkData) noexcept { - EXPECT_GT(networkData.packets.size(), 0); + EXPECT_GT(networkData.getPackets().size(), 0); EXPECT_EQ(peer, reader->getSocket().address()); - EXPECT_TRUE(folly::IOBufEqualTo()(*data, *networkData.packets[0].buf)); + EXPECT_TRUE(folly::IOBufEqualTo()(*data, *networkData.getPackets()[0].buf)); b1.post(); }; EXPECT_CALL(*transport, onNetworkData(_, _)).WillOnce(Invoke(verifyZeroRtt)); @@ -3091,7 +3091,7 @@ TEST_F(QuicServerTest, ZeroRttBeforeInitial) { EXPECT_CALL(*transport, onNetworkData(_, _)) .Times(2) .WillRepeatedly(Invoke([&](auto, auto& networkData) { - for (auto& packet : networkData.packets) { + for (const auto& packet : networkData.getPackets()) { receivedData.emplace_back(packet.buf->clone()); } if (receivedData.size() == 2) {