From e208ceffdd824e876c823b39a4ad8646ab538a6f Mon Sep 17 00:00:00 2001 From: Konstantin Tsoy Date: Mon, 6 Jun 2022 17:11:39 -0700 Subject: [PATCH] Implement group streams receiver api in transport Summary: Implement group streams receiver api in transport Reviewed By: mjoras Differential Revision: D36419901 fbshipit-source-id: 98bfefa1a4205fde8764f2e4300f51156667e024 --- quic/api/QuicTransportBase.cpp | 120 ++++++++++--- quic/api/QuicTransportBase.h | 13 ++ quic/api/test/Mocks.h | 30 ++++ quic/api/test/QuicTransportBaseTest.cpp | 206 ++++++++++++++++++++++- quic/client/QuicClientTransport.cpp | 3 +- quic/server/state/ServerStateMachine.cpp | 3 +- 6 files changed, 342 insertions(+), 33 deletions(-) diff --git a/quic/api/QuicTransportBase.cpp b/quic/api/QuicTransportBase.cpp index e45ed4d24..cbac632af 100644 --- a/quic/api/QuicTransportBase.cpp +++ b/quic/api/QuicTransportBase.cpp @@ -854,12 +854,21 @@ void QuicTransportBase::invokeReadDataAndCallbacks() { peekCallbacks_.erase(streamId); VLOG(10) << "invoking read error callbacks on stream=" << streamId << " " << *this; - readCb->readError(streamId, QuicError(*stream->streamReadError)); + if (!stream->groupId) { + readCb->readError(streamId, QuicError(*stream->streamReadError)); + } else { + readCb->readErrorWithGroup( + streamId, *stream->groupId, QuicError(*stream->streamReadError)); + } } else if ( readCb && callback->second.resumed && stream->hasReadableData()) { VLOG(10) << "invoking read callbacks on stream=" << streamId << " " << *this; - readCb->readAvailable(streamId); + if (!stream->groupId) { + readCb->readAvailable(streamId); + } else { + readCb->readAvailableWithGroup(streamId, *stream->groupId); + } } } if (self->datagramCallback_ && !conn_->datagramState.readBuffer.empty()) { @@ -1538,11 +1547,24 @@ void QuicTransportBase::handleCancelByteEventCallbacks() { } } -void QuicTransportBase::handleNewStreamCallbacks( - std::vector& streamStorage) { - streamStorage = - conn_->streamManager->consumeNewPeerStreams(std::move(streamStorage)); +void QuicTransportBase::logStreamOpenEvent(StreamId streamId) { + if (getSocketObserverContainer() && + getSocketObserverContainer() + ->hasObserversForEvent< + SocketObserverInterface::Events::streamEvents>()) { + getSocketObserverContainer() + ->invokeInterfaceMethod( + [event = SocketObserverInterface::StreamOpenEvent( + streamId, + getStreamInitiator(streamId), + getStreamDirectionality(streamId))]( + auto observer, auto observed) { + observer->streamOpened(observed, event); + }); + } +} +void QuicTransportBase::handleNewStreams(std::vector& streamStorage) { const auto& newPeerStreamIds = streamStorage; for (const auto& streamId : newPeerStreamIds) { CHECK_NOTNULL(connCallback_); @@ -1552,30 +1574,60 @@ void QuicTransportBase::handleNewStreamCallbacks( connCallback_->onNewUnidirectionalStream(streamId); } - if (getSocketObserverContainer() && - getSocketObserverContainer() - ->hasObserversForEvent< - SocketObserverInterface::Events::streamEvents>()) { - getSocketObserverContainer() - ->invokeInterfaceMethod< - SocketObserverInterface::Events::streamEvents>( - [event = SocketObserverInterface::StreamOpenEvent( - streamId, - getStreamInitiator(streamId), - getStreamDirectionality(streamId))]( - auto observer, auto observed) { - observer->streamOpened(observed, event); - }); - } - + logStreamOpenEvent(streamId); if (closeState_ != CloseState::OPEN) { return; } } - streamStorage.clear(); } +void QuicTransportBase::handleNewGroupedStreams( + std::vector& streamStorage) { + const auto& newPeerStreamIds = streamStorage; + for (const auto& streamId : newPeerStreamIds) { + CHECK_NOTNULL(connCallback_); + auto stream = conn_->streamManager->getStream(streamId); + CHECK(stream->groupId); + if (isBidirectionalStream(streamId)) { + connCallback_->onNewBidirectionalStreamInGroup( + streamId, *stream->groupId); + } else { + connCallback_->onNewUnidirectionalStreamInGroup( + streamId, *stream->groupId); + } + + logStreamOpenEvent(streamId); + if (closeState_ != CloseState::OPEN) { + return; + } + } + streamStorage.clear(); +} + +void QuicTransportBase::handleNewStreamCallbacks( + std::vector& streamStorage) { + streamStorage = + conn_->streamManager->consumeNewPeerStreams(std::move(streamStorage)); + handleNewStreams(streamStorage); +} + +void QuicTransportBase::handleNewGroupedStreamCallbacks( + std::vector& streamStorage) { + auto newStreamGroups = conn_->streamManager->consumeNewPeerStreamGroups(); + for (auto newStreamGroupId : newStreamGroups) { + if (isBidirectionalStream(newStreamGroupId)) { + connCallback_->onNewBidirectionalStreamGroup(newStreamGroupId); + } else { + connCallback_->onNewUnidirectionalStreamGroup(newStreamGroupId); + } + } + + streamStorage = conn_->streamManager->consumeNewGroupedPeerStreams( + std::move(streamStorage)); + handleNewGroupedStreams(streamStorage); +} + void QuicTransportBase::handleDeliveryCallbacks() { auto deliverableStreamId = conn_->streamManager->popDeliverable(); while (deliverableStreamId.has_value()) { @@ -1722,6 +1774,11 @@ void QuicTransportBase::processCallbacksAfterNetworkData() { return; } + handleNewGroupedStreamCallbacks(tempStorage); + if (closeState_ != CloseState::OPEN) { + return; + } + handlePingCallbacks(); if (closeState_ != CloseState::OPEN) { return; @@ -2845,7 +2902,12 @@ void QuicTransportBase::cancelAllAppCallbacks(const QuicError& err) noexcept { for (auto& cb : readCallbacksCopy) { readCallbacks_.erase(cb.first); if (cb.second.readCb) { - cb.second.readCb->readError(cb.first, err); + auto stream = conn_->streamManager->getStream(cb.first); + if (!stream->groupId) { + cb.second.readCb->readError(cb.first, err); + } else { + cb.second.readCb->readErrorWithGroup(cb.first, *stream->groupId, err); + } } } @@ -2901,8 +2963,14 @@ void QuicTransportBase::resetNonControlStreams( auto readCallbackIt = readCallbacks_.find(id); if (readCallbackIt != readCallbacks_.end() && readCallbackIt->second.readCb) { - readCallbackIt->second.readCb->readError( - id, QuicError(error, errorMsg.str())); + auto stream = conn_->streamManager->getStream(id); + if (!stream->groupId) { + readCallbackIt->second.readCb->readError( + id, QuicError(error, errorMsg.str())); + } else { + readCallbackIt->second.readCb->readErrorWithGroup( + id, *stream->groupId, QuicError(error, errorMsg.str())); + } } peekCallbacks_.erase(id); stopSending(id, error); diff --git a/quic/api/QuicTransportBase.h b/quic/api/QuicTransportBase.h index 0d9a75312..ff8b8df6a 100644 --- a/quic/api/QuicTransportBase.h +++ b/quic/api/QuicTransportBase.h @@ -709,6 +709,7 @@ class QuicTransportBase : public QuicSocket, QuicStreamPrioritiesObserver { void handleAckEventCallbacks(); void handleCancelByteEventCallbacks(); void handleNewStreamCallbacks(std::vector& newPeerStreams); + void handleNewGroupedStreamCallbacks(std::vector& newPeerStreams); void handleDeliveryCallbacks(); void handleStreamFlowControlUpdatedCallbacks( std::vector& streamStorage); @@ -933,6 +934,18 @@ class QuicTransportBase : public QuicSocket, QuicStreamPrioritiesObserver { // (value >= threshold), the connection is considered to be in backround mode. folly::Optional backgroundPriorityThreshold_; folly::Optional backgroundUtilizationFactor_; + + private: + /** + * Helper funtions to handle new streams. + */ + void handleNewStreams(std::vector& newPeerStreams); + void handleNewGroupedStreams(std::vector& newPeerStreams); + + /** + * Helper to log new stream event to observer. + */ + void logStreamOpenEvent(StreamId streamId); }; std::ostream& operator<<(std::ostream& os, const QuicTransportBase& qt); diff --git a/quic/api/test/Mocks.h b/quic/api/test/Mocks.h index db6736e85..9643fc2a6 100644 --- a/quic/api/test/Mocks.h +++ b/quic/api/test/Mocks.h @@ -46,7 +46,17 @@ class MockReadCallback : public QuicSocket::ReadCallback { public: ~MockReadCallback() override = default; MOCK_METHOD((void), readAvailable, (StreamId), (noexcept)); + MOCK_METHOD( + (void), + readAvailableWithGroup, + (StreamId, StreamGroupId), + (noexcept)); MOCK_METHOD((void), readError, (StreamId, QuicError), (noexcept)); + MOCK_METHOD( + (void), + readErrorWithGroup, + (StreamId, StreamGroupId, QuicError), + (noexcept)); }; class MockPeekCallback : public QuicSocket::PeekCallback { @@ -91,7 +101,27 @@ class MockConnectionCallback : public QuicSocket::ConnectionCallback { MOCK_METHOD((void), onFlowControlUpdate, (StreamId), (noexcept)); MOCK_METHOD((void), onNewBidirectionalStream, (StreamId), (noexcept)); + MOCK_METHOD( + (void), + onNewBidirectionalStreamGroup, + (StreamGroupId), + (noexcept)); + MOCK_METHOD( + (void), + onNewBidirectionalStreamInGroup, + (StreamId, StreamGroupId), + (noexcept)); MOCK_METHOD((void), onNewUnidirectionalStream, (StreamId), (noexcept)); + MOCK_METHOD( + (void), + onNewUnidirectionalStreamGroup, + (StreamGroupId), + (noexcept)); + MOCK_METHOD( + (void), + onNewUnidirectionalStreamInGroup, + (StreamId, StreamGroupId), + (noexcept)); MOCK_METHOD( (void), onStopSending, diff --git a/quic/api/test/QuicTransportBaseTest.cpp b/quic/api/test/QuicTransportBaseTest.cpp index cb4000849..3ecbda504 100644 --- a/quic/api/test/QuicTransportBaseTest.cpp +++ b/quic/api/test/QuicTransportBaseTest.cpp @@ -39,15 +39,26 @@ enum class TestFrameType : uint8_t { EXPIRED_DATA, REJECTED_DATA, MAX_STREAMS, - DATAGRAM + DATAGRAM, + STREAM_GROUP }; // A made up encoding decoding of a stream. -Buf encodeStreamBuffer(StreamId id, StreamBuffer data) { +Buf encodeStreamBuffer( + StreamId id, + StreamBuffer data, + folly::Optional groupId = folly::none) { auto buf = IOBuf::create(10); folly::io::Appender appender(buf.get(), 10); - appender.writeBE(static_cast(TestFrameType::STREAM)); + if (!groupId) { + appender.writeBE(static_cast(TestFrameType::STREAM)); + } else { + appender.writeBE(static_cast(TestFrameType::STREAM_GROUP)); + } appender.writeBE(id); + if (groupId) { + appender.writeBE(*groupId); + } auto dataBuf = data.data.move(); dataBuf->coalesce(); appender.writeBE(dataBuf->length()); @@ -116,6 +127,23 @@ std::pair decodeStreamBuffer( StreamBuffer(std::move(dataBuffer.first), dataBuffer.second, eof)); } +struct StreamGroupIdBuf { + StreamId id; + StreamGroupId groupId; + StreamBuffer buf; +}; + +StreamGroupIdBuf decodeStreamGroupBuffer(folly::io::Cursor& cursor) { + auto streamId = cursor.readBE(); + auto groupId = cursor.readBE(); + auto dataBuffer = decodeDataBuffer(cursor); + bool eof = (bool)cursor.readBE(); + return StreamGroupIdBuf{ + streamId, + groupId, + StreamBuffer(std::move(dataBuffer.first), dataBuffer.second, eof)}; +} + StreamBuffer decodeCryptoBuffer(folly::io::Cursor& cursor) { auto dataBuffer = decodeDataBuffer(cursor); return StreamBuffer(std::move(dataBuffer.first), dataBuffer.second, false); @@ -258,6 +286,16 @@ class TestQuicTransport auto buffer = decodeDatagramFrame(cursor); auto frame = DatagramFrame(buffer.second, std::move(buffer.first)); handleDatagram(*conn_, frame, data.receiveTimePoint); + } else if (type == TestFrameType::STREAM_GROUP) { + auto res = decodeStreamGroupBuffer(cursor); + QuicStreamState* stream = + conn_->streamManager->getStream(res.id, res.groupId); + if (!stream) { + continue; + } + appendDataToReadBuffer(*stream, std::move(res.buf)); + conn_->streamManager->updateReadableStreams(*stream); + conn_->streamManager->updatePeekableStreams(*stream); } else { auto buffer = decodeStreamBuffer(cursor); QuicStreamState* stream = conn_->streamManager->getStream(buffer.first); @@ -346,8 +384,11 @@ class TestQuicTransport void onReadError(const folly::AsyncSocketException&) noexcept {} - void addDataToStream(StreamId id, StreamBuffer data) { - auto buf = encodeStreamBuffer(id, std::move(data)); + void addDataToStream( + StreamId id, + StreamBuffer data, + folly::Optional groupId = folly::none) { + auto buf = encodeStreamBuffer(id, std::move(data), std::move(groupId)); SocketAddress addr("127.0.0.1", 1000); onNetworkData(addr, NetworkData(std::move(buf), Clock::now())); } @@ -4038,5 +4079,160 @@ TEST_P(QuicTransportImplTestBase, BackgroundModeChangeWithStreamChanges) { manager.removeClosedStream(stream2Id); } +class QuicTransportImplTestWithGroups : public QuicTransportImplTestBase {}; + +INSTANTIATE_TEST_SUITE_P( + QuicTransportImplTestWithGroups, + QuicTransportImplTestWithGroups, + ::testing::Values(DelayedStreamNotifsTestParam{ + .notifyOnNewStreamsExplicitly = true})); + +TEST_P(QuicTransportImplTestWithGroups, ReadCallbackWithGroupsDataAvailable) { + auto transportSettings = transport->getTransportSettings(); + transportSettings.maxStreamGroupsAdvertized = 16; + transport->setTransportSettings(transportSettings); + transport->getConnectionState().streamManager->refreshTransportSettings( + transportSettings); + + auto groupId = transport->createBidirectionalStreamGroup(); + EXPECT_TRUE(groupId.hasValue()); + auto stream1 = transport->createBidirectionalStreamInGroup(*groupId).value(); + auto stream2 = transport->createBidirectionalStreamInGroup(*groupId).value(); + + NiceMock readCb1; + NiceMock readCb2; + + transport->setReadCallback(stream1, &readCb1); + transport->setReadCallback(stream2, &readCb2); + + transport->addDataToStream( + stream1, + StreamBuffer(folly::IOBuf::copyBuffer("actual stream data"), 0), + *groupId); + + transport->addDataToStream( + stream2, + StreamBuffer(folly::IOBuf::copyBuffer("actual stream data"), 10), + *groupId); + + EXPECT_CALL(readCb1, readAvailableWithGroup(stream1, *groupId)); + transport->driveReadCallbacks(); + + transport->addDataToStream( + stream2, + StreamBuffer(folly::IOBuf::copyBuffer("actual stream data"), 0), + *groupId); + + EXPECT_CALL(readCb1, readAvailableWithGroup(stream1, *groupId)); + EXPECT_CALL(readCb2, readAvailableWithGroup(stream2, *groupId)); + transport->driveReadCallbacks(); + + EXPECT_CALL(readCb1, readAvailableWithGroup(stream1, *groupId)); + EXPECT_CALL(readCb2, readAvailableWithGroup(stream2, *groupId)); + transport->driveReadCallbacks(); + + EXPECT_CALL(readCb2, readAvailableWithGroup(stream2, *groupId)); + transport->setReadCallback(stream1, nullptr); + transport->driveReadCallbacks(); + transport.reset(); +} + +TEST_P(QuicTransportImplTestWithGroups, ReadErrorCallbackWithGroups) { + auto transportSettings = transport->getTransportSettings(); + transportSettings.maxStreamGroupsAdvertized = 16; + transport->setTransportSettings(transportSettings); + transport->getConnectionState().streamManager->refreshTransportSettings( + transportSettings); + + auto groupId = transport->createBidirectionalStreamGroup(); + EXPECT_TRUE(groupId.hasValue()); + auto stream1 = transport->createBidirectionalStreamInGroup(*groupId).value(); + + NiceMock readCb1; + + transport->setReadCallback(stream1, &readCb1); + + transport->addStreamReadError(stream1, LocalErrorCode::NO_ERROR); + transport->addDataToStream( + stream1, + StreamBuffer(folly::IOBuf::copyBuffer("actual stream data"), 0), + *groupId); + + EXPECT_CALL(readCb1, readErrorWithGroup(stream1, *groupId, _)); + transport->driveReadCallbacks(); + + transport.reset(); +} + +TEST_P( + QuicTransportImplTestWithGroups, + ReadCallbackWithGroupsCancellCallbacks) { + auto transportSettings = transport->getTransportSettings(); + transportSettings.maxStreamGroupsAdvertized = 16; + transport->setTransportSettings(transportSettings); + transport->getConnectionState().streamManager->refreshTransportSettings( + transportSettings); + + auto groupId = transport->createBidirectionalStreamGroup(); + EXPECT_TRUE(groupId.hasValue()); + auto stream1 = transport->createBidirectionalStreamInGroup(*groupId).value(); + auto stream2 = transport->createBidirectionalStreamInGroup(*groupId).value(); + + NiceMock readCb1; + NiceMock readCb2; + + transport->setReadCallback(stream1, &readCb1); + transport->setReadCallback(stream2, &readCb2); + + transport->addDataToStream( + stream1, + StreamBuffer(folly::IOBuf::copyBuffer("actual stream data"), 0), + *groupId); + + transport->addDataToStream( + stream2, + StreamBuffer(folly::IOBuf::copyBuffer("actual stream data"), 10), + *groupId); + + EXPECT_CALL(readCb1, readErrorWithGroup(stream1, *groupId, _)); + EXPECT_CALL(readCb2, readErrorWithGroup(stream2, *groupId, _)); + QuicError error = + QuicError(TransportErrorCode::PROTOCOL_VIOLATION, "test error"); + transport->cancelAllAppCallbacks(error); + transport.reset(); +} + +TEST_P(QuicTransportImplTestWithGroups, onNewStreamsAndGroupsCallbacks) { + auto transportSettings = transport->getTransportSettings(); + transportSettings.maxStreamGroupsAdvertized = 16; + transport->setTransportSettings(transportSettings); + transport->getConnectionState().streamManager->refreshTransportSettings( + transportSettings); + + auto readData = folly::IOBuf::copyBuffer("actual stream data"); + + StreamGroupId groupId = 0x00; + StreamId stream1 = 0x00; + EXPECT_CALL(connCallback, onNewBidirectionalStreamGroup(groupId)); + EXPECT_CALL(connCallback, onNewBidirectionalStreamInGroup(stream1, groupId)); + transport->addDataToStream( + stream1, StreamBuffer(readData->clone(), 0, true), groupId); + + StreamId stream2 = 0x04; + EXPECT_CALL(connCallback, onNewBidirectionalStreamInGroup(stream2, groupId)); + transport->addDataToStream( + stream2, StreamBuffer(readData->clone(), 0, true), groupId); + + StreamGroupId groupIdUni = 0x02; + StreamId uniStream3 = 0xa; + EXPECT_CALL(connCallback, onNewUnidirectionalStreamGroup(groupIdUni)); + EXPECT_CALL( + connCallback, onNewUnidirectionalStreamInGroup(uniStream3, groupIdUni)); + transport->addDataToStream( + uniStream3, StreamBuffer(readData->clone(), 0, true), groupIdUni); + + transport.reset(); +} + } // namespace test } // namespace quic diff --git a/quic/client/QuicClientTransport.cpp b/quic/client/QuicClientTransport.cpp index ec9aa2a0b..5b1c002a8 100644 --- a/quic/client/QuicClientTransport.cpp +++ b/quic/client/QuicClientTransport.cpp @@ -494,7 +494,8 @@ void QuicClientTransport::processPacketData( << " len=" << frame.data->computeChainDataLength() << " fin=" << frame.fin << " packetNum=" << packetNum << " " << *this; - auto stream = conn_->streamManager->getStream(frame.streamId); + auto stream = conn_->streamManager->getStream( + frame.streamId, frame.streamGroupId); pktHasRetransmittableData = true; if (!stream) { VLOG(10) << "Could not find stream=" << frame.streamId << " " diff --git a/quic/server/state/ServerStateMachine.cpp b/quic/server/state/ServerStateMachine.cpp index b09a0c77c..d4a917c82 100644 --- a/quic/server/state/ServerStateMachine.cpp +++ b/quic/server/state/ServerStateMachine.cpp @@ -1083,7 +1083,8 @@ void onServerReadDataFromOpen( << " fin=" << frame.fin << " " << conn; pktHasRetransmittableData = true; isNonProbingPacket = true; - auto stream = conn.streamManager->getStream(frame.streamId); + auto stream = conn.streamManager->getStream( + frame.streamId, frame.streamGroupId); // Ignore data from closed streams that we don't have the // state for any more. if (stream) {