diff --git a/quic/api/QuicSocket.h b/quic/api/QuicSocket.h index 9b2208bb7..7f2f9dc80 100644 --- a/quic/api/QuicSocket.h +++ b/quic/api/QuicSocket.h @@ -86,6 +86,20 @@ class QuicSocket { * Called after the transport successfully processes the received packet. */ virtual void onFirstPeerPacketProcessed() noexcept {} + + /** + * Called when more bidirectional streams become available for creation + * (max local bidirectional stream ID was increased). + */ + virtual void onBidirectionalStreamsAvailable( + uint64_t /*numStreamsAvailable*/) noexcept {} + + /** + * Called when more unidirectional streams become available for creation + * (max local unidirectional stream ID was increased). + */ + virtual void onUnidirectionalStreamsAvailable( + uint64_t /*numStreamsAvailable*/) noexcept {} }; /** diff --git a/quic/api/QuicTransportBase.cpp b/quic/api/QuicTransportBase.cpp index 017bcb307..c3a63b0cf 100644 --- a/quic/api/QuicTransportBase.cpp +++ b/quic/api/QuicTransportBase.cpp @@ -1101,6 +1101,25 @@ void QuicTransportBase::invokeDataRejectedCallbacks() { self->conn_->streamManager->clearDataRejected(); } +void QuicTransportBase::invokeStreamsAvailableCallbacks() { + if (conn_->streamManager->consumeMaxLocalBidirectionalStreamIdIncreased()) { + // check in case new streams were created in preceding callbacks + // and max is already reached + auto numStreams = getNumOpenableBidirectionalStreams(); + if (numStreams > 0) { + connCallback_->onBidirectionalStreamsAvailable(numStreams); + } + } + if (conn_->streamManager->consumeMaxLocalUnidirectionalStreamIdIncreased()) { + // check in case new streams were created in preceding callbacks + // and max is already reached + auto numStreams = getNumOpenableUnidirectionalStreams(); + if (numStreams > 0) { + connCallback_->onUnidirectionalStreamsAvailable(numStreams); + } + } +} + folly::Expected, LocalErrorCode> QuicTransportBase::sendDataRejected(StreamId id, uint64_t offset) { if (!conn_->partialReliabilityEnabled) { @@ -1589,6 +1608,8 @@ void QuicTransportBase::processCallbacksAfterNetworkData() { } } } + + invokeStreamsAvailableCallbacks(); } void QuicTransportBase::onNetworkData( diff --git a/quic/api/QuicTransportBase.h b/quic/api/QuicTransportBase.h index 37c211d62..0d3969427 100644 --- a/quic/api/QuicTransportBase.h +++ b/quic/api/QuicTransportBase.h @@ -487,6 +487,7 @@ class QuicTransportBase : public QuicSocket { void invokePeekDataAndCallbacks(); void invokeDataExpiredCallbacks(); void invokeDataRejectedCallbacks(); + void invokeStreamsAvailableCallbacks(); void updateReadLooper(); void updatePeekLooper(); void updateWriteLooper(bool thisIteration); diff --git a/quic/api/test/Mocks.h b/quic/api/test/Mocks.h index 62f34e4ec..9c28108a8 100644 --- a/quic/api/test/Mocks.h +++ b/quic/api/test/Mocks.h @@ -110,6 +110,13 @@ class MockConnectionCallback : public QuicSocket::ConnectionCallback { GMOCK_METHOD0_(, noexcept, , onReplaySafe, void()); GMOCK_METHOD0_(, noexcept, , onTransportReady, void()); GMOCK_METHOD0_(, noexcept, , onFirstPeerPacketProcessed, void()); + GMOCK_METHOD1_(, noexcept, , onBidirectionalStreamsAvailable, void(uint64_t)); + GMOCK_METHOD1_( + , + noexcept, + , + onUnidirectionalStreamsAvailable, + void(uint64_t)); }; class MockDeliveryCallback : public QuicSocket::DeliveryCallback { diff --git a/quic/api/test/QuicTransportBaseTest.cpp b/quic/api/test/QuicTransportBaseTest.cpp index 3f470caec..372f81494 100644 --- a/quic/api/test/QuicTransportBaseTest.cpp +++ b/quic/api/test/QuicTransportBaseTest.cpp @@ -33,7 +33,8 @@ enum class TestFrameType : uint8_t { STREAM, CRYPTO, EXPIRED_DATA, - REJECTED_DATA + REJECTED_DATA, + MAX_STREAMS }; // A made up encoding decoding of a stream. @@ -84,6 +85,16 @@ Buf encodeMinStreamDataFrame(const MinStreamDataFrame& frame) { return buf; } +// A made up encoding of a MaxStreamsFrame. +Buf encodeMaxStreamsFrame(const MaxStreamsFrame& frame) { + auto buf = IOBuf::create(25); + folly::io::Appender appender(buf.get(), 25); + appender.writeBE(static_cast(TestFrameType::MAX_STREAMS)); + appender.writeBE(frame.isForBidirectionalStream() ? 1 : 0); + appender.writeBE(frame.maxStreams); + return buf; +} + std::pair decodeDataBuffer(folly::io::Cursor& cursor) { Buf outData; auto len = cursor.readBE(); @@ -122,6 +133,12 @@ MinStreamDataFrame decodeMinStreamDataFrame(folly::io::Cursor& cursor) { return frame; } +MaxStreamsFrame decodeMaxStreamsFrame(folly::io::Cursor& cursor) { + bool isBidi = cursor.readBE(); + auto maxStreams = cursor.readBE(); + return MaxStreamsFrame(maxStreams, isBidi); +} + class TestPingCallback : public QuicSocket::PingCallback { public: void pingAcknowledged() noexcept override {} @@ -194,6 +211,15 @@ class TestQuicTransport } onRecvMinStreamDataFrame(stream, minDataFrame, packetNum_); packetNum_++; + } else if (type == TestFrameType::MAX_STREAMS) { + auto maxStreamsFrame = decodeMaxStreamsFrame(cursor); + if (maxStreamsFrame.isForBidirectionalStream()) { + conn_->streamManager->setMaxLocalBidirectionalStreams( + maxStreamsFrame.maxStreams); + } else { + conn_->streamManager->setMaxLocalUnidirectionalStreams( + maxStreamsFrame.maxStreams); + } } else { auto buffer = decodeStreamBuffer(cursor); QuicStreamState* stream = conn_->streamManager->getStream(buffer.first); @@ -304,6 +330,12 @@ class TestQuicTransport onNetworkData(addr, NetworkData(std::move(buf), Clock::now())); } + void addMaxStreamsFrame(MaxStreamsFrame frame) { + auto buf = encodeMaxStreamsFrame(frame); + SocketAddress addr("127.0.0.1", 1000); + onNetworkData(addr, NetworkData(std::move(buf), Clock::now())); + } + void addStreamReadError(StreamId id, QuicErrorCode ex) { QuicStreamState* stream = conn_->streamManager->getStream(id); stream->streamReadError = ex; @@ -1023,6 +1055,74 @@ TEST_F(QuicTransportImplTest, CreateStreamLimitsUnidirectionalFew) { transport.reset(); } +TEST_F(QuicTransportImplTest, onBidiStreamsAvailableCallback) { + transport->transportConn->streamManager->setMaxLocalBidirectionalStreams( + 0, /*force=*/true); + + EXPECT_CALL(connCallback, onBidirectionalStreamsAvailable(_)) + .WillOnce(Invoke([](uint64_t numAvailableStreams) { + EXPECT_EQ(numAvailableStreams, 1); + })); + transport->addMaxStreamsFrame(MaxStreamsFrame(1, /*isBidirectionalIn=*/true)); + EXPECT_EQ(transport->getNumOpenableBidirectionalStreams(), 1); + + // same value max streams frame doesn't trigger callback + transport->addMaxStreamsFrame(MaxStreamsFrame(1, /*isBidirectionalIn=*/true)); +} + +TEST_F(QuicTransportImplTest, onBidiStreamsAvailableCallbackAfterExausted) { + transport->transportConn->streamManager->setMaxLocalBidirectionalStreams( + 0, /*force=*/true); + + EXPECT_CALL(connCallback, onBidirectionalStreamsAvailable(_)).Times(2); + transport->addMaxStreamsFrame(MaxStreamsFrame( + 1, + /*isBidirectionalIn=*/true)); + EXPECT_EQ(transport->getNumOpenableBidirectionalStreams(), 1); + + auto result = transport->createBidirectionalStream(); + EXPECT_TRUE(result); + EXPECT_EQ(transport->getNumOpenableBidirectionalStreams(), 0); + + transport->addMaxStreamsFrame(MaxStreamsFrame( + 2, + /*isBidirectionalIn=*/true)); +} + +TEST_F(QuicTransportImplTest, oneUniStreamsAvailableCallback) { + transport->transportConn->streamManager->setMaxLocalUnidirectionalStreams( + 0, /*force=*/true); + + EXPECT_CALL(connCallback, onUnidirectionalStreamsAvailable(_)) + .WillOnce(Invoke([](uint64_t numAvailableStreams) { + EXPECT_EQ(numAvailableStreams, 1); + })); + transport->addMaxStreamsFrame( + MaxStreamsFrame(1, /*isBidirectionalIn=*/false)); + EXPECT_EQ(transport->getNumOpenableUnidirectionalStreams(), 1); + + // same value max streams frame doesn't trigger callback + transport->addMaxStreamsFrame( + MaxStreamsFrame(1, /*isBidirectionalIn=*/false)); +} + +TEST_F(QuicTransportImplTest, onUniStreamsAvailableCallbackAfterExausted) { + transport->transportConn->streamManager->setMaxLocalUnidirectionalStreams( + 0, /*force=*/true); + + EXPECT_CALL(connCallback, onUnidirectionalStreamsAvailable(_)).Times(2); + transport->addMaxStreamsFrame( + MaxStreamsFrame(1, /*isBidirectionalIn=*/false)); + EXPECT_EQ(transport->getNumOpenableUnidirectionalStreams(), 1); + + auto result = transport->createUnidirectionalStream(); + EXPECT_TRUE(result); + EXPECT_EQ(transport->getNumOpenableUnidirectionalStreams(), 0); + + transport->addMaxStreamsFrame( + MaxStreamsFrame(2, /*isBidirectionalIn=*/false)); +} + TEST_F(QuicTransportImplTest, ReadDataAlsoChecksLossAlarm) { transport->transportConn->oneRttWriteCipher = test::createNoOpAead(); auto stream = transport->createBidirectionalStream().value(); diff --git a/quic/state/QuicStreamManager.cpp b/quic/state/QuicStreamManager.cpp index caebd60ab..b84026c39 100644 --- a/quic/state/QuicStreamManager.cpp +++ b/quic/state/QuicStreamManager.cpp @@ -155,6 +155,7 @@ void QuicStreamManager::setMaxLocalBidirectionalStreams( initialLocalBidirectionalStreamId_; if (force || maxStreamId > maxLocalBidirectionalStreamId_) { maxLocalBidirectionalStreamId_ = maxStreamId; + maxLocalBidirectionalStreamIdIncreased_ = true; } } @@ -170,6 +171,7 @@ void QuicStreamManager::setMaxLocalUnidirectionalStreams( initialLocalUnidirectionalStreamId_; if (force || maxStreamId > maxLocalUnidirectionalStreamId_) { maxLocalUnidirectionalStreamId_ = maxStreamId; + maxLocalUnidirectionalStreamIdIncreased_ = true; } } @@ -211,6 +213,18 @@ void QuicStreamManager::setMaxRemoteUnidirectionalStreamsInternal( } } +bool QuicStreamManager::consumeMaxLocalBidirectionalStreamIdIncreased() { + bool res = maxLocalBidirectionalStreamIdIncreased_; + maxLocalBidirectionalStreamIdIncreased_ = false; + return res; +} + +bool QuicStreamManager::consumeMaxLocalUnidirectionalStreamIdIncreased() { + bool res = maxLocalUnidirectionalStreamIdIncreased_; + maxLocalUnidirectionalStreamIdIncreased_ = false; + return res; +} + void QuicStreamManager::refreshTransportSettings( const TransportSettings& settings) { transportSettings_ = &settings; @@ -371,16 +385,15 @@ QuicStreamManager::createStream(StreamId streamId) { if (existingStream) { return existingStream; } - auto& nextAcceptableStreamId = isUnidirectionalStream(streamId) + bool isUni = isUnidirectionalStream(streamId); + auto& nextAcceptableStreamId = isUni ? nextAcceptableLocalUnidirectionalStreamId_ : nextAcceptableLocalBidirectionalStreamId_; - auto maxStreamId = isUnidirectionalStream(streamId) - ? maxLocalUnidirectionalStreamId_ - : maxLocalBidirectionalStreamId_; + auto maxStreamId = + isUni ? maxLocalUnidirectionalStreamId_ : maxLocalBidirectionalStreamId_; - auto& openLocalStreams = isUnidirectionalStream(streamId) - ? openUnidirectionalLocalStreams_ - : openBidirectionalLocalStreams_; + auto& openLocalStreams = + isUni ? openUnidirectionalLocalStreams_ : openBidirectionalLocalStreams_; auto openedResult = openLocalStreamIfNotClosed( streamId, openLocalStreams, nextAcceptableStreamId, maxStreamId); if (openedResult != LocalErrorCode::NO_ERROR) { @@ -523,4 +536,5 @@ void QuicStreamManager::setStreamAsControl(QuicStreamState& stream) { bool QuicStreamManager::isAppIdle() const { return isAppIdle_; } + } // namespace quic diff --git a/quic/state/QuicStreamManager.h b/quic/state/QuicStreamManager.h index 8ebf571cf..3acfb336a 100644 --- a/quic/state/QuicStreamManager.h +++ b/quic/state/QuicStreamManager.h @@ -125,24 +125,36 @@ class QuicStreamManager { bool streamExists(StreamId streamId); uint64_t openableLocalBidirectionalStreams() { + CHECK_GE( + maxLocalBidirectionalStreamId_, + nextAcceptableLocalBidirectionalStreamId_); return (maxLocalBidirectionalStreamId_ - nextAcceptableLocalBidirectionalStreamId_) / detail::kStreamIncrement; } uint64_t openableLocalUnidirectionalStreams() { + CHECK_GE( + maxLocalUnidirectionalStreamId_, + nextAcceptableLocalUnidirectionalStreamId_); return (maxLocalUnidirectionalStreamId_ - nextAcceptableLocalUnidirectionalStreamId_) / detail::kStreamIncrement; } uint64_t openableRemoteBidirectionalStreams() { + CHECK_GE( + maxRemoteBidirectionalStreamId_, + nextAcceptablePeerBidirectionalStreamId_); return (maxRemoteBidirectionalStreamId_ - nextAcceptablePeerBidirectionalStreamId_) / detail::kStreamIncrement; } uint64_t openableRemoteUnidirectionalStreams() { + CHECK_GE( + maxRemoteUnidirectionalStreamId_, + nextAcceptablePeerUnidirectionalStreamId_); return (maxRemoteUnidirectionalStreamId_ - nextAcceptablePeerUnidirectionalStreamId_) / detail::kStreamIncrement; @@ -312,6 +324,18 @@ class QuicStreamManager { */ void setMaxRemoteUnidirectionalStreams(uint64_t maxStreams); + /* + * Returns true if MaxLocalBidirectionalStreamId was increased + * since last call of this function (resets flag). + */ + bool consumeMaxLocalBidirectionalStreamIdIncreased(); + + /* + * Returns true if MaxLocalUnidirectionalStreamId was increased + * since last call of this function (resets flag). + */ + bool consumeMaxLocalUnidirectionalStreamIdIncreased(); + void refreshTransportSettings(const TransportSettings& settings); /* @@ -831,6 +855,9 @@ class QuicStreamManager { bool isAppIdle_{false}; const TransportSettings* transportSettings_; + + bool maxLocalBidirectionalStreamIdIncreased_{false}; + bool maxLocalUnidirectionalStreamIdIncreased_{false}; }; } // namespace quic