diff --git a/quic/api/QuicPacketScheduler.cpp b/quic/api/QuicPacketScheduler.cpp index 38b099d79..37485d5cd 100644 --- a/quic/api/QuicPacketScheduler.cpp +++ b/quic/api/QuicPacketScheduler.cpp @@ -232,11 +232,12 @@ bool RetransmissionScheduler::hasPendingData() const { StreamFrameScheduler::StreamFrameScheduler(QuicConnectionStateBase& conn) : conn_(conn) {} -void StreamFrameScheduler::writeStreams(PacketBuilderInterface& builder) { - uint64_t connWritableBytes = getSendConnFlowControlBytesWire(conn_); - MiddleStartingIterationWrapper wrapper( - conn_.streamManager->writableStreams(), - conn_.schedulingState.nextScheduledStream); +StreamId StreamFrameScheduler::writeStreamsHelper( + PacketBuilderInterface& builder, + const std::set& writableStreams, + StreamId nextScheduledStream, + uint64_t& connWritableBytes) { + MiddleStartingIterationWrapper wrapper(writableStreams, nextScheduledStream); auto writableStreamItr = wrapper.cbegin(); // This will write the stream frames in a round robin fashion ordered by // stream id. The iterator will wrap around the collection at the end, and we @@ -250,7 +251,36 @@ void StreamFrameScheduler::writeStreams(PacketBuilderInterface& builder) { break; } } - conn_.schedulingState.nextScheduledStream = *writableStreamItr; + return *writableStreamItr; +} + +void StreamFrameScheduler::writeStreams(PacketBuilderInterface& builder) { + DCHECK(conn_.streamManager->hasWritable()); + uint64_t connWritableBytes = getSendConnFlowControlBytesWire(conn_); + if (connWritableBytes == 0) { + return; + } + // Write the control streams first as a naive binary priority mechanism. + const auto& writableControlStreams = + conn_.streamManager->writableControlStreams(); + if (!writableControlStreams.empty()) { + conn_.schedulingState.nextScheduledControlStream = writeStreamsHelper( + builder, + writableControlStreams, + conn_.schedulingState.nextScheduledControlStream, + connWritableBytes); + } + if (connWritableBytes == 0) { + return; + } + const auto& writableStreams = conn_.streamManager->writableStreams(); + if (!writableStreams.empty()) { + conn_.schedulingState.nextScheduledStream = writeStreamsHelper( + builder, + writableStreams, + conn_.schedulingState.nextScheduledStream, + connWritableBytes); + } } // namespace quic bool StreamFrameScheduler::hasPendingData() const { diff --git a/quic/api/QuicPacketScheduler.h b/quic/api/QuicPacketScheduler.h index 59a6b2d3b..d955f4759 100644 --- a/quic/api/QuicPacketScheduler.h +++ b/quic/api/QuicPacketScheduler.h @@ -159,6 +159,12 @@ class StreamFrameScheduler { const MapType::key_type& start_; }; + StreamId writeStreamsHelper( + PacketBuilderInterface& builder, + const std::set& writableStreams, + StreamId nextScheduledStream, + uint64_t& connWritableBytes); + using WritableStreamItr = MiddleStartingIterationWrapper::MiddleStartingIterator; diff --git a/quic/api/test/QuicPacketSchedulerTest.cpp b/quic/api/test/QuicPacketSchedulerTest.cpp index ba433893c..1f0d0cccb 100644 --- a/quic/api/test/QuicPacketSchedulerTest.cpp +++ b/quic/api/test/QuicPacketSchedulerTest.cpp @@ -762,6 +762,70 @@ TEST_F(QuicPacketSchedulerTest, StreamFrameSchedulerRoundRobin) { EXPECT_EQ(*frames[2].asWriteStreamFrame(), f3); } +TEST_F(QuicPacketSchedulerTest, StreamFrameSchedulerRoundRobinControl) { + QuicClientConnectionState conn; + conn.streamManager->setMaxLocalBidirectionalStreams(10); + conn.flowControlState.peerAdvertisedMaxOffset = 100000; + conn.flowControlState.peerAdvertisedInitialMaxStreamOffsetBidiRemote = 100000; + auto connId = getTestConnectionId(); + StreamFrameScheduler scheduler(conn); + ShortHeader shortHeader1( + ProtectionType::KeyPhaseZero, + connId, + getNextPacketNum(conn, PacketNumberSpace::AppData)); + RegularQuicPacketBuilder builder1( + conn.udpSendPacketLen, + std::move(shortHeader1), + conn.ackStates.appDataAckState.largestAckedByPeer); + auto stream1 = conn.streamManager->createNextBidirectionalStream().value(); + auto stream2 = conn.streamManager->createNextBidirectionalStream().value(); + auto stream3 = conn.streamManager->createNextBidirectionalStream().value(); + auto stream4 = conn.streamManager->createNextBidirectionalStream().value(); + conn.streamManager->setStreamAsControl(*stream2); + conn.streamManager->setStreamAsControl(*stream4); + auto largeBuf = folly::IOBuf::createChain(conn.udpSendPacketLen * 2, 4096); + auto curBuf = largeBuf.get(); + do { + curBuf->append(curBuf->capacity()); + curBuf = curBuf->next(); + } while (curBuf != largeBuf.get()); + auto chainLen = largeBuf->computeChainDataLength(); + writeDataToQuicStream(*stream1, std::move(largeBuf), false); + writeDataToQuicStream(*stream2, folly::IOBuf::copyBuffer("some data"), false); + writeDataToQuicStream(*stream3, folly::IOBuf::copyBuffer("some data"), false); + writeDataToQuicStream(*stream4, folly::IOBuf::copyBuffer("some data"), false); + // Force the wraparound initially. + conn.schedulingState.nextScheduledStream = stream4->id + 8; + scheduler.writeStreams(builder1); + EXPECT_EQ(conn.schedulingState.nextScheduledStream, stream3->id); + EXPECT_EQ(conn.schedulingState.nextScheduledControlStream, stream2->id); + + // Should write frames for stream2, stream4, followed by stream 3 then 1. + MockQuicPacketBuilder builder2; + EXPECT_CALL(builder2, remainingSpaceInPkt()).WillRepeatedly(Return(4096)); + EXPECT_CALL(builder2, appendFrame(_)).WillRepeatedly(Invoke([&](auto f) { + builder2.frames_.push_back(f); + })); + scheduler.writeStreams(builder2); + auto& frames = builder2.frames_; + ASSERT_EQ(frames.size(), 4); + WriteStreamFrame f1(stream2->id, 0, 9, false); + WriteStreamFrame f2(stream4->id, 0, 9, false); + WriteStreamFrame f3(stream3->id, 0, 9, false); + WriteStreamFrame f4(stream1->id, 0, chainLen, false); + ASSERT_TRUE(frames[0].asWriteStreamFrame()); + EXPECT_EQ(*frames[0].asWriteStreamFrame(), f1); + ASSERT_TRUE(frames[1].asWriteStreamFrame()); + EXPECT_EQ(*frames[1].asWriteStreamFrame(), f2); + ASSERT_TRUE(frames[2].asWriteStreamFrame()); + EXPECT_EQ(*frames[2].asWriteStreamFrame(), f3); + ASSERT_TRUE(frames[3].asWriteStreamFrame()); + EXPECT_EQ(*frames[3].asWriteStreamFrame(), f4); + + EXPECT_EQ(conn.schedulingState.nextScheduledStream, stream3->id); + EXPECT_EQ(conn.schedulingState.nextScheduledControlStream, stream2->id); +} + TEST_F(QuicPacketSchedulerTest, StreamFrameSchedulerOneStream) { QuicClientConnectionState conn; conn.streamManager->setMaxLocalBidirectionalStreams(10); @@ -808,7 +872,7 @@ TEST_F(QuicPacketSchedulerTest, StreamFrameSchedulerRemoveOne) { // Manually remove a stream and set the next scheduled to that stream. builder.frames_.clear(); - conn.streamManager->removeWritable(stream2->id); + conn.streamManager->removeWritable(*stream2); conn.schedulingState.nextScheduledStream = stream2->id; scheduler.writeStreams(builder); ASSERT_EQ(builder.frames_.size(), 1); diff --git a/quic/api/test/QuicTransportFunctionsTest.cpp b/quic/api/test/QuicTransportFunctionsTest.cpp index a70f844f2..f0f65651b 100644 --- a/quic/api/test/QuicTransportFunctionsTest.cpp +++ b/quic/api/test/QuicTransportFunctionsTest.cpp @@ -1777,7 +1777,8 @@ TEST_F(QuicTransportFunctionsTest, HasAppDataToWrite) { auto conn = createConn(); conn->flowControlState.peerAdvertisedMaxOffset = 1000; conn->flowControlState.sumCurWriteOffset = 800; - conn->streamManager->addWritable(0); + QuicStreamState stream(0, *conn); + conn->streamManager->addWritable(stream); EXPECT_EQ(WriteDataReason::NO_WRITE, hasNonAckDataToWrite(*conn)); conn->oneRttWriteCipher = test::createNoOpAead(); diff --git a/quic/state/QuicStreamManager.cpp b/quic/state/QuicStreamManager.cpp index 3d8032fd1..50fe644ff 100644 --- a/quic/state/QuicStreamManager.cpp +++ b/quic/state/QuicStreamManager.cpp @@ -364,6 +364,7 @@ void QuicStreamManager::removeClosedStream(StreamId streamId) { readableStreams_.erase(streamId); peekableStreams_.erase(streamId); writableStreams_.erase(streamId); + writableControlStreams_.erase(streamId); blockedStreams_.erase(streamId); deliverableStreams_.erase(streamId); windowUpdates_.erase(streamId); @@ -463,9 +464,9 @@ void QuicStreamManager::updateReadableStreams(QuicStreamState& stream) { void QuicStreamManager::updateWritableStreams(QuicStreamState& stream) { if (stream.hasWritableData() && !stream.streamWriteError.hasValue()) { - stream.conn.streamManager->addWritable(stream.id); + stream.conn.streamManager->addWritable(stream); } else { - stream.conn.streamManager->removeWritable(stream.id); + stream.conn.streamManager->removeWritable(stream); } } diff --git a/quic/state/QuicStreamManager.h b/quic/state/QuicStreamManager.h index dc3fb64c6..e6d9c665d 100644 --- a/quic/state/QuicStreamManager.h +++ b/quic/state/QuicStreamManager.h @@ -216,32 +216,50 @@ class QuicStreamManager { return writableStreams_; } + // TODO figure out a better interface here. + /* + * Returns a mutable reference to the container holding the writable stream + * IDs. + */ + auto& writableControlStreams() { + return writableControlStreams_; + } + /* * Returns if there are any writable streams. */ bool hasWritable() const { - return !writableStreams_.empty(); + return !writableStreams_.empty() || !writableControlStreams_.empty(); } /* * Returns if the current writable streams contains the given id. */ bool writableContains(StreamId streamId) const { - return writableStreams_.count(streamId) > 0; + return writableStreams_.count(streamId) > 0 || + writableControlStreams_.count(streamId) > 0; } /* * Add a writable stream id. */ - void addWritable(StreamId streamId) { - writableStreams_.insert(streamId); + void addWritable(const QuicStreamState& stream) { + if (stream.isControl) { + writableControlStreams_.insert(stream.id); + } else { + writableStreams_.insert(stream.id); + } } /* * Remove a writable stream id. */ - void removeWritable(StreamId streamId) { - writableStreams_.erase(streamId); + void removeWritable(const QuicStreamState& stream) { + if (stream.isControl) { + writableControlStreams_.erase(stream.id); + } else { + writableStreams_.erase(stream.id); + } } /* @@ -249,6 +267,7 @@ class QuicStreamManager { */ void clearWritable() { writableStreams_.clear(); + writableControlStreams_.clear(); } /* @@ -737,9 +756,12 @@ class QuicStreamManager { // List of streams that have pending peeks std::set peekableStreams_; - // List of streams that have writable data + // List of !control streams that have writable data std::set writableStreams_; + // List of control streams that have writable data + std::set writableControlStreams_; + // List of streams that were blocked std::unordered_map blockedStreams_; diff --git a/quic/state/StateData.h b/quic/state/StateData.h index ce62f1ff1..b2b8d633b 100644 --- a/quic/state/StateData.h +++ b/quic/state/StateData.h @@ -676,6 +676,7 @@ struct QuicConnectionStateBase { struct PacketSchedulingState { StreamId nextScheduledStream{0}; + StreamId nextScheduledControlStream{0}; }; PacketSchedulingState schedulingState; diff --git a/quic/state/test/QuicStreamFunctionsTest.cpp b/quic/state/test/QuicStreamFunctionsTest.cpp index 73b82cce0..5281315fd 100644 --- a/quic/state/test/QuicStreamFunctionsTest.cpp +++ b/quic/state/test/QuicStreamFunctionsTest.cpp @@ -1511,7 +1511,7 @@ TEST_F(QuicStreamFunctionsTest, RemovedClosedState) { auto streamId = stream->id; conn.streamManager->readableStreams().emplace(streamId); conn.streamManager->peekableStreams().emplace(streamId); - conn.streamManager->addWritable(streamId); + conn.streamManager->addWritable(*stream); conn.streamManager->queueBlocked(streamId, 0); conn.streamManager->addDeliverable(streamId); conn.streamManager->addLoss(streamId);