1
0
mirror of https://github.com/facebookincubator/mvfst.git synced 2025-08-08 09:42:06 +03:00

Implement group streams receiver api in transport

Summary: Implement group streams receiver api in transport

Reviewed By: mjoras

Differential Revision: D36419901

fbshipit-source-id: 98bfefa1a4205fde8764f2e4300f51156667e024
This commit is contained in:
Konstantin Tsoy
2022-06-06 17:11:39 -07:00
committed by Facebook GitHub Bot
parent 451b519280
commit e208ceffdd
6 changed files with 342 additions and 33 deletions

View File

@@ -854,12 +854,21 @@ void QuicTransportBase::invokeReadDataAndCallbacks() {
peekCallbacks_.erase(streamId);
VLOG(10) << "invoking read error callbacks on stream=" << streamId << " "
<< *this;
readCb->readError(streamId, QuicError(*stream->streamReadError));
if (!stream->groupId) {
readCb->readError(streamId, QuicError(*stream->streamReadError));
} else {
readCb->readErrorWithGroup(
streamId, *stream->groupId, QuicError(*stream->streamReadError));
}
} else if (
readCb && callback->second.resumed && stream->hasReadableData()) {
VLOG(10) << "invoking read callbacks on stream=" << streamId << " "
<< *this;
readCb->readAvailable(streamId);
if (!stream->groupId) {
readCb->readAvailable(streamId);
} else {
readCb->readAvailableWithGroup(streamId, *stream->groupId);
}
}
}
if (self->datagramCallback_ && !conn_->datagramState.readBuffer.empty()) {
@@ -1538,11 +1547,24 @@ void QuicTransportBase::handleCancelByteEventCallbacks() {
}
}
void QuicTransportBase::handleNewStreamCallbacks(
std::vector<StreamId>& streamStorage) {
streamStorage =
conn_->streamManager->consumeNewPeerStreams(std::move(streamStorage));
void QuicTransportBase::logStreamOpenEvent(StreamId streamId) {
if (getSocketObserverContainer() &&
getSocketObserverContainer()
->hasObserversForEvent<
SocketObserverInterface::Events::streamEvents>()) {
getSocketObserverContainer()
->invokeInterfaceMethod<SocketObserverInterface::Events::streamEvents>(
[event = SocketObserverInterface::StreamOpenEvent(
streamId,
getStreamInitiator(streamId),
getStreamDirectionality(streamId))](
auto observer, auto observed) {
observer->streamOpened(observed, event);
});
}
}
void QuicTransportBase::handleNewStreams(std::vector<StreamId>& streamStorage) {
const auto& newPeerStreamIds = streamStorage;
for (const auto& streamId : newPeerStreamIds) {
CHECK_NOTNULL(connCallback_);
@@ -1552,30 +1574,60 @@ void QuicTransportBase::handleNewStreamCallbacks(
connCallback_->onNewUnidirectionalStream(streamId);
}
if (getSocketObserverContainer() &&
getSocketObserverContainer()
->hasObserversForEvent<
SocketObserverInterface::Events::streamEvents>()) {
getSocketObserverContainer()
->invokeInterfaceMethod<
SocketObserverInterface::Events::streamEvents>(
[event = SocketObserverInterface::StreamOpenEvent(
streamId,
getStreamInitiator(streamId),
getStreamDirectionality(streamId))](
auto observer, auto observed) {
observer->streamOpened(observed, event);
});
}
logStreamOpenEvent(streamId);
if (closeState_ != CloseState::OPEN) {
return;
}
}
streamStorage.clear();
}
void QuicTransportBase::handleNewGroupedStreams(
std::vector<StreamId>& streamStorage) {
const auto& newPeerStreamIds = streamStorage;
for (const auto& streamId : newPeerStreamIds) {
CHECK_NOTNULL(connCallback_);
auto stream = conn_->streamManager->getStream(streamId);
CHECK(stream->groupId);
if (isBidirectionalStream(streamId)) {
connCallback_->onNewBidirectionalStreamInGroup(
streamId, *stream->groupId);
} else {
connCallback_->onNewUnidirectionalStreamInGroup(
streamId, *stream->groupId);
}
logStreamOpenEvent(streamId);
if (closeState_ != CloseState::OPEN) {
return;
}
}
streamStorage.clear();
}
void QuicTransportBase::handleNewStreamCallbacks(
std::vector<StreamId>& streamStorage) {
streamStorage =
conn_->streamManager->consumeNewPeerStreams(std::move(streamStorage));
handleNewStreams(streamStorage);
}
void QuicTransportBase::handleNewGroupedStreamCallbacks(
std::vector<StreamId>& streamStorage) {
auto newStreamGroups = conn_->streamManager->consumeNewPeerStreamGroups();
for (auto newStreamGroupId : newStreamGroups) {
if (isBidirectionalStream(newStreamGroupId)) {
connCallback_->onNewBidirectionalStreamGroup(newStreamGroupId);
} else {
connCallback_->onNewUnidirectionalStreamGroup(newStreamGroupId);
}
}
streamStorage = conn_->streamManager->consumeNewGroupedPeerStreams(
std::move(streamStorage));
handleNewGroupedStreams(streamStorage);
}
void QuicTransportBase::handleDeliveryCallbacks() {
auto deliverableStreamId = conn_->streamManager->popDeliverable();
while (deliverableStreamId.has_value()) {
@@ -1722,6 +1774,11 @@ void QuicTransportBase::processCallbacksAfterNetworkData() {
return;
}
handleNewGroupedStreamCallbacks(tempStorage);
if (closeState_ != CloseState::OPEN) {
return;
}
handlePingCallbacks();
if (closeState_ != CloseState::OPEN) {
return;
@@ -2845,7 +2902,12 @@ void QuicTransportBase::cancelAllAppCallbacks(const QuicError& err) noexcept {
for (auto& cb : readCallbacksCopy) {
readCallbacks_.erase(cb.first);
if (cb.second.readCb) {
cb.second.readCb->readError(cb.first, err);
auto stream = conn_->streamManager->getStream(cb.first);
if (!stream->groupId) {
cb.second.readCb->readError(cb.first, err);
} else {
cb.second.readCb->readErrorWithGroup(cb.first, *stream->groupId, err);
}
}
}
@@ -2901,8 +2963,14 @@ void QuicTransportBase::resetNonControlStreams(
auto readCallbackIt = readCallbacks_.find(id);
if (readCallbackIt != readCallbacks_.end() &&
readCallbackIt->second.readCb) {
readCallbackIt->second.readCb->readError(
id, QuicError(error, errorMsg.str()));
auto stream = conn_->streamManager->getStream(id);
if (!stream->groupId) {
readCallbackIt->second.readCb->readError(
id, QuicError(error, errorMsg.str()));
} else {
readCallbackIt->second.readCb->readErrorWithGroup(
id, *stream->groupId, QuicError(error, errorMsg.str()));
}
}
peekCallbacks_.erase(id);
stopSending(id, error);

View File

@@ -709,6 +709,7 @@ class QuicTransportBase : public QuicSocket, QuicStreamPrioritiesObserver {
void handleAckEventCallbacks();
void handleCancelByteEventCallbacks();
void handleNewStreamCallbacks(std::vector<StreamId>& newPeerStreams);
void handleNewGroupedStreamCallbacks(std::vector<StreamId>& newPeerStreams);
void handleDeliveryCallbacks();
void handleStreamFlowControlUpdatedCallbacks(
std::vector<StreamId>& streamStorage);
@@ -933,6 +934,18 @@ class QuicTransportBase : public QuicSocket, QuicStreamPrioritiesObserver {
// (value >= threshold), the connection is considered to be in backround mode.
folly::Optional<PriorityLevel> backgroundPriorityThreshold_;
folly::Optional<float> backgroundUtilizationFactor_;
private:
/**
* Helper funtions to handle new streams.
*/
void handleNewStreams(std::vector<StreamId>& newPeerStreams);
void handleNewGroupedStreams(std::vector<StreamId>& newPeerStreams);
/**
* Helper to log new stream event to observer.
*/
void logStreamOpenEvent(StreamId streamId);
};
std::ostream& operator<<(std::ostream& os, const QuicTransportBase& qt);

View File

@@ -46,7 +46,17 @@ class MockReadCallback : public QuicSocket::ReadCallback {
public:
~MockReadCallback() override = default;
MOCK_METHOD((void), readAvailable, (StreamId), (noexcept));
MOCK_METHOD(
(void),
readAvailableWithGroup,
(StreamId, StreamGroupId),
(noexcept));
MOCK_METHOD((void), readError, (StreamId, QuicError), (noexcept));
MOCK_METHOD(
(void),
readErrorWithGroup,
(StreamId, StreamGroupId, QuicError),
(noexcept));
};
class MockPeekCallback : public QuicSocket::PeekCallback {
@@ -91,7 +101,27 @@ class MockConnectionCallback : public QuicSocket::ConnectionCallback {
MOCK_METHOD((void), onFlowControlUpdate, (StreamId), (noexcept));
MOCK_METHOD((void), onNewBidirectionalStream, (StreamId), (noexcept));
MOCK_METHOD(
(void),
onNewBidirectionalStreamGroup,
(StreamGroupId),
(noexcept));
MOCK_METHOD(
(void),
onNewBidirectionalStreamInGroup,
(StreamId, StreamGroupId),
(noexcept));
MOCK_METHOD((void), onNewUnidirectionalStream, (StreamId), (noexcept));
MOCK_METHOD(
(void),
onNewUnidirectionalStreamGroup,
(StreamGroupId),
(noexcept));
MOCK_METHOD(
(void),
onNewUnidirectionalStreamInGroup,
(StreamId, StreamGroupId),
(noexcept));
MOCK_METHOD(
(void),
onStopSending,

View File

@@ -39,15 +39,26 @@ enum class TestFrameType : uint8_t {
EXPIRED_DATA,
REJECTED_DATA,
MAX_STREAMS,
DATAGRAM
DATAGRAM,
STREAM_GROUP
};
// A made up encoding decoding of a stream.
Buf encodeStreamBuffer(StreamId id, StreamBuffer data) {
Buf encodeStreamBuffer(
StreamId id,
StreamBuffer data,
folly::Optional<StreamGroupId> groupId = folly::none) {
auto buf = IOBuf::create(10);
folly::io::Appender appender(buf.get(), 10);
appender.writeBE(static_cast<uint8_t>(TestFrameType::STREAM));
if (!groupId) {
appender.writeBE(static_cast<uint8_t>(TestFrameType::STREAM));
} else {
appender.writeBE(static_cast<uint8_t>(TestFrameType::STREAM_GROUP));
}
appender.writeBE(id);
if (groupId) {
appender.writeBE(*groupId);
}
auto dataBuf = data.data.move();
dataBuf->coalesce();
appender.writeBE<uint32_t>(dataBuf->length());
@@ -116,6 +127,23 @@ std::pair<StreamId, StreamBuffer> decodeStreamBuffer(
StreamBuffer(std::move(dataBuffer.first), dataBuffer.second, eof));
}
struct StreamGroupIdBuf {
StreamId id;
StreamGroupId groupId;
StreamBuffer buf;
};
StreamGroupIdBuf decodeStreamGroupBuffer(folly::io::Cursor& cursor) {
auto streamId = cursor.readBE<StreamId>();
auto groupId = cursor.readBE<StreamGroupId>();
auto dataBuffer = decodeDataBuffer(cursor);
bool eof = (bool)cursor.readBE<uint8_t>();
return StreamGroupIdBuf{
streamId,
groupId,
StreamBuffer(std::move(dataBuffer.first), dataBuffer.second, eof)};
}
StreamBuffer decodeCryptoBuffer(folly::io::Cursor& cursor) {
auto dataBuffer = decodeDataBuffer(cursor);
return StreamBuffer(std::move(dataBuffer.first), dataBuffer.second, false);
@@ -258,6 +286,16 @@ class TestQuicTransport
auto buffer = decodeDatagramFrame(cursor);
auto frame = DatagramFrame(buffer.second, std::move(buffer.first));
handleDatagram(*conn_, frame, data.receiveTimePoint);
} else if (type == TestFrameType::STREAM_GROUP) {
auto res = decodeStreamGroupBuffer(cursor);
QuicStreamState* stream =
conn_->streamManager->getStream(res.id, res.groupId);
if (!stream) {
continue;
}
appendDataToReadBuffer(*stream, std::move(res.buf));
conn_->streamManager->updateReadableStreams(*stream);
conn_->streamManager->updatePeekableStreams(*stream);
} else {
auto buffer = decodeStreamBuffer(cursor);
QuicStreamState* stream = conn_->streamManager->getStream(buffer.first);
@@ -346,8 +384,11 @@ class TestQuicTransport
void onReadError(const folly::AsyncSocketException&) noexcept {}
void addDataToStream(StreamId id, StreamBuffer data) {
auto buf = encodeStreamBuffer(id, std::move(data));
void addDataToStream(
StreamId id,
StreamBuffer data,
folly::Optional<StreamGroupId> groupId = folly::none) {
auto buf = encodeStreamBuffer(id, std::move(data), std::move(groupId));
SocketAddress addr("127.0.0.1", 1000);
onNetworkData(addr, NetworkData(std::move(buf), Clock::now()));
}
@@ -4038,5 +4079,160 @@ TEST_P(QuicTransportImplTestBase, BackgroundModeChangeWithStreamChanges) {
manager.removeClosedStream(stream2Id);
}
class QuicTransportImplTestWithGroups : public QuicTransportImplTestBase {};
INSTANTIATE_TEST_SUITE_P(
QuicTransportImplTestWithGroups,
QuicTransportImplTestWithGroups,
::testing::Values(DelayedStreamNotifsTestParam{
.notifyOnNewStreamsExplicitly = true}));
TEST_P(QuicTransportImplTestWithGroups, ReadCallbackWithGroupsDataAvailable) {
auto transportSettings = transport->getTransportSettings();
transportSettings.maxStreamGroupsAdvertized = 16;
transport->setTransportSettings(transportSettings);
transport->getConnectionState().streamManager->refreshTransportSettings(
transportSettings);
auto groupId = transport->createBidirectionalStreamGroup();
EXPECT_TRUE(groupId.hasValue());
auto stream1 = transport->createBidirectionalStreamInGroup(*groupId).value();
auto stream2 = transport->createBidirectionalStreamInGroup(*groupId).value();
NiceMock<MockReadCallback> readCb1;
NiceMock<MockReadCallback> readCb2;
transport->setReadCallback(stream1, &readCb1);
transport->setReadCallback(stream2, &readCb2);
transport->addDataToStream(
stream1,
StreamBuffer(folly::IOBuf::copyBuffer("actual stream data"), 0),
*groupId);
transport->addDataToStream(
stream2,
StreamBuffer(folly::IOBuf::copyBuffer("actual stream data"), 10),
*groupId);
EXPECT_CALL(readCb1, readAvailableWithGroup(stream1, *groupId));
transport->driveReadCallbacks();
transport->addDataToStream(
stream2,
StreamBuffer(folly::IOBuf::copyBuffer("actual stream data"), 0),
*groupId);
EXPECT_CALL(readCb1, readAvailableWithGroup(stream1, *groupId));
EXPECT_CALL(readCb2, readAvailableWithGroup(stream2, *groupId));
transport->driveReadCallbacks();
EXPECT_CALL(readCb1, readAvailableWithGroup(stream1, *groupId));
EXPECT_CALL(readCb2, readAvailableWithGroup(stream2, *groupId));
transport->driveReadCallbacks();
EXPECT_CALL(readCb2, readAvailableWithGroup(stream2, *groupId));
transport->setReadCallback(stream1, nullptr);
transport->driveReadCallbacks();
transport.reset();
}
TEST_P(QuicTransportImplTestWithGroups, ReadErrorCallbackWithGroups) {
auto transportSettings = transport->getTransportSettings();
transportSettings.maxStreamGroupsAdvertized = 16;
transport->setTransportSettings(transportSettings);
transport->getConnectionState().streamManager->refreshTransportSettings(
transportSettings);
auto groupId = transport->createBidirectionalStreamGroup();
EXPECT_TRUE(groupId.hasValue());
auto stream1 = transport->createBidirectionalStreamInGroup(*groupId).value();
NiceMock<MockReadCallback> readCb1;
transport->setReadCallback(stream1, &readCb1);
transport->addStreamReadError(stream1, LocalErrorCode::NO_ERROR);
transport->addDataToStream(
stream1,
StreamBuffer(folly::IOBuf::copyBuffer("actual stream data"), 0),
*groupId);
EXPECT_CALL(readCb1, readErrorWithGroup(stream1, *groupId, _));
transport->driveReadCallbacks();
transport.reset();
}
TEST_P(
QuicTransportImplTestWithGroups,
ReadCallbackWithGroupsCancellCallbacks) {
auto transportSettings = transport->getTransportSettings();
transportSettings.maxStreamGroupsAdvertized = 16;
transport->setTransportSettings(transportSettings);
transport->getConnectionState().streamManager->refreshTransportSettings(
transportSettings);
auto groupId = transport->createBidirectionalStreamGroup();
EXPECT_TRUE(groupId.hasValue());
auto stream1 = transport->createBidirectionalStreamInGroup(*groupId).value();
auto stream2 = transport->createBidirectionalStreamInGroup(*groupId).value();
NiceMock<MockReadCallback> readCb1;
NiceMock<MockReadCallback> readCb2;
transport->setReadCallback(stream1, &readCb1);
transport->setReadCallback(stream2, &readCb2);
transport->addDataToStream(
stream1,
StreamBuffer(folly::IOBuf::copyBuffer("actual stream data"), 0),
*groupId);
transport->addDataToStream(
stream2,
StreamBuffer(folly::IOBuf::copyBuffer("actual stream data"), 10),
*groupId);
EXPECT_CALL(readCb1, readErrorWithGroup(stream1, *groupId, _));
EXPECT_CALL(readCb2, readErrorWithGroup(stream2, *groupId, _));
QuicError error =
QuicError(TransportErrorCode::PROTOCOL_VIOLATION, "test error");
transport->cancelAllAppCallbacks(error);
transport.reset();
}
TEST_P(QuicTransportImplTestWithGroups, onNewStreamsAndGroupsCallbacks) {
auto transportSettings = transport->getTransportSettings();
transportSettings.maxStreamGroupsAdvertized = 16;
transport->setTransportSettings(transportSettings);
transport->getConnectionState().streamManager->refreshTransportSettings(
transportSettings);
auto readData = folly::IOBuf::copyBuffer("actual stream data");
StreamGroupId groupId = 0x00;
StreamId stream1 = 0x00;
EXPECT_CALL(connCallback, onNewBidirectionalStreamGroup(groupId));
EXPECT_CALL(connCallback, onNewBidirectionalStreamInGroup(stream1, groupId));
transport->addDataToStream(
stream1, StreamBuffer(readData->clone(), 0, true), groupId);
StreamId stream2 = 0x04;
EXPECT_CALL(connCallback, onNewBidirectionalStreamInGroup(stream2, groupId));
transport->addDataToStream(
stream2, StreamBuffer(readData->clone(), 0, true), groupId);
StreamGroupId groupIdUni = 0x02;
StreamId uniStream3 = 0xa;
EXPECT_CALL(connCallback, onNewUnidirectionalStreamGroup(groupIdUni));
EXPECT_CALL(
connCallback, onNewUnidirectionalStreamInGroup(uniStream3, groupIdUni));
transport->addDataToStream(
uniStream3, StreamBuffer(readData->clone(), 0, true), groupIdUni);
transport.reset();
}
} // namespace test
} // namespace quic

View File

@@ -494,7 +494,8 @@ void QuicClientTransport::processPacketData(
<< " len=" << frame.data->computeChainDataLength()
<< " fin=" << frame.fin << " packetNum=" << packetNum << " "
<< *this;
auto stream = conn_->streamManager->getStream(frame.streamId);
auto stream = conn_->streamManager->getStream(
frame.streamId, frame.streamGroupId);
pktHasRetransmittableData = true;
if (!stream) {
VLOG(10) << "Could not find stream=" << frame.streamId << " "

View File

@@ -1083,7 +1083,8 @@ void onServerReadDataFromOpen(
<< " fin=" << frame.fin << " " << conn;
pktHasRetransmittableData = true;
isNonProbingPacket = true;
auto stream = conn.streamManager->getStream(frame.streamId);
auto stream = conn.streamManager->getStream(
frame.streamId, frame.streamGroupId);
// Ignore data from closed streams that we don't have the
// state for any more.
if (stream) {