diff --git a/quic/QuicConstants.h b/quic/QuicConstants.h index b1551addd..cc24ea0cf 100644 --- a/quic/QuicConstants.h +++ b/quic/QuicConstants.h @@ -595,4 +595,8 @@ enum class DataPathType : uint8_t { ContinuousMemory = 1, }; +// Stream priority level, can only be in [0, 7] +using PriorityLevel = uint8_t; +constexpr uint8_t kDefaultMaxPriority = 7; + } // namespace quic diff --git a/quic/api/QuicPacketScheduler.cpp b/quic/api/QuicPacketScheduler.cpp index e40ff4f6e..e5f4514de 100644 --- a/quic/api/QuicPacketScheduler.cpp +++ b/quic/api/QuicPacketScheduler.cpp @@ -9,6 +9,109 @@ #include #include +namespace { +using namespace quic; + +/** + * A helper iterator adaptor class that starts iteration of streams from a + * specific stream id. + */ +class MiddleStartingIterationWrapper { + public: + using MapType = std::set; + + class MiddleStartingIterator + : public boost::iterator_facade< + MiddleStartingIterator, + const MiddleStartingIterationWrapper::MapType::value_type, + boost::forward_traversal_tag> { + friend class boost::iterator_core_access; + + public: + using MapType = MiddleStartingIterationWrapper::MapType; + + MiddleStartingIterator() = delete; + + MiddleStartingIterator( + const MapType* streams, + const MapType::key_type& start) + : streams_(streams) { + itr_ = streams_->lower_bound(start); + checkForWrapAround(); + // We don't want to mark it as wrapped around initially, instead just + // act as if start was the first element. + wrappedAround_ = false; + } + + MiddleStartingIterator(const MapType* streams, MapType::const_iterator itr) + : streams_(streams), itr_(itr) { + checkForWrapAround(); + // We don't want to mark it as wrapped around initially, instead just + // act as if start was the first element. + wrappedAround_ = false; + } + + FOLLY_NODISCARD const MapType::value_type& dereference() const { + return *itr_; + } + + FOLLY_NODISCARD MapType::const_iterator rawIterator() const { + return itr_; + } + + FOLLY_NODISCARD bool equal(const MiddleStartingIterator& other) const { + return wrappedAround_ == other.wrappedAround_ && itr_ == other.itr_; + } + + void increment() { + ++itr_; + checkForWrapAround(); + } + + void checkForWrapAround() { + if (itr_ == streams_->cend()) { + wrappedAround_ = true; + itr_ = streams_->cbegin(); + } + } + + private: + friend class MiddleStartingIterationWrapper; + bool wrappedAround_{false}; + const MapType* streams_{nullptr}; + MapType::const_iterator itr_; + }; + + MiddleStartingIterationWrapper( + const MapType& streams, + const MapType::key_type& start) + : streams_(streams), start_(&streams_, start) {} + + MiddleStartingIterationWrapper( + const MapType& streams, + const MapType::const_iterator& start) + : streams_(streams), start_(&streams_, start) {} + + FOLLY_NODISCARD MiddleStartingIterator cbegin() const { + return start_; + } + + FOLLY_NODISCARD MiddleStartingIterator cend() const { + MiddleStartingIterator itr(start_); + itr.wrappedAround_ = true; + return itr; + } + + private: + const MapType& streams_; + const MiddleStartingIterator start_; +}; + +using WritableStreamItr = + MiddleStartingIterationWrapper::MiddleStartingIterator; + +} // namespace + namespace quic { bool hasAcksToSchedule(const AckState& ackState) { @@ -225,31 +328,70 @@ RetransmissionScheduler::RetransmissionScheduler( const QuicConnectionStateBase& conn) : conn_(conn) {} +// Return true if this stream wrote some data +bool RetransmissionScheduler::writeStreamLossBuffers( + PacketBuilderInterface& builder, + StreamId id) { + auto stream = conn_.streamManager->findStream(id); + CHECK(stream); + bool wroteStreamFrame = false; + for (auto buffer = stream->lossBuffer.cbegin(); + buffer != stream->lossBuffer.cend(); + ++buffer) { + auto bufferLen = buffer->data.chainLength(); + auto dataLen = writeStreamFrameHeader( + builder, + stream->id, + buffer->offset, + bufferLen, // writeBufferLen -- only the len of the single buffer. + bufferLen, // flowControlLen -- not relevant, already flow controlled. + buffer->eof, + folly::none /* skipLenHint */); + if (dataLen) { + wroteStreamFrame = true; + writeStreamFrameData(builder, buffer->data, *dataLen); + VLOG(4) << "Wrote retransmitted stream=" << stream->id + << " offset=" << buffer->offset << " bytes=" << *dataLen + << " fin=" << (buffer->eof && *dataLen == bufferLen) << " " + << conn_; + } else { + // Either we filled the packet or ran out of data for this stream (EOF?) + break; + } + } + return wroteStreamFrame; +} + void RetransmissionScheduler::writeRetransmissionStreams( PacketBuilderInterface& builder) { - for (auto streamId : conn_.streamManager->lossStreams()) { - auto stream = conn_.streamManager->findStream(streamId); - CHECK(stream); - for (auto buffer = stream->lossBuffer.cbegin(); - buffer != stream->lossBuffer.cend(); - ++buffer) { - auto bufferLen = buffer->data.chainLength(); - auto dataLen = writeStreamFrameHeader( - builder, - stream->id, - buffer->offset, - bufferLen, // writeBufferLen -- only the len of the single buffer. - bufferLen, // flowControlLen -- not relevant, already flow controlled. - buffer->eof, - folly::none /* skipLenHint */); - if (dataLen) { - writeStreamFrameData(builder, buffer->data, *dataLen); - VLOG(4) << "Wrote retransmitted stream=" << stream->id - << " offset=" << buffer->offset << " bytes=" << *dataLen - << " fin=" << (buffer->eof && *dataLen == bufferLen) << " " - << conn_; - } else { - return; + auto& lossStreams = conn_.streamManager->lossStreams(); + for (size_t index = 0; + index < lossStreams.levels.size() && builder.remainingSpaceInPkt() > 0; + index++) { + auto& level = lossStreams.levels[index]; + if (level.streams.empty()) { + // No data here, keep going + continue; + } + if (level.incremental) { + // Round robin the streams at this level + MiddleStartingIterationWrapper wrapper(level.streams, level.next); + auto writableStreamItr = wrapper.cbegin(); + while (writableStreamItr != wrapper.cend()) { + if (writeStreamLossBuffers(builder, *writableStreamItr)) { + writableStreamItr++; + } else { + // We didn't write anything + break; + } + } + level.next = writableStreamItr.rawIterator(); + } else { + // walk the sequential streams in order until we run out of space + for (auto stream : level.streams) { + if (!writeStreamLossBuffers(builder, stream)) { + break; + } } } } @@ -288,6 +430,54 @@ StreamId StreamFrameScheduler::writeStreamsHelper( return *writableStreamItr; } +void StreamFrameScheduler::writeStreamsHelper( + PacketBuilderInterface& builder, + PriorityQueue& writableStreams, + uint64_t& connWritableBytes, + bool streamPerPacket) { + // Fill a packet with non-control stream data, in priority order + for (size_t index = 0; index < writableStreams.levels.size() && + builder.remainingSpaceInPkt() > 0 && connWritableBytes > 0; + index++) { + PriorityQueue::Level& level = writableStreams.levels[index]; + if (level.streams.empty()) { + // No data here, keep going + continue; + } + if (level.incremental) { + // Round robin the streams at this level + MiddleStartingIterationWrapper wrapper(level.streams, level.next); + auto writableStreamItr = wrapper.cbegin(); + while (writableStreamItr != wrapper.cend() && connWritableBytes > 0) { + if (writeNextStreamFrame( + builder, *writableStreamItr, connWritableBytes)) { + writableStreamItr++; + if (streamPerPacket) { + level.next = writableStreamItr.rawIterator(); + return; + } + } else { + // Either we filled the packet, ran out of flow control, + // or ran out of data at this level + break; + } + } + level.next = writableStreamItr.rawIterator(); + } else { + // walk the sequential streams in order until we run out of space + for (auto streamIt = level.streams.begin(); + streamIt != level.streams.end() && connWritableBytes > 0; + ++streamIt) { + if (!writeNextStreamFrame(builder, *streamIt, connWritableBytes)) { + break; + } else if (streamPerPacket) { + return; + } + } + } + } +} + void StreamFrameScheduler::writeStreams(PacketBuilderInterface& builder) { DCHECK(conn_.streamManager->hasWritable()); uint64_t connWritableBytes = getSendConnFlowControlBytesWire(conn_); @@ -308,12 +498,11 @@ void StreamFrameScheduler::writeStreams(PacketBuilderInterface& builder) { if (connWritableBytes == 0) { return; } - const auto& writableStreams = conn_.streamManager->writableStreams(); + auto& writableStreams = conn_.streamManager->writableStreams(); if (!writableStreams.empty()) { - conn_.schedulingState.nextScheduledStream = writeStreamsHelper( + writeStreamsHelper( builder, writableStreams, - conn_.schedulingState.nextScheduledStream, connWritableBytes, conn_.transportSettings.streamFramePerPacket); } diff --git a/quic/api/QuicPacketScheduler.h b/quic/api/QuicPacketScheduler.h index 565d5fb86..344fd3f08 100644 --- a/quic/api/QuicPacketScheduler.h +++ b/quic/api/QuicPacketScheduler.h @@ -68,6 +68,8 @@ class RetransmissionScheduler { bool hasPendingData() const; private: + bool writeStreamLossBuffers(PacketBuilderInterface& builder, StreamId id); + const QuicConnectionStateBase& conn_; }; @@ -84,84 +86,6 @@ class StreamFrameScheduler { bool hasPendingData() const; private: - /** - * A helper iterator adaptor class that starts iteration of streams from a - * specific stream id. - */ - class MiddleStartingIterationWrapper { - public: - using MapType = std::set; - - class MiddleStartingIterator - : public boost::iterator_facade< - MiddleStartingIterator, - const MiddleStartingIterationWrapper::MapType::value_type, - boost::forward_traversal_tag> { - friend class boost::iterator_core_access; - - public: - using MapType = MiddleStartingIterationWrapper::MapType; - - MiddleStartingIterator() = default; - - MiddleStartingIterator( - const MapType* streams, - const MapType::key_type& start) - : streams_(streams) { - itr_ = streams_->lower_bound(start); - checkForWrapAround(); - // We don't want to mark it as wrapped around initially, instead just - // act as if start was the first element. - wrappedAround_ = false; - } - - const MapType::value_type& dereference() const { - return *itr_; - } - - bool equal(const MiddleStartingIterator& other) const { - return wrappedAround_ == other.wrappedAround_ && itr_ == other.itr_; - } - - void increment() { - ++itr_; - checkForWrapAround(); - } - - void checkForWrapAround() { - if (itr_ == streams_->cend()) { - wrappedAround_ = true; - itr_ = streams_->cbegin(); - } - } - - private: - friend class MiddleStartingIterationWrapper; - bool wrappedAround_{false}; - const MapType* streams_{nullptr}; - MapType::const_iterator itr_; - }; - - MiddleStartingIterationWrapper( - const MapType& streams, - const MapType::key_type& start) - : streams_(streams), start_(start) {} - - MiddleStartingIterator cbegin() const { - return MiddleStartingIterator(&streams_, start_); - } - - MiddleStartingIterator cend() const { - MiddleStartingIterator itr(&streams_, start_); - itr.wrappedAround_ = true; - return itr; - } - - private: - const MapType& streams_; - const MapType::key_type& start_; - }; - StreamId writeStreamsHelper( PacketBuilderInterface& builder, const std::set& writableStreams, @@ -169,8 +93,11 @@ class StreamFrameScheduler { uint64_t& connWritableBytes, bool streamPerPacket); - using WritableStreamItr = - MiddleStartingIterationWrapper::MiddleStartingIterator; + void writeStreamsHelper( + PacketBuilderInterface& builder, + PriorityQueue& writableStreams, + uint64_t& connWritableBytes, + bool streamPerPacket); /** * Helper function to write either stream data if stream is not flow diff --git a/quic/api/QuicSocket.h b/quic/api/QuicSocket.h index 289e3fa68..47e3d5c3e 100644 --- a/quic/api/QuicSocket.h +++ b/quic/api/QuicSocket.h @@ -445,6 +445,13 @@ class QuicSocket { */ virtual bool isPartiallyReliableTransport() const = 0; + /** + * Set stream priority. + * level: can only be in [0, 7]. + */ + virtual folly::Expected + setStreamPriority(StreamId id, PriorityLevel level, bool incremental) = 0; + /** * ===== Read API ==== */ diff --git a/quic/api/QuicTransportBase.cpp b/quic/api/QuicTransportBase.cpp index ba447f043..3da05b0fa 100644 --- a/quic/api/QuicTransportBase.cpp +++ b/quic/api/QuicTransportBase.cpp @@ -2907,6 +2907,27 @@ bool QuicTransportBase::isPartiallyReliableTransport() const { return conn_->partialReliabilityEnabled; } +folly::Expected +QuicTransportBase::setStreamPriority( + StreamId id, + PriorityLevel level, + bool incremental) { + if (closeState_ != CloseState::OPEN) { + return folly::makeUnexpected(LocalErrorCode::CONNECTION_CLOSED); + } + if (level > kDefaultMaxPriority) { + return folly::makeUnexpected(LocalErrorCode::INVALID_OPERATION); + } + if (!conn_->streamManager->streamExists(id)) { + // It's not an error to try to prioritize a non-existent stream. + return folly::unit; + } + // It's not an error to prioritize a stream after it's sent its FIN - this + // can reprioritize retransmissions. + conn_->streamManager->setStreamPriority(id, level, incremental); + return folly::unit; +} + void QuicTransportBase::setCongestionControl(CongestionControlType type) { DCHECK(conn_); if (!conn_->congestionController || diff --git a/quic/api/QuicTransportBase.h b/quic/api/QuicTransportBase.h index 7d6a7527e..a3d041aa5 100644 --- a/quic/api/QuicTransportBase.h +++ b/quic/api/QuicTransportBase.h @@ -317,6 +317,11 @@ class QuicTransportBase : public QuicSocket { bool isPartiallyReliableTransport() const override; + folly::Expected setStreamPriority( + StreamId id, + PriorityLevel level, + bool incremental) override; + /** * Invoke onCanceled on all the delivery callbacks registered for streamId. */ diff --git a/quic/api/test/MockQuicSocket.h b/quic/api/test/MockQuicSocket.h index de3d6f2bc..c8cd7f115 100644 --- a/quic/api/test/MockQuicSocket.h +++ b/quic/api/test/MockQuicSocket.h @@ -96,6 +96,9 @@ class MockQuicSocket : public QuicSocket { SharedBuf)); MOCK_CONST_METHOD0(isKnobSupported, bool()); MOCK_CONST_METHOD0(isPartiallyReliableTransport, bool()); + MOCK_METHOD3( + setStreamPriority, + folly::Expected(StreamId, uint8_t, bool)); MOCK_METHOD3( setReadCallback, folly::Expected( diff --git a/quic/api/test/QuicPacketSchedulerTest.cpp b/quic/api/test/QuicPacketSchedulerTest.cpp index 56e6a0063..609274b20 100644 --- a/quic/api/test/QuicPacketSchedulerTest.cpp +++ b/quic/api/test/QuicPacketSchedulerTest.cpp @@ -1068,7 +1068,7 @@ TEST_F(QuicPacketSchedulerTest, StreamFrameSchedulerAllFit) { folly::IOBuf::copyBuffer("some data"), false); scheduler.writeStreams(builder); - EXPECT_EQ(conn.schedulingState.nextScheduledStream, 0); + EXPECT_EQ(conn.streamManager->writableStreams().getNextScheduledStream(), 0); } TEST_F(QuicPacketSchedulerTest, StreamFrameSchedulerRoundRobin) { @@ -1112,9 +1112,8 @@ TEST_F(QuicPacketSchedulerTest, StreamFrameSchedulerRoundRobin) { folly::IOBuf::copyBuffer("some data"), false); // Force the wraparound initially. - conn.schedulingState.nextScheduledStream = stream3 + 8; scheduler.writeStreams(builder); - EXPECT_EQ(conn.schedulingState.nextScheduledStream, 4); + EXPECT_EQ(conn.streamManager->writableStreams().getNextScheduledStream(), 4); // Should write frames for stream2, stream3, followed by stream1 again. NiceMock builder2; @@ -1176,10 +1175,10 @@ TEST_F(QuicPacketSchedulerTest, StreamFrameSchedulerRoundRobinStreamPerPacket) { *conn.streamManager->findStream(stream3), folly::IOBuf::copyBuffer("some data"), false); - // Force the wraparound initially. - conn.schedulingState.nextScheduledStream = stream3 + 8; + // The default is to wraparound initially. scheduler.writeStreams(builder1); - EXPECT_EQ(conn.schedulingState.nextScheduledStream, 4); + EXPECT_EQ( + conn.streamManager->writableStreams().getNextScheduledStream(), stream2); // Should write frames for stream2, stream3, followed by stream1 again. NiceMock builder2; @@ -1205,6 +1204,75 @@ TEST_F(QuicPacketSchedulerTest, StreamFrameSchedulerRoundRobinStreamPerPacket) { EXPECT_EQ(*frames[2].asWriteStreamFrame(), f3); } +TEST_F(QuicPacketSchedulerTest, StreamFrameSchedulerSequential) { + QuicClientConnectionState conn( + FizzClientQuicHandshakeContext::Builder().build()); + conn.streamManager->setMaxLocalBidirectionalStreams(10); + conn.flowControlState.peerAdvertisedMaxOffset = 100000; + conn.flowControlState.peerAdvertisedInitialMaxStreamOffsetBidiRemote = 100000; + auto connId = getTestConnectionId(); + StreamFrameScheduler scheduler(conn); + ShortHeader shortHeader1( + ProtectionType::KeyPhaseZero, + connId, + getNextPacketNum(conn, PacketNumberSpace::AppData)); + RegularQuicPacketBuilder builder1( + conn.udpSendPacketLen, + std::move(shortHeader1), + conn.ackStates.appDataAckState.largestAckedByPeer.value_or(0)); + auto stream1 = + conn.streamManager->createNextBidirectionalStream().value()->id; + auto stream2 = + conn.streamManager->createNextBidirectionalStream().value()->id; + auto stream3 = + conn.streamManager->createNextBidirectionalStream().value()->id; + conn.streamManager->findStream(stream1)->priority = Priority(0, false); + conn.streamManager->findStream(stream2)->priority = Priority(0, false); + conn.streamManager->findStream(stream3)->priority = Priority(0, false); + auto largeBuf = folly::IOBuf::createChain(conn.udpSendPacketLen * 2, 4096); + auto curBuf = largeBuf.get(); + do { + curBuf->append(curBuf->capacity()); + curBuf = curBuf->next(); + } while (curBuf != largeBuf.get()); + auto chainLen = largeBuf->computeChainDataLength(); + writeDataToQuicStream( + *conn.streamManager->findStream(stream1), std::move(largeBuf), false); + writeDataToQuicStream( + *conn.streamManager->findStream(stream2), + folly::IOBuf::copyBuffer("some data"), + false); + writeDataToQuicStream( + *conn.streamManager->findStream(stream3), + folly::IOBuf::copyBuffer("some data"), + false); + // The default is to wraparound initially. + scheduler.writeStreams(builder1); + EXPECT_EQ( + conn.streamManager->writableStreams().getNextScheduledStream( + Priority(0, false)), + stream1); + + // Should write frames for stream1, stream2, stream3, in that order. + NiceMock builder2; + EXPECT_CALL(builder2, remainingSpaceInPkt()).WillRepeatedly(Return(4096)); + EXPECT_CALL(builder2, appendFrame(_)).WillRepeatedly(Invoke([&](auto f) { + builder2.frames_.push_back(f); + })); + scheduler.writeStreams(builder2); + auto& frames = builder2.frames_; + ASSERT_EQ(frames.size(), 3); + WriteStreamFrame f1(stream1, 0, chainLen, false); + WriteStreamFrame f2(stream2, 0, 9, false); + WriteStreamFrame f3(stream3, 0, 9, false); + ASSERT_TRUE(frames[0].asWriteStreamFrame()); + EXPECT_EQ(*frames[0].asWriteStreamFrame(), f1); + ASSERT_TRUE(frames[1].asWriteStreamFrame()); + EXPECT_EQ(*frames[1].asWriteStreamFrame(), f2); + ASSERT_TRUE(frames[2].asWriteStreamFrame()); + EXPECT_EQ(*frames[2].asWriteStreamFrame(), f3); +} + TEST_F(QuicPacketSchedulerTest, StreamFrameSchedulerRoundRobinControl) { QuicClientConnectionState conn( FizzClientQuicHandshakeContext::Builder().build()); @@ -1255,10 +1323,10 @@ TEST_F(QuicPacketSchedulerTest, StreamFrameSchedulerRoundRobinControl) { *conn.streamManager->findStream(stream4), folly::IOBuf::copyBuffer("some data"), false); - // Force the wraparound initially. - conn.schedulingState.nextScheduledStream = stream4 + 8; + // The default is to wraparound initially. scheduler.writeStreams(builder); - EXPECT_EQ(conn.schedulingState.nextScheduledStream, stream3); + EXPECT_EQ( + conn.streamManager->writableStreams().getNextScheduledStream(), stream3); EXPECT_EQ(conn.schedulingState.nextScheduledControlStream, stream2); // Should write frames for stream2, stream4, followed by stream 3 then 1. @@ -1283,7 +1351,8 @@ TEST_F(QuicPacketSchedulerTest, StreamFrameSchedulerRoundRobinControl) { ASSERT_TRUE(frames[3].asWriteStreamFrame()); EXPECT_EQ(*frames[3].asWriteStreamFrame(), f4); - EXPECT_EQ(conn.schedulingState.nextScheduledStream, stream3); + EXPECT_EQ( + conn.streamManager->writableStreams().getNextScheduledStream(), stream3); EXPECT_EQ(conn.schedulingState.nextScheduledControlStream, stream2); } @@ -1307,7 +1376,7 @@ TEST_F(QuicPacketSchedulerTest, StreamFrameSchedulerOneStream) { auto stream1 = conn.streamManager->createNextBidirectionalStream().value(); writeDataToQuicStream(*stream1, folly::IOBuf::copyBuffer("some data"), false); scheduler.writeStreams(builder); - EXPECT_EQ(conn.schedulingState.nextScheduledStream, 0); + EXPECT_EQ(conn.streamManager->writableStreams().getNextScheduledStream(), 0); } TEST_F(QuicPacketSchedulerTest, StreamFrameSchedulerRemoveOne) { @@ -1344,8 +1413,8 @@ TEST_F(QuicPacketSchedulerTest, StreamFrameSchedulerRemoveOne) { // Manually remove a stream and set the next scheduled to that stream. builder.frames_.clear(); + conn.streamManager->writableStreams().setNextScheduledStream(stream2); conn.streamManager->removeWritable(*conn.streamManager->findStream(stream2)); - conn.schedulingState.nextScheduledStream = stream2; scheduler.writeStreams(builder); ASSERT_EQ(builder.frames_.size(), 1); ASSERT_TRUE(builder.frames_[0].asWriteStreamFrame()); diff --git a/quic/api/test/QuicTransportTest.cpp b/quic/api/test/QuicTransportTest.cpp index cd3196bb5..6870763d7 100644 --- a/quic/api/test/QuicTransportTest.cpp +++ b/quic/api/test/QuicTransportTest.cpp @@ -147,11 +147,7 @@ void dropPackets(QuicServerConnectionState& conn) { }), std::move(*itr->second)); stream->retransmissionBuffer.erase(itr); - if (std::find( - conn.streamManager->lossStreams().begin(), - conn.streamManager->lossStreams().end(), - streamFrame->streamId) == - conn.streamManager->lossStreams().end()) { + if (conn.streamManager->lossStreams().count(streamFrame->streamId) == 0) { conn.streamManager->addLoss(streamFrame->streamId); } } @@ -462,6 +458,7 @@ TEST_F(QuicTransportTest, WriteSmall) { EXPECT_CALL(*socket_, write(_, _)).WillOnce(Invoke(bufLength)); transport_->writeChain(stream, buf->clone(), false, false); + transport_->setStreamPriority(stream, 0, false); loopForWrites(); auto& conn = transport_->getConnectionState(); verifyCorrectness(conn, 0, stream, *buf); @@ -1594,6 +1591,8 @@ TEST_F(QuicTransportTest, NonWritableStreamAPI) { EXPECT_EQ(LocalErrorCode::STREAM_CLOSED, res1.error()); auto res2 = transport_->notifyPendingWriteOnStream(streamId, &writeCallback_); EXPECT_EQ(LocalErrorCode::STREAM_CLOSED, res2.error()); + auto res3 = transport_->setStreamPriority(streamId, 0, false); + EXPECT_FALSE(res3.hasError()); } TEST_F(QuicTransportTest, RstWrittenStream) { @@ -2880,7 +2879,7 @@ TEST_F(QuicTransportTest, WriteStreamFromMiddleOfMap) { conn.outstandings.packets.clear(); // Start from stream2 instead of stream1 - conn.schedulingState.nextScheduledStream = s2; + conn.streamManager->writableStreams().setNextScheduledStream(s2); writableBytes = kDefaultUDPSendPacketLen - 100; EXPECT_CALL(*socket_, write(_, _)).WillOnce(Invoke(bufLength)); @@ -2903,7 +2902,7 @@ TEST_F(QuicTransportTest, WriteStreamFromMiddleOfMap) { conn.outstandings.packets.clear(); // Test wrap around - conn.schedulingState.nextScheduledStream = s2; + conn.streamManager->writableStreams().setNextScheduledStream(s2); writableBytes = kDefaultUDPSendPacketLen; EXPECT_CALL(*socket_, write(_, _)).WillOnce(Invoke(bufLength)); writeQuicDataToSocket( diff --git a/quic/fizz/client/test/QuicClientTransportTest.cpp b/quic/fizz/client/test/QuicClientTransportTest.cpp index a742cf7dc..90da31591 100644 --- a/quic/fizz/client/test/QuicClientTransportTest.cpp +++ b/quic/fizz/client/test/QuicClientTransportTest.cpp @@ -4734,14 +4734,10 @@ TEST_F(QuicClientTransportAfterStartTest, ResetClearsPendingLoss) { CHECK_NOTNULL(findPacketWithStream(client->getNonConstConn(), streamId)); markPacketLoss(client->getNonConstConn(), *forceLossPacket, false); auto& pendingLossStreams = client->getConn().streamManager->lossStreams(); - auto it = - std::find(pendingLossStreams.begin(), pendingLossStreams.end(), streamId); - ASSERT_TRUE(it != pendingLossStreams.end()); + ASSERT_TRUE(pendingLossStreams.count(streamId) > 0); client->resetStream(streamId, GenericApplicationErrorCode::UNKNOWN); - it = - std::find(pendingLossStreams.begin(), pendingLossStreams.end(), streamId); - ASSERT_TRUE(it == pendingLossStreams.end()); + ASSERT_TRUE(pendingLossStreams.count(streamId) == 0); } TEST_F(QuicClientTransportAfterStartTest, LossAfterResetStream) { @@ -4763,9 +4759,7 @@ TEST_F(QuicClientTransportAfterStartTest, LossAfterResetStream) { client->getNonConstConn().streamManager->getStream(streamId)); ASSERT_TRUE(stream->lossBuffer.empty()); auto& pendingLossStreams = client->getConn().streamManager->lossStreams(); - auto it = - std::find(pendingLossStreams.begin(), pendingLossStreams.end(), streamId); - ASSERT_TRUE(it == pendingLossStreams.end()); + ASSERT_TRUE(pendingLossStreams.count(streamId) == 0); } TEST_F(QuicClientTransportAfterStartTest, SendResetAfterEom) { diff --git a/quic/state/QuicPriorityQueue.h b/quic/state/QuicPriorityQueue.h new file mode 100644 index 000000000..64b89d42f --- /dev/null +++ b/quic/state/QuicPriorityQueue.h @@ -0,0 +1,179 @@ +#pragma once + +#include +#include +#include + +#include + +namespace quic { + +constexpr uint8_t kDefaultPriorityLevels = kDefaultMaxPriority + 1; + +/** + * Priority is expressed as a level [0,7] and an incremental flag. + */ +struct Priority { + uint8_t level : 3; + bool incremental : 1; + + Priority(uint8_t l, bool i) : level(l), incremental(i) {} +}; + +/** + * Default priority, urgency = 3, incremental = true + * Note this is different from the priority draft where default incremental = 0 + */ +const Priority kDefaultPriority(3, true); + +/** + * Priority queue for Quic streams. It represents each level/incremental bucket + * as an entry in a vector. Each entry holds a set of streams (sorted by + * stream ID, ascending). There is also a map of all streams currently in the + * queue, mapping from ID -> bucket index. The interface is almost identical + * to std::set (insert, erase, count, clear), except that insert takes an + * optional priority parameter. + */ +struct PriorityQueue { + struct Level { + std::set streams; + mutable decltype(streams)::const_iterator next{streams.end()}; + bool incremental{false}; + }; + std::vector levels; + std::map writableStreams; + + PriorityQueue() { + levels.resize(kDefaultPriorityLevels * 2); + for (size_t index = 1; index < levels.size(); index += 2) { + levels[index].incremental = true; + } + } + + static size_t priority2index(Priority pri, size_t max) { + auto index = pri.level * 2 + uint8_t(pri.incremental); + DCHECK_LT(index, max) << "Logic error: level=" << pri.level + << " incremental=" << pri.incremental; + return index; + } + + /** + * Update stream priority if the stream already exist in the PriorityQueue + * + * This is a no-op if the stream doesn't exist, or its priority is the same as + * the input. + */ + void updateIfExist(StreamId id, Priority priority = kDefaultPriority) { + auto iter = writableStreams.find(id); + if (iter == writableStreams.end()) { + return; + } + auto index = priority2index(priority, levels.size()); + if (iter->second == index) { + // no need to update + return; + } + eraseFromLevel(iter->second, iter->first); + iter->second = index; + auto res = levels[index].streams.insert(id); + DCHECK(res.second) << "PriorityQueue inconsistentent: stream=" << id + << " already at level=" << index; + } + + void insertOrUpdate(StreamId id, Priority pri = kDefaultPriority) { + auto it = writableStreams.find(id); + auto index = priority2index(pri, levels.size()); + if (it != writableStreams.end()) { + if (it->second == index) { + // No op, this stream is already inserted at the correct priority level + return; + } + VLOG(4) << "Updating priority of stream=" << id << " from " << it->second + << " to " << index; + // Meh, too hard. Just erase it and start over. + eraseFromLevel(it->second, it->first); + it->second = index; + } else { + writableStreams.emplace(id, index); + } + auto res = levels[index].streams.insert(id); + DCHECK(res.second) << "PriorityQueue inconsistentent: stream=" << id + << " already at level=" << index; + } + + void erase(StreamId id) { + auto it = find(id); + erase(it); + } + + // Only used for testing + void clear() { + writableStreams.clear(); + for (auto& level : levels) { + level.streams.clear(); + level.next = level.streams.end(); + } + } + + FOLLY_NODISCARD size_t count(StreamId id) const { + return writableStreams.count(id); + } + + FOLLY_NODISCARD bool empty() const { + return writableStreams.empty(); + } + + // Testing helper to override scheduling state + void setNextScheduledStream(StreamId id) { + auto it = writableStreams.find(id); + CHECK(it != writableStreams.end()); + auto& level = levels[it->second]; + auto streamIt = level.streams.find(id); + CHECK(streamIt != level.streams.end()); + level.next = streamIt; + } + + // Only used for testing + FOLLY_NODISCARD StreamId + getNextScheduledStream(Priority pri = kDefaultPriority) const { + auto& level = levels[priority2index(pri, levels.size())]; + if (level.next == level.streams.end()) { + CHECK(!level.streams.empty()); + return *level.streams.begin(); + } + return *level.next; + } + + private: + using WSIterator = decltype(writableStreams)::iterator; + + WSIterator find(StreamId id) { + return writableStreams.find(id); + } + + void eraseFromLevel(size_t levelIndex, StreamId id) { + auto& level = levels[levelIndex]; + auto streamIt = level.streams.find(id); + if (streamIt != level.streams.end()) { + if (streamIt == level.next) { + level.next = level.streams.erase(streamIt); + } else { + level.streams.erase(streamIt); + } + } else { + LOG(DFATAL) << "Stream=" << levelIndex + << " not found in PriorityQueue level=" << id; + } + } + + // Helper function to erase an iter from writableStream and its corresponding + // item from levels. + void erase(WSIterator it) { + if (it != writableStreams.end()) { + eraseFromLevel(it->second, it->first); + writableStreams.erase(it); + } + } +}; + +} // namespace quic diff --git a/quic/state/QuicStreamManager.cpp b/quic/state/QuicStreamManager.cpp index a4fa3c10c..f70084940 100644 --- a/quic/state/QuicStreamManager.cpp +++ b/quic/state/QuicStreamManager.cpp @@ -225,6 +225,20 @@ bool QuicStreamManager::consumeMaxLocalUnidirectionalStreamIdIncreased() { return res; } +void QuicStreamManager::setStreamPriority( + StreamId id, + PriorityLevel level, + bool incremental) { + auto stream = findStream(id); + if (stream) { + stream->priority = Priority(level, incremental); + // If this stream is already in the writable or loss queus, update the + // priority there. + writableStreams_.updateIfExist(id, stream->priority); + lossStreams_.updateIfExist(id, stream->priority); + } +} + void QuicStreamManager::refreshTransportSettings( const TransportSettings& settings) { transportSettings_ = &settings; @@ -480,9 +494,11 @@ void QuicStreamManager::removeClosedStream(StreamId streamId) { void QuicStreamManager::updateLossStreams(QuicStreamState& stream) { if (stream.lossBuffer.empty()) { + // No-op if not present lossStreams_.erase(stream.id); } else { - lossStreams_.emplace(stream.id); + // No-op if already inserted + lossStreams_.insertOrUpdate(stream.id, stream.priority); } } diff --git a/quic/state/QuicStreamManager.h b/quic/state/QuicStreamManager.h index ac0f34cef..41f024cdb 100644 --- a/quic/state/QuicStreamManager.h +++ b/quic/state/QuicStreamManager.h @@ -195,18 +195,24 @@ class QuicStreamManager { } } - const auto& lossStreams() const { + auto& lossStreams() const { return lossStreams_; } + // This should be only used in testing code. void addLoss(StreamId streamId) { - lossStreams_.insert(streamId); + auto stream = findStream(streamId); + if (stream) { + lossStreams_.insertOrUpdate(streamId, stream->priority); + } } bool hasLoss() const { return !lossStreams_.empty(); } + void setStreamPriority(StreamId id, PriorityLevel level, bool incremental); + // TODO figure out a better interface here. /* * Returns a mutable reference to the container holding the writable stream @@ -247,7 +253,7 @@ class QuicStreamManager { if (stream.isControl) { writableControlStreams_.insert(stream.id); } else { - writableStreams_.insert(stream.id); + writableStreams_.insertOrUpdate(stream.id, stream.priority); } } @@ -880,7 +886,7 @@ class QuicStreamManager { folly::F14FastSet flowControlUpdated_; // Data structure to keep track of stream that have detected lost data - folly::F14FastSet lossStreams_; + PriorityQueue lossStreams_; // Set of streams that have pending reads folly::F14FastSet readableStreams_; @@ -889,7 +895,7 @@ class QuicStreamManager { folly::F14FastSet peekableStreams_; // Set of !control streams that have writable data - std::set writableStreams_; + PriorityQueue writableStreams_; // Set of control streams that have writable data std::set writableControlStreams_; diff --git a/quic/state/StateData.h b/quic/state/StateData.h index d4b95be11..1c3769fdf 100644 --- a/quic/state/StateData.h +++ b/quic/state/StateData.h @@ -721,7 +721,6 @@ struct QuicConnectionStateBase : public folly::DelayedDestruction { uint64_t peerMaxUdpPayloadSize{kDefaultUDPSendPacketLen}; struct PacketSchedulingState { - StreamId nextScheduledStream{0}; StreamId nextScheduledControlStream{0}; }; diff --git a/quic/state/StreamData.h b/quic/state/StreamData.h index 57fc431d3..20cf11afd 100644 --- a/quic/state/StreamData.h +++ b/quic/state/StreamData.h @@ -12,6 +12,7 @@ #include #include #include +#include namespace quic { @@ -209,6 +210,8 @@ struct QuicStreamState : public QuicStreamLike { // lastHolbTime indicates whether the stream is HOL blocked at the moment. uint32_t holbCount{0}; + Priority priority{kDefaultPriority}; + // Returns true if both send and receive state machines are in a terminal // state bool inTerminalStates() const { diff --git a/quic/state/test/CMakeLists.txt b/quic/state/test/CMakeLists.txt index 7ea375bab..d2fad8b3b 100644 --- a/quic/state/test/CMakeLists.txt +++ b/quic/state/test/CMakeLists.txt @@ -27,6 +27,8 @@ quic_add_test(TARGET QuicStreamFunctionsTest ) quic_add_test(TARGET QuicStreamManagerTest + SOURCES + QuicPriorityQueueTest.cpp QuicStreamManagerTest.cpp DEPENDS mvfst_client diff --git a/quic/state/test/QuicPriorityQueueTest.cpp b/quic/state/test/QuicPriorityQueueTest.cpp new file mode 100644 index 000000000..d2d93a208 --- /dev/null +++ b/quic/state/test/QuicPriorityQueueTest.cpp @@ -0,0 +1,78 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + * + */ + +#include +#include + +#include + +namespace quic::test { + +class QuicPriorityQueueTest : public testing::Test { + public: + PriorityQueue queue_; +}; + +TEST_F(QuicPriorityQueueTest, TestBasic) { + EXPECT_TRUE(queue_.empty()); + EXPECT_EQ(queue_.count(0), 0); + + StreamId id = 0; + // Insert two streams for every level and incremental + for (uint8_t i = 0; i < queue_.levels.size(); i++) { + queue_.insertOrUpdate(id++, Priority(i / 2, i & 0x1)); + queue_.setNextScheduledStream(id - 1); + queue_.insertOrUpdate(id++, Priority(i / 2, i & 0x1)); + } + + for (int16_t i = id - 1; i >= 0; i--) { + EXPECT_EQ(queue_.count(i), 1); + } + + for (auto& level : queue_.levels) { + EXPECT_EQ(level.streams.size(), 2); + } + + for (uint8_t i = 0; i < queue_.levels.size(); i++) { + id = i * 2; + EXPECT_EQ(queue_.getNextScheduledStream(Priority(i / 2, i & 0x1)), id); + queue_.erase(id); + EXPECT_EQ(queue_.count(id), 0); + EXPECT_EQ(queue_.getNextScheduledStream(Priority(i / 2, i & 0x1)), id + 1); + } + + queue_.clear(); + EXPECT_TRUE(queue_.empty()); +} + +TEST_F(QuicPriorityQueueTest, TestUpdate) { + queue_.insertOrUpdate(0, Priority(0, false)); + EXPECT_EQ(queue_.count(0), 1); + + // Update no-op + queue_.insertOrUpdate(0, Priority(0, false)); + EXPECT_EQ(queue_.count(0), 1); + + // Update move to different bucket + queue_.insertOrUpdate(0, Priority(0, true)); + EXPECT_EQ(queue_.count(0), 1); + + EXPECT_EQ(queue_.getNextScheduledStream(Priority(0, true)), 0); +} + +TEST_F(QuicPriorityQueueTest, UpdateIfExist) { + queue_.updateIfExist(0); + EXPECT_EQ(0, queue_.count(0)); + + queue_.insertOrUpdate(0, Priority(0, false)); + EXPECT_EQ(queue_.getNextScheduledStream(Priority(0, false)), 0); + queue_.updateIfExist(0, Priority(1, true)); + EXPECT_EQ(queue_.getNextScheduledStream(Priority(1, true)), 0); +} + +} // namespace quic::test