diff --git a/quic/api/test/QuicTransportTest.cpp b/quic/api/test/QuicTransportTest.cpp index 9c2435ebd..943ca250a 100644 --- a/quic/api/test/QuicTransportTest.cpp +++ b/quic/api/test/QuicTransportTest.cpp @@ -1001,7 +1001,9 @@ TEST_F(QuicTransportTest, ClonePathChallenge) { conn.outstandingHandshakePacketsCount = 0; conn.outstandingPureAckPacketsCount = 0; conn.outstandingPackets.clear(); - conn.lossState.lossTime.clear(); + conn.lossState.initialLossTime.clear(); + conn.lossState.handshakeLossTime.clear(); + conn.lossState.appDataLossTime.clear(); PathChallengeFrame pathChallenge(123); conn.pendingEvents.pathChallenge = pathChallenge; @@ -1062,7 +1064,9 @@ TEST_F(QuicTransportTest, OnlyClonePathValidationIfOutstanding) { conn.outstandingHandshakePacketsCount = 0; conn.outstandingPureAckPacketsCount = 0; conn.outstandingPackets.clear(); - conn.lossState.lossTime.clear(); + conn.lossState.initialLossTime.clear(); + conn.lossState.handshakeLossTime.clear(); + conn.lossState.appDataLossTime.clear(); PathChallengeFrame pathChallenge(123); conn.pendingEvents.pathChallenge = pathChallenge; @@ -1225,7 +1229,9 @@ TEST_F(QuicTransportTest, ClonePathResponse) { conn.outstandingHandshakePacketsCount = 0; conn.outstandingPureAckPacketsCount = 0; conn.outstandingPackets.clear(); - conn.lossState.lossTime.clear(); + conn.lossState.initialLossTime.clear(); + conn.lossState.handshakeLossTime.clear(); + conn.lossState.appDataLossTime.clear(); EXPECT_EQ(conn.pendingEvents.frames.size(), 0); PathResponseFrame pathResponse(123); @@ -1337,7 +1343,9 @@ TEST_F(QuicTransportTest, CloneNewConnectionIdFrame) { conn.outstandingHandshakePacketsCount = 0; conn.outstandingPureAckPacketsCount = 0; conn.outstandingPackets.clear(); - conn.lossState.lossTime.clear(); + conn.lossState.initialLossTime.clear(); + conn.lossState.handshakeLossTime.clear(); + conn.lossState.appDataLossTime.clear(); NewConnectionIdFrame newConnId( 1, ConnectionId({2, 4, 2, 3}), StatelessResetToken()); diff --git a/quic/loss/QuicLossFunctions.h b/quic/loss/QuicLossFunctions.h index 822b1a172..254a55caa 100644 --- a/quic/loss/QuicLossFunctions.h +++ b/quic/loss/QuicLossFunctions.h @@ -65,7 +65,18 @@ calculateAlarmDuration(const QuicConnectionStateBase& conn) { folly::Optional alarmMethod; TimePoint lastSentPacketTime = conn.lossState.lastRetransmittablePacketSentTime; - if (conn.outstandingHandshakePacketsCount > 0) { + auto lossTimeAndSpace = earliestLossTimer(conn); + if (lossTimeAndSpace.first) { + if (*lossTimeAndSpace.first > lastSentPacketTime) { + // We do this so that lastSentPacketTime + alarmDuration = lossTime + alarmDuration = std::chrono::duration_cast( + *lossTimeAndSpace.first - lastSentPacketTime); + } else { + // This should trigger an immediate alarm. + alarmDuration = 0us; + } + alarmMethod = LossState::AlarmMethod::EarlyRetransmitOrReordering; + } else if (conn.outstandingHandshakePacketsCount > 0) { if (conn.lossState.srtt == 0us) { alarmDuration = kDefaultInitialRtt * 2; } else { @@ -78,16 +89,6 @@ calculateAlarmDuration(const QuicConnectionStateBase& conn) { // Handshake packet loss timer shouldn't be affected by other packets. lastSentPacketTime = conn.lossState.lastHandshakePacketSentTime; DCHECK_NE(lastSentPacketTime.time_since_epoch().count(), 0); - } else if (conn.lossState.lossTime) { - if (*conn.lossState.lossTime > lastSentPacketTime) { - // We do this so that lastSentPacketTime + alarmDuration = lossTime - alarmDuration = std::chrono::duration_cast( - *conn.lossState.lossTime - lastSentPacketTime); - } else { - // This should trigger an immediate alarm. - alarmDuration = 0us; - } - alarmMethod = LossState::AlarmMethod::EarlyRetransmitOrReordering; } else { auto ptoTimeout = calculatePTO(conn); ptoTimeout *= 1 << std::min(conn.lossState.ptoCount, (uint32_t)31); @@ -208,7 +209,7 @@ folly::Optional detectLossPackets( const LossVisitor& lossVisitor, TimePoint lossTime, PacketNumberSpace pnSpace) { - conn.lossState.lossTime.clear(); + getLossTime(conn, pnSpace).clear(); std::chrono::microseconds delayUntilLost = std::max(conn.lossState.srtt, conn.lossState.lrtt) * 9 / 8; VLOG(10) << __func__ << " outstanding=" << conn.outstandingPackets.size() @@ -272,10 +273,9 @@ folly::Optional detectLossPackets( iter = conn.outstandingPackets.erase(iter); } - // Because we handle handshake timer before lossTime, lossTime is used by - // AppData space only. So this is fine. auto earliest = getFirstOutstandingPacket(conn, pnSpace); - for (; earliest != conn.outstandingPackets.end(); earliest++) { + for (; earliest != conn.outstandingPackets.end(); + earliest = getNextOutstandingPacket(conn, pnSpace, earliest + 1)) { if (!earliest->pureAck && (!earliest->associatedEvent || conn.outstandingPacketEvents.count(*earliest->associatedEvent))) { @@ -289,7 +289,7 @@ folly::Optional detectLossPackets( << conn.outstandingPackets.empty() << " delayUntilLost" << delayUntilLost.count() << "us" << " " << conn; - conn.lossState.lossTime = delayUntilLost + earliest->time; + getLossTime(conn, pnSpace) = delayUntilLost + earliest->time; } if (lossEvent.largestLostPacketNum.hasValue()) { DCHECK(lossEvent.largestLostSentTime && lossEvent.smallestLostSentTime); @@ -374,18 +374,16 @@ void onLossDetectionAlarm( VLOG(10) << "Transmission alarm fired with no outstanding packets " << conn; return; } - // TODO: The specs prioritize EarlyRetransmitOrReordering over crypto timer. - if (conn.lossState.currentAlarmMethod == LossState::AlarmMethod::Handshake) { - onHandshakeAlarm(conn, lossVisitor); - } else if ( - conn.lossState.currentAlarmMethod == + if (conn.lossState.currentAlarmMethod == LossState::AlarmMethod::EarlyRetransmitOrReordering) { + auto lossTimeAndSpace = earliestLossTimer(conn); + CHECK(lossTimeAndSpace.first); auto lossEvent = detectLossPackets( conn, - getAckState(conn, PacketNumberSpace::AppData).largestAckedByPeer, + getAckState(conn, lossTimeAndSpace.second).largestAckedByPeer, lossVisitor, now, - PacketNumberSpace::AppData); + lossTimeAndSpace.second); if (conn.congestionController && lossEvent) { DCHECK(lossEvent->largestLostSentTime && lossEvent->smallestLostSentTime); lossEvent->persistentCongestion = isPersistentCongestion( @@ -395,6 +393,9 @@ void onLossDetectionAlarm( conn.congestionController->onPacketAckOrLoss( folly::none, std::move(lossEvent)); } + } else if ( + conn.lossState.currentAlarmMethod == LossState::AlarmMethod::Handshake) { + onHandshakeAlarm(conn, lossVisitor); } else { onPTOAlarm(conn); } diff --git a/quic/loss/test/QuicLossFunctionsTest.cpp b/quic/loss/test/QuicLossFunctionsTest.cpp index ed52d95a5..53a1bc76f 100644 --- a/quic/loss/test/QuicLossFunctionsTest.cpp +++ b/quic/loss/test/QuicLossFunctionsTest.cpp @@ -766,7 +766,52 @@ TEST_F(QuicLossFunctionsTest, TestTimeReordering) { ->packet.header, [](const auto& h) { return h.getPacketSequenceNum(); }); EXPECT_EQ(packetNum, 6); - EXPECT_TRUE(conn->lossState.lossTime); + EXPECT_TRUE(conn->lossState.appDataLossTime); +} + +TEST_F(QuicLossFunctionsTest, LossTimePreemptsCryptoTimer) { + std::vector lostPackets; + auto conn = createConn(); + conn->lossState.srtt = 100ms; + conn->lossState.lrtt = 100ms; + auto expectedDelayUntilLost = 900000us / 8; + auto sendTime = Clock::now(); + // Send two: + sendPacket(*conn, sendTime, false, folly::none, PacketType::Handshake); + PacketNum second = sendPacket( + *conn, sendTime + 1ms, false, folly::none, PacketType::Handshake); + auto lossTime = sendTime + 50ms; + detectLossPackets( + *conn, + second, + testingLossMarkFunc(lostPackets), + lossTime, + PacketNumberSpace::Handshake); + EXPECT_TRUE(lostPackets.empty()); + EXPECT_TRUE(conn->lossState.handshakeLossTime.hasValue()); + EXPECT_EQ( + expectedDelayUntilLost + sendTime, + conn->lossState.handshakeLossTime.value()); + + MockClock::mockNow = [=]() { return sendTime; }; + auto alarm = calculateAlarmDuration(*conn); + EXPECT_EQ( + std::chrono::duration_cast( + expectedDelayUntilLost), + alarm.first); + EXPECT_EQ(LossState::AlarmMethod::EarlyRetransmitOrReordering, alarm.second); + // Manual set lossState. Calling setLossDetectionAlarm requries a Timeout + conn->lossState.currentAlarmMethod = alarm.second; + + // Second packet gets acked: + getAckState(*conn, PacketNumberSpace::Handshake).largestAckedByPeer = second; + conn->outstandingPackets.pop_back(); + MockClock::mockNow = [=]() { return sendTime + expectedDelayUntilLost + 5s; }; + onLossDetectionAlarm( + *conn, testingLossMarkFunc(lostPackets)); + EXPECT_EQ(1, lostPackets.size()); + EXPECT_FALSE(conn->lossState.handshakeLossTime.hasValue()); + EXPECT_TRUE(conn->outstandingPackets.empty()); } TEST_F(QuicLossFunctionsTest, PTONoLongerMarksPacketsToBeRetransmitted) { @@ -805,6 +850,7 @@ TEST_F( .WillRepeatedly(Return()); std::vector lostPackets; PacketNum expectedLargestLostNum = 0; + conn->lossState.currentAlarmMethod = LossState::AlarmMethod::Handshake; for (auto i = 0; i < 10; i++) { // Half are handshakes auto sentPacketNum = sendPacket( @@ -840,6 +886,7 @@ TEST_F( TEST_F(QuicLossFunctionsTest, HandshakeAlarmWithOneRttCipher) { auto conn = createClientConn(); conn->oneRttWriteCipher = createNoOpAead(); + conn->lossState.currentAlarmMethod = LossState::AlarmMethod::Handshake; std::vector lostPackets; sendPacket( *conn, TimePoint(100ms), false, folly::none, PacketType::Handshake); @@ -933,7 +980,7 @@ TEST_F(QuicLossFunctionsTest, AlarmDurationHasLossTime) { TimePoint lastPacketSentTime = Clock::now(); auto thisMoment = lastPacketSentTime; MockClock::mockNow = [=]() { return thisMoment; }; - conn->lossState.lossTime = thisMoment + 100ms; + conn->lossState.appDataLossTime = thisMoment + 100ms; conn->lossState.srtt = 200ms; conn->lossState.lrtt = 150ms; @@ -951,7 +998,7 @@ TEST_F(QuicLossFunctionsTest, AlarmDurationLossTimeIsZero) { TimePoint lastPacketSentTime = Clock::now(); auto thisMoment = lastPacketSentTime + 200ms; MockClock::mockNow = [=]() { return thisMoment; }; - conn->lossState.lossTime = lastPacketSentTime + 100ms; + conn->lossState.appDataLossTime = lastPacketSentTime + 100ms; conn->lossState.srtt = 200ms; conn->lossState.lrtt = 150ms; diff --git a/quic/server/test/QuicServerTransportTest.cpp b/quic/server/test/QuicServerTransportTest.cpp index 5ef036928..dc9f11133 100644 --- a/quic/server/test/QuicServerTransportTest.cpp +++ b/quic/server/test/QuicServerTransportTest.cpp @@ -1540,7 +1540,9 @@ TEST_F(QuicServerTransportTest, TestCloneStopSending) { server->getNonConstConn().outstandingHandshakePacketsCount = 0; server->getNonConstConn().outstandingPureAckPacketsCount = 0; server->getNonConstConn().outstandingPackets.clear(); - server->getNonConstConn().lossState.lossTime.clear(); + server->getNonConstConn().lossState.initialLossTime.clear(); + server->getNonConstConn().lossState.handshakeLossTime.clear(); + server->getNonConstConn().lossState.appDataLossTime.clear(); server->stopSending(streamId, GenericApplicationErrorCode::UNKNOWN); loopForWrites(); diff --git a/quic/state/QuicStateFunctions.cpp b/quic/state/QuicStateFunctions.cpp index c2ea9fbe7..0a018dcb9 100644 --- a/quic/state/QuicStateFunctions.cpp +++ b/quic/state/QuicStateFunctions.cpp @@ -11,6 +11,24 @@ #include #include +namespace { +template +std::pair, A> minOptional( + std::pair, A> p1, + std::pair, A> p2) { + if (!p1.first && !p2.first) { + return std::make_pair(folly::none, p1.second); + } + if (!p1.first) { + return p2; + } + if (!p2.first) { + return p1; + } + return *p1.first < *p2.first ? p1 : p2; +} +} // namespace + namespace quic { void updateRtt( @@ -240,4 +258,30 @@ bool hasReceivedPackets(const QuicConnectionStateBase& conn) noexcept { conn.ackStates.handshakeAckState.largestReceivedPacketNum || conn.ackStates.appDataAckState.largestReceivedPacketNum; } + +folly::Optional& getLossTime( + QuicConnectionStateBase& conn, + PacketNumberSpace pnSpace) noexcept { + switch (pnSpace) { + case PacketNumberSpace::Initial: + return conn.lossState.initialLossTime; + case PacketNumberSpace::Handshake: + return conn.lossState.handshakeLossTime; + case PacketNumberSpace::AppData: + return conn.lossState.appDataLossTime; + } + folly::assume_unreachable(); +} + +std::pair, PacketNumberSpace> earliestLossTimer( + const QuicConnectionStateBase& conn) noexcept { + return minOptional( + minOptional( + std::make_pair( + conn.lossState.initialLossTime, PacketNumberSpace::Initial), + std::make_pair( + conn.lossState.handshakeLossTime, PacketNumberSpace::Handshake)), + std::make_pair( + conn.lossState.appDataLossTime, PacketNumberSpace::AppData)); +} } // namespace quic diff --git a/quic/state/QuicStateFunctions.h b/quic/state/QuicStateFunctions.h index 387b1c512..b19d30939 100644 --- a/quic/state/QuicStateFunctions.h +++ b/quic/state/QuicStateFunctions.h @@ -114,4 +114,11 @@ bool hasNotReceivedNewPacketsSinceLastCloseSent( void updateLargestReceivedPacketsAtLastCloseSent( QuicConnectionStateBase& conn) noexcept; + +folly::Optional& getLossTime( + QuicConnectionStateBase& conn, + PacketNumberSpace pnSpace) noexcept; + +std::pair, PacketNumberSpace> earliestLossTimer( + const QuicConnectionStateBase& conn) noexcept; } // namespace quic diff --git a/quic/state/StateData.h b/quic/state/StateData.h index deb30bb4d..d417eb2ac 100644 --- a/quic/state/StateData.h +++ b/quic/state/StateData.h @@ -313,9 +313,10 @@ struct LossState { // Reordering threshold used uint32_t reorderingThreshold{kReorderingThreshold}; // Timer for time reordering detection or early retransmit alarm. - folly::Optional lossTime; + folly::Optional initialLossTime, handshakeLossTime, + appDataLossTime; // Current method by which the loss detection alarm is set. - AlarmMethod currentAlarmMethod{AlarmMethod::Handshake}; + AlarmMethod currentAlarmMethod{AlarmMethod::EarlyRetransmitOrReordering}; // Total number of packet retransmitted on this connection, including packet // clones, retransmitted clones, handshake and rejected zero rtt packets. uint32_t rtxCount{0}; diff --git a/quic/state/test/QuicStateFunctionsTest.cpp b/quic/state/test/QuicStateFunctionsTest.cpp index 7321c9168..9cb217c8c 100644 --- a/quic/state/test/QuicStateFunctionsTest.cpp +++ b/quic/state/test/QuicStateFunctionsTest.cpp @@ -529,6 +529,24 @@ TEST_P(QuicStateFunctionsTest, HasNotReceivedNewPacketsSinceLastClose) { EXPECT_TRUE(hasReceivedPacketsAtLastCloseSent(conn)); } +TEST_F(QuicStateFunctionsTest, EarliestLossTimer) { + QuicConnectionStateBase conn(QuicNodeType::Server); + EXPECT_FALSE(earliestLossTimer(conn).first.hasValue()); + auto currentTime = Clock::now(); + conn.lossState.initialLossTime = currentTime; + EXPECT_EQ(PacketNumberSpace::Initial, earliestLossTimer(conn).second); + EXPECT_EQ(currentTime, earliestLossTimer(conn).first.value()); + conn.lossState.appDataLossTime = currentTime - 1s; + EXPECT_EQ(PacketNumberSpace::AppData, earliestLossTimer(conn).second); + EXPECT_EQ(currentTime - 1s, earliestLossTimer(conn).first.value()); + conn.lossState.handshakeLossTime = currentTime + 1s; + EXPECT_EQ(PacketNumberSpace::AppData, earliestLossTimer(conn).second); + EXPECT_EQ(currentTime - 1s, earliestLossTimer(conn).first.value()); + conn.lossState.appDataLossTime= currentTime + 1s; + EXPECT_EQ(PacketNumberSpace::Initial, earliestLossTimer(conn).second); + EXPECT_EQ(currentTime, earliestLossTimer(conn).first.value()); +} + INSTANTIATE_TEST_CASE_P( QuicStateFunctionsTests, QuicStateFunctionsTest,