diff --git a/quic/api/QuicPacketScheduler.cpp b/quic/api/QuicPacketScheduler.cpp index a725ed4ab..87101639d 100644 --- a/quic/api/QuicPacketScheduler.cpp +++ b/quic/api/QuicPacketScheduler.cpp @@ -509,10 +509,7 @@ CloningScheduler::CloningScheduler( cipherOverhead_(cipherOverhead) {} bool CloningScheduler::hasData() const { - return frameScheduler_.hasData() || - (!conn_.outstandings.packets.empty() && - conn_.outstandings.packets.size() != - conn_.outstandings.handshakePacketsCount); + return frameScheduler_.hasData() || (!conn_.outstandings.packets.empty()); } SchedulingResult CloningScheduler::scheduleFramesForPacket( @@ -533,12 +530,15 @@ SchedulingResult CloningScheduler::scheduleFramesForPacket( auto header = builder.getPacketHeader(); std::move(builder).releaseOutputBuffer(); // Look for an outstanding packet that's no larger than the writableBytes - // This is a loop, but it builds at most one packet. - for (auto iter = conn_.outstandings.packets.rbegin(); - iter != conn_.outstandings.packets.rend(); - ++iter) { - auto opPnSpace = iter->packet.header.getPacketNumberSpace(); - if (opPnSpace != PacketNumberSpace::AppData) { + for (auto& outstandingPacket : conn_.outstandings.packets) { + auto opPnSpace = outstandingPacket.packet.header.getPacketNumberSpace(); + // Reusing the RegularQuicPacketBuilder throughout loop bodies will lead to + // frames belong to different original packets being written into the same + // clone packet. So re-create a RegularQuicPacketBuilder every time. + // TODO: We can avoid the copy & rebuild of the header by creating an + // independent header builder. + auto builderPnSpace = builder.getPacketHeader().getPacketNumberSpace(); + if (opPnSpace != builderPnSpace) { continue; } size_t prevSize = 0; @@ -550,8 +550,6 @@ SchedulingResult CloningScheduler::scheduleFramesForPacket( // Reusing the same builder throughout loop bodies will lead to frames // belong to different original packets being written into the same clone // packet. So re-create a builder every time. - auto builderPnSpace = header.getPacketNumberSpace(); - CHECK_EQ(builderPnSpace, PacketNumberSpace::AppData); std::unique_ptr internalBuilder; if (conn_.transportSettings.dataPathType == DataPathType::ChainedMemory) { internalBuilder = std::make_unique( @@ -566,19 +564,16 @@ SchedulingResult CloningScheduler::scheduleFramesForPacket( header, getAckState(conn_, builderPnSpace).largestAckedByPeer.value_or(0)); } - // We shouldn't clone Handshake packet. - if (iter->isHandshake) { - continue; - } // If the packet is already a clone that has been processed, we don't clone // it again. - if (iter->associatedEvent && - conn_.outstandings.packetEvents.count(*iter->associatedEvent) == 0) { + if (outstandingPacket.associatedEvent && + conn_.outstandings.packetEvents.count( + *outstandingPacket.associatedEvent) == 0) { continue; } // I think this only fail if udpSendPacketLen somehow shrinks in the middle // of a connection. - if (iter->encodedSize > writableBytes + cipherOverhead_) { + if (outstandingPacket.encodedSize > writableBytes + cipherOverhead_) { continue; } @@ -595,7 +590,7 @@ SchedulingResult CloningScheduler::scheduleFramesForPacket( // network just fine; Or we can throw away the built packet and send a ping. // Rebuilder will write the rest of frames - auto rebuildResult = rebuilder.rebuildFromPacket(*iter); + auto rebuildResult = rebuilder.rebuildFromPacket(outstandingPacket); if (rebuildResult) { return SchedulingResult( std::move(rebuildResult), std::move(*internalBuilder).buildPacket()); diff --git a/quic/api/QuicTransportBase.cpp b/quic/api/QuicTransportBase.cpp index 6687b90bb..6de3e2f3c 100644 --- a/quic/api/QuicTransportBase.cpp +++ b/quic/api/QuicTransportBase.cpp @@ -413,7 +413,9 @@ void QuicTransportBase::closeImpl( // Don't need outstanding packets. conn_->outstandings.packets.clear(); + conn_->outstandings.initialPacketsCount = 0; conn_->outstandings.handshakePacketsCount = 0; + conn_->outstandings.clonedPacketsCount = 0; // We don't need no congestion control. conn_->congestionController = nullptr; diff --git a/quic/api/QuicTransportFunctions.cpp b/quic/api/QuicTransportFunctions.cpp index 08f09213b..c181bc2a7 100644 --- a/quic/api/QuicTransportFunctions.cpp +++ b/quic/api/QuicTransportFunctions.cpp @@ -140,7 +140,8 @@ uint64_t writeQuicDataToSocketImpl( aead, headerCipher, version); - connection.pendingEvents.numProbePackets = 0; + CHECK_GE(connection.pendingEvents.numProbePackets, written); + connection.pendingEvents.numProbePackets -= written; } auto schedulerBuilder = FrameScheduler::Builder( @@ -648,7 +649,6 @@ void updateConnection( } if (packetEvent) { DCHECK(conn.outstandings.packetEvents.count(*packetEvent)); - DCHECK(!isHandshake); pkt.associatedEvent = std::move(packetEvent); conn.lossState.totalBytesCloned += encodedSize; } @@ -675,17 +675,24 @@ void updateConnection( conn.pathValidationLimiter->onPacketSent(pkt.encodedSize); } if (pkt.isHandshake) { - ++conn.outstandings.handshakePacketsCount; + if (!pkt.associatedEvent) { + if (packetNumberSpace == PacketNumberSpace::Initial) { + ++conn.outstandings.initialPacketsCount; + } else { + CHECK_EQ(packetNumberSpace, PacketNumberSpace::Handshake); + ++conn.outstandings.handshakePacketsCount; + } + } conn.lossState.lastHandshakePacketSentTime = pkt.time; } conn.lossState.lastRetransmittablePacketSentTime = pkt.time; if (pkt.associatedEvent) { - CHECK_EQ(packetNumberSpace, PacketNumberSpace::AppData); ++conn.outstandings.clonedPacketsCount; ++conn.lossState.timeoutBasedRtxCount; } auto opCount = conn.outstandings.packets.size(); + DCHECK_GE(opCount, conn.outstandings.initialPacketsCount); DCHECK_GE(opCount, conn.outstandings.handshakePacketsCount); DCHECK_GE(opCount, conn.outstandings.clonedPacketsCount); } @@ -775,8 +782,31 @@ uint64_t writeCryptoAndAckDataToSocket( .cryptoFrames()) .build(); auto builder = LongHeaderBuilder(packetType); + uint64_t written = 0; + auto& cryptoStream = + *getCryptoStream(*connection.cryptoState, encryptionLevel); + if ((connection.pendingEvents.numProbePackets && + cryptoStream.retransmissionBuffer.size()) || + scheduler.hasData()) { + written = writeProbingDataToSocket( + sock, + connection, + srcConnId, + dstConnId, + builder, + LongHeader::typeToPacketNumberSpace(packetType), + scheduler, + std::min( + packetLimit, connection.pendingEvents.numProbePackets), + cleartextCipher, + headerCipher, + version, + token); + CHECK_GE(connection.pendingEvents.numProbePackets, written); + connection.pendingEvents.numProbePackets -= written; + } // Crypto data is written without aead protection. - auto written = writeConnectionDataToSocket( + written += writeConnectionDataToSocket( sock, connection, srcConnId, @@ -785,7 +815,7 @@ uint64_t writeCryptoAndAckDataToSocket( LongHeader::typeToPacketNumberSpace(packetType), scheduler, congestionControlWritableBytes, - packetLimit, + packetLimit - written, cleartextCipher, headerCipher, version, @@ -794,7 +824,7 @@ uint64_t writeCryptoAndAckDataToSocket( << " written crypto and acks data type=" << packetType << " packets=" << written << " " << connection; - DCHECK_GE(packetLimit, written); + CHECK_GE(packetLimit, written); return written; } @@ -1205,7 +1235,8 @@ uint64_t writeProbingDataToSocket( uint8_t probesToSend, const Aead& aead, const PacketNumberCipher& headerCipher, - QuicVersion version) { + QuicVersion version, + const std::string& token) { // Skip a packet number for probing packets to elicit acks increaseNextPacketNum(connection, pnSpace); CloningScheduler cloningScheduler( @@ -1222,9 +1253,12 @@ uint64_t writeProbingDataToSocket( probesToSend, aead, headerCipher, - version); + version, + token); if (probesToSend && !written) { // Fall back to send a ping: + // TODO: Now that Probes can be used for handshake packets. We need to make + // sure we only send Ping here, no other Simple frames. sendSimpleFrame(connection, PingFrame()); auto pingScheduler = std::move(FrameScheduler::Builder( connection, diff --git a/quic/api/QuicTransportFunctions.h b/quic/api/QuicTransportFunctions.h index 6fe93f968..27db8d307 100644 --- a/quic/api/QuicTransportFunctions.h +++ b/quic/api/QuicTransportFunctions.h @@ -278,7 +278,8 @@ uint64_t writeProbingDataToSocket( uint8_t probesToSend, const Aead& aead, const PacketNumberCipher& headerCipher, - QuicVersion version); + QuicVersion version, + const std::string& token = std::string()); HeaderBuilder LongHeaderBuilder(LongHeader::Types packetType); HeaderBuilder ShortHeaderBuilder(); diff --git a/quic/api/test/QuicPacketSchedulerTest.cpp b/quic/api/test/QuicPacketSchedulerTest.cpp index 8219f2427..e23c6abd3 100644 --- a/quic/api/test/QuicPacketSchedulerTest.cpp +++ b/quic/api/test/QuicPacketSchedulerTest.cpp @@ -452,7 +452,76 @@ TEST_F(QuicPacketSchedulerTest, CloningSchedulerTest) { auto result = cloningScheduler.scheduleFramesForPacket( std::move(builder), kDefaultUDPSendPacketLen); EXPECT_TRUE(result.packetEvent.has_value() && result.packet.has_value()); - EXPECT_EQ(packetNum, *result.packetEvent); + EXPECT_EQ(packetNum, result.packetEvent->packetNumber); +} + +TEST_F(QuicPacketSchedulerTest, WriteOnlyOutstandingPacketsTest) { + QuicClientConnectionState conn( + FizzClientQuicHandshakeContext::Builder().build()); + FrameScheduler noopScheduler("frame"); + ASSERT_FALSE(noopScheduler.hasData()); + CloningScheduler cloningScheduler(noopScheduler, conn, "CopyCat", 0); + EXPECT_FALSE(cloningScheduler.hasData()); + auto packetNum = addOutstandingPacket(conn); + // There needs to have retransmittable frame for the rebuilder to work + conn.outstandings.packets.back().packet.frames.push_back( + MaxDataFrame(conn.flowControlState.advertisedMaxOffset)); + EXPECT_TRUE(cloningScheduler.hasData()); + + ASSERT_FALSE(noopScheduler.hasData()); + ShortHeader header( + ProtectionType::KeyPhaseOne, + conn.clientConnectionId.value_or(getTestConnectionId()), + getNextPacketNum(conn, PacketNumberSpace::AppData)); + RegularQuicPacketBuilder regularBuilder( + conn.udpSendPacketLen, + std::move(header), + conn.ackStates.appDataAckState.largestAckedByPeer.value_or(0)); + + // Create few frames + ConnectionCloseFrame connCloseFrame( + QuicErrorCode(TransportErrorCode::FRAME_ENCODING_ERROR), + "The sun is in the sky."); + MaxStreamsFrame maxStreamFrame(999, true); + PingFrame pingFrame; + AckBlocks ackBlocks; + ackBlocks.insert(10, 100); + ackBlocks.insert(200, 1000); + AckFrameMetaData ackMeta(ackBlocks, 0us, kDefaultAckDelayExponent); + + // Write those framses with a regular builder + writeFrame(connCloseFrame, regularBuilder); + writeFrame(QuicSimpleFrame(maxStreamFrame), regularBuilder); + writeFrame(QuicSimpleFrame(pingFrame), regularBuilder); + writeAckFrame(ackMeta, regularBuilder); + + auto result = cloningScheduler.scheduleFramesForPacket( + std::move(regularBuilder), kDefaultUDPSendPacketLen); + EXPECT_TRUE(result.packetEvent.hasValue() && result.packet.hasValue()); + EXPECT_EQ(packetNum, result.packetEvent->packetNumber); + // written packet should not have any frame in the builder + auto& writtenPacket = *result.packet; + auto shortHeader = writtenPacket.packet.header.asShort(); + CHECK(shortHeader); + EXPECT_EQ(ProtectionType::KeyPhaseOne, shortHeader->getProtectionType()); + EXPECT_EQ( + conn.ackStates.appDataAckState.nextPacketNum, + shortHeader->getPacketSequenceNum()); + + // Test that the only frame that's written is maxdataframe + EXPECT_GE(writtenPacket.packet.frames.size(), 1); + auto& writtenFrame = writtenPacket.packet.frames.at(0); + auto maxDataFrame = writtenFrame.asMaxDataFrame(); + CHECK(maxDataFrame); + for (auto& frame : writtenPacket.packet.frames) { + bool present = false; + /* the next four frames should not be written */ + present |= frame.asConnectionCloseFrame() ? true : false; + present |= frame.asQuicSimpleFrame() ? true : false; + present |= frame.asQuicSimpleFrame() ? true : false; + present |= frame.asWriteAckFrame() ? true : false; + ASSERT_FALSE(present); + } } TEST_F(QuicPacketSchedulerTest, DoNotCloneProcessedClonedPacket) { @@ -467,7 +536,8 @@ TEST_F(QuicPacketSchedulerTest, DoNotCloneProcessedClonedPacket) { conn.outstandings.packets.back().packet.frames.push_back( MaxDataFrame(conn.flowControlState.advertisedMaxOffset)); addOutstandingPacket(conn); - conn.outstandings.packets.back().associatedEvent = 1; + conn.outstandings.packets.back().associatedEvent = + PacketEvent(PacketNumberSpace::AppData, 1); // There needs to have retransmittable frame for the rebuilder to work conn.outstandings.packets.back().packet.frames.push_back( MaxDataFrame(conn.flowControlState.advertisedMaxOffset)); @@ -483,10 +553,10 @@ TEST_F(QuicPacketSchedulerTest, DoNotCloneProcessedClonedPacket) { auto result = cloningScheduler.scheduleFramesForPacket( std::move(builder), kDefaultUDPSendPacketLen); EXPECT_TRUE(result.packetEvent.has_value() && result.packet.has_value()); - EXPECT_EQ(expected, *result.packetEvent); + EXPECT_EQ(expected, result.packetEvent->packetNumber); } -TEST_F(QuicPacketSchedulerTest, CloneSchedulerHasDataIgnoresNonAppData) { +TEST_F(QuicPacketSchedulerTest, CloneSchedulerHasHandshakeData) { QuicClientConnectionState conn( FizzClientQuicHandshakeContext::Builder().build()); FrameScheduler noopScheduler("frame"); @@ -494,9 +564,25 @@ TEST_F(QuicPacketSchedulerTest, CloneSchedulerHasDataIgnoresNonAppData) { EXPECT_FALSE(cloningScheduler.hasData()); addHandshakeOutstandingPacket(conn); + EXPECT_TRUE(cloningScheduler.hasData()); +} + +TEST_F(QuicPacketSchedulerTest, CloneSchedulerHasInitialData) { + QuicClientConnectionState conn( + FizzClientQuicHandshakeContext::Builder().build()); + FrameScheduler noopScheduler("frame"); + CloningScheduler cloningScheduler(noopScheduler, conn, "CopyCat", 0); EXPECT_FALSE(cloningScheduler.hasData()); addInitialOutstandingPacket(conn); + EXPECT_TRUE(cloningScheduler.hasData()); +} + +TEST_F(QuicPacketSchedulerTest, CloneSchedulerHasAppDataData) { + QuicClientConnectionState conn( + FizzClientQuicHandshakeContext::Builder().build()); + FrameScheduler noopScheduler("frame"); + CloningScheduler cloningScheduler(noopScheduler, conn, "CopyCat", 0); EXPECT_FALSE(cloningScheduler.hasData()); addOutstandingPacket(conn); @@ -528,7 +614,7 @@ TEST_F(QuicPacketSchedulerTest, DoNotCloneHandshake) { auto result = cloningScheduler.scheduleFramesForPacket( std::move(builder), kDefaultUDPSendPacketLen); EXPECT_TRUE(result.packetEvent.has_value() && result.packet.has_value()); - EXPECT_EQ(expected, *result.packetEvent); + EXPECT_EQ(expected, result.packetEvent->packetNumber); } TEST_F(QuicPacketSchedulerTest, CloneSchedulerUseNormalSchedulerFirst) { @@ -586,7 +672,8 @@ TEST_F(QuicPacketSchedulerTest, CloneWillGenerateNewWindowUpdate) { auto stream = conn.streamManager->createNextBidirectionalStream().value(); FrameScheduler noopScheduler("frame"); CloningScheduler cloningScheduler(noopScheduler, conn, "GiantsShoulder", 0); - auto expectedPacketEvent = addOutstandingPacket(conn); + PacketEvent expectedPacketEvent( + PacketNumberSpace::AppData, addOutstandingPacket(conn)); ASSERT_EQ(1, conn.outstandings.packets.size()); conn.outstandings.packets.back().packet.frames.push_back(MaxDataFrame(1000)); conn.outstandings.packets.back().packet.frames.push_back( @@ -691,7 +778,7 @@ TEST_F(QuicPacketSchedulerTest, CloningSchedulerWithInplaceBuilder) { auto result = cloningScheduler.scheduleFramesForPacket( std::move(builder), kDefaultUDPSendPacketLen); EXPECT_TRUE(result.packetEvent.has_value() && result.packet.has_value()); - EXPECT_EQ(packetNum, *result.packetEvent); + EXPECT_EQ(packetNum, result.packetEvent->packetNumber); // Something was written into the buffer: EXPECT_TRUE(bufAccessor.ownsBuffer()); @@ -764,7 +851,7 @@ TEST_F(QuicPacketSchedulerTest, CloningSchedulerWithInplaceBuilderFullPacket) { std::move(internalBuilder), conn.udpSendPacketLen); EXPECT_TRUE( cloneResult.packetEvent.has_value() && cloneResult.packet.has_value()); - EXPECT_EQ(packetNum, *cloneResult.packetEvent); + EXPECT_EQ(packetNum, cloneResult.packetEvent->packetNumber); // Something was written into the buffer: EXPECT_TRUE(bufAccessor.ownsBuffer()); diff --git a/quic/api/test/QuicTransportFunctionsTest.cpp b/quic/api/test/QuicTransportFunctionsTest.cpp index cb63f8cc3..8abf150a7 100644 --- a/quic/api/test/QuicTransportFunctionsTest.cpp +++ b/quic/api/test/QuicTransportFunctionsTest.cpp @@ -629,7 +629,8 @@ TEST_F(QuicTransportFunctionsTest, TestImplicitAck) { initialStream->writeBuffer.append(data->clone()); updateConnection( *conn, folly::none, packet.packet, TimePoint(), getEncodedSize(packet)); - EXPECT_EQ(1, conn->outstandings.handshakePacketsCount); + EXPECT_EQ(1, conn->outstandings.initialPacketsCount); + EXPECT_EQ(0, conn->outstandings.handshakePacketsCount); EXPECT_EQ(1, conn->outstandings.packets.size()); EXPECT_EQ(1, initialStream->retransmissionBuffer.size()); @@ -642,7 +643,8 @@ TEST_F(QuicTransportFunctionsTest, TestImplicitAck) { initialStream->writeBuffer.append(data->clone()); updateConnection( *conn, folly::none, packet.packet, TimePoint(), getEncodedSize(packet)); - EXPECT_EQ(2, conn->outstandings.handshakePacketsCount); + EXPECT_EQ(2, conn->outstandings.initialPacketsCount); + EXPECT_EQ(0, conn->outstandings.handshakePacketsCount); EXPECT_EQ(2, conn->outstandings.packets.size()); EXPECT_EQ(3, initialStream->retransmissionBuffer.size()); EXPECT_TRUE(initialStream->writeBuffer.empty()); @@ -654,7 +656,7 @@ TEST_F(QuicTransportFunctionsTest, TestImplicitAck) { initialStream->retransmissionBuffer.erase(0); initialStream->lossBuffer.emplace_back(std::move(firstBuf), 0, false); conn->outstandings.packets.pop_front(); - conn->outstandings.handshakePacketsCount--; + conn->outstandings.initialPacketsCount--; auto handshakeStream = getCryptoStream(*conn->cryptoState, EncryptionLevel::Handshake); @@ -666,7 +668,8 @@ TEST_F(QuicTransportFunctionsTest, TestImplicitAck) { handshakeStream->writeBuffer.append(data->clone()); updateConnection( *conn, folly::none, packet.packet, TimePoint(), getEncodedSize(packet)); - EXPECT_EQ(2, conn->outstandings.handshakePacketsCount); + EXPECT_EQ(1, conn->outstandings.initialPacketsCount); + EXPECT_EQ(1, conn->outstandings.handshakePacketsCount); EXPECT_EQ(2, conn->outstandings.packets.size()); EXPECT_EQ(1, handshakeStream->retransmissionBuffer.size()); @@ -676,7 +679,8 @@ TEST_F(QuicTransportFunctionsTest, TestImplicitAck) { handshakeStream->writeBuffer.append(data->clone()); updateConnection( *conn, folly::none, packet.packet, TimePoint(), getEncodedSize(packet)); - EXPECT_EQ(3, conn->outstandings.handshakePacketsCount); + EXPECT_EQ(1, conn->outstandings.initialPacketsCount); + EXPECT_EQ(2, conn->outstandings.handshakePacketsCount); EXPECT_EQ(3, conn->outstandings.packets.size()); EXPECT_EQ(2, handshakeStream->retransmissionBuffer.size()); EXPECT_TRUE(handshakeStream->writeBuffer.empty()); @@ -695,6 +699,7 @@ TEST_F(QuicTransportFunctionsTest, TestImplicitAck) { conn->outstandings.handshakePacketsCount--; implicitAckCryptoStream(*conn, EncryptionLevel::Initial); + EXPECT_EQ(0, conn->outstandings.initialPacketsCount); EXPECT_EQ(1, conn->outstandings.handshakePacketsCount); EXPECT_EQ(1, conn->outstandings.packets.size()); EXPECT_TRUE(initialStream->retransmissionBuffer.empty()); @@ -702,6 +707,7 @@ TEST_F(QuicTransportFunctionsTest, TestImplicitAck) { EXPECT_TRUE(initialStream->lossBuffer.empty()); implicitAckCryptoStream(*conn, EncryptionLevel::Handshake); + EXPECT_EQ(0, conn->outstandings.initialPacketsCount); EXPECT_EQ(0, conn->outstandings.handshakePacketsCount); EXPECT_TRUE(conn->outstandings.packets.empty()); EXPECT_TRUE(handshakeStream->retransmissionBuffer.empty()); @@ -960,7 +966,7 @@ TEST_F(QuicTransportFunctionsTest, TestUpdateConnectionWithCloneResult) { MaxDataFrame maxDataFrame(maxDataAmt); conn->pendingEvents.connWindowUpdate = true; writePacket.frames.push_back(std::move(maxDataFrame)); - PacketEvent event = 1; + PacketEvent event(PacketNumberSpace::AppData, 1); conn->outstandings.packetEvents.insert(event); auto futureMoment = thisMoment + 50ms; MockClock::mockNow = [=]() { return futureMoment; }; @@ -1960,7 +1966,7 @@ TEST_F(QuicTransportFunctionsTest, UpdateConnectionCloneCounter) { MaxDataFrame(conn->flowControlState.advertisedMaxOffset); conn->pendingEvents.connWindowUpdate = true; packet.packet.frames.emplace_back(connWindowUpdate); - PacketEvent packetEvent = 100; + PacketEvent packetEvent(PacketNumberSpace::AppData, 100); conn->outstandings.packetEvents.insert(packetEvent); updateConnection(*conn, packetEvent, packet.packet, TimePoint(), 123); EXPECT_EQ(1, conn->outstandings.clonedPacketsCount); @@ -1982,7 +1988,9 @@ TEST_F(QuicTransportFunctionsTest, ClearBlockedFromPendingEvents) { TEST_F(QuicTransportFunctionsTest, ClonedBlocked) { auto conn = createConn(); - auto packetEvent = conn->ackStates.appDataAckState.nextPacketNum; + PacketEvent packetEvent( + PacketNumberSpace::AppData, + conn->ackStates.appDataAckState.nextPacketNum); auto packet = buildEmptyPacket(*conn, PacketNumberSpace::AppData); auto stream = conn->streamManager->createNextBidirectionalStream().value(); StreamDataBlockedFrame blockedFrame(stream->id, 1000); @@ -2043,7 +2051,9 @@ TEST_F(QuicTransportFunctionsTest, ClearRstFromPendingEvents) { TEST_F(QuicTransportFunctionsTest, ClonedRst) { auto conn = createConn(); - auto packetEvent = conn->ackStates.appDataAckState.nextPacketNum; + PacketEvent packetEvent( + PacketNumberSpace::AppData, + conn->ackStates.appDataAckState.nextPacketNum); auto stream = conn->streamManager->createNextBidirectionalStream().value(); auto packet = buildEmptyPacket(*conn, PacketNumberSpace::AppData); RstStreamFrame rstStreamFrame( @@ -2073,7 +2083,7 @@ TEST_F(QuicTransportFunctionsTest, TimeoutBasedRetxCountUpdate) { RstStreamFrame rstStreamFrame( stream->id, GenericApplicationErrorCode::UNKNOWN, 0); packet.packet.frames.push_back(rstStreamFrame); - PacketEvent packetEvent = 100; + PacketEvent packetEvent(PacketNumberSpace::AppData, 100); conn->outstandings.packetEvents.insert(packetEvent); updateConnection(*conn, packetEvent, packet.packet, TimePoint(), 500); EXPECT_EQ(247, conn->lossState.timeoutBasedRtxCount); @@ -2408,20 +2418,16 @@ TEST_F(QuicTransportFunctionsTest, WriteProbingWithInplaceBuilder) { StreamFrameScheduler streamScheduler(*conn); ASSERT_FALSE(streamScheduler.hasPendingData()); - // The last packet may not be a full packet - auto lastPacketSize = conn->outstandings.packets.back().encodedSize; - size_t expectedOutstandingPacketsCount = 5; - if (lastPacketSize < conn->udpSendPacketLen) { - expectedOutstandingPacketsCount++; - } + // The first packet has be a full packet + auto firstPacketSize = conn->outstandings.packets.front().encodedSize; + auto outstandingPacketsCount = conn->outstandings.packets.size(); + ASSERT_EQ(firstPacketSize, conn->udpSendPacketLen); EXPECT_CALL(mockSock, write(_, _)) .Times(1) .WillOnce(Invoke([&](const folly::SocketAddress&, const std::unique_ptr& buf) { EXPECT_FALSE(buf->isChained()); - // If the last packet isn't full, it may have the stream length field - // but the clone won't have it. - EXPECT_LE(buf->length(), lastPacketSize); + EXPECT_EQ(buf->length(), firstPacketSize); return buf->length(); })); writeProbingDataToSocketForTest( @@ -2431,56 +2437,11 @@ TEST_F(QuicTransportFunctionsTest, WriteProbingWithInplaceBuilder) { *aead, *headerCipher, getVersion(*conn)); - EXPECT_EQ( - conn->outstandings.packets.size(), expectedOutstandingPacketsCount + 1); + EXPECT_EQ(conn->outstandings.packets.size(), outstandingPacketsCount + 1); EXPECT_EQ(0, bufPtr->length()); EXPECT_EQ(0, bufPtr->headroom()); // Clone again, this time 2 pacckets. - if (lastPacketSize < conn->udpSendPacketLen) { - EXPECT_CALL(mockSock, writeGSO(_, _, _)) - .Times(1) - .WillOnce(Invoke([&](const folly::SocketAddress&, - const std::unique_ptr& buf, - int gso) { - EXPECT_FALSE(buf->isChained()); - EXPECT_LE(gso, lastPacketSize); - EXPECT_LE(buf->length(), lastPacketSize * 2); - return buf->length(); - })); - } else { - EXPECT_CALL(mockSock, writeGSO(_, _, _)) - .Times(1) - .WillOnce(Invoke([&](const folly::SocketAddress&, - const std::unique_ptr& buf, - int gso) { - EXPECT_FALSE(buf->isChained()); - EXPECT_EQ(conn->udpSendPacketLen, gso); - EXPECT_EQ(buf->length(), conn->udpSendPacketLen * 4); - return buf->length(); - })); - } - writeProbingDataToSocketForTest( - mockSock, - *conn, - 2 /* probesToSend */, - *aead, - *headerCipher, - getVersion(*conn)); - EXPECT_EQ(0, bufPtr->length()); - EXPECT_EQ(0, bufPtr->headroom()); - EXPECT_EQ( - conn->outstandings.packets.size(), expectedOutstandingPacketsCount + 3); - - // Clear out all the small packets: - while (conn->outstandings.packets.back().encodedSize < - conn->udpSendPacketLen) { - conn->outstandings.packets.pop_back(); - } - ASSERT_FALSE(conn->outstandings.packets.empty()); - auto currentOutstandingPackets = conn->outstandings.packets.size(); - - // Clone 2 full size packets EXPECT_CALL(mockSock, writeGSO(_, _, _)) .Times(1) .WillOnce(Invoke([&](const folly::SocketAddress&, @@ -2498,9 +2459,9 @@ TEST_F(QuicTransportFunctionsTest, WriteProbingWithInplaceBuilder) { *aead, *headerCipher, getVersion(*conn)); - EXPECT_EQ(conn->outstandings.packets.size(), currentOutstandingPackets + 2); EXPECT_EQ(0, bufPtr->length()); EXPECT_EQ(0, bufPtr->headroom()); + EXPECT_EQ(conn->outstandings.packets.size(), outstandingPacketsCount + 3); } } // namespace test diff --git a/quic/api/test/QuicTransportTest.cpp b/quic/api/test/QuicTransportTest.cpp index e6f10281a..14ca6b7c2 100644 --- a/quic/api/test/QuicTransportTest.cpp +++ b/quic/api/test/QuicTransportTest.cpp @@ -1037,14 +1037,13 @@ TEST_F(QuicTransportTest, ClonePathChallenge) { // Force a timeout with no data so that it clones the packet transport_->lossTimeout().timeoutExpired(); - // On PTO, endpoint sends 2 probing packets, thus 1+2=3 - EXPECT_EQ(conn.outstandings.packets.size(), 3); + EXPECT_EQ(conn.outstandings.packets.size(), 2); numPathChallengePackets = std::count_if( conn.outstandings.packets.begin(), conn.outstandings.packets.end(), findFrameInPacketFunc()); - EXPECT_EQ(numPathChallengePackets, 3); + EXPECT_EQ(numPathChallengePackets, 2); } TEST_F(QuicTransportTest, OnlyClonePathValidationIfOutstanding) { @@ -1275,6 +1274,7 @@ TEST_F(QuicTransportTest, SendNewConnectionIdFrame) { TEST_F(QuicTransportTest, CloneNewConnectionIdFrame) { auto& conn = transport_->getConnectionState(); // knock every handshake outstanding packets out + conn.outstandings.initialPacketsCount = 0; conn.outstandings.handshakePacketsCount = 0; conn.outstandings.packets.clear(); for (auto& t : conn.lossState.lossTimes) { @@ -1296,13 +1296,12 @@ TEST_F(QuicTransportTest, CloneNewConnectionIdFrame) { // Force a timeout with no data so that it clones the packet transport_->lossTimeout().timeoutExpired(); - // On PTO, endpoint sends 2 probing packets, thus 1+2=3 - EXPECT_EQ(conn.outstandings.packets.size(), 3); + EXPECT_EQ(conn.outstandings.packets.size(), 2); numNewConnIdPackets = std::count_if( conn.outstandings.packets.begin(), conn.outstandings.packets.end(), findFrameInPacketFunc()); - EXPECT_EQ(numNewConnIdPackets, 3); + EXPECT_EQ(numNewConnIdPackets, 2); } TEST_F(QuicTransportTest, BusyWriteLoopDetection) { @@ -1415,6 +1414,7 @@ TEST_F(QuicTransportTest, SendRetireConnectionIdFrame) { TEST_F(QuicTransportTest, CloneRetireConnectionIdFrame) { auto& conn = transport_->getConnectionState(); // knock every handshake outstanding packets out + conn.outstandings.initialPacketsCount = 0; conn.outstandings.handshakePacketsCount = 0; conn.outstandings.packets.clear(); for (auto& t : conn.lossState.lossTimes) { @@ -1436,14 +1436,13 @@ TEST_F(QuicTransportTest, CloneRetireConnectionIdFrame) { // Force a timeout with no data so that it clones the packet transport_->lossTimeout().timeoutExpired(); - // On PTO, endpoint sends 2 probing packets, thus 1+2=3 - EXPECT_EQ(conn.outstandings.packets.size(), 3); + EXPECT_EQ(conn.outstandings.packets.size(), 2); numRetireConnIdPackets = std::count_if( conn.outstandings.packets.begin(), conn.outstandings.packets.end(), findFrameInPacketFunc< QuicSimpleFrame::Type::RetireConnectionIdFrame_E>()); - EXPECT_EQ(numRetireConnIdPackets, 3); + EXPECT_EQ(numRetireConnIdPackets, 2); } TEST_F(QuicTransportTest, ResendRetireConnectionIdOnLoss) { diff --git a/quic/client/QuicClientTransport.cpp b/quic/client/QuicClientTransport.cpp index cbd87957c..523c79458 100644 --- a/quic/client/QuicClientTransport.cpp +++ b/quic/client/QuicClientTransport.cpp @@ -729,11 +729,14 @@ void QuicClientTransport::writeData() { ? conn_->pacer->updateAndGetWriteBatchSize(Clock::now()) : conn_->transportSettings.writeConnectionDataPacketsLimit); if (conn_->initialWriteCipher) { - CryptoStreamScheduler initialScheduler( - *conn_, - *getCryptoStream(*conn_->cryptoState, EncryptionLevel::Initial)); + auto& initialCryptoStream = + *getCryptoStream(*conn_->cryptoState, EncryptionLevel::Initial); + CryptoStreamScheduler initialScheduler(*conn_, initialCryptoStream); - if (initialScheduler.hasData() || + if ((initialCryptoStream.retransmissionBuffer.size() && + conn_->outstandings.initialPacketsCount && + conn_->pendingEvents.numProbePackets) || + initialScheduler.hasData() || (conn_->ackStates.initialAckState.needsToSendAckImmediately && hasAcksToSchedule(conn_->ackStates.initialAckState))) { CHECK(conn_->initialHeaderCipher); @@ -749,15 +752,18 @@ void QuicClientTransport::writeData() { packetLimit, clientConn_->retryToken); } - if (!packetLimit) { + if (!packetLimit && !conn_->pendingEvents.numProbePackets) { return; } } if (conn_->handshakeWriteCipher) { - CryptoStreamScheduler handshakeScheduler( - *conn_, - *getCryptoStream(*conn_->cryptoState, EncryptionLevel::Handshake)); - if (handshakeScheduler.hasData() || + auto& handshakeCryptoStream = + *getCryptoStream(*conn_->cryptoState, EncryptionLevel::Handshake); + CryptoStreamScheduler handshakeScheduler(*conn_, handshakeCryptoStream); + if ((conn_->outstandings.handshakePacketsCount && + handshakeCryptoStream.retransmissionBuffer.size() && + conn_->pendingEvents.numProbePackets) || + handshakeScheduler.hasData() || (conn_->ackStates.handshakeAckState.needsToSendAckImmediately && hasAcksToSchedule(conn_->ackStates.handshakeAckState))) { CHECK(conn_->handshakeWriteHeaderCipher); @@ -772,7 +778,7 @@ void QuicClientTransport::writeData() { version, packetLimit); } - if (!packetLimit) { + if (!packetLimit && !conn_->pendingEvents.numProbePackets) { return; } } @@ -788,7 +794,7 @@ void QuicClientTransport::writeData() { version, packetLimit); } - if (!packetLimit) { + if (!packetLimit && !conn_->pendingEvents.numProbePackets) { return; } if (conn_->oneRttWriteCipher) { diff --git a/quic/codec/QuicPacketRebuilder.cpp b/quic/codec/QuicPacketRebuilder.cpp index 33c3f5fa1..43ef75d05 100644 --- a/quic/codec/QuicPacketRebuilder.cpp +++ b/quic/codec/QuicPacketRebuilder.cpp @@ -31,9 +31,11 @@ PacketEvent PacketRebuilder::cloneOutstandingPacket(OutstandingPacket& packet) { conn_.outstandings.packetEvents.count(*packet.associatedEvent)); if (!packet.associatedEvent) { auto packetNum = packet.packet.header.getPacketSequenceNum(); - DCHECK(!conn_.outstandings.packetEvents.count(packetNum)); - packet.associatedEvent = packetNum; - conn_.outstandings.packetEvents.insert(packetNum); + auto packetNumberSpace = packet.packet.header.getPacketNumberSpace(); + PacketEvent event(packetNumberSpace, packetNum); + DCHECK(!conn_.outstandings.packetEvents.count(event)); + packet.associatedEvent = event; + conn_.outstandings.packetEvents.insert(event); ++conn_.outstandings.clonedPacketsCount; } return *packet.associatedEvent; @@ -44,11 +46,12 @@ folly::Optional PacketRebuilder::rebuildFromPacket( // TODO: if PMTU changes between the transmission of the original packet and // now, then we cannot clone everything in the packet. - // TODO: make sure this cannot be called on handshake packets. bool writeSuccess = false; bool windowUpdateWritten = false; bool shouldWriteWindowUpdate = false; bool notPureAck = false; + auto encryptionLevel = + protectionTypeToEncryptionLevel(packet.packet.header.getProtectionType()); for (auto iter = packet.packet.frames.cbegin(); iter != packet.packet.frames.cend(); iter++) { @@ -109,15 +112,8 @@ folly::Optional PacketRebuilder::rebuildFromPacket( } case QuicWriteFrame::Type::WriteCryptoFrame_E: { const WriteCryptoFrame& cryptoFrame = *frame.asWriteCryptoFrame(); - // initialStream and handshakeStream can only be in handshake packet, - // so they are not clonable - CHECK(!packet.isHandshake); - // key update not supported - DCHECK( - packet.packet.header.getProtectionType() == - ProtectionType::KeyPhaseZero); - auto& stream = conn_.cryptoState->oneRttStream; - auto buf = cloneCryptoRetransmissionBuffer(cryptoFrame, stream); + auto stream = getCryptoStream(*conn_.cryptoState, encryptionLevel); + auto buf = cloneCryptoRetransmissionBuffer(cryptoFrame, *stream); // No crypto data found to be cloned, just skip if (!buf) { diff --git a/quic/fizz/client/test/QuicClientTransportTest.cpp b/quic/fizz/client/test/QuicClientTransportTest.cpp index 74ef3e92f..f9f75151d 100644 --- a/quic/fizz/client/test/QuicClientTransportTest.cpp +++ b/quic/fizz/client/test/QuicClientTransportTest.cpp @@ -2411,8 +2411,8 @@ class QuicClientTransportHappyEyeballsTest : public QuicClientTransportTest { EXPECT_TRUE(conn.happyEyeballsState.shouldWriteToFirstSocket); EXPECT_TRUE(conn.happyEyeballsState.shouldWriteToSecondSocket); - EXPECT_CALL(*sock, write(firstAddress, _)); - EXPECT_CALL(*secondSock, write(secondAddress, _)); + EXPECT_CALL(*sock, write(firstAddress, _)).Times(2); + EXPECT_CALL(*secondSock, write(secondAddress, _)).Times(2); client->lossTimeout().cancelTimeout(); client->lossTimeout().timeoutExpired(); } @@ -2452,7 +2452,7 @@ class QuicClientTransportHappyEyeballsTest : public QuicClientTransportTest { EXPECT_TRUE(conn.happyEyeballsState.shouldWriteToSecondSocket); EXPECT_CALL(*sock, write(_, _)).Times(0); - EXPECT_CALL(*secondSock, write(secondAddress, _)); + EXPECT_CALL(*secondSock, write(secondAddress, _)).Times(2); client->lossTimeout().cancelTimeout(); client->lossTimeout().timeoutExpired(); } @@ -2487,8 +2487,8 @@ class QuicClientTransportHappyEyeballsTest : public QuicClientTransportTest { EXPECT_TRUE(conn.happyEyeballsState.shouldWriteToFirstSocket); EXPECT_TRUE(conn.happyEyeballsState.shouldWriteToSecondSocket); - EXPECT_CALL(*sock, write(firstAddress, _)); - EXPECT_CALL(*secondSock, write(secondAddress, _)); + EXPECT_CALL(*sock, write(firstAddress, _)).Times(2); + EXPECT_CALL(*secondSock, write(secondAddress, _)).Times(2); client->lossTimeout().cancelTimeout(); client->lossTimeout().timeoutExpired(); } @@ -2527,7 +2527,7 @@ class QuicClientTransportHappyEyeballsTest : public QuicClientTransportTest { EXPECT_TRUE(conn.happyEyeballsState.shouldWriteToFirstSocket); EXPECT_FALSE(conn.happyEyeballsState.shouldWriteToSecondSocket); - EXPECT_CALL(*sock, write(firstAddress, _)); + EXPECT_CALL(*sock, write(firstAddress, _)).Times(2); EXPECT_CALL(*secondSock, write(_, _)).Times(0); client->lossTimeout().cancelTimeout(); client->lossTimeout().timeoutExpired(); @@ -2564,8 +2564,8 @@ class QuicClientTransportHappyEyeballsTest : public QuicClientTransportTest { EXPECT_TRUE(conn.happyEyeballsState.shouldWriteToFirstSocket); EXPECT_TRUE(conn.happyEyeballsState.shouldWriteToSecondSocket); - EXPECT_CALL(*sock, write(firstAddress, _)); - EXPECT_CALL(*secondSock, write(secondAddress, _)); + EXPECT_CALL(*sock, write(firstAddress, _)).Times(2); + EXPECT_CALL(*secondSock, write(secondAddress, _)).Times(2); client->lossTimeout().cancelTimeout(); client->lossTimeout().timeoutExpired(); } @@ -5174,23 +5174,6 @@ TEST_F(QuicClientTransportAfterStartTest, SetCongestionControlBbr) { EXPECT_TRUE(isConnectionPaced(client->getConn())); } -TEST_F( - QuicClientTransportAfterStartTest, - TestOneRttPacketWillNotRescheduleHandshakeAlarm) { - EXPECT_TRUE(client->lossTimeout().isScheduled()); - auto timeRemaining1 = client->lossTimeout().getTimeRemaining(); - - auto sleepAmountMillis = 10; - usleep(sleepAmountMillis * 1000); - auto streamId = client->createBidirectionalStream().value(); - client->writeChain(streamId, IOBuf::copyBuffer("hello"), true, false); - loopForWrites(); - - EXPECT_TRUE(client->lossTimeout().isScheduled()); - auto timeRemaining2 = client->lossTimeout().getTimeRemaining(); - EXPECT_GE(timeRemaining1.count() - timeRemaining2.count(), sleepAmountMillis); -} - TEST_F(QuicClientTransportAfterStartTest, PingIsRetransmittable) { PingFrame pingFrame; ShortHeader header( @@ -5566,72 +5549,6 @@ TEST_F(QuicZeroRttClientTest, TestZeroRttRejectionWithSmallerFlowControl) { EXPECT_THROW(recvServerHello(), std::runtime_error); } -TEST_F( - QuicZeroRttClientTest, - TestZeroRttPacketWillNotRescheduleHandshakeAlarm) { - EXPECT_CALL(*mockQuicPskCache_, getPsk(hostname_)) - .WillOnce(InvokeWithoutArgs([]() { - QuicCachedPsk quicCachedPsk; - quicCachedPsk.transportParams.initialMaxStreamDataBidiLocal = - kDefaultStreamWindowSize; - quicCachedPsk.transportParams.initialMaxStreamDataBidiRemote = - kDefaultStreamWindowSize; - quicCachedPsk.transportParams.initialMaxStreamDataUni = - kDefaultStreamWindowSize; - quicCachedPsk.transportParams.initialMaxData = - kDefaultConnectionWindowSize; - quicCachedPsk.transportParams.idleTimeout = kDefaultIdleTimeout.count(); - quicCachedPsk.transportParams.maxRecvPacketSize = - kDefaultUDPReadBufferSize; - quicCachedPsk.transportParams.initialMaxStreamsBidi = - std::numeric_limits::max(); - quicCachedPsk.transportParams.initialMaxStreamsUni = - std::numeric_limits::max(); - return quicCachedPsk; - })); - bool performedValidation = false; - client->setEarlyDataAppParamsFunctions( - [&](const folly::Optional&, const Buf&) { - performedValidation = true; - return true; - }, - []() -> Buf { return nullptr; }); - startClient(); - EXPECT_TRUE(performedValidation); - - EXPECT_TRUE(client->lossTimeout().isScheduled()); - auto timeRemaining1 = client->lossTimeout().getTimeRemaining(); - - auto initialUDPSendPacketLen = client->getConn().udpSendPacketLen; - socketWrites.clear(); - - auto sleepAmountMillis = 10; - usleep(sleepAmountMillis * 1000); - auto streamId = client->createBidirectionalStream().value(); - client->writeChain(streamId, IOBuf::copyBuffer("hello"), true, false); - loopForWrites(); - - EXPECT_TRUE(client->lossTimeout().isScheduled()); - auto timeRemaining2 = client->lossTimeout().getTimeRemaining(); - EXPECT_GE(timeRemaining1.count() - timeRemaining2.count(), sleepAmountMillis); - - EXPECT_TRUE(zeroRttPacketsOutstanding()); - mockClientHandshake->setZeroRttRejected(false); - assertWritten(false, LongHeader::Types::ZeroRtt); - EXPECT_CALL(clientConnCallback, onReplaySafe()); - recvServerHello(); - - // All the data is still there. - EXPECT_TRUE(zeroRttPacketsOutstanding()); - // Transport parameters did not change since zero rtt was accepted. - verifyTransportParameters( - kDefaultConnectionWindowSize, - kDefaultStreamWindowSize, - kDefaultIdleTimeout, - kDefaultAckDelayExponent, - initialUDPSendPacketLen); -} - class QuicZeroRttHappyEyeballsClientTransportTest : public QuicZeroRttClientTest { public: diff --git a/quic/loss/QuicLossFunctions.cpp b/quic/loss/QuicLossFunctions.cpp index dcfbf7835..5373aa0e5 100644 --- a/quic/loss/QuicLossFunctions.cpp +++ b/quic/loss/QuicLossFunctions.cpp @@ -52,7 +52,12 @@ void onPTOAlarm(QuicConnectionStateBase& conn) { if (conn.lossState.ptoCount == conn.transportSettings.maxNumPTOs) { throw QuicInternalException("Exceeded max PTO", LocalErrorCode::NO_ERROR); } - conn.pendingEvents.numProbePackets = kPacketToSendForPTO; + + // If there is only one packet outstanding, no point to clone it twice in the + // same write loop. + conn.pendingEvents.numProbePackets = + std::min( + conn.outstandings.packets.size(), kPacketToSendForPTO); } void markPacketLoss( diff --git a/quic/loss/QuicLossFunctions.h b/quic/loss/QuicLossFunctions.h index a01644423..7eee4483f 100644 --- a/quic/loss/QuicLossFunctions.h +++ b/quic/loss/QuicLossFunctions.h @@ -46,9 +46,6 @@ inline std::ostream& operator<<( std::ostream& os, const LossState::AlarmMethod& alarmMethod) { switch (alarmMethod) { - case LossState::AlarmMethod::Handshake: - os << "Handshake"; - break; case LossState::AlarmMethod::EarlyRetransmitOrReordering: os << "EarlyRetransmitOrReordering"; break; @@ -77,19 +74,6 @@ calculateAlarmDuration(const QuicConnectionStateBase& conn) { alarmDuration = 0us; } alarmMethod = LossState::AlarmMethod::EarlyRetransmitOrReordering; - } else if (conn.outstandings.handshakePacketsCount > 0) { - if (conn.lossState.srtt == 0us) { - alarmDuration = conn.transportSettings.initialRtt * 2; - } else { - alarmDuration = conn.lossState.srtt * 2; - } - alarmDuration += conn.lossState.maxAckDelay; - alarmDuration *= - 1ULL << std::min(conn.lossState.handshakeAlarmCount, (uint16_t)15); - alarmMethod = LossState::AlarmMethod::Handshake; - // Handshake packet loss timer shouldn't be affected by other packets. - lastSentPacketTime = conn.lossState.lastHandshakePacketSentTime; - DCHECK_NE(lastSentPacketTime.time_since_epoch().count(), 0); } else { auto ptoTimeout = calculatePTO(conn); ptoTimeout *= 1ULL << std::min(conn.lossState.ptoCount, (uint32_t)31); @@ -165,6 +149,7 @@ void setLossDetectionAlarm(QuicConnectionStateBase& conn, Timeout& timeout) { VLOG_IF(10, !timeout.isLossTimeoutScheduled()) << __func__ << " alarm not scheduled" << " outstanding=" << totalPacketsOutstanding + << " initialPackets=" << conn.outstandings.initialPacketsCount << " handshakePackets=" << conn.outstandings.handshakePacketsCount << " " << nodeToString(conn.nodeType) << " " << conn; return; @@ -175,7 +160,11 @@ void setLossDetectionAlarm(QuicConnectionStateBase& conn, Timeout& timeout) { VLOG(10) << __func__ << " setting transmission" << " alarm=" << alarmDuration.first.count() << "ms" << " method=" << conn.lossState.currentAlarmMethod + << " haDataToWrite=" << hasDataToWrite << " outstanding=" << totalPacketsOutstanding + << " outstanding clone=" << conn.outstandings.clonedPacketsCount + << " packetEvents=" << conn.outstandings.packetEvents.size() + << " initialPackets=" << conn.outstandings.initialPacketsCount << " handshakePackets=" << conn.outstandings.handshakePacketsCount << " " << nodeToString(conn.nodeType) << " " << conn; timeout.scheduleLossTimeout(alarmDuration.first); @@ -240,9 +229,15 @@ folly::Optional detectLossPackets( if (pkt.associatedEvent) { conn.outstandings.packetEvents.erase(*pkt.associatedEvent); } - if (pkt.isHandshake) { - DCHECK(conn.outstandings.handshakePacketsCount); - --conn.outstandings.handshakePacketsCount; + if (pkt.isHandshake && !processed) { + if (currentPacketNumberSpace == PacketNumberSpace::Initial) { + CHECK(conn.outstandings.initialPacketsCount); + --conn.outstandings.initialPacketsCount; + } else { + CHECK_EQ(PacketNumberSpace::Handshake, currentPacketNumberSpace); + CHECK(conn.outstandings.handshakePacketsCount); + --conn.outstandings.handshakePacketsCount; + } } VLOG(10) << __func__ << " lost packetNum=" << currentPacketNum << " handshake=" << pkt.isHandshake << " " << conn; @@ -291,63 +286,6 @@ folly::Optional detectLossPackets( void onPTOAlarm(QuicConnectionStateBase& conn); -template -void onHandshakeAlarm( - QuicConnectionStateBase& conn, - const LossVisitor& lossVisitor) { - // TODO: This code marks all outstanding handshake packets as loss. - // Alternatively we can experiment with only retransmit them without marking - // loss - VLOG(10) << __func__ << " " << conn; - ++conn.lossState.ptoCount; - ++conn.lossState.totalPTOCount; - ++conn.lossState.handshakeAlarmCount; - QUIC_STATS(conn.statsCallback, onPTO); - QUIC_TRACE( - handshake_alarm, - conn, - conn.lossState.largestSent.value_or(0), - conn.lossState.handshakeAlarmCount, - (uint64_t)conn.outstandings.handshakePacketsCount, - (uint64_t)conn.outstandings.packets.size()); - if (conn.qLogger) { - conn.qLogger->addLossAlarm( - conn.lossState.largestSent.value_or(0), - conn.lossState.handshakeAlarmCount, - (uint64_t)conn.outstandings.packets.size(), - kHandshakeAlarm); - } - CongestionController::LossEvent lossEvent(ClockType::now()); - auto iter = conn.outstandings.packets.begin(); - while (iter != conn.outstandings.packets.end()) { - // the word "handshake" in our code base is unfortunately overloaded. - if (iter->isHandshake) { - auto& packet = *iter; - auto currentPacketNum = packet.packet.header.getPacketSequenceNum(); - auto currentPacketNumSpace = packet.packet.header.getPacketNumberSpace(); - VLOG(10) << "HandshakeAlarm, removing packetNum=" << currentPacketNum - << " packetNumSpace=" << currentPacketNumSpace << " " << conn; - lossEvent.addLostPacket(std::move(packet)); - lossVisitor(conn, packet.packet, false, currentPacketNum); - DCHECK(conn.outstandings.handshakePacketsCount); - --conn.outstandings.handshakePacketsCount; - ++conn.lossState.timeoutBasedRtxCount; - ++conn.lossState.rtxCount; - iter = conn.outstandings.packets.erase(iter); - } else { - iter++; - } - } - if (conn.congestionController && lossEvent.largestLostPacketNum.hasValue()) { - conn.congestionController->onRemoveBytesFromInflight(lossEvent.lostBytes); - } - if (conn.nodeType == QuicNodeType::Client && conn.oneRttWriteCipher) { - // When sending client finished, we should also send a 1-rtt probe packet to - // elicit an ack. - conn.pendingEvents.numProbePackets = kPacketToSendForPTO; - } -} - /* * Function invoked when loss detection timer fires */ @@ -379,9 +317,6 @@ void onLossDetectionAlarm( conn.congestionController->onPacketAckOrLoss( folly::none, std::move(lossEvent)); } - } else if ( - conn.lossState.currentAlarmMethod == LossState::AlarmMethod::Handshake) { - onHandshakeAlarm(conn, lossVisitor); } else { onPTOAlarm(conn); } @@ -389,6 +324,7 @@ void onLossDetectionAlarm( VLOG(10) << __func__ << " setLossDetectionAlarm=" << conn.pendingEvents.setLossDetectionAlarm << " outstanding=" << conn.outstandings.packets.size() + << " initialPackets=" << conn.outstandings.initialPacketsCount << " handshakePackets=" << conn.outstandings.handshakePacketsCount << " " << conn; } @@ -415,7 +351,6 @@ folly::Optional handleAckForLoss( // TODO: Should we NOT reset these counters if the received Ack frame // doesn't ack anything that's in OP list? conn.lossState.ptoCount = 0; - conn.lossState.handshakeAlarmCount = 0; largestAcked = std::max( largestAcked.value_or(*ack.largestAckedPacket), *ack.largestAckedPacket); @@ -432,6 +367,7 @@ folly::Optional handleAckForLoss( << " setLossDetectionAlarm=" << conn.pendingEvents.setLossDetectionAlarm << " outstanding=" << conn.outstandings.packets.size() + << " initialPackets=" << conn.outstandings.initialPacketsCount << " handshakePackets=" << conn.outstandings.handshakePacketsCount << " " << conn; return lossEvent; diff --git a/quic/loss/test/QuicLossFunctionsTest.cpp b/quic/loss/test/QuicLossFunctionsTest.cpp index f9a09b0a6..26aa73d55 100644 --- a/quic/loss/test/QuicLossFunctionsTest.cpp +++ b/quic/loss/test/QuicLossFunctionsTest.cpp @@ -38,6 +38,7 @@ class MockLossTimeout { }; enum class PacketType { + Initial, Handshake, ZeroRtt, OneRtt, @@ -138,6 +139,16 @@ PacketNum QuicLossFunctionsTest::sendPacket( folly::Optional header; bool isHandshake = false; switch (packetType) { + case PacketType::Initial: + header = LongHeader( + LongHeader::Types::Initial, + *conn.clientConnectionId, + *conn.serverConnectionId, + conn.ackStates.initialAckState.nextPacketNum, + *conn.version); + conn.outstandings.initialPacketsCount++; + isHandshake = true; + break; case PacketType::Handshake: header = LongHeader( LongHeader::Types::Handshake, @@ -145,6 +156,7 @@ PacketNum QuicLossFunctionsTest::sendPacket( *conn.serverConnectionId, conn.ackStates.handshakeAckState.nextPacketNum, *conn.version); + conn.outstandings.handshakePacketsCount++; isHandshake = true; break; case PacketType::ZeroRtt: @@ -187,7 +199,6 @@ PacketNum QuicLossFunctionsTest::sendPacket( packet.packet, time, encodedSize, isHandshake, encodedSize); outstandingPacket.associatedEvent = associatedEvent; if (isHandshake) { - conn.outstandings.handshakePacketsCount++; conn.lossState.lastHandshakePacketSentTime = time; } conn.lossState.lastRetransmittablePacketSentTime = time; @@ -202,7 +213,9 @@ PacketNum QuicLossFunctionsTest::sendPacket( conn.outstandings.packets.end(), [&associatedEvent](const auto& packet) { auto packetNum = packet.packet.header.getPacketSequenceNum(); - return packetNum == *associatedEvent; + auto packetNumSpace = packet.packet.header.getPacketNumberSpace(); + return packetNum == associatedEvent->packetNumber && + packetNumSpace == associatedEvent->packetNumberSpace; }); if (it != conn.outstandings.packets.end()) { if (!it->associatedEvent) { @@ -222,12 +235,18 @@ PacketNum QuicLossFunctionsTest::sendPacket( TEST_F(QuicLossFunctionsTest, AllPacketsProcessed) { auto conn = createConn(); EXPECT_CALL(*transportInfoCb_, onPTO()).Times(0); - auto pkt1 = conn->ackStates.appDataAckState.nextPacketNum; - sendPacket(*conn, Clock::now(), pkt1, PacketType::OneRtt); - auto pkt2 = conn->ackStates.appDataAckState.nextPacketNum; - sendPacket(*conn, Clock::now(), pkt2, PacketType::OneRtt); - auto pkt3 = conn->ackStates.appDataAckState.nextPacketNum; - sendPacket(*conn, Clock::now(), pkt3, PacketType::OneRtt); + PacketEvent packetEvent1( + PacketNumberSpace::AppData, + conn->ackStates.appDataAckState.nextPacketNum); + sendPacket(*conn, Clock::now(), packetEvent1, PacketType::OneRtt); + PacketEvent packetEvent2( + PacketNumberSpace::AppData, + conn->ackStates.appDataAckState.nextPacketNum); + sendPacket(*conn, Clock::now(), packetEvent2, PacketType::OneRtt); + PacketEvent packetEvent3( + PacketNumberSpace::AppData, + conn->ackStates.appDataAckState.nextPacketNum); + sendPacket(*conn, Clock::now(), packetEvent3, PacketType::OneRtt); EXPECT_CALL(timeout, cancelLossTimeout()).Times(1); setLossDetectionAlarm(*conn, timeout); EXPECT_FALSE(conn->pendingEvents.setLossDetectionAlarm); @@ -292,7 +311,8 @@ TEST_F(QuicLossFunctionsTest, TestOnPTOSkipProcessed) { // By adding an associatedEvent that doesn't exist in the // outstandings.packetEvents, they are all processed and will skip lossVisitor for (auto i = 0; i < 10; i++) { - sendPacket(*conn, TimePoint(), i, PacketType::OneRtt); + PacketEvent packetEvent(PacketNumberSpace::AppData, i); + sendPacket(*conn, TimePoint(), packetEvent, PacketType::OneRtt); } EXPECT_EQ(10, conn->outstandings.packets.size()); std::vector lostPackets; @@ -707,7 +727,6 @@ TEST_F(QuicLossFunctionsTest, TestHandleAckedPacket) { auto mockQLogger = std::make_shared(VantagePoint::Server); conn->qLogger = mockQLogger; conn->lossState.ptoCount = 10; - conn->lossState.handshakeAlarmCount = 5; conn->lossState.reorderingThreshold = 10; sendPacket(*conn, TimePoint(), folly::none, PacketType::OneRtt); @@ -735,7 +754,6 @@ TEST_F(QuicLossFunctionsTest, TestHandleAckedPacket) { Clock::now()); EXPECT_EQ(0, conn->lossState.ptoCount); - EXPECT_EQ(0, conn->lossState.handshakeAlarmCount); EXPECT_TRUE(conn->outstandings.packets.empty()); EXPECT_FALSE(conn->pendingEvents.setLossDetectionAlarm); EXPECT_FALSE(testLossMarkFuncCalled); @@ -980,9 +998,7 @@ TEST_F(QuicLossFunctionsTest, PTONoLongerMarksPacketsToBeRetransmitted) { EXPECT_TRUE(lostPackets.empty()); } -TEST_F( - QuicLossFunctionsTest, - WhenHandshakeOutstandingAlarmMarksAllHandshakeAsLoss) { +TEST_F(QuicLossFunctionsTest, PTOWithHandshakePackets) { auto conn = createConn(); auto mockQLogger = std::make_shared(VantagePoint::Server); conn->qLogger = mockQLogger; @@ -991,10 +1007,10 @@ TEST_F( conn->congestionController = std::move(mockCongestionController); EXPECT_CALL(*rawCongestionController, onPacketSent(_)) .WillRepeatedly(Return()); - EXPECT_CALL(*mockQLogger, addLossAlarm(5, 1, 10, kHandshakeAlarm)); + EXPECT_CALL(*mockQLogger, addLossAlarm(_, _, _, _)); std::vector lostPackets; PacketNum expectedLargestLostNum = 0; - conn->lossState.currentAlarmMethod = LossState::AlarmMethod::Handshake; + conn->lossState.currentAlarmMethod = LossState::AlarmMethod::PTO; for (auto i = 0; i < 10; i++) { // Half are handshakes auto sentPacketNum = sendPacket( @@ -1005,43 +1021,15 @@ TEST_F( expectedLargestLostNum = std::max( expectedLargestLostNum, i % 2 ? sentPacketNum : expectedLargestLostNum); } - uint64_t expectedLostBytes = std::accumulate( - conn->outstandings.packets.begin(), - conn->outstandings.packets.end(), - 0, - [](uint64_t num, const OutstandingPacket& packet) { - return packet.isHandshake ? num + packet.encodedSize : num; - }); - EXPECT_CALL( - *rawCongestionController, onRemoveBytesFromInflight(expectedLostBytes)) - .Times(1); + EXPECT_CALL(*transportInfoCb_, onPTO()); onLossDetectionAlarm( *conn, testingLossMarkFunc(lostPackets)); - // Half are lost - EXPECT_EQ(5, lostPackets.size()); - EXPECT_EQ(1, conn->lossState.handshakeAlarmCount); - EXPECT_EQ(5, conn->lossState.timeoutBasedRtxCount); - EXPECT_EQ(conn->pendingEvents.numProbePackets, 0); - EXPECT_EQ(5, conn->lossState.rtxCount); -} - -TEST_F(QuicLossFunctionsTest, HandshakeAlarmWithOneRttCipher) { - auto conn = createClientConn(); - auto mockQLogger = std::make_shared(VantagePoint::Client); - conn->qLogger = mockQLogger; - conn->oneRttWriteCipher = createNoOpAead(); - conn->lossState.currentAlarmMethod = LossState::AlarmMethod::Handshake; - std::vector lostPackets; - EXPECT_CALL(*mockQLogger, addLossAlarm(1, 1, 1, kHandshakeAlarm)); - sendPacket(*conn, TimePoint(100ms), folly::none, PacketType::Handshake); - onLossDetectionAlarm( - *conn, testingLossMarkFunc(lostPackets)); - - // Half should be marked as loss - EXPECT_EQ(lostPackets.size(), 1); - EXPECT_EQ(conn->lossState.handshakeAlarmCount, 1); + EXPECT_EQ(0, lostPackets.size()); + EXPECT_EQ(1, conn->lossState.ptoCount); + EXPECT_EQ(0, conn->lossState.timeoutBasedRtxCount); EXPECT_EQ(conn->pendingEvents.numProbePackets, kPacketToSendForPTO); + EXPECT_EQ(0, conn->lossState.rtxCount); } TEST_F(QuicLossFunctionsTest, EmptyOutstandingNoTimeout) { @@ -1050,38 +1038,6 @@ TEST_F(QuicLossFunctionsTest, EmptyOutstandingNoTimeout) { setLossDetectionAlarm(*conn, timeout); } -TEST_F(QuicLossFunctionsTest, AlarmDurationHandshakeOutstanding) { - auto conn = createConn(); - conn->lossState.maxAckDelay = 25ms; - TimePoint lastPacketSentTime = Clock::now(); - std::chrono::milliseconds packetSentDelay = 10ms; - auto thisMoment = lastPacketSentTime + packetSentDelay; - MockClock::mockNow = [=]() { return thisMoment; }; - sendPacket(*conn, lastPacketSentTime, folly::none, PacketType::Handshake); - - MockClock::mockNow = [=]() { return thisMoment; }; - auto duration = calculateAlarmDuration(*conn); - EXPECT_EQ( - conn->transportSettings.initialRtt * 2 - packetSentDelay + 25ms, - duration.first); - EXPECT_EQ(duration.second, LossState::AlarmMethod::Handshake); - - conn->lossState.srtt = 100ms; - duration = calculateAlarmDuration(*conn); - EXPECT_EQ( - std::chrono::duration_cast(225ms) - - packetSentDelay, - duration.first); - - conn->lossState.maxAckDelay = 45ms; - conn->lossState.handshakeAlarmCount = 2; - duration = calculateAlarmDuration(*conn); - EXPECT_EQ( - std::chrono::duration_cast(980ms) - - packetSentDelay, - duration.first); -} - TEST_F(QuicLossFunctionsTest, AlarmDurationHasLossTime) { auto conn = createConn(); TimePoint lastPacketSentTime = Clock::now(); @@ -1182,7 +1138,8 @@ TEST_F(QuicLossFunctionsTest, SkipLossVisitor) { PacketNum lastSent; for (size_t i = 0; i < 5; i++) { lastSent = conn->ackStates.appDataAckState.nextPacketNum; - sendPacket(*conn, Clock::now(), lastSent, PacketType::OneRtt); + PacketEvent packetEvent(PacketNumberSpace::AppData, lastSent); + sendPacket(*conn, Clock::now(), packetEvent, PacketType::OneRtt); } detectLossPackets( *conn, @@ -1210,7 +1167,7 @@ TEST_F(QuicLossFunctionsTest, NoDoubleProcess) { }; // Send 6 packets, so when we ack the last one, we mark the first two loss PacketNum lastSent; - PacketEvent event = 0; + PacketEvent event(PacketNumberSpace::AppData, 0); for (size_t i = 0; i < 6; i++) { lastSent = sendPacket(*conn, Clock::now(), event, PacketType::OneRtt); } @@ -1232,8 +1189,10 @@ TEST_F(QuicLossFunctionsTest, NoDoubleProcess) { TEST_F(QuicLossFunctionsTest, DetectPacketLossClonedPacketsCounter) { auto conn = createConn(); - auto packet1 = conn->ackStates.appDataAckState.nextPacketNum; - sendPacket(*conn, Clock::now(), packet1, PacketType::OneRtt); + PacketEvent packetEvent1( + PacketNumberSpace::AppData, + conn->ackStates.appDataAckState.nextPacketNum); + sendPacket(*conn, Clock::now(), packetEvent1, PacketType::OneRtt); sendPacket(*conn, Clock::now(), folly::none, PacketType::OneRtt); sendPacket(*conn, Clock::now(), folly::none, PacketType::OneRtt); sendPacket(*conn, Clock::now(), folly::none, PacketType::OneRtt); @@ -1321,21 +1280,6 @@ TEST_F(QuicLossFunctionsTest, TestTotalPTOCount) { EXPECT_EQ(101, conn->lossState.totalPTOCount); } -TEST_F(QuicLossFunctionsTest, HandshakeAlarmPTOCountingAndCallbacks) { - auto conn = createConn(); - auto mockQLogger = std::make_shared(VantagePoint::Server); - conn->qLogger = mockQLogger; - conn->lossState.ptoCount = 22; - conn->lossState.totalPTOCount = 100; - conn->lossState.handshakeAlarmCount = 3; - EXPECT_CALL(*mockQLogger, addLossAlarm(0, 4, 0, kHandshakeAlarm)); - EXPECT_CALL(*transportInfoCb_, onPTO()); - onHandshakeAlarm(*conn, [](const auto&, auto, bool, PacketNum) {}); - EXPECT_EQ(101, conn->lossState.totalPTOCount); - EXPECT_EQ(23, conn->lossState.ptoCount); - EXPECT_EQ(4, conn->lossState.handshakeAlarmCount); -} - TEST_F(QuicLossFunctionsTest, TestExceedsMaxPTOThrows) { auto conn = createConn(); auto mockQLogger = std::make_shared(VantagePoint::Server); @@ -1425,17 +1369,19 @@ TEST_F(QuicLossFunctionsTest, TestZeroRttRejectedWithClones) { // By adding an associatedEvent that doesn't exist in the // outstandings.packetEvents, they are all processed and will skip lossVisitor std::set zeroRttPackets; - folly::Optional lastPacket; + folly::Optional lastPacketEvent; for (auto i = 0; i < 2; i++) { - lastPacket = - sendPacket(*conn, TimePoint(), lastPacket, PacketType::ZeroRtt); - zeroRttPackets.emplace(*lastPacket); + auto packetNum = + sendPacket(*conn, TimePoint(), lastPacketEvent, PacketType::ZeroRtt); + lastPacketEvent = PacketEvent(PacketNumberSpace::AppData, packetNum); + zeroRttPackets.emplace(packetNum); } zeroRttPackets.emplace( sendPacket(*conn, TimePoint(), folly::none, PacketType::ZeroRtt)); for (auto zeroRttPacketNum : zeroRttPackets) { - lastPacket = - sendPacket(*conn, TimePoint(), zeroRttPacketNum, PacketType::OneRtt); + PacketEvent zeroRttPacketEvent( + PacketNumberSpace::AppData, zeroRttPacketNum); + sendPacket(*conn, TimePoint(), zeroRttPacketEvent, PacketType::OneRtt); } EXPECT_EQ(6, conn->outstandings.packets.size()); @@ -1504,15 +1450,56 @@ TEST_F(QuicLossFunctionsTest, TimeThreshold) { PacketNumberSpace::AppData); } +TEST_F(QuicLossFunctionsTest, OutstandingInitialCounting) { + auto conn = createConn(); + // Simplify the test by never triggering timer threshold + conn->lossState.srtt = 100s; + PacketNum largestSent = 0; + while (largestSent < 10) { + largestSent = + sendPacket(*conn, Clock::now(), folly::none, PacketType::Initial); + } + EXPECT_EQ(10, conn->outstandings.initialPacketsCount); + auto noopLossVisitor = [&](auto& /* conn */, + auto& /* packet */, + bool /* processed */, + PacketNum /* currentPacketNum */) {}; + detectLossPackets( + *conn, + largestSent, + noopLossVisitor, + TimePoint(100ms), + PacketNumberSpace::Initial); + // [1, 6] are removed, [7, 10] are still in OP list + EXPECT_EQ(4, conn->outstandings.initialPacketsCount); +} + +TEST_F(QuicLossFunctionsTest, OutstandingHandshakeCounting) { + auto conn = createConn(); + // Simplify the test by never triggering timer threshold + conn->lossState.srtt = 100s; + PacketNum largestSent = 0; + while (largestSent < 10) { + largestSent = + sendPacket(*conn, Clock::now(), folly::none, PacketType::Handshake); + } + EXPECT_EQ(10, conn->outstandings.handshakePacketsCount); + auto noopLossVisitor = [&](auto& /* conn */, + auto& /* packet */, + bool /* processed */, + PacketNum /* currentPacketNum */) {}; + detectLossPackets( + *conn, + largestSent, + noopLossVisitor, + TimePoint(100ms), + PacketNumberSpace::Handshake); + // [1, 6] are removed, [7, 10] are still in OP list + EXPECT_EQ(4, conn->outstandings.handshakePacketsCount); +} + TEST_P(QuicLossFunctionsTest, CappedShiftNoCrash) { auto conn = createConn(); - conn->lossState.handshakeAlarmCount = - std::numeric_limitslossState.handshakeAlarmCount)>::max(); - sendPacket(*conn, Clock::now(), folly::none, PacketType::Handshake); - ASSERT_GT(conn->outstandings.handshakePacketsCount, 0); - calculateAlarmDuration(*conn); - - conn->lossState.handshakeAlarmCount = 0; conn->outstandings.handshakePacketsCount = 0; conn->outstandings.packets.clear(); conn->lossState.ptoCount = diff --git a/quic/server/QuicServerTransport.cpp b/quic/server/QuicServerTransport.cpp index 0920b27c9..c2bf206ac 100644 --- a/quic/server/QuicServerTransport.cpp +++ b/quic/server/QuicServerTransport.cpp @@ -208,10 +208,13 @@ void QuicServerTransport::writeData() { ? conn_->pacer->updateAndGetWriteBatchSize(Clock::now()) : conn_->transportSettings.writeConnectionDataPacketsLimit); if (conn_->initialWriteCipher) { - CryptoStreamScheduler initialScheduler( - *conn_, - *getCryptoStream(*conn_->cryptoState, EncryptionLevel::Initial)); - if (initialScheduler.hasData() || + auto& initialCryptoStream = + *getCryptoStream(*conn_->cryptoState, EncryptionLevel::Initial); + CryptoStreamScheduler initialScheduler(*conn_, initialCryptoStream); + if ((conn_->pendingEvents.numProbePackets && + initialCryptoStream.retransmissionBuffer.size() && + conn_->outstandings.initialPacketsCount) || + initialScheduler.hasData() || (conn_->ackStates.initialAckState.needsToSendAckImmediately && hasAcksToSchedule(conn_->ackStates.initialAckState))) { CHECK(conn_->initialWriteCipher); @@ -227,15 +230,18 @@ void QuicServerTransport::writeData() { version, packetLimit); } - if (!packetLimit) { + if (!packetLimit && !conn_->pendingEvents.numProbePackets) { return; } } if (conn_->handshakeWriteCipher) { - CryptoStreamScheduler handshakeScheduler( - *conn_, - *getCryptoStream(*conn_->cryptoState, EncryptionLevel::Handshake)); - if (handshakeScheduler.hasData() || + auto& handshakeCryptoStream = + *getCryptoStream(*conn_->cryptoState, EncryptionLevel::Handshake); + CryptoStreamScheduler handshakeScheduler(*conn_, handshakeCryptoStream); + if ((conn_->outstandings.handshakePacketsCount && + handshakeCryptoStream.retransmissionBuffer.size() && + conn_->pendingEvents.numProbePackets) || + handshakeScheduler.hasData() || (conn_->ackStates.handshakeAckState.needsToSendAckImmediately && hasAcksToSchedule(conn_->ackStates.handshakeAckState))) { CHECK(conn_->handshakeWriteCipher); @@ -251,7 +257,7 @@ void QuicServerTransport::writeData() { version, packetLimit); } - if (!packetLimit) { + if (!packetLimit && !conn_->pendingEvents.numProbePackets) { return; } } diff --git a/quic/server/test/QuicServerTransportTest.cpp b/quic/server/test/QuicServerTransportTest.cpp index e64ada637..b55b3b6a4 100644 --- a/quic/server/test/QuicServerTransportTest.cpp +++ b/quic/server/test/QuicServerTransportTest.cpp @@ -1179,6 +1179,7 @@ TEST_F(QuicServerTransportTest, TestOpenAckStreamFrame) { // Remove any packets that might have been queued. server->getNonConstConn().outstandings.packets.clear(); + server->getNonConstConn().outstandings.initialPacketsCount = 0; server->getNonConstConn().outstandings.handshakePacketsCount = 0; server->writeChain(streamId, data->clone(), false, false); loopForWrites(); @@ -1781,6 +1782,7 @@ TEST_F(QuicServerTransportTest, TestCloneStopSending) { server->getNonConstConn().qLogger = qLogger; server->getNonConstConn().streamManager->getStream(streamId); // knock every handshake outstanding packets out + server->getNonConstConn().outstandings.initialPacketsCount = 0; server->getNonConstConn().outstandings.handshakePacketsCount = 0; server->getNonConstConn().outstandings.packets.clear(); for (auto& t : server->getNonConstConn().lossState.lossTimes) { diff --git a/quic/state/AckHandlers.cpp b/quic/state/AckHandlers.cpp index 06e98a43b..165425a41 100644 --- a/quic/state/AckHandlers.cpp +++ b/quic/state/AckHandlers.cpp @@ -48,6 +48,7 @@ void processAckFrame( // acks which leads to different number of packets being acked usually. ack.ackedPackets.reserve(kDefaultRxPacketsBeforeAckAfterInit); auto currentPacketIt = getLastOutstandingPacket(conn, pnSpace); + uint64_t initialPacketAcked = 0; uint64_t handshakePacketAcked = 0; uint64_t clonedPacketsAcked = 0; folly::Optional @@ -107,8 +108,15 @@ void processAckFrame( VLOG(10) << __func__ << " acked packetNum=" << currentPacketNum << " space=" << currentPacketNumberSpace << " handshake=" << (int)rPacketIt->isHandshake << " " << conn; - if (rPacketIt->isHandshake) { - ++handshakePacketAcked; + bool needsProcess = !rPacketIt->associatedEvent || + conn.outstandings.packetEvents.count(*rPacketIt->associatedEvent); + if (rPacketIt->isHandshake && needsProcess) { + if (currentPacketNumberSpace == PacketNumberSpace::Initial) { + ++initialPacketAcked; + } else { + CHECK_EQ(PacketNumberSpace::Handshake, currentPacketNumberSpace); + ++handshakePacketAcked; + } } ack.ackedBytes += rPacketIt->encodedSize; if (rPacketIt->associatedEvent) { @@ -124,8 +132,8 @@ void processAckFrame( } // Only invoke AckVisitor if the packet doesn't have an associated // PacketEvent; or the PacketEvent is in conn.outstandings.packetEvents - if (!rPacketIt->associatedEvent || - conn.outstandings.packetEvents.count(*rPacketIt->associatedEvent)) { + if (needsProcess /*!rPacketIt->associatedEvent || + conn.outstandings.packetEvents.count(*rPacketIt->associatedEvent)*/) { for (auto& packetFrame : rPacketIt->packet.frames) { ackVisitor(*rPacketIt, packetFrame, frame); } @@ -177,20 +185,23 @@ void processAckFrame( if (lastAckedPacketSentTime) { conn.lossState.lastAckedPacketSentTime = *lastAckedPacketSentTime; } - DCHECK_GE(conn.outstandings.handshakePacketsCount, handshakePacketAcked); + CHECK_GE(conn.outstandings.initialPacketsCount, initialPacketAcked); + conn.outstandings.initialPacketsCount -= initialPacketAcked; + CHECK_GE(conn.outstandings.handshakePacketsCount, handshakePacketAcked); conn.outstandings.handshakePacketsCount -= handshakePacketAcked; - DCHECK_GE(conn.outstandings.clonedPacketsCount, clonedPacketsAcked); + CHECK_GE(conn.outstandings.clonedPacketsCount, clonedPacketsAcked); conn.outstandings.clonedPacketsCount -= clonedPacketsAcked; auto updatedOustandingPacketsCount = conn.outstandings.packets.size(); - DCHECK_GE( - updatedOustandingPacketsCount, conn.outstandings.handshakePacketsCount); - DCHECK_GE( - updatedOustandingPacketsCount, conn.outstandings.clonedPacketsCount); + CHECK_GE( + updatedOustandingPacketsCount, + conn.outstandings.handshakePacketsCount + + conn.outstandings.initialPacketsCount); + CHECK_GE(updatedOustandingPacketsCount, conn.outstandings.clonedPacketsCount); auto lossEvent = handleAckForLoss(conn, lossVisitor, ack, pnSpace); if (conn.congestionController && (ack.largestAckedPacket.has_value() || lossEvent)) { if (lossEvent) { - DCHECK(lossEvent->largestLostSentTime && lossEvent->smallestLostSentTime); + CHECK(lossEvent->largestLostSentTime && lossEvent->smallestLostSentTime); lossEvent->persistentCongestion = isPersistentCongestion( conn, *lossEvent->smallestLostSentTime, diff --git a/quic/state/CMakeLists.txt b/quic/state/CMakeLists.txt index 83cb0dd22..888be7a83 100644 --- a/quic/state/CMakeLists.txt +++ b/quic/state/CMakeLists.txt @@ -8,6 +8,7 @@ add_library( QuicStreamManager.cpp QuicStreamUtilities.cpp StateData.cpp + PacketEvent.cpp PendingPathRateLimiter.cpp ) diff --git a/quic/state/PacketEvent.cpp b/quic/state/PacketEvent.cpp new file mode 100644 index 000000000..97d82a970 --- /dev/null +++ b/quic/state/PacketEvent.cpp @@ -0,0 +1,35 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + * + */ + +#include +#include +#include + +namespace quic { + +PacketEvent::PacketEvent( + PacketNumberSpace packetNumberSpaceIn, + PacketNum packetNumberIn) + : packetNumberSpace(packetNumberSpaceIn), packetNumber(packetNumberIn) {} + +bool operator==(const PacketEvent& lhs, const PacketEvent& rhs) { + return static_cast>( + lhs.packetNumberSpace) == + static_cast>( + rhs.packetNumberSpace) && + lhs.packetNumber == rhs.packetNumber; +} + +size_t PacketEventHash::operator()(const PacketEvent& packetEvent) const + noexcept { + return folly::hash::hash_combine( + static_cast>( + packetEvent.packetNumberSpace), + packetEvent.packetNumber); +} +} // namespace quic diff --git a/quic/state/PacketEvent.h b/quic/state/PacketEvent.h new file mode 100644 index 000000000..11a3b29e4 --- /dev/null +++ b/quic/state/PacketEvent.h @@ -0,0 +1,48 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + * + */ + +#pragma once + +#include + +namespace quic { + +/** + * There are cases that we may clone an outstanding packet and resend it as is. + * When that happens, we assign a PacketEvent to both the original and cloned + * packet if no PacketEvent is already associated with the original packet. If + * the original packet already has a PacketEvent, we copy that value into the + * cloned packet. + * A connection maintains a set of PacketEvents. When a packet with a + * PacketEvent is acked or lost, we search the set. If the PacketEvent is + * present in the set, we process the ack or loss event (e.g. update RTT, notify + * CongestionController, and detect loss with this packet) as well as frames in + * the packet. Then we remove the PacketEvent from the set. If the PacketEvent + * is absent in the set, we consider all frames contained in the packet are + * already processed. We will still handle the ack or loss event and update the + * connection. But no frame will be processed. + * + * TODO: Current PacketNum is an alias to uint64_t. We should just make + * PacketNum be a type with both the space and the number, then PacketEvent will + * just be an alias to this type. + */ +struct PacketEvent { + PacketNumberSpace packetNumberSpace; + PacketNum packetNumber; + + PacketEvent() = delete; + PacketEvent(PacketNumberSpace packetNumberSpaceIn, PacketNum packetNumberIn); +}; + +// To work with F14 Set: +bool operator==(const PacketEvent& lhs, const PacketEvent& rhs); + +struct PacketEventHash { + size_t operator()(const PacketEvent& packetEvent) const noexcept; +}; +} // namespace quic diff --git a/quic/state/StateData.h b/quic/state/StateData.h index f1ed6d1b1..b65101520 100644 --- a/quic/state/StateData.h +++ b/quic/state/StateData.h @@ -18,6 +18,7 @@ #include #include #include +#include #include #include #include @@ -102,23 +103,6 @@ struct NetworkDataSingle { } }; -/** - * There are cases that we may clone an outstanding packet and resend it as is. - * When that happens, we assign a PacketEvent to both the original and cloned - * packet if no PacketEvent is already associated with the original packet. If - * the original packet already has a PacketEvent, we copy that value into the - * cloned packet. - * A connection maintains a set of PacketEvents. When a packet with a - * PacketEvent is acked or lost, we search the set. If the PacketEvent is - * present in the set, we process the ack or loss event (e.g. update RTT, notify - * CongestionController, and detect loss with this packet) as well as frames in - * the packet. Then we remove the PacketEvent from the set. If the PacketEvent - * is absent in the set, we consider all frames contained in the packet are - * already processed. We will still handle the ack or loss event and update the - * connection. But no frame will be processed. - */ -using PacketEvent = PacketNum; - // Data structure to represent outstanding retransmittable packets struct OutstandingPacket { // Structure representing the frames that are outstanding including the header @@ -187,9 +171,14 @@ struct OutstandingsInfo { // associatedEvent or if it's not in this set, there is no need to process its // frames upon ack or loss. // TODO: Enforce only AppTraffic packets to be clonable - folly::F14FastSet packetEvents; + folly::F14FastSet packetEvents; - // Number of handshake packets outstanding. + // Number of outstanding packets in Initial space, not including cloned + // Initial packets. + uint64_t initialPacketsCount{0}; + + // Number of outstanding packets in Handshake space, not including cloned + // Handshake packets. uint64_t handshakePacketsCount{0}; // Number of packets are clones or cloned. @@ -449,7 +438,7 @@ using Resets = folly::F14FastMap; using FrameList = std::vector; struct LossState { - enum class AlarmMethod { EarlyRetransmitOrReordering, Handshake, PTO }; + enum class AlarmMethod { EarlyRetransmitOrReordering, PTO }; // Smooth rtt std::chrono::microseconds srtt{0us}; // Latest rtt @@ -458,10 +447,8 @@ struct LossState { std::chrono::microseconds rttvar{0us}; // Number of packet loss timer fired before receiving an ack uint32_t ptoCount{0}; - // The number of times the handshake packets have been retransmitted without - // receiving an ack. - uint16_t handshakeAlarmCount{0}; - // The time when last handshake packet was sent + // The time when last handshake packet (including both Initial and Handshake + // space) was sent TimePoint lastHandshakePacketSentTime; // Latest packet number sent // TODO: this also needs to be 3 numbers now... diff --git a/quic/state/test/AckHandlersTest.cpp b/quic/state/test/AckHandlersTest.cpp index fd15d699d..b351cad9c 100644 --- a/quic/state/test/AckHandlersTest.cpp +++ b/quic/state/test/AckHandlersTest.cpp @@ -457,7 +457,8 @@ TEST_P(AckHandlersTest, TestHandshakeCounterUpdate) { QuicServerConnectionState conn; StreamId stream = 1; for (PacketNum packetNum = 0; packetNum < 10; packetNum++) { - auto regularPacket = createNewPacket(packetNum, GetParam()); + auto regularPacket = createNewPacket( + packetNum, (packetNum % 2 ? GetParam() : PacketNumberSpace::AppData)); WriteStreamFrame frame( stream, 100 * packetNum + 0, 100 * packetNum + 100, false); regularPacket.frames.emplace_back(std::move(frame)); @@ -465,9 +466,13 @@ TEST_P(AckHandlersTest, TestHandshakeCounterUpdate) { std::move(regularPacket), Clock::now(), 0, - packetNum % 2, + packetNum % 2 && GetParam() != PacketNumberSpace::AppData, packetNum / 2); - conn.outstandings.handshakePacketsCount += packetNum % 2; + if (GetParam() == PacketNumberSpace::Initial) { + conn.outstandings.initialPacketsCount += packetNum % 2; + } else if (GetParam() == PacketNumberSpace::Handshake) { + conn.outstandings.handshakePacketsCount += packetNum % 2; + } } ReadAckFrame ackFrame; @@ -482,10 +487,21 @@ TEST_P(AckHandlersTest, TestHandshakeCounterUpdate) { [&](const auto&, const auto&, const ReadAckFrame&) {}, testLossHandler(lostPackets), Clock::now()); - // When [3, 7] are acked, [0, 2] will also be marked loss, due to reordering - // threshold - EXPECT_EQ(1, conn.outstandings.handshakePacketsCount); - EXPECT_EQ(2, conn.outstandings.packets.size()); + // When [3, 7] are acked, [0, 2] may also be marked loss if they are in the + // same packet number space, due to reordering threshold + if (GetParam() == PacketNumberSpace::Initial) { + EXPECT_EQ(1, conn.outstandings.initialPacketsCount); + // AppData packets won't be acked by an ack in Initial space: + // So 0, 2, 4, 6, 8 and 9 are left in OP list + EXPECT_EQ(6, conn.outstandings.packets.size()); + } else if (GetParam() == PacketNumberSpace::Handshake) { + EXPECT_EQ(1, conn.outstandings.handshakePacketsCount); + // AppData packets won't be acked by an ack in Handshake space: + // So 0, 2, 4, 6, 8 and 9 are left in OP list + EXPECT_EQ(6, conn.outstandings.packets.size()); + } else { + EXPECT_EQ(2, conn.outstandings.packets.size()); + } } TEST_P(AckHandlersTest, PurgeAcks) { @@ -571,7 +587,7 @@ TEST_P(AckHandlersTest, SkipAckVisitor) { std::move(regularPacket), Clock::now(), 1, false, 1); // Give this outstandingPacket an associatedEvent that's not in // outstandings.packetEvents - outstandingPacket.associatedEvent = 0; + outstandingPacket.associatedEvent.emplace(PacketNumberSpace::AppData, 0); conn.outstandings.packets.push_back(std::move(outstandingPacket)); conn.outstandings.clonedPacketsCount++; @@ -611,17 +627,20 @@ TEST_P(AckHandlersTest, NoDoubleProcess) { OutstandingPacket outstandingPacket1( std::move(regularPacket1), Clock::now(), 1, false, 1); - outstandingPacket1.associatedEvent = packetNum1; + outstandingPacket1.associatedEvent.emplace( + PacketNumberSpace::AppData, packetNum1); OutstandingPacket outstandingPacket2( std::move(regularPacket2), Clock::now(), 1, false, 1); // The seconds packet has the same PacketEvent - outstandingPacket2.associatedEvent = packetNum1; + outstandingPacket2.associatedEvent.emplace( + PacketNumberSpace::AppData, packetNum1); conn.outstandings.packets.push_back(std::move(outstandingPacket1)); conn.outstandings.packets.push_back(std::move(outstandingPacket2)); conn.outstandings.clonedPacketsCount += 2; - conn.outstandings.packetEvents.insert(packetNum1); + conn.outstandings.packetEvents.insert( + PacketEvent(PacketNumberSpace::AppData, packetNum1)); // A counting ack visitor uint16_t ackVisitorCounter = 0; @@ -673,7 +692,8 @@ TEST_P(AckHandlersTest, ClonedPacketsCounter) { regularPacket1.frames.push_back(frame); OutstandingPacket outstandingPacket1( std::move(regularPacket1), Clock::now(), 1, false, 1); - outstandingPacket1.associatedEvent = packetNum1; + outstandingPacket1.associatedEvent.emplace( + PacketNumberSpace::AppData, packetNum1); conn.ackStates.appDataAckState.nextPacketNum++; auto packetNum2 = conn.ackStates.appDataAckState.nextPacketNum; @@ -685,7 +705,8 @@ TEST_P(AckHandlersTest, ClonedPacketsCounter) { conn.outstandings.packets.push_back(std::move(outstandingPacket1)); conn.outstandings.packets.push_back(std::move(outstandingPacket2)); conn.outstandings.clonedPacketsCount = 1; - conn.outstandings.packetEvents.insert(packetNum1); + conn.outstandings.packetEvents.emplace( + PacketNumberSpace::AppData, packetNum1); ReadAckFrame ackFrame; ackFrame.largestAcked = packetNum2; diff --git a/quic/state/test/PacketEventTest.cpp b/quic/state/test/PacketEventTest.cpp new file mode 100644 index 000000000..4adf6d790 --- /dev/null +++ b/quic/state/test/PacketEventTest.cpp @@ -0,0 +1,39 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + * + */ + +#include +#include + +using namespace folly; +using namespace testing; + +namespace quic { +namespace test { +TEST(PacketEventTest, EqTest) { + PacketEvent initialEvent(PacketNumberSpace::Initial, 0); + PacketEvent initialEvent0(PacketNumberSpace::Initial, 0); + EXPECT_TRUE(initialEvent == initialEvent0); + + PacketEvent initialEvent1(PacketNumberSpace::Initial, 1); + EXPECT_FALSE(initialEvent0 == initialEvent1); + + PacketEvent handshakeEvent(PacketNumberSpace::Handshake, 0); + EXPECT_FALSE(handshakeEvent == initialEvent); +} + +TEST(PacketEventTest, HashTest) { + PacketEventHash hashObj; + PacketEvent initialEvent0(PacketNumberSpace::Initial, 0); + PacketEvent initialEvent1(PacketNumberSpace::Initial, 1); + EXPECT_NE(hashObj(initialEvent0), hashObj(initialEvent1)); + + PacketEvent handshakeEvent0(PacketNumberSpace::Handshake, 0); + EXPECT_NE(hashObj(initialEvent0), hashObj(handshakeEvent0)); +} +} // namespace test +} // namespace quic