diff --git a/quic/api/QuicTransportFunctions.cpp b/quic/api/QuicTransportFunctions.cpp index 4c7252d86..785e1f1f7 100644 --- a/quic/api/QuicTransportFunctions.cpp +++ b/quic/api/QuicTransportFunctions.cpp @@ -496,7 +496,8 @@ void updateConnection( ++conn.outstandingHandshakePacketsCount; conn.lossState.lastHandshakePacketSentTime = pkt.time; } - conn.lossState.lastRetransmittablePacketSentTime = pkt.time; + conn.lossState.lastRetransmittablePacketSentTimes[packetNumberSpace] = + pkt.time; if (pkt.associatedEvent) { CHECK_EQ(packetNumberSpace, PacketNumberSpace::AppData); ++conn.outstandingClonedPacketsCount; diff --git a/quic/api/test/QuicTransportTest.cpp b/quic/api/test/QuicTransportTest.cpp index 9b2a213c7..4d9146fb9 100644 --- a/quic/api/test/QuicTransportTest.cpp +++ b/quic/api/test/QuicTransportTest.cpp @@ -1022,9 +1022,9 @@ TEST_F(QuicTransportTest, ClonePathChallenge) { // knock every handshake outstanding packets out conn.outstandingHandshakePacketsCount = 0; conn.outstandingPackets.clear(); - conn.lossState.initialLossTime.clear(); - conn.lossState.handshakeLossTime.clear(); - conn.lossState.appDataLossTime.clear(); + for (auto& t : conn.lossState.lossTimes) { + t.clear(); + } PathChallengeFrame pathChallenge(123); conn.pendingEvents.pathChallenge = pathChallenge; @@ -1057,9 +1057,9 @@ TEST_F(QuicTransportTest, OnlyClonePathValidationIfOutstanding) { // knock every handshake outstanding packets out conn.outstandingHandshakePacketsCount = 0; conn.outstandingPackets.clear(); - conn.lossState.initialLossTime.clear(); - conn.lossState.handshakeLossTime.clear(); - conn.lossState.appDataLossTime.clear(); + for (auto& t : conn.lossState.lossTimes) { + t.clear(); + } PathChallengeFrame pathChallenge(123); conn.pendingEvents.pathChallenge = pathChallenge; @@ -1199,9 +1199,9 @@ TEST_F(QuicTransportTest, ClonePathResponse) { // knock every handshake outstanding packets out conn.outstandingHandshakePacketsCount = 0; conn.outstandingPackets.clear(); - conn.lossState.initialLossTime.clear(); - conn.lossState.handshakeLossTime.clear(); - conn.lossState.appDataLossTime.clear(); + for (auto& t : conn.lossState.lossTimes) { + t.clear(); + } EXPECT_EQ(conn.pendingEvents.frames.size(), 0); PathResponseFrame pathResponse(123); @@ -1282,9 +1282,9 @@ TEST_F(QuicTransportTest, CloneNewConnectionIdFrame) { // knock every handshake outstanding packets out conn.outstandingHandshakePacketsCount = 0; conn.outstandingPackets.clear(); - conn.lossState.initialLossTime.clear(); - conn.lossState.handshakeLossTime.clear(); - conn.lossState.appDataLossTime.clear(); + for (auto& t : conn.lossState.lossTimes) { + t.clear(); + } NewConnectionIdFrame newConnId( 1, 0, ConnectionId({2, 4, 2, 3}), StatelessResetToken()); @@ -1420,9 +1420,9 @@ TEST_F(QuicTransportTest, CloneRetireConnectionIdFrame) { // knock every handshake outstanding packets out conn.outstandingHandshakePacketsCount = 0; conn.outstandingPackets.clear(); - conn.lossState.initialLossTime.clear(); - conn.lossState.handshakeLossTime.clear(); - conn.lossState.appDataLossTime.clear(); + for (auto& t : conn.lossState.lossTimes) { + t.clear(); + } RetireConnectionIdFrame retireConnId(1); sendSimpleFrame(conn, retireConnId); diff --git a/quic/codec/Types.h b/quic/codec/Types.h index 367f37641..7ff85d447 100644 --- a/quic/codec/Types.h +++ b/quic/codec/Types.h @@ -42,6 +42,8 @@ enum class PacketNumberSpace : uint8_t { Initial, Handshake, AppData, + // MAX has to be updated whenever other enumerators are added to this enum + MAX = PacketNumberSpace::AppData }; constexpr uint8_t kHeaderFormMask = 0x80; diff --git a/quic/common/EnumArray.h b/quic/common/EnumArray.h new file mode 100644 index 000000000..abe1463b5 --- /dev/null +++ b/quic/common/EnumArray.h @@ -0,0 +1,53 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + * + */ + +#pragma once + +#include +#include + +namespace quic { + +// A generic class that extends std::array to be indexable using an Enum. +// The enum K has to list enumerators with all values between 0 and K::MAX +// (inclusive) and no others + +template +class EnumArray : public std::array { + public: + using IntType = typename std::underlying_type::type; + static constexpr IntType ArraySize = IntType(K::MAX) + 1; + constexpr const V& operator[](K key) const { + size_t ik = keyToInt(key); + return this->std::array::operator[](ik); + } + constexpr V& operator[](K key) { + size_t ik = keyToInt(key); + return this->std::array::operator[](ik); + } + // Returns all valid values for the enum + [[nodiscard]] constexpr std::array keys() const { + return keyArrayHelper(std::make_integer_sequence{}); + } + + private: + constexpr IntType keyToInt(K key) const { + auto ik = static_cast(key); + DCHECK(ik >= 0 && ik < ArraySize); + return ik; + } + + template + constexpr auto keyArrayHelper(std::integer_sequence) const { + return std::array{static_cast(i)...}; + } + + std::array arr; +}; + +} // namespace quic diff --git a/quic/loss/QuicLossFunctions.h b/quic/loss/QuicLossFunctions.h index 288d10277..6677a4357 100644 --- a/quic/loss/QuicLossFunctions.h +++ b/quic/loss/QuicLossFunctions.h @@ -64,8 +64,11 @@ std::pair calculateAlarmDuration(const QuicConnectionStateBase& conn) { std::chrono::microseconds alarmDuration; folly::Optional alarmMethod; - TimePoint lastSentPacketTime = - conn.lossState.lastRetransmittablePacketSentTime; + auto lastSentPacketTimeAndSpace = earliestTimeAndSpace( + conn.lossState.lastRetransmittablePacketSentTimes, + canSetLossTimerForAppData(conn)); + DCHECK(lastSentPacketTimeAndSpace.first.hasValue()); + TimePoint lastSentPacketTime = lastSentPacketTimeAndSpace.first.value(); auto lossTimeAndSpace = earliestLossTimer(conn); if (lossTimeAndSpace.first) { if (*lossTimeAndSpace.first > lastSentPacketTime) { diff --git a/quic/loss/test/QuicLossFunctionsTest.cpp b/quic/loss/test/QuicLossFunctionsTest.cpp index 221f94658..091a9de47 100644 --- a/quic/loss/test/QuicLossFunctionsTest.cpp +++ b/quic/loss/test/QuicLossFunctionsTest.cpp @@ -83,6 +83,8 @@ class QuicLossFunctionsTest : public TestWithParam { ServerConnectionIdParams params(0, 0, 0); conn->connIdAlgo = connIdAlgo_.get(); conn->serverConnectionId = connIdAlgo_->encodeConnectionId(params); + // for canSetLossTimerForAppData() + conn->oneRttWriteCipher = createNoOpAead(); return conn; } @@ -187,7 +189,7 @@ PacketNum QuicLossFunctionsTest::sendPacket( conn.outstandingHandshakePacketsCount++; conn.lossState.lastHandshakePacketSentTime = time; } - conn.lossState.lastRetransmittablePacketSentTime = time; + conn.lossState.lastRetransmittablePacketSentTimes[packetNumberSpace] = time; if (conn.congestionController) { conn.congestionController->onPacketSent(outstandingPacket); } @@ -900,7 +902,7 @@ TEST_F(QuicLossFunctionsTest, TestTimeReordering) { auto packetNum = getFirstOutstandingPacket(*conn, PacketNumberSpace::AppData) ->packet.header.getPacketSequenceNum(); EXPECT_EQ(packetNum, 6); - EXPECT_TRUE(conn->lossState.appDataLossTime); + EXPECT_TRUE(conn->lossState.lossTimes[PacketNumberSpace::AppData]); } TEST_F(QuicLossFunctionsTest, LossTimePreemptsCryptoTimer) { @@ -923,10 +925,11 @@ TEST_F(QuicLossFunctionsTest, LossTimePreemptsCryptoTimer) { lossTime, PacketNumberSpace::Handshake); EXPECT_TRUE(lostPackets.empty()); - EXPECT_TRUE(conn->lossState.handshakeLossTime.hasValue()); + EXPECT_TRUE( + conn->lossState.lossTimes[PacketNumberSpace::Handshake].hasValue()); EXPECT_EQ( expectedDelayUntilLost + sendTime, - conn->lossState.handshakeLossTime.value()); + conn->lossState.lossTimes[PacketNumberSpace::Handshake].value()); MockClock::mockNow = [=]() { return sendTime; }; auto alarm = calculateAlarmDuration(*conn); @@ -945,7 +948,8 @@ TEST_F(QuicLossFunctionsTest, LossTimePreemptsCryptoTimer) { onLossDetectionAlarm( *conn, testingLossMarkFunc(lostPackets)); EXPECT_EQ(1, lostPackets.size()); - EXPECT_FALSE(conn->lossState.handshakeLossTime.hasValue()); + EXPECT_FALSE( + conn->lossState.lossTimes[PacketNumberSpace::Handshake].hasValue()); EXPECT_TRUE(conn->outstandingPackets.empty()); } @@ -1081,7 +1085,7 @@ TEST_F(QuicLossFunctionsTest, AlarmDurationHasLossTime) { TimePoint lastPacketSentTime = Clock::now(); auto thisMoment = lastPacketSentTime; MockClock::mockNow = [=]() { return thisMoment; }; - conn->lossState.appDataLossTime = thisMoment + 100ms; + conn->lossState.lossTimes[PacketNumberSpace::AppData] = thisMoment + 100ms; conn->lossState.srtt = 200ms; conn->lossState.lrtt = 150ms; @@ -1099,7 +1103,8 @@ TEST_F(QuicLossFunctionsTest, AlarmDurationLossTimeIsZero) { TimePoint lastPacketSentTime = Clock::now(); auto thisMoment = lastPacketSentTime + 200ms; MockClock::mockNow = [=]() { return thisMoment; }; - conn->lossState.appDataLossTime = lastPacketSentTime + 100ms; + conn->lossState.lossTimes[PacketNumberSpace::AppData] = + lastPacketSentTime + 100ms; conn->lossState.srtt = 200ms; conn->lossState.lrtt = 150ms; diff --git a/quic/server/test/QuicServerTransportTest.cpp b/quic/server/test/QuicServerTransportTest.cpp index 53acc6213..5497a5165 100644 --- a/quic/server/test/QuicServerTransportTest.cpp +++ b/quic/server/test/QuicServerTransportTest.cpp @@ -1741,9 +1741,9 @@ TEST_F(QuicServerTransportTest, TestCloneStopSending) { // knock every handshake outstanding packets out server->getNonConstConn().outstandingHandshakePacketsCount = 0; server->getNonConstConn().outstandingPackets.clear(); - server->getNonConstConn().lossState.initialLossTime.clear(); - server->getNonConstConn().lossState.handshakeLossTime.clear(); - server->getNonConstConn().lossState.appDataLossTime.clear(); + for (auto& t : server->getNonConstConn().lossState.lossTimes) { + t.clear(); + } server->stopSending(streamId, GenericApplicationErrorCode::UNKNOWN); loopForWrites(); diff --git a/quic/state/QuicStateFunctions.cpp b/quic/state/QuicStateFunctions.cpp index ecb0147a5..91f48147b 100644 --- a/quic/state/QuicStateFunctions.cpp +++ b/quic/state/QuicStateFunctions.cpp @@ -22,22 +22,6 @@ getPreviousOutstandingPacket( return packetNumberSpace == op.packet.header.getPacketNumberSpace(); }); } - -template -std::pair, A> minOptional( - std::pair, A> p1, - std::pair, A> p2) { - if (!p1.first && !p2.first) { - return std::make_pair(folly::none, p1.second); - } - if (!p1.first) { - return p2; - } - if (!p2.first) { - return p1; - } - return *p1.first < *p2.first ? p1 : p2; -} } // namespace namespace quic { @@ -294,26 +278,37 @@ bool hasReceivedPackets(const QuicConnectionStateBase& conn) noexcept { folly::Optional& getLossTime( QuicConnectionStateBase& conn, PacketNumberSpace pnSpace) noexcept { - switch (pnSpace) { - case PacketNumberSpace::Initial: - return conn.lossState.initialLossTime; - case PacketNumberSpace::Handshake: - return conn.lossState.handshakeLossTime; - case PacketNumberSpace::AppData: - return conn.lossState.appDataLossTime; - } - folly::assume_unreachable(); + return conn.lossState.lossTimes[pnSpace]; +} + +bool canSetLossTimerForAppData(const QuicConnectionStateBase& conn) noexcept { + return conn.oneRttWriteCipher != nullptr; } std::pair, PacketNumberSpace> earliestLossTimer( const QuicConnectionStateBase& conn) noexcept { - return minOptional( - minOptional( - std::make_pair( - conn.lossState.initialLossTime, PacketNumberSpace::Initial), - std::make_pair( - conn.lossState.handshakeLossTime, PacketNumberSpace::Handshake)), - std::make_pair( - conn.lossState.appDataLossTime, PacketNumberSpace::AppData)); + bool considerAppData = canSetLossTimerForAppData(conn); + return earliestTimeAndSpace(conn.lossState.lossTimes, considerAppData); } + +std::pair, PacketNumberSpace> earliestTimeAndSpace( + const EnumArray>& times, + bool considerAppData) noexcept { + std::pair, PacketNumberSpace> res = { + folly::none, PacketNumberSpace::Initial}; + for (PacketNumberSpace pns : times.keys()) { + if (!times[pns]) { + continue; + } + if (pns == PacketNumberSpace::AppData && !considerAppData) { + continue; + } + if (!res.first || *res.first > *times[pns]) { + res.first = times[pns]; + res.second = pns; + } + } + return res; +} + } // namespace quic diff --git a/quic/state/QuicStateFunctions.h b/quic/state/QuicStateFunctions.h index 139f81b18..8339a2af5 100644 --- a/quic/state/QuicStateFunctions.h +++ b/quic/state/QuicStateFunctions.h @@ -105,6 +105,12 @@ folly::Optional& getLossTime( QuicConnectionStateBase& conn, PacketNumberSpace pnSpace) noexcept; +bool canSetLossTimerForAppData(const QuicConnectionStateBase& conn) noexcept; + std::pair, PacketNumberSpace> earliestLossTimer( const QuicConnectionStateBase& conn) noexcept; + +std::pair, PacketNumberSpace> earliestTimeAndSpace( + const EnumArray>& times, + bool considerAppData) noexcept; } // namespace quic diff --git a/quic/state/StateData.h b/quic/state/StateData.h index 001a998fa..f4a8435a2 100644 --- a/quic/state/StateData.h +++ b/quic/state/StateData.h @@ -13,6 +13,7 @@ #include #include #include +#include #include #include #include @@ -445,8 +446,7 @@ struct LossState { // Reordering threshold used uint32_t reorderingThreshold{kReorderingThreshold}; // Timer for time reordering detection or early retransmit alarm. - folly::Optional initialLossTime, handshakeLossTime, - appDataLossTime; + EnumArray> lossTimes; // Current method by which the loss detection alarm is set. AlarmMethod currentAlarmMethod{AlarmMethod::EarlyRetransmitOrReordering}; // Total number of packet retransmitted on this connection, including packet @@ -483,8 +483,10 @@ struct LossState { folly::Optional lastAckedPacketSentTime; // The latest time a packet is acked folly::Optional lastAckedTime; - // The time when last retranmittable packet is sent - TimePoint lastRetransmittablePacketSentTime; + // The time when last retranmittable packet is sent for every packet number + // space + EnumArray> + lastRetransmittablePacketSentTimes; }; class Logger; diff --git a/quic/state/test/QuicStateFunctionsTest.cpp b/quic/state/test/QuicStateFunctionsTest.cpp index c850d7de7..7dfaa498c 100644 --- a/quic/state/test/QuicStateFunctionsTest.cpp +++ b/quic/state/test/QuicStateFunctionsTest.cpp @@ -621,18 +621,26 @@ TEST_F(QuicStateFunctionsTest, EarliestLossTimer) { QuicConnectionStateBase conn(QuicNodeType::Server); EXPECT_FALSE(earliestLossTimer(conn).first.hasValue()); auto currentTime = Clock::now(); - conn.lossState.initialLossTime = currentTime; + + // Before handshake completed + conn.lossState.lossTimes[PacketNumberSpace::Initial] = currentTime; EXPECT_EQ(PacketNumberSpace::Initial, earliestLossTimer(conn).second); EXPECT_EQ(currentTime, earliestLossTimer(conn).first.value()); - conn.lossState.appDataLossTime = currentTime - 1s; - EXPECT_EQ(PacketNumberSpace::AppData, earliestLossTimer(conn).second); - EXPECT_EQ(currentTime - 1s, earliestLossTimer(conn).first.value()); - conn.lossState.handshakeLossTime = currentTime + 1s; - EXPECT_EQ(PacketNumberSpace::AppData, earliestLossTimer(conn).second); - EXPECT_EQ(currentTime - 1s, earliestLossTimer(conn).first.value()); - conn.lossState.appDataLossTime = currentTime + 1s; + conn.lossState.lossTimes[PacketNumberSpace::AppData] = currentTime - 2s; EXPECT_EQ(PacketNumberSpace::Initial, earliestLossTimer(conn).second); EXPECT_EQ(currentTime, earliestLossTimer(conn).first.value()); + conn.lossState.lossTimes[PacketNumberSpace::Handshake] = currentTime - 1s; + EXPECT_EQ(PacketNumberSpace::Handshake, earliestLossTimer(conn).second); + EXPECT_EQ(currentTime - 1s, earliestLossTimer(conn).first.value()); + + conn.oneRttWriteCipher = createNoOpAead(); + + // After one-rtt cipher is available + EXPECT_EQ(PacketNumberSpace::AppData, earliestLossTimer(conn).second); + EXPECT_EQ(currentTime - 2s, earliestLossTimer(conn).first.value()); + conn.lossState.lossTimes[PacketNumberSpace::AppData] = currentTime + 1s; + EXPECT_EQ(PacketNumberSpace::Handshake, earliestLossTimer(conn).second); + EXPECT_EQ(currentTime - 1s, earliestLossTimer(conn).first.value()); } TEST_P(QuicStateFunctionsTest, CloseTranportStateChange) {