1
0
mirror of https://github.com/facebookincubator/mvfst.git synced 2025-08-08 09:42:06 +03:00

Add stream prioritization

Summary: Adds a top level API to set stream priorities, mirroring what is currently proposed in httpbis.  For now, default all streams to the highest urgency, round-robin, which mirrors the current behavior in mvfst.

Reviewed By: mjoras

Differential Revision: D20318260

fbshipit-source-id: eec625e2ab641f7fa6266517776a2ca9798e8f89
This commit is contained in:
Yang Chi
2020-11-10 20:06:48 -08:00
committed by Facebook GitHub Bot
parent e2269b49f6
commit 000a0e23ca
17 changed files with 642 additions and 141 deletions

View File

@@ -595,4 +595,8 @@ enum class DataPathType : uint8_t {
ContinuousMemory = 1,
};
// Stream priority level, can only be in [0, 7]
using PriorityLevel = uint8_t;
constexpr uint8_t kDefaultMaxPriority = 7;
} // namespace quic

View File

@@ -9,6 +9,109 @@
#include <quic/api/QuicPacketScheduler.h>
#include <quic/flowcontrol/QuicFlowController.h>
namespace {
using namespace quic;
/**
* A helper iterator adaptor class that starts iteration of streams from a
* specific stream id.
*/
class MiddleStartingIterationWrapper {
public:
using MapType = std::set<StreamId>;
class MiddleStartingIterator
: public boost::iterator_facade<
MiddleStartingIterator,
const MiddleStartingIterationWrapper::MapType::value_type,
boost::forward_traversal_tag> {
friend class boost::iterator_core_access;
public:
using MapType = MiddleStartingIterationWrapper::MapType;
MiddleStartingIterator() = delete;
MiddleStartingIterator(
const MapType* streams,
const MapType::key_type& start)
: streams_(streams) {
itr_ = streams_->lower_bound(start);
checkForWrapAround();
// We don't want to mark it as wrapped around initially, instead just
// act as if start was the first element.
wrappedAround_ = false;
}
MiddleStartingIterator(const MapType* streams, MapType::const_iterator itr)
: streams_(streams), itr_(itr) {
checkForWrapAround();
// We don't want to mark it as wrapped around initially, instead just
// act as if start was the first element.
wrappedAround_ = false;
}
FOLLY_NODISCARD const MapType::value_type& dereference() const {
return *itr_;
}
FOLLY_NODISCARD MapType::const_iterator rawIterator() const {
return itr_;
}
FOLLY_NODISCARD bool equal(const MiddleStartingIterator& other) const {
return wrappedAround_ == other.wrappedAround_ && itr_ == other.itr_;
}
void increment() {
++itr_;
checkForWrapAround();
}
void checkForWrapAround() {
if (itr_ == streams_->cend()) {
wrappedAround_ = true;
itr_ = streams_->cbegin();
}
}
private:
friend class MiddleStartingIterationWrapper;
bool wrappedAround_{false};
const MapType* streams_{nullptr};
MapType::const_iterator itr_;
};
MiddleStartingIterationWrapper(
const MapType& streams,
const MapType::key_type& start)
: streams_(streams), start_(&streams_, start) {}
MiddleStartingIterationWrapper(
const MapType& streams,
const MapType::const_iterator& start)
: streams_(streams), start_(&streams_, start) {}
FOLLY_NODISCARD MiddleStartingIterator cbegin() const {
return start_;
}
FOLLY_NODISCARD MiddleStartingIterator cend() const {
MiddleStartingIterator itr(start_);
itr.wrappedAround_ = true;
return itr;
}
private:
const MapType& streams_;
const MiddleStartingIterator start_;
};
using WritableStreamItr =
MiddleStartingIterationWrapper::MiddleStartingIterator;
} // namespace
namespace quic {
bool hasAcksToSchedule(const AckState& ackState) {
@@ -225,11 +328,13 @@ RetransmissionScheduler::RetransmissionScheduler(
const QuicConnectionStateBase& conn)
: conn_(conn) {}
void RetransmissionScheduler::writeRetransmissionStreams(
PacketBuilderInterface& builder) {
for (auto streamId : conn_.streamManager->lossStreams()) {
auto stream = conn_.streamManager->findStream(streamId);
// Return true if this stream wrote some data
bool RetransmissionScheduler::writeStreamLossBuffers(
PacketBuilderInterface& builder,
StreamId id) {
auto stream = conn_.streamManager->findStream(id);
CHECK(stream);
bool wroteStreamFrame = false;
for (auto buffer = stream->lossBuffer.cbegin();
buffer != stream->lossBuffer.cend();
++buffer) {
@@ -243,13 +348,50 @@ void RetransmissionScheduler::writeRetransmissionStreams(
buffer->eof,
folly::none /* skipLenHint */);
if (dataLen) {
wroteStreamFrame = true;
writeStreamFrameData(builder, buffer->data, *dataLen);
VLOG(4) << "Wrote retransmitted stream=" << stream->id
<< " offset=" << buffer->offset << " bytes=" << *dataLen
<< " fin=" << (buffer->eof && *dataLen == bufferLen) << " "
<< conn_;
} else {
return;
// Either we filled the packet or ran out of data for this stream (EOF?)
break;
}
}
return wroteStreamFrame;
}
void RetransmissionScheduler::writeRetransmissionStreams(
PacketBuilderInterface& builder) {
auto& lossStreams = conn_.streamManager->lossStreams();
for (size_t index = 0;
index < lossStreams.levels.size() && builder.remainingSpaceInPkt() > 0;
index++) {
auto& level = lossStreams.levels[index];
if (level.streams.empty()) {
// No data here, keep going
continue;
}
if (level.incremental) {
// Round robin the streams at this level
MiddleStartingIterationWrapper wrapper(level.streams, level.next);
auto writableStreamItr = wrapper.cbegin();
while (writableStreamItr != wrapper.cend()) {
if (writeStreamLossBuffers(builder, *writableStreamItr)) {
writableStreamItr++;
} else {
// We didn't write anything
break;
}
}
level.next = writableStreamItr.rawIterator();
} else {
// walk the sequential streams in order until we run out of space
for (auto stream : level.streams) {
if (!writeStreamLossBuffers(builder, stream)) {
break;
}
}
}
}
@@ -288,6 +430,54 @@ StreamId StreamFrameScheduler::writeStreamsHelper(
return *writableStreamItr;
}
void StreamFrameScheduler::writeStreamsHelper(
PacketBuilderInterface& builder,
PriorityQueue& writableStreams,
uint64_t& connWritableBytes,
bool streamPerPacket) {
// Fill a packet with non-control stream data, in priority order
for (size_t index = 0; index < writableStreams.levels.size() &&
builder.remainingSpaceInPkt() > 0 && connWritableBytes > 0;
index++) {
PriorityQueue::Level& level = writableStreams.levels[index];
if (level.streams.empty()) {
// No data here, keep going
continue;
}
if (level.incremental) {
// Round robin the streams at this level
MiddleStartingIterationWrapper wrapper(level.streams, level.next);
auto writableStreamItr = wrapper.cbegin();
while (writableStreamItr != wrapper.cend() && connWritableBytes > 0) {
if (writeNextStreamFrame(
builder, *writableStreamItr, connWritableBytes)) {
writableStreamItr++;
if (streamPerPacket) {
level.next = writableStreamItr.rawIterator();
return;
}
} else {
// Either we filled the packet, ran out of flow control,
// or ran out of data at this level
break;
}
}
level.next = writableStreamItr.rawIterator();
} else {
// walk the sequential streams in order until we run out of space
for (auto streamIt = level.streams.begin();
streamIt != level.streams.end() && connWritableBytes > 0;
++streamIt) {
if (!writeNextStreamFrame(builder, *streamIt, connWritableBytes)) {
break;
} else if (streamPerPacket) {
return;
}
}
}
}
}
void StreamFrameScheduler::writeStreams(PacketBuilderInterface& builder) {
DCHECK(conn_.streamManager->hasWritable());
uint64_t connWritableBytes = getSendConnFlowControlBytesWire(conn_);
@@ -308,12 +498,11 @@ void StreamFrameScheduler::writeStreams(PacketBuilderInterface& builder) {
if (connWritableBytes == 0) {
return;
}
const auto& writableStreams = conn_.streamManager->writableStreams();
auto& writableStreams = conn_.streamManager->writableStreams();
if (!writableStreams.empty()) {
conn_.schedulingState.nextScheduledStream = writeStreamsHelper(
writeStreamsHelper(
builder,
writableStreams,
conn_.schedulingState.nextScheduledStream,
connWritableBytes,
conn_.transportSettings.streamFramePerPacket);
}

View File

@@ -68,6 +68,8 @@ class RetransmissionScheduler {
bool hasPendingData() const;
private:
bool writeStreamLossBuffers(PacketBuilderInterface& builder, StreamId id);
const QuicConnectionStateBase& conn_;
};
@@ -84,84 +86,6 @@ class StreamFrameScheduler {
bool hasPendingData() const;
private:
/**
* A helper iterator adaptor class that starts iteration of streams from a
* specific stream id.
*/
class MiddleStartingIterationWrapper {
public:
using MapType = std::set<StreamId>;
class MiddleStartingIterator
: public boost::iterator_facade<
MiddleStartingIterator,
const MiddleStartingIterationWrapper::MapType::value_type,
boost::forward_traversal_tag> {
friend class boost::iterator_core_access;
public:
using MapType = MiddleStartingIterationWrapper::MapType;
MiddleStartingIterator() = default;
MiddleStartingIterator(
const MapType* streams,
const MapType::key_type& start)
: streams_(streams) {
itr_ = streams_->lower_bound(start);
checkForWrapAround();
// We don't want to mark it as wrapped around initially, instead just
// act as if start was the first element.
wrappedAround_ = false;
}
const MapType::value_type& dereference() const {
return *itr_;
}
bool equal(const MiddleStartingIterator& other) const {
return wrappedAround_ == other.wrappedAround_ && itr_ == other.itr_;
}
void increment() {
++itr_;
checkForWrapAround();
}
void checkForWrapAround() {
if (itr_ == streams_->cend()) {
wrappedAround_ = true;
itr_ = streams_->cbegin();
}
}
private:
friend class MiddleStartingIterationWrapper;
bool wrappedAround_{false};
const MapType* streams_{nullptr};
MapType::const_iterator itr_;
};
MiddleStartingIterationWrapper(
const MapType& streams,
const MapType::key_type& start)
: streams_(streams), start_(start) {}
MiddleStartingIterator cbegin() const {
return MiddleStartingIterator(&streams_, start_);
}
MiddleStartingIterator cend() const {
MiddleStartingIterator itr(&streams_, start_);
itr.wrappedAround_ = true;
return itr;
}
private:
const MapType& streams_;
const MapType::key_type& start_;
};
StreamId writeStreamsHelper(
PacketBuilderInterface& builder,
const std::set<StreamId>& writableStreams,
@@ -169,8 +93,11 @@ class StreamFrameScheduler {
uint64_t& connWritableBytes,
bool streamPerPacket);
using WritableStreamItr =
MiddleStartingIterationWrapper::MiddleStartingIterator;
void writeStreamsHelper(
PacketBuilderInterface& builder,
PriorityQueue& writableStreams,
uint64_t& connWritableBytes,
bool streamPerPacket);
/**
* Helper function to write either stream data if stream is not flow

View File

@@ -445,6 +445,13 @@ class QuicSocket {
*/
virtual bool isPartiallyReliableTransport() const = 0;
/**
* Set stream priority.
* level: can only be in [0, 7].
*/
virtual folly::Expected<folly::Unit, LocalErrorCode>
setStreamPriority(StreamId id, PriorityLevel level, bool incremental) = 0;
/**
* ===== Read API ====
*/

View File

@@ -2907,6 +2907,27 @@ bool QuicTransportBase::isPartiallyReliableTransport() const {
return conn_->partialReliabilityEnabled;
}
folly::Expected<folly::Unit, LocalErrorCode>
QuicTransportBase::setStreamPriority(
StreamId id,
PriorityLevel level,
bool incremental) {
if (closeState_ != CloseState::OPEN) {
return folly::makeUnexpected(LocalErrorCode::CONNECTION_CLOSED);
}
if (level > kDefaultMaxPriority) {
return folly::makeUnexpected(LocalErrorCode::INVALID_OPERATION);
}
if (!conn_->streamManager->streamExists(id)) {
// It's not an error to try to prioritize a non-existent stream.
return folly::unit;
}
// It's not an error to prioritize a stream after it's sent its FIN - this
// can reprioritize retransmissions.
conn_->streamManager->setStreamPriority(id, level, incremental);
return folly::unit;
}
void QuicTransportBase::setCongestionControl(CongestionControlType type) {
DCHECK(conn_);
if (!conn_->congestionController ||

View File

@@ -317,6 +317,11 @@ class QuicTransportBase : public QuicSocket {
bool isPartiallyReliableTransport() const override;
folly::Expected<folly::Unit, LocalErrorCode> setStreamPriority(
StreamId id,
PriorityLevel level,
bool incremental) override;
/**
* Invoke onCanceled on all the delivery callbacks registered for streamId.
*/

View File

@@ -96,6 +96,9 @@ class MockQuicSocket : public QuicSocket {
SharedBuf));
MOCK_CONST_METHOD0(isKnobSupported, bool());
MOCK_CONST_METHOD0(isPartiallyReliableTransport, bool());
MOCK_METHOD3(
setStreamPriority,
folly::Expected<folly::Unit, LocalErrorCode>(StreamId, uint8_t, bool));
MOCK_METHOD3(
setReadCallback,
folly::Expected<folly::Unit, LocalErrorCode>(

View File

@@ -1068,7 +1068,7 @@ TEST_F(QuicPacketSchedulerTest, StreamFrameSchedulerAllFit) {
folly::IOBuf::copyBuffer("some data"),
false);
scheduler.writeStreams(builder);
EXPECT_EQ(conn.schedulingState.nextScheduledStream, 0);
EXPECT_EQ(conn.streamManager->writableStreams().getNextScheduledStream(), 0);
}
TEST_F(QuicPacketSchedulerTest, StreamFrameSchedulerRoundRobin) {
@@ -1112,9 +1112,8 @@ TEST_F(QuicPacketSchedulerTest, StreamFrameSchedulerRoundRobin) {
folly::IOBuf::copyBuffer("some data"),
false);
// Force the wraparound initially.
conn.schedulingState.nextScheduledStream = stream3 + 8;
scheduler.writeStreams(builder);
EXPECT_EQ(conn.schedulingState.nextScheduledStream, 4);
EXPECT_EQ(conn.streamManager->writableStreams().getNextScheduledStream(), 4);
// Should write frames for stream2, stream3, followed by stream1 again.
NiceMock<MockQuicPacketBuilder> builder2;
@@ -1176,10 +1175,10 @@ TEST_F(QuicPacketSchedulerTest, StreamFrameSchedulerRoundRobinStreamPerPacket) {
*conn.streamManager->findStream(stream3),
folly::IOBuf::copyBuffer("some data"),
false);
// Force the wraparound initially.
conn.schedulingState.nextScheduledStream = stream3 + 8;
// The default is to wraparound initially.
scheduler.writeStreams(builder1);
EXPECT_EQ(conn.schedulingState.nextScheduledStream, 4);
EXPECT_EQ(
conn.streamManager->writableStreams().getNextScheduledStream(), stream2);
// Should write frames for stream2, stream3, followed by stream1 again.
NiceMock<MockQuicPacketBuilder> builder2;
@@ -1205,6 +1204,75 @@ TEST_F(QuicPacketSchedulerTest, StreamFrameSchedulerRoundRobinStreamPerPacket) {
EXPECT_EQ(*frames[2].asWriteStreamFrame(), f3);
}
TEST_F(QuicPacketSchedulerTest, StreamFrameSchedulerSequential) {
QuicClientConnectionState conn(
FizzClientQuicHandshakeContext::Builder().build());
conn.streamManager->setMaxLocalBidirectionalStreams(10);
conn.flowControlState.peerAdvertisedMaxOffset = 100000;
conn.flowControlState.peerAdvertisedInitialMaxStreamOffsetBidiRemote = 100000;
auto connId = getTestConnectionId();
StreamFrameScheduler scheduler(conn);
ShortHeader shortHeader1(
ProtectionType::KeyPhaseZero,
connId,
getNextPacketNum(conn, PacketNumberSpace::AppData));
RegularQuicPacketBuilder builder1(
conn.udpSendPacketLen,
std::move(shortHeader1),
conn.ackStates.appDataAckState.largestAckedByPeer.value_or(0));
auto stream1 =
conn.streamManager->createNextBidirectionalStream().value()->id;
auto stream2 =
conn.streamManager->createNextBidirectionalStream().value()->id;
auto stream3 =
conn.streamManager->createNextBidirectionalStream().value()->id;
conn.streamManager->findStream(stream1)->priority = Priority(0, false);
conn.streamManager->findStream(stream2)->priority = Priority(0, false);
conn.streamManager->findStream(stream3)->priority = Priority(0, false);
auto largeBuf = folly::IOBuf::createChain(conn.udpSendPacketLen * 2, 4096);
auto curBuf = largeBuf.get();
do {
curBuf->append(curBuf->capacity());
curBuf = curBuf->next();
} while (curBuf != largeBuf.get());
auto chainLen = largeBuf->computeChainDataLength();
writeDataToQuicStream(
*conn.streamManager->findStream(stream1), std::move(largeBuf), false);
writeDataToQuicStream(
*conn.streamManager->findStream(stream2),
folly::IOBuf::copyBuffer("some data"),
false);
writeDataToQuicStream(
*conn.streamManager->findStream(stream3),
folly::IOBuf::copyBuffer("some data"),
false);
// The default is to wraparound initially.
scheduler.writeStreams(builder1);
EXPECT_EQ(
conn.streamManager->writableStreams().getNextScheduledStream(
Priority(0, false)),
stream1);
// Should write frames for stream1, stream2, stream3, in that order.
NiceMock<MockQuicPacketBuilder> builder2;
EXPECT_CALL(builder2, remainingSpaceInPkt()).WillRepeatedly(Return(4096));
EXPECT_CALL(builder2, appendFrame(_)).WillRepeatedly(Invoke([&](auto f) {
builder2.frames_.push_back(f);
}));
scheduler.writeStreams(builder2);
auto& frames = builder2.frames_;
ASSERT_EQ(frames.size(), 3);
WriteStreamFrame f1(stream1, 0, chainLen, false);
WriteStreamFrame f2(stream2, 0, 9, false);
WriteStreamFrame f3(stream3, 0, 9, false);
ASSERT_TRUE(frames[0].asWriteStreamFrame());
EXPECT_EQ(*frames[0].asWriteStreamFrame(), f1);
ASSERT_TRUE(frames[1].asWriteStreamFrame());
EXPECT_EQ(*frames[1].asWriteStreamFrame(), f2);
ASSERT_TRUE(frames[2].asWriteStreamFrame());
EXPECT_EQ(*frames[2].asWriteStreamFrame(), f3);
}
TEST_F(QuicPacketSchedulerTest, StreamFrameSchedulerRoundRobinControl) {
QuicClientConnectionState conn(
FizzClientQuicHandshakeContext::Builder().build());
@@ -1255,10 +1323,10 @@ TEST_F(QuicPacketSchedulerTest, StreamFrameSchedulerRoundRobinControl) {
*conn.streamManager->findStream(stream4),
folly::IOBuf::copyBuffer("some data"),
false);
// Force the wraparound initially.
conn.schedulingState.nextScheduledStream = stream4 + 8;
// The default is to wraparound initially.
scheduler.writeStreams(builder);
EXPECT_EQ(conn.schedulingState.nextScheduledStream, stream3);
EXPECT_EQ(
conn.streamManager->writableStreams().getNextScheduledStream(), stream3);
EXPECT_EQ(conn.schedulingState.nextScheduledControlStream, stream2);
// Should write frames for stream2, stream4, followed by stream 3 then 1.
@@ -1283,7 +1351,8 @@ TEST_F(QuicPacketSchedulerTest, StreamFrameSchedulerRoundRobinControl) {
ASSERT_TRUE(frames[3].asWriteStreamFrame());
EXPECT_EQ(*frames[3].asWriteStreamFrame(), f4);
EXPECT_EQ(conn.schedulingState.nextScheduledStream, stream3);
EXPECT_EQ(
conn.streamManager->writableStreams().getNextScheduledStream(), stream3);
EXPECT_EQ(conn.schedulingState.nextScheduledControlStream, stream2);
}
@@ -1307,7 +1376,7 @@ TEST_F(QuicPacketSchedulerTest, StreamFrameSchedulerOneStream) {
auto stream1 = conn.streamManager->createNextBidirectionalStream().value();
writeDataToQuicStream(*stream1, folly::IOBuf::copyBuffer("some data"), false);
scheduler.writeStreams(builder);
EXPECT_EQ(conn.schedulingState.nextScheduledStream, 0);
EXPECT_EQ(conn.streamManager->writableStreams().getNextScheduledStream(), 0);
}
TEST_F(QuicPacketSchedulerTest, StreamFrameSchedulerRemoveOne) {
@@ -1344,8 +1413,8 @@ TEST_F(QuicPacketSchedulerTest, StreamFrameSchedulerRemoveOne) {
// Manually remove a stream and set the next scheduled to that stream.
builder.frames_.clear();
conn.streamManager->writableStreams().setNextScheduledStream(stream2);
conn.streamManager->removeWritable(*conn.streamManager->findStream(stream2));
conn.schedulingState.nextScheduledStream = stream2;
scheduler.writeStreams(builder);
ASSERT_EQ(builder.frames_.size(), 1);
ASSERT_TRUE(builder.frames_[0].asWriteStreamFrame());

View File

@@ -147,11 +147,7 @@ void dropPackets(QuicServerConnectionState& conn) {
}),
std::move(*itr->second));
stream->retransmissionBuffer.erase(itr);
if (std::find(
conn.streamManager->lossStreams().begin(),
conn.streamManager->lossStreams().end(),
streamFrame->streamId) ==
conn.streamManager->lossStreams().end()) {
if (conn.streamManager->lossStreams().count(streamFrame->streamId) == 0) {
conn.streamManager->addLoss(streamFrame->streamId);
}
}
@@ -462,6 +458,7 @@ TEST_F(QuicTransportTest, WriteSmall) {
EXPECT_CALL(*socket_, write(_, _)).WillOnce(Invoke(bufLength));
transport_->writeChain(stream, buf->clone(), false, false);
transport_->setStreamPriority(stream, 0, false);
loopForWrites();
auto& conn = transport_->getConnectionState();
verifyCorrectness(conn, 0, stream, *buf);
@@ -1594,6 +1591,8 @@ TEST_F(QuicTransportTest, NonWritableStreamAPI) {
EXPECT_EQ(LocalErrorCode::STREAM_CLOSED, res1.error());
auto res2 = transport_->notifyPendingWriteOnStream(streamId, &writeCallback_);
EXPECT_EQ(LocalErrorCode::STREAM_CLOSED, res2.error());
auto res3 = transport_->setStreamPriority(streamId, 0, false);
EXPECT_FALSE(res3.hasError());
}
TEST_F(QuicTransportTest, RstWrittenStream) {
@@ -2880,7 +2879,7 @@ TEST_F(QuicTransportTest, WriteStreamFromMiddleOfMap) {
conn.outstandings.packets.clear();
// Start from stream2 instead of stream1
conn.schedulingState.nextScheduledStream = s2;
conn.streamManager->writableStreams().setNextScheduledStream(s2);
writableBytes = kDefaultUDPSendPacketLen - 100;
EXPECT_CALL(*socket_, write(_, _)).WillOnce(Invoke(bufLength));
@@ -2903,7 +2902,7 @@ TEST_F(QuicTransportTest, WriteStreamFromMiddleOfMap) {
conn.outstandings.packets.clear();
// Test wrap around
conn.schedulingState.nextScheduledStream = s2;
conn.streamManager->writableStreams().setNextScheduledStream(s2);
writableBytes = kDefaultUDPSendPacketLen;
EXPECT_CALL(*socket_, write(_, _)).WillOnce(Invoke(bufLength));
writeQuicDataToSocket(

View File

@@ -4734,14 +4734,10 @@ TEST_F(QuicClientTransportAfterStartTest, ResetClearsPendingLoss) {
CHECK_NOTNULL(findPacketWithStream(client->getNonConstConn(), streamId));
markPacketLoss(client->getNonConstConn(), *forceLossPacket, false);
auto& pendingLossStreams = client->getConn().streamManager->lossStreams();
auto it =
std::find(pendingLossStreams.begin(), pendingLossStreams.end(), streamId);
ASSERT_TRUE(it != pendingLossStreams.end());
ASSERT_TRUE(pendingLossStreams.count(streamId) > 0);
client->resetStream(streamId, GenericApplicationErrorCode::UNKNOWN);
it =
std::find(pendingLossStreams.begin(), pendingLossStreams.end(), streamId);
ASSERT_TRUE(it == pendingLossStreams.end());
ASSERT_TRUE(pendingLossStreams.count(streamId) == 0);
}
TEST_F(QuicClientTransportAfterStartTest, LossAfterResetStream) {
@@ -4763,9 +4759,7 @@ TEST_F(QuicClientTransportAfterStartTest, LossAfterResetStream) {
client->getNonConstConn().streamManager->getStream(streamId));
ASSERT_TRUE(stream->lossBuffer.empty());
auto& pendingLossStreams = client->getConn().streamManager->lossStreams();
auto it =
std::find(pendingLossStreams.begin(), pendingLossStreams.end(), streamId);
ASSERT_TRUE(it == pendingLossStreams.end());
ASSERT_TRUE(pendingLossStreams.count(streamId) == 0);
}
TEST_F(QuicClientTransportAfterStartTest, SendResetAfterEom) {

View File

@@ -0,0 +1,179 @@
#pragma once
#include <glog/logging.h>
#include <map>
#include <set>
#include <quic/codec/Types.h>
namespace quic {
constexpr uint8_t kDefaultPriorityLevels = kDefaultMaxPriority + 1;
/**
* Priority is expressed as a level [0,7] and an incremental flag.
*/
struct Priority {
uint8_t level : 3;
bool incremental : 1;
Priority(uint8_t l, bool i) : level(l), incremental(i) {}
};
/**
* Default priority, urgency = 3, incremental = true
* Note this is different from the priority draft where default incremental = 0
*/
const Priority kDefaultPriority(3, true);
/**
* Priority queue for Quic streams. It represents each level/incremental bucket
* as an entry in a vector. Each entry holds a set of streams (sorted by
* stream ID, ascending). There is also a map of all streams currently in the
* queue, mapping from ID -> bucket index. The interface is almost identical
* to std::set (insert, erase, count, clear), except that insert takes an
* optional priority parameter.
*/
struct PriorityQueue {
struct Level {
std::set<StreamId> streams;
mutable decltype(streams)::const_iterator next{streams.end()};
bool incremental{false};
};
std::vector<Level> levels;
std::map<StreamId, size_t> writableStreams;
PriorityQueue() {
levels.resize(kDefaultPriorityLevels * 2);
for (size_t index = 1; index < levels.size(); index += 2) {
levels[index].incremental = true;
}
}
static size_t priority2index(Priority pri, size_t max) {
auto index = pri.level * 2 + uint8_t(pri.incremental);
DCHECK_LT(index, max) << "Logic error: level=" << pri.level
<< " incremental=" << pri.incremental;
return index;
}
/**
* Update stream priority if the stream already exist in the PriorityQueue
*
* This is a no-op if the stream doesn't exist, or its priority is the same as
* the input.
*/
void updateIfExist(StreamId id, Priority priority = kDefaultPriority) {
auto iter = writableStreams.find(id);
if (iter == writableStreams.end()) {
return;
}
auto index = priority2index(priority, levels.size());
if (iter->second == index) {
// no need to update
return;
}
eraseFromLevel(iter->second, iter->first);
iter->second = index;
auto res = levels[index].streams.insert(id);
DCHECK(res.second) << "PriorityQueue inconsistentent: stream=" << id
<< " already at level=" << index;
}
void insertOrUpdate(StreamId id, Priority pri = kDefaultPriority) {
auto it = writableStreams.find(id);
auto index = priority2index(pri, levels.size());
if (it != writableStreams.end()) {
if (it->second == index) {
// No op, this stream is already inserted at the correct priority level
return;
}
VLOG(4) << "Updating priority of stream=" << id << " from " << it->second
<< " to " << index;
// Meh, too hard. Just erase it and start over.
eraseFromLevel(it->second, it->first);
it->second = index;
} else {
writableStreams.emplace(id, index);
}
auto res = levels[index].streams.insert(id);
DCHECK(res.second) << "PriorityQueue inconsistentent: stream=" << id
<< " already at level=" << index;
}
void erase(StreamId id) {
auto it = find(id);
erase(it);
}
// Only used for testing
void clear() {
writableStreams.clear();
for (auto& level : levels) {
level.streams.clear();
level.next = level.streams.end();
}
}
FOLLY_NODISCARD size_t count(StreamId id) const {
return writableStreams.count(id);
}
FOLLY_NODISCARD bool empty() const {
return writableStreams.empty();
}
// Testing helper to override scheduling state
void setNextScheduledStream(StreamId id) {
auto it = writableStreams.find(id);
CHECK(it != writableStreams.end());
auto& level = levels[it->second];
auto streamIt = level.streams.find(id);
CHECK(streamIt != level.streams.end());
level.next = streamIt;
}
// Only used for testing
FOLLY_NODISCARD StreamId
getNextScheduledStream(Priority pri = kDefaultPriority) const {
auto& level = levels[priority2index(pri, levels.size())];
if (level.next == level.streams.end()) {
CHECK(!level.streams.empty());
return *level.streams.begin();
}
return *level.next;
}
private:
using WSIterator = decltype(writableStreams)::iterator;
WSIterator find(StreamId id) {
return writableStreams.find(id);
}
void eraseFromLevel(size_t levelIndex, StreamId id) {
auto& level = levels[levelIndex];
auto streamIt = level.streams.find(id);
if (streamIt != level.streams.end()) {
if (streamIt == level.next) {
level.next = level.streams.erase(streamIt);
} else {
level.streams.erase(streamIt);
}
} else {
LOG(DFATAL) << "Stream=" << levelIndex
<< " not found in PriorityQueue level=" << id;
}
}
// Helper function to erase an iter from writableStream and its corresponding
// item from levels.
void erase(WSIterator it) {
if (it != writableStreams.end()) {
eraseFromLevel(it->second, it->first);
writableStreams.erase(it);
}
}
};
} // namespace quic

View File

@@ -225,6 +225,20 @@ bool QuicStreamManager::consumeMaxLocalUnidirectionalStreamIdIncreased() {
return res;
}
void QuicStreamManager::setStreamPriority(
StreamId id,
PriorityLevel level,
bool incremental) {
auto stream = findStream(id);
if (stream) {
stream->priority = Priority(level, incremental);
// If this stream is already in the writable or loss queus, update the
// priority there.
writableStreams_.updateIfExist(id, stream->priority);
lossStreams_.updateIfExist(id, stream->priority);
}
}
void QuicStreamManager::refreshTransportSettings(
const TransportSettings& settings) {
transportSettings_ = &settings;
@@ -480,9 +494,11 @@ void QuicStreamManager::removeClosedStream(StreamId streamId) {
void QuicStreamManager::updateLossStreams(QuicStreamState& stream) {
if (stream.lossBuffer.empty()) {
// No-op if not present
lossStreams_.erase(stream.id);
} else {
lossStreams_.emplace(stream.id);
// No-op if already inserted
lossStreams_.insertOrUpdate(stream.id, stream.priority);
}
}

View File

@@ -195,18 +195,24 @@ class QuicStreamManager {
}
}
const auto& lossStreams() const {
auto& lossStreams() const {
return lossStreams_;
}
// This should be only used in testing code.
void addLoss(StreamId streamId) {
lossStreams_.insert(streamId);
auto stream = findStream(streamId);
if (stream) {
lossStreams_.insertOrUpdate(streamId, stream->priority);
}
}
bool hasLoss() const {
return !lossStreams_.empty();
}
void setStreamPriority(StreamId id, PriorityLevel level, bool incremental);
// TODO figure out a better interface here.
/*
* Returns a mutable reference to the container holding the writable stream
@@ -247,7 +253,7 @@ class QuicStreamManager {
if (stream.isControl) {
writableControlStreams_.insert(stream.id);
} else {
writableStreams_.insert(stream.id);
writableStreams_.insertOrUpdate(stream.id, stream.priority);
}
}
@@ -880,7 +886,7 @@ class QuicStreamManager {
folly::F14FastSet<StreamId> flowControlUpdated_;
// Data structure to keep track of stream that have detected lost data
folly::F14FastSet<StreamId> lossStreams_;
PriorityQueue lossStreams_;
// Set of streams that have pending reads
folly::F14FastSet<StreamId> readableStreams_;
@@ -889,7 +895,7 @@ class QuicStreamManager {
folly::F14FastSet<StreamId> peekableStreams_;
// Set of !control streams that have writable data
std::set<StreamId> writableStreams_;
PriorityQueue writableStreams_;
// Set of control streams that have writable data
std::set<StreamId> writableControlStreams_;

View File

@@ -721,7 +721,6 @@ struct QuicConnectionStateBase : public folly::DelayedDestruction {
uint64_t peerMaxUdpPayloadSize{kDefaultUDPSendPacketLen};
struct PacketSchedulingState {
StreamId nextScheduledStream{0};
StreamId nextScheduledControlStream{0};
};

View File

@@ -12,6 +12,7 @@
#include <quic/QuicConstants.h>
#include <quic/codec/Types.h>
#include <quic/common/SmallVec.h>
#include <quic/state/QuicPriorityQueue.h>
namespace quic {
@@ -209,6 +210,8 @@ struct QuicStreamState : public QuicStreamLike {
// lastHolbTime indicates whether the stream is HOL blocked at the moment.
uint32_t holbCount{0};
Priority priority{kDefaultPriority};
// Returns true if both send and receive state machines are in a terminal
// state
bool inTerminalStates() const {

View File

@@ -27,6 +27,8 @@ quic_add_test(TARGET QuicStreamFunctionsTest
)
quic_add_test(TARGET QuicStreamManagerTest
SOURCES
QuicPriorityQueueTest.cpp
QuicStreamManagerTest.cpp
DEPENDS
mvfst_client

View File

@@ -0,0 +1,78 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*
*/
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include <quic/state/QuicPriorityQueue.h>
namespace quic::test {
class QuicPriorityQueueTest : public testing::Test {
public:
PriorityQueue queue_;
};
TEST_F(QuicPriorityQueueTest, TestBasic) {
EXPECT_TRUE(queue_.empty());
EXPECT_EQ(queue_.count(0), 0);
StreamId id = 0;
// Insert two streams for every level and incremental
for (uint8_t i = 0; i < queue_.levels.size(); i++) {
queue_.insertOrUpdate(id++, Priority(i / 2, i & 0x1));
queue_.setNextScheduledStream(id - 1);
queue_.insertOrUpdate(id++, Priority(i / 2, i & 0x1));
}
for (int16_t i = id - 1; i >= 0; i--) {
EXPECT_EQ(queue_.count(i), 1);
}
for (auto& level : queue_.levels) {
EXPECT_EQ(level.streams.size(), 2);
}
for (uint8_t i = 0; i < queue_.levels.size(); i++) {
id = i * 2;
EXPECT_EQ(queue_.getNextScheduledStream(Priority(i / 2, i & 0x1)), id);
queue_.erase(id);
EXPECT_EQ(queue_.count(id), 0);
EXPECT_EQ(queue_.getNextScheduledStream(Priority(i / 2, i & 0x1)), id + 1);
}
queue_.clear();
EXPECT_TRUE(queue_.empty());
}
TEST_F(QuicPriorityQueueTest, TestUpdate) {
queue_.insertOrUpdate(0, Priority(0, false));
EXPECT_EQ(queue_.count(0), 1);
// Update no-op
queue_.insertOrUpdate(0, Priority(0, false));
EXPECT_EQ(queue_.count(0), 1);
// Update move to different bucket
queue_.insertOrUpdate(0, Priority(0, true));
EXPECT_EQ(queue_.count(0), 1);
EXPECT_EQ(queue_.getNextScheduledStream(Priority(0, true)), 0);
}
TEST_F(QuicPriorityQueueTest, UpdateIfExist) {
queue_.updateIfExist(0);
EXPECT_EQ(0, queue_.count(0));
queue_.insertOrUpdate(0, Priority(0, false));
EXPECT_EQ(queue_.getNextScheduledStream(Priority(0, false)), 0);
queue_.updateIfExist(0, Priority(1, true));
EXPECT_EQ(queue_.getNextScheduledStream(Priority(1, true)), 0);
}
} // namespace quic::test