From 1275798146552adde3b578661b72f41cef092ac7 Mon Sep 17 00:00:00 2001 From: Matt Joras Date: Tue, 20 Dec 2022 11:08:43 -0800 Subject: [PATCH] Make the AckState for Initial/Handshake a unique_ptr Summary: We don't need to carry these states after the handshake is confirmed, so make them pointers instead. This will facilitate adding a structure to the AckState for tracking duplicate packets. (Note: this ignores all push blocking failures!) Reviewed By: hanidamlaj Differential Revision: D41626895 fbshipit-source-id: d8ac960b3672b9bb9adaaececa53a1203ec801e0 --- quic/api/QuicTransportBase.cpp | 8 +- quic/api/QuicTransportFunctions.cpp | 64 ++++++----- quic/api/QuicTransportFunctions.h | 4 + quic/api/test/QuicPacketSchedulerTest.cpp | 94 ++++++++-------- quic/api/test/QuicTransportBaseTest.cpp | 2 +- quic/api/test/QuicTransportFunctionsTest.cpp | 88 +++++++-------- quic/client/QuicClientTransport.cpp | 16 +-- quic/client/state/ClientStateMachine.cpp | 8 +- quic/codec/QuicReadCodec.cpp | 5 +- quic/loss/test/QuicLossFunctionsTest.cpp | 16 +-- quic/server/QuicServerTransport.cpp | 8 +- quic/server/state/ServerStateMachine.cpp | 8 +- quic/server/test/QuicServerTransportTest.cpp | 14 +-- quic/state/AckStates.h | 10 +- quic/state/QuicStateFunctions.cpp | 106 ++++++++++++------ quic/state/QuicStateFunctions.h | 4 + quic/state/StateData.h | 2 + .../stream/test/StreamStateMachineTest.cpp | 4 +- quic/state/test/AckHandlersTest.cpp | 24 ++-- quic/state/test/QuicStateFunctionsTest.cpp | 13 ++- 20 files changed, 281 insertions(+), 217 deletions(-) diff --git a/quic/api/QuicTransportBase.cpp b/quic/api/QuicTransportBase.cpp index dff1671a3..f1e452c77 100644 --- a/quic/api/QuicTransportBase.cpp +++ b/quic/api/QuicTransportBase.cpp @@ -385,8 +385,12 @@ void QuicTransportBase::closeImpl( conn_->pendingEvents = QuicConnectionStateBase::PendingEvents(); conn_->streamManager->clearActionable(); conn_->streamManager->clearWritable(); - conn_->ackStates.initialAckState.acks.clear(); - conn_->ackStates.handshakeAckState.acks.clear(); + if (conn_->ackStates.initialAckState) { + conn_->ackStates.initialAckState->acks.clear(); + } + if (conn_->ackStates.handshakeAckState) { + conn_->ackStates.handshakeAckState->acks.clear(); + } conn_->ackStates.appDataAckState.acks.clear(); if (transportReadyNotified_) { diff --git a/quic/api/QuicTransportFunctions.cpp b/quic/api/QuicTransportFunctions.cpp index bd2c33f9f..b987b6a68 100644 --- a/quic/api/QuicTransportFunctions.cpp +++ b/quic/api/QuicTransportFunctions.cpp @@ -49,9 +49,15 @@ std::string largestAckScheduledToString( const quic::QuicConnectionStateBase& conn) noexcept { return folly::to( "[", - optionalToString(conn.ackStates.initialAckState.largestAckScheduled), + optionalToString( + conn.ackStates.initialAckState + ? conn.ackStates.initialAckState->largestAckScheduled + : folly::none), ",", - optionalToString(conn.ackStates.handshakeAckState.largestAckScheduled), + optionalToString( + conn.ackStates.handshakeAckState + ? conn.ackStates.handshakeAckState->largestAckScheduled + : folly::none), ",", optionalToString(conn.ackStates.appDataAckState.largestAckScheduled), "]"); @@ -61,35 +67,20 @@ std::string largestAckToSendToString( const quic::QuicConnectionStateBase& conn) noexcept { return folly::to( "[", - optionalToString(largestAckToSend(conn.ackStates.initialAckState)), + optionalToString( + conn.ackStates.initialAckState + ? largestAckToSend(*conn.ackStates.initialAckState) + : folly::none), ",", - optionalToString(largestAckToSend(conn.ackStates.handshakeAckState)), + optionalToString( + conn.ackStates.handshakeAckState + ? largestAckToSend(*conn.ackStates.handshakeAckState) + : folly::none), ",", optionalToString(largestAckToSend(conn.ackStates.appDataAckState)), "]"); } -bool toWriteInitialAcks(const quic::QuicConnectionStateBase& conn) { - return ( - conn.initialWriteCipher && - hasAcksToSchedule(conn.ackStates.initialAckState) && - conn.ackStates.initialAckState.needsToSendAckImmediately); -} - -bool toWriteHandshakeAcks(const quic::QuicConnectionStateBase& conn) { - return ( - conn.handshakeWriteCipher && - hasAcksToSchedule(conn.ackStates.handshakeAckState) && - conn.ackStates.handshakeAckState.needsToSendAckImmediately); -} - -bool toWriteAppDataAcks(const quic::QuicConnectionStateBase& conn) { - return ( - conn.oneRttWriteCipher && - hasAcksToSchedule(conn.ackStates.appDataAckState) && - conn.ackStates.appDataAckState.needsToSendAckImmediately); -} - using namespace quic; /** @@ -1796,11 +1787,13 @@ void handshakeConfirmed(QuicConnectionStateBase& conn) { conn.readCodec->setInitialReadCipher(nullptr); conn.readCodec->setInitialHeaderCipher(nullptr); implicitAckCryptoStream(conn, EncryptionLevel::Initial); + conn.ackStates.initialAckState.reset(); conn.handshakeWriteCipher.reset(); conn.handshakeWriteHeaderCipher.reset(); conn.readCodec->setHandshakeReadCipher(nullptr); conn.readCodec->setHandshakeHeaderCipher(nullptr); implicitAckCryptoStream(conn, EncryptionLevel::Handshake); + conn.ackStates.handshakeAckState.reset(); } bool hasInitialOrHandshakeCiphers(QuicConnectionStateBase& conn) { @@ -1839,4 +1832,25 @@ bool setCustomTransportParameter( return true; } +bool toWriteInitialAcks(const quic::QuicConnectionStateBase& conn) { + return ( + conn.initialWriteCipher && conn.ackStates.initialAckState && + hasAcksToSchedule(*conn.ackStates.initialAckState) && + conn.ackStates.initialAckState->needsToSendAckImmediately); +} + +bool toWriteHandshakeAcks(const quic::QuicConnectionStateBase& conn) { + return ( + conn.handshakeWriteCipher && conn.ackStates.handshakeAckState && + hasAcksToSchedule(*conn.ackStates.handshakeAckState) && + conn.ackStates.handshakeAckState->needsToSendAckImmediately); +} + +bool toWriteAppDataAcks(const quic::QuicConnectionStateBase& conn) { + return ( + conn.oneRttWriteCipher && + hasAcksToSchedule(conn.ackStates.appDataAckState) && + conn.ackStates.appDataAckState.needsToSendAckImmediately); +} + } // namespace quic diff --git a/quic/api/QuicTransportFunctions.h b/quic/api/QuicTransportFunctions.h index 6625c185d..b91bcc11a 100644 --- a/quic/api/QuicTransportFunctions.h +++ b/quic/api/QuicTransportFunctions.h @@ -344,4 +344,8 @@ bool setCustomTransportParameter( std::unique_ptr customParam, std::vector& customTransportParameters); +bool toWriteInitialAcks(const quic::QuicConnectionStateBase& conn); +bool toWriteHandshakeAcks(const quic::QuicConnectionStateBase& conn); +bool toWriteAppDataAcks(const quic::QuicConnectionStateBase& conn); + } // namespace quic diff --git a/quic/api/test/QuicPacketSchedulerTest.cpp b/quic/api/test/QuicPacketSchedulerTest.cpp index 66368cec6..5757ec09e 100644 --- a/quic/api/test/QuicPacketSchedulerTest.cpp +++ b/quic/api/test/QuicPacketSchedulerTest.cpp @@ -175,7 +175,7 @@ TEST_F(QuicPacketSchedulerTest, CryptoPaddingInitialPacket) { RegularQuicPacketBuilder builder( conn.udpSendPacketLen, std::move(longHeader), - conn.ackStates.initialAckState.largestAckedByPeer.value_or(0)); + conn.ackStates.initialAckState->largestAckedByPeer.value_or(0)); FrameScheduler cryptoOnlyScheduler = std::move( FrameScheduler::Builder( @@ -207,10 +207,10 @@ TEST_F(QuicPacketSchedulerTest, PaddingInitialPureAcks) { RegularQuicPacketBuilder builder( conn.udpSendPacketLen, std::move(longHeader), - conn.ackStates.handshakeAckState.largestAckedByPeer.value_or(0)); - conn.ackStates.initialAckState.largestRecvdPacketTime = Clock::now(); - conn.ackStates.initialAckState.needsToSendAckImmediately = true; - conn.ackStates.initialAckState.acks.insert(10); + conn.ackStates.handshakeAckState->largestAckedByPeer.value_or(0)); + conn.ackStates.initialAckState->largestRecvdPacketTime = Clock::now(); + conn.ackStates.initialAckState->needsToSendAckImmediately = true; + conn.ackStates.initialAckState->acks.insert(10); FrameScheduler acksOnlyScheduler = std::move( FrameScheduler::Builder( @@ -241,10 +241,10 @@ TEST_F(QuicPacketSchedulerTest, InitialPaddingDoesNotUseWrapper) { RegularQuicPacketBuilder builder( conn.udpSendPacketLen, std::move(longHeader), - conn.ackStates.handshakeAckState.largestAckedByPeer.value_or(0)); - conn.ackStates.initialAckState.largestRecvdPacketTime = Clock::now(); - conn.ackStates.initialAckState.needsToSendAckImmediately = true; - conn.ackStates.initialAckState.acks.insert(10); + conn.ackStates.handshakeAckState->largestAckedByPeer.value_or(0)); + conn.ackStates.initialAckState->largestRecvdPacketTime = Clock::now(); + conn.ackStates.initialAckState->needsToSendAckImmediately = true; + conn.ackStates.initialAckState->acks.insert(10); FrameScheduler acksOnlyScheduler = std::move( FrameScheduler::Builder( @@ -275,7 +275,7 @@ TEST_F(QuicPacketSchedulerTest, CryptoServerInitialPadded) { RegularQuicPacketBuilder builder1( conn.udpSendPacketLen, std::move(longHeader1), - conn.ackStates.initialAckState.largestAckedByPeer.value_or(0)); + conn.ackStates.initialAckState->largestAckedByPeer.value_or(0)); FrameScheduler scheduler = std::move( FrameScheduler::Builder( @@ -308,7 +308,7 @@ TEST_F(QuicPacketSchedulerTest, PadTwoInitialPackets) { RegularQuicPacketBuilder builder1( conn.udpSendPacketLen, std::move(longHeader1), - conn.ackStates.initialAckState.largestAckedByPeer.value_or(0)); + conn.ackStates.initialAckState->largestAckedByPeer.value_or(0)); FrameScheduler scheduler = std::move( FrameScheduler::Builder( @@ -336,7 +336,7 @@ TEST_F(QuicPacketSchedulerTest, PadTwoInitialPackets) { RegularQuicPacketBuilder builder2( conn.udpSendPacketLen, std::move(longHeader2), - conn.ackStates.initialAckState.largestAckedByPeer.value_or(0)); + conn.ackStates.initialAckState->largestAckedByPeer.value_or(0)); writeDataToQuicStream( conn.cryptoState->initialStream, folly::IOBuf::copyBuffer("shlo again")); auto result2 = scheduler.scheduleFramesForPacket( @@ -359,7 +359,7 @@ TEST_F(QuicPacketSchedulerTest, CryptoPaddingRetransmissionClientInitial) { RegularQuicPacketBuilder builder( conn.udpSendPacketLen, std::move(longHeader), - conn.ackStates.initialAckState.largestAckedByPeer.value_or(0)); + conn.ackStates.initialAckState->largestAckedByPeer.value_or(0)); FrameScheduler scheduler = std::move( FrameScheduler::Builder( @@ -391,7 +391,7 @@ TEST_F(QuicPacketSchedulerTest, CryptoSchedulerOnlySingleLossFits) { RegularQuicPacketBuilder builder( conn.udpSendPacketLen, std::move(longHeader), - conn.ackStates.handshakeAckState.largestAckedByPeer.value_or(0)); + conn.ackStates.handshakeAckState->largestAckedByPeer.value_or(0)); builder.encodePacketHeader(); PacketBuilderWrapper builderWrapper(builder, 13); CryptoStreamScheduler scheduler( @@ -419,8 +419,8 @@ TEST_F(QuicPacketSchedulerTest, CryptoWritePartialLossBuffer) { RegularQuicPacketBuilder builder( 25, std::move(longHeader), - conn.ackStates.initialAckState.largestAckedByPeer.value_or( - conn.ackStates.initialAckState.nextPacketNum)); + conn.ackStates.initialAckState->largestAckedByPeer.value_or( + conn.ackStates.initialAckState->nextPacketNum)); FrameScheduler cryptoOnlyScheduler = std::move( FrameScheduler::Builder( @@ -712,7 +712,7 @@ TEST_F(QuicPacketSchedulerTest, DoNotCloneProcessedClonedPacket) { RegularQuicPacketBuilder builder( conn.udpSendPacketLen, std::move(header), - conn.ackStates.initialAckState.largestAckedByPeer.value_or(0)); + conn.ackStates.initialAckState->largestAckedByPeer.value_or(0)); auto result = cloningScheduler.scheduleFramesForPacket( std::move(builder), kDefaultUDPSendPacketLen); EXPECT_TRUE(result.packetEvent.has_value() && result.packet.has_value()); @@ -767,8 +767,8 @@ TEST_F(QuicPacketSchedulerTest, CloneSchedulerHasHandshakeDataAndAcks) { WriteCryptoFrame(0, 4)); // Make it look like we received some acks from the peer. - conn.ackStates.handshakeAckState.acks.insert(10); - conn.ackStates.handshakeAckState.largestRecvdPacketTime = Clock::now(); + conn.ackStates.handshakeAckState->acks.insert(10); + conn.ackStates.handshakeAckState->largestRecvdPacketTime = Clock::now(); // Create cloning scheduler. CloningScheduler cloningScheduler(noopScheduler, conn, "CopyCat", 0); @@ -1184,19 +1184,19 @@ class AckSchedulingTest : public TestWithParam {}; TEST_F(QuicPacketSchedulerTest, AckStateHasAcksToSchedule) { QuicClientConnectionState conn( FizzClientQuicHandshakeContext::Builder().build()); - EXPECT_FALSE(hasAcksToSchedule(conn.ackStates.initialAckState)); - EXPECT_FALSE(hasAcksToSchedule(conn.ackStates.handshakeAckState)); + EXPECT_FALSE(hasAcksToSchedule(*conn.ackStates.initialAckState)); + EXPECT_FALSE(hasAcksToSchedule(*conn.ackStates.handshakeAckState)); EXPECT_FALSE(hasAcksToSchedule(conn.ackStates.appDataAckState)); - conn.ackStates.initialAckState.acks.insert(0, 100); - EXPECT_TRUE(hasAcksToSchedule(conn.ackStates.initialAckState)); + conn.ackStates.initialAckState->acks.insert(0, 100); + EXPECT_TRUE(hasAcksToSchedule(*conn.ackStates.initialAckState)); - conn.ackStates.handshakeAckState.acks.insert(0, 100); - conn.ackStates.handshakeAckState.largestAckScheduled = 200; - EXPECT_FALSE(hasAcksToSchedule(conn.ackStates.handshakeAckState)); + conn.ackStates.handshakeAckState->acks.insert(0, 100); + conn.ackStates.handshakeAckState->largestAckScheduled = 200; + EXPECT_FALSE(hasAcksToSchedule(*conn.ackStates.handshakeAckState)); - conn.ackStates.handshakeAckState.largestAckScheduled = folly::none; - EXPECT_TRUE(hasAcksToSchedule(conn.ackStates.handshakeAckState)); + conn.ackStates.handshakeAckState->largestAckScheduled = folly::none; + EXPECT_TRUE(hasAcksToSchedule(*conn.ackStates.handshakeAckState)); } TEST_F(QuicPacketSchedulerTest, AckSchedulerHasAcksToSchedule) { @@ -1212,30 +1212,30 @@ TEST_F(QuicPacketSchedulerTest, AckSchedulerHasAcksToSchedule) { EXPECT_FALSE(handshakeAckScheduler.hasPendingAcks()); EXPECT_FALSE(appDataAckScheduler.hasPendingAcks()); - conn.ackStates.initialAckState.acks.insert(0, 100); + conn.ackStates.initialAckState->acks.insert(0, 100); EXPECT_TRUE(initialAckScheduler.hasPendingAcks()); - conn.ackStates.handshakeAckState.acks.insert(0, 100); - conn.ackStates.handshakeAckState.largestAckScheduled = 200; + conn.ackStates.handshakeAckState->acks.insert(0, 100); + conn.ackStates.handshakeAckState->largestAckScheduled = 200; EXPECT_FALSE(handshakeAckScheduler.hasPendingAcks()); - conn.ackStates.handshakeAckState.largestAckScheduled = folly::none; + conn.ackStates.handshakeAckState->largestAckScheduled = folly::none; EXPECT_TRUE(handshakeAckScheduler.hasPendingAcks()); } TEST_F(QuicPacketSchedulerTest, LargestAckToSend) { QuicClientConnectionState conn( FizzClientQuicHandshakeContext::Builder().build()); - EXPECT_EQ(folly::none, largestAckToSend(conn.ackStates.initialAckState)); - EXPECT_EQ(folly::none, largestAckToSend(conn.ackStates.handshakeAckState)); + EXPECT_EQ(folly::none, largestAckToSend(*conn.ackStates.initialAckState)); + EXPECT_EQ(folly::none, largestAckToSend(*conn.ackStates.handshakeAckState)); EXPECT_EQ(folly::none, largestAckToSend(conn.ackStates.appDataAckState)); - conn.ackStates.initialAckState.acks.insert(0, 50); - conn.ackStates.handshakeAckState.acks.insert(0, 50); - conn.ackStates.handshakeAckState.acks.insert(75, 150); + conn.ackStates.initialAckState->acks.insert(0, 50); + conn.ackStates.handshakeAckState->acks.insert(0, 50); + conn.ackStates.handshakeAckState->acks.insert(75, 150); - EXPECT_EQ(50, *largestAckToSend(conn.ackStates.initialAckState)); - EXPECT_EQ(150, *largestAckToSend(conn.ackStates.handshakeAckState)); + EXPECT_EQ(50, *largestAckToSend(*conn.ackStates.initialAckState)); + EXPECT_EQ(150, *largestAckToSend(*conn.ackStates.handshakeAckState)); EXPECT_EQ(folly::none, largestAckToSend(conn.ackStates.appDataAckState)); } @@ -1253,18 +1253,18 @@ TEST_F(QuicPacketSchedulerTest, NeedsToSendAckWithoutAcksAvailable) { EXPECT_FALSE(handshakeAckScheduler.hasPendingAcks()); EXPECT_FALSE(appDataAckScheduler.hasPendingAcks()); - conn.ackStates.initialAckState.needsToSendAckImmediately = true; - conn.ackStates.handshakeAckState.needsToSendAckImmediately = true; + conn.ackStates.initialAckState->needsToSendAckImmediately = true; + conn.ackStates.handshakeAckState->needsToSendAckImmediately = true; conn.ackStates.appDataAckState.needsToSendAckImmediately = true; - conn.ackStates.initialAckState.acks.insert(0, 100); + conn.ackStates.initialAckState->acks.insert(0, 100); EXPECT_TRUE(initialAckScheduler.hasPendingAcks()); - conn.ackStates.handshakeAckState.acks.insert(0, 100); - conn.ackStates.handshakeAckState.largestAckScheduled = 200; + conn.ackStates.handshakeAckState->acks.insert(0, 100); + conn.ackStates.handshakeAckState->largestAckScheduled = 200; EXPECT_FALSE(handshakeAckScheduler.hasPendingAcks()); - conn.ackStates.handshakeAckState.largestAckScheduled = folly::none; + conn.ackStates.handshakeAckState->largestAckScheduled = folly::none; EXPECT_TRUE(handshakeAckScheduler.hasPendingAcks()); } @@ -2380,7 +2380,7 @@ TEST_F(QuicPacketSchedulerTest, ImmediateAckFrameSchedulerOnRequest) { RegularQuicPacketBuilder builder( conn.udpSendPacketLen, std::move(longHeader), - conn.ackStates.initialAckState.largestAckedByPeer.value_or(0)); + conn.ackStates.initialAckState->largestAckedByPeer.value_or(0)); FrameScheduler immediateAckOnlyScheduler = std::move( @@ -2416,7 +2416,7 @@ TEST_F(QuicPacketSchedulerTest, ImmediateAckFrameSchedulerNotRequested) { RegularQuicPacketBuilder builder( conn.udpSendPacketLen, std::move(longHeader), - conn.ackStates.initialAckState.largestAckedByPeer.value_or(0)); + conn.ackStates.initialAckState->largestAckedByPeer.value_or(0)); FrameScheduler immediateAckOnlyScheduler = std::move( diff --git a/quic/api/test/QuicTransportBaseTest.cpp b/quic/api/test/QuicTransportBaseTest.cpp index 93ac1147b..72c6e2f33 100644 --- a/quic/api/test/QuicTransportBaseTest.cpp +++ b/quic/api/test/QuicTransportBaseTest.cpp @@ -841,7 +841,7 @@ TEST_P(QuicTransportImplTestBase, WriteAckPacketUnsetsLooper) { pktHasCryptoData, Clock::now()); ASSERT_TRUE(transport->transportConn->ackStates.initialAckState - .needsToSendAckImmediately); + ->needsToSendAckImmediately); // Trigger the loop callback. This will trigger writes and we assume this will // write the acks since we have nothing else to write. transport->writeLooper()->runLoopCallback(); diff --git a/quic/api/test/QuicTransportFunctionsTest.cpp b/quic/api/test/QuicTransportFunctionsTest.cpp index ce395d83c..a86049c18 100644 --- a/quic/api/test/QuicTransportFunctionsTest.cpp +++ b/quic/api/test/QuicTransportFunctionsTest.cpp @@ -114,14 +114,14 @@ auto buildEmptyPacket( LongHeader::Types::Initial, *conn.clientConnectionId, *conn.serverConnectionId, - conn.ackStates.initialAckState.nextPacketNum, + conn.ackStates.initialAckState->nextPacketNum, *conn.version); } else if (pnSpace == PacketNumberSpace::Handshake) { header = LongHeader( LongHeader::Types::Handshake, *conn.clientConnectionId, *conn.serverConnectionId, - conn.ackStates.handshakeAckState.nextPacketNum, + conn.ackStates.handshakeAckState->nextPacketNum, *conn.version); } else if (pnSpace == PacketNumberSpace::AppData) { header = LongHeader( @@ -249,9 +249,9 @@ TEST_F(QuicTransportFunctionsTest, TestUpdateConnection) { packet.packet.frames.push_back(std::move(writeStreamFrame2)); auto currentNextInitialPacketNum = - conn->ackStates.initialAckState.nextPacketNum; + conn->ackStates.initialAckState->nextPacketNum; auto currentNextHandshakePacketNum = - conn->ackStates.handshakeAckState.nextPacketNum; + conn->ackStates.handshakeAckState->nextPacketNum; auto currentNextAppDataPacketNum = conn->ackStates.appDataAckState.nextPacketNum; EXPECT_CALL(*rawCongestionController, onPacketSent(_)).Times(1); @@ -268,10 +268,10 @@ TEST_F(QuicTransportFunctionsTest, TestUpdateConnection) { false /* isDSRPacket */); EXPECT_EQ( - conn->ackStates.initialAckState.nextPacketNum, + conn->ackStates.initialAckState->nextPacketNum, currentNextInitialPacketNum); EXPECT_GT( - conn->ackStates.handshakeAckState.nextPacketNum, + conn->ackStates.handshakeAckState->nextPacketNum, currentNextHandshakePacketNum); EXPECT_EQ( conn->ackStates.appDataAckState.nextPacketNum, @@ -313,9 +313,9 @@ TEST_F(QuicTransportFunctionsTest, TestUpdateConnection) { packet2.packet.frames.push_back(std::move(writeStreamFrame5)); EXPECT_CALL(*rawCongestionController, onPacketSent(_)).Times(1); - currentNextInitialPacketNum = conn->ackStates.initialAckState.nextPacketNum; + currentNextInitialPacketNum = conn->ackStates.initialAckState->nextPacketNum; currentNextHandshakePacketNum = - conn->ackStates.handshakeAckState.nextPacketNum; + conn->ackStates.handshakeAckState->nextPacketNum; currentNextAppDataPacketNum = conn->ackStates.appDataAckState.nextPacketNum; EXPECT_CALL(*rawCongestionController, isAppLimited()) .Times(1) @@ -329,10 +329,10 @@ TEST_F(QuicTransportFunctionsTest, TestUpdateConnection) { getEncodedBodySize(packet2), false /* isDSRPacket */); EXPECT_EQ( - conn->ackStates.initialAckState.nextPacketNum, + conn->ackStates.initialAckState->nextPacketNum, currentNextInitialPacketNum); EXPECT_GT( - conn->ackStates.handshakeAckState.nextPacketNum, + conn->ackStates.handshakeAckState->nextPacketNum, currentNextHandshakePacketNum); EXPECT_EQ( conn->ackStates.appDataAckState.nextPacketNum, @@ -430,9 +430,9 @@ TEST_F(QuicTransportFunctionsTest, TestUpdateConnectionPacketRetrans) { // mimic send, call updateConnection auto currentNextInitialPacketNum = - conn->ackStates.initialAckState.nextPacketNum; + conn->ackStates.initialAckState->nextPacketNum; auto currentNextHandshakePacketNum = - conn->ackStates.handshakeAckState.nextPacketNum; + conn->ackStates.handshakeAckState->nextPacketNum; auto currentNextAppDataPacketNum = conn->ackStates.appDataAckState.nextPacketNum; EXPECT_CALL(*rawCongestionController, onPacketSent(_)).Times(1); @@ -450,10 +450,10 @@ TEST_F(QuicTransportFunctionsTest, TestUpdateConnectionPacketRetrans) { // appData packet number should increase EXPECT_EQ( - conn->ackStates.initialAckState.nextPacketNum, + conn->ackStates.initialAckState->nextPacketNum, currentNextInitialPacketNum); // no change EXPECT_EQ( - conn->ackStates.handshakeAckState.nextPacketNum, + conn->ackStates.handshakeAckState->nextPacketNum, currentNextHandshakePacketNum); // no change EXPECT_GT( conn->ackStates.appDataAckState.nextPacketNum, @@ -494,9 +494,9 @@ TEST_F(QuicTransportFunctionsTest, TestUpdateConnectionPacketRetrans) { // mimic send, call updateConnection EXPECT_CALL(*rawCongestionController, onPacketSent(_)).Times(1); - currentNextInitialPacketNum = conn->ackStates.initialAckState.nextPacketNum; + currentNextInitialPacketNum = conn->ackStates.initialAckState->nextPacketNum; currentNextHandshakePacketNum = - conn->ackStates.handshakeAckState.nextPacketNum; + conn->ackStates.handshakeAckState->nextPacketNum; currentNextAppDataPacketNum = conn->ackStates.appDataAckState.nextPacketNum; EXPECT_CALL(*rawCongestionController, isAppLimited()) .Times(1) @@ -510,10 +510,10 @@ TEST_F(QuicTransportFunctionsTest, TestUpdateConnectionPacketRetrans) { getEncodedBodySize(packet2), false /* isDSRPacket */); EXPECT_EQ( - conn->ackStates.initialAckState.nextPacketNum, + conn->ackStates.initialAckState->nextPacketNum, currentNextInitialPacketNum); // no change EXPECT_EQ( - conn->ackStates.handshakeAckState.nextPacketNum, + conn->ackStates.handshakeAckState->nextPacketNum, currentNextHandshakePacketNum); // no change EXPECT_GT( conn->ackStates.appDataAckState.nextPacketNum, @@ -598,9 +598,9 @@ TEST_F( // mimic send, call updateConnection auto currentNextInitialPacketNum = - conn->ackStates.initialAckState.nextPacketNum; + conn->ackStates.initialAckState->nextPacketNum; auto currentNextHandshakePacketNum = - conn->ackStates.handshakeAckState.nextPacketNum; + conn->ackStates.handshakeAckState->nextPacketNum; auto currentNextAppDataPacketNum = conn->ackStates.appDataAckState.nextPacketNum; EXPECT_CALL(*rawCongestionController, onPacketSent(_)).Times(1); @@ -618,10 +618,10 @@ TEST_F( // appData packet number should increase EXPECT_EQ( - conn->ackStates.initialAckState.nextPacketNum, + conn->ackStates.initialAckState->nextPacketNum, currentNextInitialPacketNum); // no change EXPECT_EQ( - conn->ackStates.handshakeAckState.nextPacketNum, + conn->ackStates.handshakeAckState->nextPacketNum, currentNextHandshakePacketNum); // no change EXPECT_GT( conn->ackStates.appDataAckState.nextPacketNum, @@ -683,9 +683,9 @@ TEST_F( // mimic send, call updateConnection EXPECT_CALL(*rawCongestionController, onPacketSent(_)).Times(1); - currentNextInitialPacketNum = conn->ackStates.initialAckState.nextPacketNum; + currentNextInitialPacketNum = conn->ackStates.initialAckState->nextPacketNum; currentNextHandshakePacketNum = - conn->ackStates.handshakeAckState.nextPacketNum; + conn->ackStates.handshakeAckState->nextPacketNum; currentNextAppDataPacketNum = conn->ackStates.appDataAckState.nextPacketNum; EXPECT_CALL(*rawCongestionController, isAppLimited()) .Times(1) @@ -699,10 +699,10 @@ TEST_F( getEncodedBodySize(packet2), false /* isDSRPacket */); EXPECT_EQ( - conn->ackStates.initialAckState.nextPacketNum, + conn->ackStates.initialAckState->nextPacketNum, currentNextInitialPacketNum); // no change EXPECT_EQ( - conn->ackStates.handshakeAckState.nextPacketNum, + conn->ackStates.handshakeAckState->nextPacketNum, currentNextHandshakePacketNum); // no change EXPECT_GT( conn->ackStates.appDataAckState.nextPacketNum, @@ -782,8 +782,8 @@ TEST_F(QuicTransportFunctionsTest, TestUpdateConnectionD6DNeedsAppDataPNSpace) { TEST_F(QuicTransportFunctionsTest, TestUpdateConnectionPacketSorting) { auto conn = createConn(); conn->qLogger = std::make_shared(VantagePoint::Client); - conn->ackStates.initialAckState.nextPacketNum = 0; - conn->ackStates.handshakeAckState.nextPacketNum = 1; + conn->ackStates.initialAckState->nextPacketNum = 0; + conn->ackStates.handshakeAckState->nextPacketNum = 1; conn->ackStates.appDataAckState.nextPacketNum = 2; auto initialPacket = buildEmptyPacket(*conn, PacketNumberSpace::Initial); auto handshakePacket = buildEmptyPacket(*conn, PacketNumberSpace::Handshake); @@ -967,7 +967,7 @@ TEST_F(QuicTransportFunctionsTest, TestUpdateConnectionEmptyAckWriteResult) { // buildEmptyPacket() builds a Handshake packet, we use handshakeAckState to // verify. auto currentPendingLargestAckScheduled = - conn->ackStates.handshakeAckState.largestAckScheduled; + conn->ackStates.handshakeAckState->largestAckScheduled; updateConnection( *conn, folly::none, @@ -991,7 +991,7 @@ TEST_F(QuicTransportFunctionsTest, TestUpdateConnectionEmptyAckWriteResult) { EXPECT_EQ( currentPendingLargestAckScheduled, - conn->ackStates.handshakeAckState.largestAckScheduled); + conn->ackStates.handshakeAckState->largestAckScheduled); } TEST_F(QuicTransportFunctionsTest, TestUpdateConnectionPureAckCounter) { @@ -2970,11 +2970,11 @@ TEST_F(QuicTransportFunctionsTest, NothingWritten) { EXPECT_CALL(*rawCongestionController, getWritableBytes()) .WillRepeatedly(Return(18)); - addAckStatesWithCurrentTimestamps(conn->ackStates.initialAckState, 0, 1000); + addAckStatesWithCurrentTimestamps(*conn->ackStates.initialAckState, 0, 1000); addAckStatesWithCurrentTimestamps( - conn->ackStates.initialAckState, 1500, 2000); + *conn->ackStates.initialAckState, 1500, 2000); addAckStatesWithCurrentTimestamps( - conn->ackStates.initialAckState, 2500, 3000); + *conn->ackStates.initialAckState, 2500, 3000); auto res = writeQuicDataToSocket( *rawSocket, *conn, @@ -3155,7 +3155,7 @@ TEST_F(QuicTransportFunctionsTest, WriteProbingCryptoData) { conn.clientConnectionId = getTestConnectionId(); // writeCryptoDataProbesToSocketForTest writes Initial LongHeader, thus it // writes at Initial level. - auto currentPacketSeqNum = conn.ackStates.initialAckState.nextPacketNum; + auto currentPacketSeqNum = conn.ackStates.initialAckState->nextPacketNum; // Replace real congestionController with MockCongestionController: auto mockCongestionController = std::make_unique>(); @@ -3184,7 +3184,7 @@ TEST_F(QuicTransportFunctionsTest, WriteProbingCryptoData) { })); writeCryptoDataProbesToSocketForTest( *rawSocket, conn, 1, *aead, *headerCipher, getVersion(conn)); - EXPECT_LT(currentPacketSeqNum, conn.ackStates.initialAckState.nextPacketNum); + EXPECT_LT(currentPacketSeqNum, conn.ackStates.initialAckState->nextPacketNum); EXPECT_FALSE(conn.outstandings.packets.empty()); EXPECT_TRUE(conn.pendingEvents.setLossDetectionAlarm); EXPECT_GT(cryptoStream->currentWriteOffset, currentStreamWriteOffset); @@ -3202,7 +3202,7 @@ TEST_F(QuicTransportFunctionsTest, WriteableBytesLimitedProbingCryptoData) { conn.clientConnectionId = getTestConnectionId(); // writeCryptoDataProbesToSocketForTest writes Initial LongHeader, thus it // writes at Initial level. - auto currentPacketSeqNum = conn.ackStates.initialAckState.nextPacketNum; + auto currentPacketSeqNum = conn.ackStates.initialAckState->nextPacketNum; // Replace real congestionController with MockCongestionController: auto mockCongestionController = std::make_unique>(); @@ -3232,7 +3232,7 @@ TEST_F(QuicTransportFunctionsTest, WriteableBytesLimitedProbingCryptoData) { *rawSocket, conn, probesToSend, *aead, *headerCipher, getVersion(conn)); EXPECT_EQ(conn.numProbesWritableBytesLimited, 1); - EXPECT_LT(currentPacketSeqNum, conn.ackStates.initialAckState.nextPacketNum); + EXPECT_LT(currentPacketSeqNum, conn.ackStates.initialAckState->nextPacketNum); EXPECT_FALSE(conn.outstandings.packets.empty()); EXPECT_TRUE(conn.pendingEvents.setLossDetectionAlarm); EXPECT_GT(cryptoStream->currentWriteOffset, currentStreamWriteOffset); @@ -3643,8 +3643,8 @@ TEST_F(QuicTransportFunctionsTest, HasAckDataToWriteCipherAndAckStateMatch) { conn->ackStates.appDataAckState.needsToSendAckImmediately = true; conn->ackStates.appDataAckState.acks.insert(0, 100); EXPECT_FALSE(hasAckDataToWrite(*conn)); - conn->ackStates.initialAckState.needsToSendAckImmediately = true; - conn->ackStates.initialAckState.acks.insert(0, 100); + conn->ackStates.initialAckState->needsToSendAckImmediately = true; + conn->ackStates.initialAckState->acks.insert(0, 100); EXPECT_TRUE(hasAckDataToWrite(*conn)); } @@ -3661,15 +3661,15 @@ TEST_F(QuicTransportFunctionsTest, HasAckDataToWriteNoImmediateAcks) { TEST_F(QuicTransportFunctionsTest, HasAckDataToWriteNoAcksScheduled) { auto conn = createConn(); conn->oneRttWriteCipher = test::createNoOpAead(); - conn->ackStates.initialAckState.needsToSendAckImmediately = true; + conn->ackStates.initialAckState->needsToSendAckImmediately = true; EXPECT_FALSE(hasAckDataToWrite(*conn)); } TEST_F(QuicTransportFunctionsTest, HasAckDataToWrite) { auto conn = createConn(); conn->oneRttWriteCipher = test::createNoOpAead(); - conn->ackStates.initialAckState.needsToSendAckImmediately = true; - conn->ackStates.initialAckState.acks.insert(0); + conn->ackStates.initialAckState->needsToSendAckImmediately = true; + conn->ackStates.initialAckState->acks.insert(0); EXPECT_TRUE(hasAckDataToWrite(*conn)); } @@ -3679,9 +3679,9 @@ TEST_F(QuicTransportFunctionsTest, HasAckDataToWriteMismatch) { // should not send. auto conn = createConn(); EXPECT_FALSE(hasAckDataToWrite(*conn)); - conn->ackStates.initialAckState.needsToSendAckImmediately = true; + conn->ackStates.initialAckState->needsToSendAckImmediately = true; EXPECT_FALSE(hasAckDataToWrite(*conn)); - conn->ackStates.handshakeAckState.acks.insert(0, 10); + conn->ackStates.handshakeAckState->acks.insert(0, 10); conn->handshakeWriteCipher = test::createNoOpAead(); EXPECT_FALSE(hasAckDataToWrite(*conn)); } diff --git a/quic/client/QuicClientTransport.cpp b/quic/client/QuicClientTransport.cpp index 0b883abb1..b91c3c3c2 100644 --- a/quic/client/QuicClientTransport.cpp +++ b/quic/client/QuicClientTransport.cpp @@ -634,8 +634,12 @@ void QuicClientTransport::processPacketData( FrameType::IMMEDIATE_ACK); } // Send an ACK from any packet number space. - conn_->ackStates.initialAckState.needsToSendAckImmediately = true; - conn_->ackStates.handshakeAckState.needsToSendAckImmediately = true; + if (conn_->ackStates.initialAckState) { + conn_->ackStates.initialAckState->needsToSendAckImmediately = true; + } + if (conn_->ackStates.handshakeAckState) { + conn_->ackStates.handshakeAckState->needsToSendAckImmediately = true; + } conn_->ackStates.appDataAckState.needsToSendAckImmediately = true; break; } @@ -908,9 +912,7 @@ void QuicClientTransport::writeData() { if ((initialCryptoStream.retransmissionBuffer.size() && conn_->outstandings.packetCount[PacketNumberSpace::Initial] && numProbePackets) || - initialScheduler.hasData() || - (conn_->ackStates.initialAckState.needsToSendAckImmediately && - hasAcksToSchedule(conn_->ackStates.initialAckState))) { + initialScheduler.hasData() || toWriteInitialAcks(*conn_)) { CHECK(conn_->initialHeaderCipher); std::string& token = clientConn_->retryToken.empty() ? clientConn_->newToken @@ -941,9 +943,7 @@ void QuicClientTransport::writeData() { if ((conn_->outstandings.packetCount[PacketNumberSpace::Handshake] && handshakeCryptoStream.retransmissionBuffer.size() && numProbePackets) || - handshakeScheduler.hasData() || - (conn_->ackStates.handshakeAckState.needsToSendAckImmediately && - hasAcksToSchedule(conn_->ackStates.handshakeAckState))) { + handshakeScheduler.hasData() || toWriteHandshakeAcks(*conn_)) { CHECK(conn_->handshakeWriteHeaderCipher); packetLimit -= writeCryptoAndAckDataToSocket( *socket_, diff --git a/quic/client/state/ClientStateMachine.cpp b/quic/client/state/ClientStateMachine.cpp index 699989421..1ce968e0c 100644 --- a/quic/client/state/ClientStateMachine.cpp +++ b/quic/client/state/ClientStateMachine.cpp @@ -39,10 +39,10 @@ std::unique_ptr undoAllClientStateForRetry( conn->originalDestinationConnectionId; // TODO: don't carry server connection id over to the new connection. newConn->serverConnectionId = conn->serverConnectionId; - newConn->ackStates.initialAckState.nextPacketNum = - conn->ackStates.initialAckState.nextPacketNum; - newConn->ackStates.handshakeAckState.nextPacketNum = - conn->ackStates.handshakeAckState.nextPacketNum; + newConn->ackStates.initialAckState->nextPacketNum = + conn->ackStates.initialAckState->nextPacketNum; + newConn->ackStates.handshakeAckState->nextPacketNum = + conn->ackStates.handshakeAckState->nextPacketNum; newConn->ackStates.appDataAckState.nextPacketNum = conn->ackStates.appDataAckState.nextPacketNum; newConn->version = conn->version; diff --git a/quic/codec/QuicReadCodec.cpp b/quic/codec/QuicReadCodec.cpp index 0d62a70c9..ad2e4efff 100644 --- a/quic/codec/QuicReadCodec.cpp +++ b/quic/codec/QuicReadCodec.cpp @@ -188,11 +188,12 @@ CodecResult QuicReadCodec::parseLongHeaderPacket( folly::Optional largestRecvdPacketNum; switch (longHeaderTypeToProtectionType(type)) { case ProtectionType::Initial: - largestRecvdPacketNum = ackStates.initialAckState.largestRecvdPacketNum; + largestRecvdPacketNum = ackStates.initialAckState->largestRecvdPacketNum; break; case ProtectionType::Handshake: - largestRecvdPacketNum = ackStates.handshakeAckState.largestRecvdPacketNum; + largestRecvdPacketNum = + ackStates.handshakeAckState->largestRecvdPacketNum; break; case ProtectionType::ZeroRtt: diff --git a/quic/loss/test/QuicLossFunctionsTest.cpp b/quic/loss/test/QuicLossFunctionsTest.cpp index fea1e5037..20cbf370e 100644 --- a/quic/loss/test/QuicLossFunctionsTest.cpp +++ b/quic/loss/test/QuicLossFunctionsTest.cpp @@ -101,8 +101,8 @@ class QuicLossFunctionsTest : public TestWithParam { FizzServerQuicHandshakeContext::Builder().build()); conn->clientConnectionId = getTestConnectionId(); conn->version = QuicVersion::MVFST; - conn->ackStates.initialAckState.nextPacketNum = 1; - conn->ackStates.handshakeAckState.nextPacketNum = 1; + conn->ackStates.initialAckState->nextPacketNum = 1; + conn->ackStates.handshakeAckState->nextPacketNum = 1; conn->ackStates.appDataAckState.nextPacketNum = 1; conn->flowControlState.peerAdvertisedInitialMaxStreamOffsetBidiLocal = kDefaultStreamWindowSize; @@ -134,8 +134,8 @@ class QuicLossFunctionsTest : public TestWithParam { FizzClientQuicHandshakeContext::Builder().build()); conn->clientConnectionId = getTestConnectionId(); conn->version = QuicVersion::MVFST; - conn->ackStates.initialAckState.nextPacketNum = 1; - conn->ackStates.handshakeAckState.nextPacketNum = 1; + conn->ackStates.initialAckState->nextPacketNum = 1; + conn->ackStates.handshakeAckState->nextPacketNum = 1; conn->ackStates.appDataAckState.nextPacketNum = 1; conn->flowControlState.peerAdvertisedInitialMaxStreamOffsetBidiLocal = kDefaultStreamWindowSize; @@ -198,7 +198,7 @@ PacketNum QuicLossFunctionsTest::sendPacket( LongHeader::Types::Initial, *conn.clientConnectionId, *conn.serverConnectionId, - conn.ackStates.initialAckState.nextPacketNum, + conn.ackStates.initialAckState->nextPacketNum, *conn.version); isHandshake = true; break; @@ -207,7 +207,7 @@ PacketNum QuicLossFunctionsTest::sendPacket( LongHeader::Types::Handshake, *conn.clientConnectionId, *conn.serverConnectionId, - conn.ackStates.handshakeAckState.nextPacketNum, + conn.ackStates.handshakeAckState->nextPacketNum, *conn.version); isHandshake = true; break; @@ -736,7 +736,7 @@ TEST_F(QuicLossFunctionsTest, TestMarkPacketLoss) { writeDataToQuicStream(*stream1, buf->clone(), true); writeDataToQuicStream(*stream2, buf->clone(), true); - auto packetSeqNum = conn->ackStates.handshakeAckState.nextPacketNum; + auto packetSeqNum = conn->ackStates.handshakeAckState->nextPacketNum; LongHeader header( LongHeader::Types::Handshake, *conn->clientConnectionId, @@ -1026,7 +1026,7 @@ TEST_F(QuicLossFunctionsTest, TestHandleAckForLoss) { LongHeader::Types::Handshake, *conn->clientConnectionId, *conn->serverConnectionId, - conn->ackStates.handshakeAckState.nextPacketNum++, + conn->ackStates.handshakeAckState->nextPacketNum++, conn->version.value()); RegularQuicWritePacket outstandingRegularPacket(std::move(longHeader)); auto now = Clock::now(); diff --git a/quic/server/QuicServerTransport.cpp b/quic/server/QuicServerTransport.cpp index 091be4982..8315810be 100644 --- a/quic/server/QuicServerTransport.cpp +++ b/quic/server/QuicServerTransport.cpp @@ -273,9 +273,7 @@ void QuicServerTransport::writeData() { conn_->pendingEvents.numProbePackets[PacketNumberSpace::Initial]; if ((numProbePackets && initialCryptoStream.retransmissionBuffer.size() && conn_->outstandings.packetCount[PacketNumberSpace::Initial]) || - initialScheduler.hasData() || - (conn_->ackStates.initialAckState.needsToSendAckImmediately && - hasAcksToSchedule(conn_->ackStates.initialAckState))) { + initialScheduler.hasData() || toWriteInitialAcks(*conn_)) { CHECK(conn_->initialWriteCipher); CHECK(conn_->initialHeaderCipher); @@ -306,9 +304,7 @@ void QuicServerTransport::writeData() { if ((conn_->outstandings.packetCount[PacketNumberSpace::Handshake] && handshakeCryptoStream.retransmissionBuffer.size() && numProbePackets) || - handshakeScheduler.hasData() || - (conn_->ackStates.handshakeAckState.needsToSendAckImmediately && - hasAcksToSchedule(conn_->ackStates.handshakeAckState))) { + handshakeScheduler.hasData() || toWriteHandshakeAcks(*conn_)) { CHECK(conn_->handshakeWriteCipher); CHECK(conn_->handshakeWriteHeaderCipher); auto res = writeCryptoAndAckDataToSocket( diff --git a/quic/server/state/ServerStateMachine.cpp b/quic/server/state/ServerStateMachine.cpp index e976cedb2..d94864369 100644 --- a/quic/server/state/ServerStateMachine.cpp +++ b/quic/server/state/ServerStateMachine.cpp @@ -1320,8 +1320,12 @@ void onServerReadDataFromOpen( FrameType::IMMEDIATE_ACK); } // Send an ACK from any packet number space. - conn.ackStates.initialAckState.needsToSendAckImmediately = true; - conn.ackStates.handshakeAckState.needsToSendAckImmediately = true; + if (conn.ackStates.initialAckState) { + conn.ackStates.initialAckState->needsToSendAckImmediately = true; + } + if (conn.ackStates.handshakeAckState) { + conn.ackStates.handshakeAckState->needsToSendAckImmediately = true; + } conn.ackStates.appDataAckState.needsToSendAckImmediately = true; break; } diff --git a/quic/server/test/QuicServerTransportTest.cpp b/quic/server/test/QuicServerTransportTest.cpp index 4fbd1ba6b..9d36e0407 100644 --- a/quic/server/test/QuicServerTransportTest.cpp +++ b/quic/server/test/QuicServerTransportTest.cpp @@ -112,12 +112,12 @@ TEST_F(QuicServerTransportTest, TestReadMultipleStreams) { auto packet = std::move(builder).buildPacket(); // Clear out the existing acks to make sure that we are the cause of the acks. - server->getNonConstConn().ackStates.initialAckState.acks.clear(); - server->getNonConstConn().ackStates.initialAckState.largestRecvdPacketTime = - folly::none; - server->getNonConstConn().ackStates.handshakeAckState.acks.clear(); - server->getNonConstConn().ackStates.handshakeAckState.largestRecvdPacketTime = + server->getNonConstConn().ackStates.initialAckState->acks.clear(); + server->getNonConstConn().ackStates.initialAckState->largestRecvdPacketTime = folly::none; + server->getNonConstConn().ackStates.handshakeAckState->acks.clear(); + server->getNonConstConn() + .ackStates.handshakeAckState->largestRecvdPacketTime = folly::none; server->getNonConstConn().ackStates.appDataAckState.acks.clear(); server->getNonConstConn().ackStates.appDataAckState.largestRecvdPacketTime = folly::none; @@ -3056,9 +3056,9 @@ TEST_F(QuicServerTransportTest, ImmediateAckProtocolViolation) { ASSERT_THROW(deliverData(packetToBuf(packet)), std::runtime_error); // Verify that none of the ack states have changed EXPECT_FALSE( - server->getConn().ackStates.initialAckState.needsToSendAckImmediately); + server->getConn().ackStates.initialAckState->needsToSendAckImmediately); EXPECT_FALSE( - server->getConn().ackStates.handshakeAckState.needsToSendAckImmediately); + server->getConn().ackStates.handshakeAckState->needsToSendAckImmediately); EXPECT_FALSE( server->getConn().ackStates.appDataAckState.needsToSendAckImmediately); } diff --git a/quic/state/AckStates.h b/quic/state/AckStates.h index 7cac24cd7..5a20545a4 100644 --- a/quic/state/AckStates.h +++ b/quic/state/AckStates.h @@ -51,17 +51,19 @@ struct AckState : WriteAckState { struct AckStates { explicit AckStates(PacketNum startingNum) { - initialAckState.nextPacketNum = startingNum; - handshakeAckState.nextPacketNum = startingNum; + initialAckState = std::make_unique(); + handshakeAckState = std::make_unique(); + initialAckState->nextPacketNum = startingNum; + handshakeAckState->nextPacketNum = startingNum; appDataAckState.nextPacketNum = startingNum; } AckStates() : AckStates(folly::Random::secureRand32(kMaxInitialPacketNum)) {} // AckState for acks to peer packets in Initial packet number space. - AckState initialAckState; + std::unique_ptr initialAckState{}; // AckState for acks to peer packets in Handshake packet number space. - AckState handshakeAckState; + std::unique_ptr handshakeAckState{}; // AckState for acks to peer packets in AppData packet number space. AckState appDataAckState; std::chrono::microseconds maxAckDelay{kMaxAckTimeout}; diff --git a/quic/state/QuicStateFunctions.cpp b/quic/state/QuicStateFunctions.cpp index eee5c4de2..ac0fe5a04 100644 --- a/quic/state/QuicStateFunctions.cpp +++ b/quic/state/QuicStateFunctions.cpp @@ -199,9 +199,9 @@ AckState& getAckState( PacketNumberSpace pnSpace) noexcept { switch (pnSpace) { case PacketNumberSpace::Initial: - return conn.ackStates.initialAckState; + return *CHECK_NOTNULL(conn.ackStates.initialAckState.get()); case PacketNumberSpace::Handshake: - return conn.ackStates.handshakeAckState; + return *CHECK_NOTNULL(conn.ackStates.handshakeAckState.get()); case PacketNumberSpace::AppData: return conn.ackStates.appDataAckState; } @@ -213,21 +213,43 @@ const AckState& getAckState( PacketNumberSpace pnSpace) noexcept { switch (pnSpace) { case PacketNumberSpace::Initial: - return conn.ackStates.initialAckState; + return *CHECK_NOTNULL(conn.ackStates.initialAckState.get()); case PacketNumberSpace::Handshake: - return conn.ackStates.handshakeAckState; + return *CHECK_NOTNULL(conn.ackStates.handshakeAckState.get()); case PacketNumberSpace::AppData: return conn.ackStates.appDataAckState; } folly::assume_unreachable(); } +const AckState* getAckStatePtr( + const QuicConnectionStateBase& conn, + PacketNumberSpace pnSpace) noexcept { + switch (pnSpace) { + case PacketNumberSpace::Initial: + return conn.ackStates.initialAckState.get(); + case PacketNumberSpace::Handshake: + return conn.ackStates.handshakeAckState.get(); + case PacketNumberSpace::AppData: + return &conn.ackStates.appDataAckState; + } + folly::assume_unreachable(); +} + AckStateVersion currentAckStateVersion( const QuicConnectionStateBase& conn) noexcept { - return AckStateVersion( - conn.ackStates.initialAckState.acks.insertVersion(), - conn.ackStates.handshakeAckState.acks.insertVersion(), - conn.ackStates.appDataAckState.acks.insertVersion()); + AckStateVersion ret; + if (conn.ackStates.initialAckState) { + ret.initialAckStateVersion = + conn.ackStates.initialAckState->acks.insertVersion(); + } + if (conn.ackStates.handshakeAckState) { + ret.handshakeAckStateVersion = + conn.ackStates.handshakeAckState->acks.insertVersion(); + } + ret.appDataAckStateVersion = + conn.ackStates.appDataAckState.acks.insertVersion(); + return ret; } PacketNum getNextPacketNum( @@ -280,47 +302,57 @@ std::deque::iterator getNextOutstandingPacket( bool hasReceivedPacketsAtLastCloseSent( const QuicConnectionStateBase& conn) noexcept { - return conn.ackStates.initialAckState.largestReceivedAtLastCloseSent || - conn.ackStates.handshakeAckState.largestReceivedAtLastCloseSent || - conn.ackStates.appDataAckState.largestReceivedAtLastCloseSent; + const auto* initialAckState = conn.ackStates.initialAckState.get(); + const auto* handshakeAckState = conn.ackStates.handshakeAckState.get(); + const auto& appDataAckState = conn.ackStates.appDataAckState; + return (initialAckState && initialAckState->largestReceivedAtLastCloseSent) || + (handshakeAckState && + handshakeAckState->largestReceivedAtLastCloseSent) || + appDataAckState.largestReceivedAtLastCloseSent; } bool hasNotReceivedNewPacketsSinceLastCloseSent( const QuicConnectionStateBase& conn) noexcept { - DCHECK( - !conn.ackStates.initialAckState.largestReceivedAtLastCloseSent || - *conn.ackStates.initialAckState.largestReceivedAtLastCloseSent <= - *conn.ackStates.initialAckState.largestRecvdPacketNum); - DCHECK( - !conn.ackStates.handshakeAckState.largestReceivedAtLastCloseSent || - *conn.ackStates.handshakeAckState.largestReceivedAtLastCloseSent <= - *conn.ackStates.handshakeAckState.largestRecvdPacketNum); - DCHECK( - !conn.ackStates.appDataAckState.largestReceivedAtLastCloseSent || - *conn.ackStates.appDataAckState.largestReceivedAtLastCloseSent <= - *conn.ackStates.appDataAckState.largestRecvdPacketNum); - return conn.ackStates.initialAckState.largestReceivedAtLastCloseSent == - conn.ackStates.initialAckState.largestRecvdPacketNum && - conn.ackStates.handshakeAckState.largestReceivedAtLastCloseSent == - conn.ackStates.handshakeAckState.largestRecvdPacketNum && - conn.ackStates.appDataAckState.largestReceivedAtLastCloseSent == - conn.ackStates.appDataAckState.largestRecvdPacketNum; + const auto* initialAckState = conn.ackStates.initialAckState.get(); + const auto* handshakeAckState = conn.ackStates.handshakeAckState.get(); + const auto& appDataAckState = conn.ackStates.appDataAckState; + + return (initialAckState ? initialAckState->largestReceivedAtLastCloseSent == + initialAckState->largestRecvdPacketNum + : true) && + (handshakeAckState ? handshakeAckState->largestReceivedAtLastCloseSent == + handshakeAckState->largestRecvdPacketNum + : true) && + appDataAckState.largestReceivedAtLastCloseSent == + appDataAckState.largestRecvdPacketNum; } void updateLargestReceivedPacketsAtLastCloseSent( QuicConnectionStateBase& conn) noexcept { - conn.ackStates.initialAckState.largestReceivedAtLastCloseSent = - conn.ackStates.initialAckState.largestRecvdPacketNum; - conn.ackStates.handshakeAckState.largestReceivedAtLastCloseSent = - conn.ackStates.handshakeAckState.largestRecvdPacketNum; - conn.ackStates.appDataAckState.largestReceivedAtLastCloseSent = + auto* initialAckState = conn.ackStates.initialAckState.get(); + auto* handshakeAckState = conn.ackStates.handshakeAckState.get(); + auto& appDataAckState = conn.ackStates.appDataAckState; + + if (initialAckState) { + initialAckState->largestReceivedAtLastCloseSent = + conn.ackStates.initialAckState->largestRecvdPacketNum; + } + if (handshakeAckState) { + handshakeAckState->largestReceivedAtLastCloseSent = + handshakeAckState->largestRecvdPacketNum; + } + appDataAckState.largestReceivedAtLastCloseSent = conn.ackStates.appDataAckState.largestRecvdPacketNum; } bool hasReceivedPackets(const QuicConnectionStateBase& conn) noexcept { - return conn.ackStates.initialAckState.largestRecvdPacketNum || - conn.ackStates.handshakeAckState.largestRecvdPacketNum || - conn.ackStates.appDataAckState.largestRecvdPacketNum; + const auto* initialAckState = conn.ackStates.initialAckState.get(); + const auto* handshakeAckState = conn.ackStates.handshakeAckState.get(); + const auto& appDataAckState = conn.ackStates.appDataAckState; + + return (initialAckState ? initialAckState->largestRecvdPacketNum : true) || + (handshakeAckState ? handshakeAckState->largestRecvdPacketNum : true) || + appDataAckState.largestRecvdPacketNum; } folly::Optional& getLossTime( diff --git a/quic/state/QuicStateFunctions.h b/quic/state/QuicStateFunctions.h index a293c7568..fda1dca73 100644 --- a/quic/state/QuicStateFunctions.h +++ b/quic/state/QuicStateFunctions.h @@ -44,6 +44,10 @@ const AckState& getAckState( const QuicConnectionStateBase& conn, PacketNumberSpace pnSpace) noexcept; +const AckState* getAckStatePtr( + const QuicConnectionStateBase& conn, + PacketNumberSpace pnSpace) noexcept; + AckStateVersion currentAckStateVersion( const QuicConnectionStateBase& conn) noexcept; diff --git a/quic/state/StateData.h b/quic/state/StateData.h index d0b99af83..df834ce35 100644 --- a/quic/state/StateData.h +++ b/quic/state/StateData.h @@ -808,6 +808,8 @@ struct AckStateVersion { uint64_t handshakeVersion, uint64_t appDataVersion); + AckStateVersion() = default; + bool operator==(const AckStateVersion& other) const; bool operator!=(const AckStateVersion& other) const; }; diff --git a/quic/state/stream/test/StreamStateMachineTest.cpp b/quic/state/stream/test/StreamStateMachineTest.cpp index 2df99d82e..592508a29 100644 --- a/quic/state/stream/test/StreamStateMachineTest.cpp +++ b/quic/state/stream/test/StreamStateMachineTest.cpp @@ -35,8 +35,8 @@ std::unique_ptr createConn() { FizzServerQuicHandshakeContext::Builder().build()); conn->clientConnectionId = getTestConnectionId(); conn->version = QuicVersion::MVFST; - conn->ackStates.initialAckState.nextPacketNum = 1; - conn->ackStates.handshakeAckState.nextPacketNum = 1; + conn->ackStates.initialAckState->nextPacketNum = 1; + conn->ackStates.handshakeAckState->nextPacketNum = 1; conn->ackStates.appDataAckState.nextPacketNum = 1; conn->flowControlState.peerAdvertisedInitialMaxStreamOffsetBidiLocal = kDefaultStreamWindowSize; diff --git a/quic/state/test/AckHandlersTest.cpp b/quic/state/test/AckHandlersTest.cpp index 37a95f7b8..3704d85c3 100644 --- a/quic/state/test/AckHandlersTest.cpp +++ b/quic/state/test/AckHandlersTest.cpp @@ -1164,17 +1164,17 @@ TEST_P(AckHandlersTest, PurgeAcks) { WriteAckFrame ackFrame; ackFrame.ackBlocks.emplace_back(900, 1000); ackFrame.ackBlocks.emplace_back(500, 700); - conn.ackStates.initialAckState.acks.insert(900, 1200); - conn.ackStates.initialAckState.acks.insert(500, 800); + conn.ackStates.initialAckState->acks.insert(900, 1200); + conn.ackStates.initialAckState->acks.insert(500, 800); auto expectedTime = Clock::now(); - conn.ackStates.initialAckState.largestRecvdPacketTime = expectedTime; - commonAckVisitorForAckFrame(conn.ackStates.initialAckState, ackFrame); + conn.ackStates.initialAckState->largestRecvdPacketTime = expectedTime; + commonAckVisitorForAckFrame(*conn.ackStates.initialAckState, ackFrame); // We should have purged old packets in ack state - EXPECT_EQ(conn.ackStates.initialAckState.acks.size(), 1); - EXPECT_EQ(conn.ackStates.initialAckState.acks.front().start, 1001); - EXPECT_EQ(conn.ackStates.initialAckState.acks.front().end, 1200); + EXPECT_EQ(conn.ackStates.initialAckState->acks.size(), 1); + EXPECT_EQ(conn.ackStates.initialAckState->acks.front().start, 1001); + EXPECT_EQ(conn.ackStates.initialAckState->acks.front().end, 1200); EXPECT_EQ( - expectedTime, *conn.ackStates.initialAckState.largestRecvdPacketTime); + expectedTime, *conn.ackStates.initialAckState->largestRecvdPacketTime); } TEST_P(AckHandlersTest, NoSkipAckVisitor) { @@ -1590,9 +1590,9 @@ TEST_P(AckHandlersTest, AckNotOutstandingButLoss) { conn.lossState.lrtt = 150ms; // Packet 2 has been sent and acked: if (GetParam() == PacketNumberSpace::Initial) { - conn.ackStates.initialAckState.largestAckedByPeer = 2; + conn.ackStates.initialAckState->largestAckedByPeer = 2; } else if (GetParam() == PacketNumberSpace::Handshake) { - conn.ackStates.handshakeAckState.largestAckedByPeer = 2; + conn.ackStates.handshakeAckState->largestAckedByPeer = 2; } else { conn.ackStates.appDataAckState.largestAckedByPeer = 2; } @@ -3171,14 +3171,14 @@ class AckEventForAppDataTest : public Test { LongHeader::Types::Initial, *conn_->clientConnectionId, *conn_->serverConnectionId, - conn_->ackStates.initialAckState.nextPacketNum, + conn_->ackStates.initialAckState->nextPacketNum, *conn_->version); } else if (pnSpace == PacketNumberSpace::Handshake) { header = LongHeader( LongHeader::Types::Handshake, *conn_->clientConnectionId, *conn_->serverConnectionId, - conn_->ackStates.handshakeAckState.nextPacketNum, + conn_->ackStates.handshakeAckState->nextPacketNum, *conn_->version); } else if (pnSpace == PacketNumberSpace::AppData) { header = LongHeader( diff --git a/quic/state/test/QuicStateFunctionsTest.cpp b/quic/state/test/QuicStateFunctionsTest.cpp index 001ec8735..de219470e 100644 --- a/quic/state/test/QuicStateFunctionsTest.cpp +++ b/quic/state/test/QuicStateFunctionsTest.cpp @@ -1161,17 +1161,18 @@ TEST_F(QuicStateFunctionsTest, GetOutstandingPackets) { TEST_F(QuicStateFunctionsTest, UpdateLargestReceivePacketsAtLatCloseSent) { QuicConnectionStateBase conn(QuicNodeType::Client); - EXPECT_FALSE(conn.ackStates.initialAckState.largestReceivedAtLastCloseSent); - EXPECT_FALSE(conn.ackStates.handshakeAckState.largestReceivedAtLastCloseSent); + EXPECT_FALSE(conn.ackStates.initialAckState->largestReceivedAtLastCloseSent); + EXPECT_FALSE( + conn.ackStates.handshakeAckState->largestReceivedAtLastCloseSent); EXPECT_FALSE(conn.ackStates.appDataAckState.largestReceivedAtLastCloseSent); - conn.ackStates.initialAckState.largestRecvdPacketNum = 123; - conn.ackStates.handshakeAckState.largestRecvdPacketNum = 654; + conn.ackStates.initialAckState->largestRecvdPacketNum = 123; + conn.ackStates.handshakeAckState->largestRecvdPacketNum = 654; conn.ackStates.appDataAckState.largestRecvdPacketNum = 789; updateLargestReceivedPacketsAtLastCloseSent(conn); EXPECT_EQ( - 123, *conn.ackStates.initialAckState.largestReceivedAtLastCloseSent); + 123, *conn.ackStates.initialAckState->largestReceivedAtLastCloseSent); EXPECT_EQ( - 654, *conn.ackStates.handshakeAckState.largestReceivedAtLastCloseSent); + 654, *conn.ackStates.handshakeAckState->largestReceivedAtLastCloseSent); EXPECT_EQ( 789, *conn.ackStates.appDataAckState.largestReceivedAtLastCloseSent); }