diff --git a/quic/api/QuicTransportBaseLite.cpp b/quic/api/QuicTransportBaseLite.cpp index 0042f391c..4aa1373a8 100644 --- a/quic/api/QuicTransportBaseLite.cpp +++ b/quic/api/QuicTransportBaseLite.cpp @@ -646,6 +646,14 @@ void QuicTransportBaseLite::setConnectionSetupCallback( void QuicTransportBaseLite::setConnectionCallback( folly::MaybeManagedPtr callback) { connCallback_ = callback; + if (connCallback_) { + runOnEvbAsync([](auto self) { self->processCallbacksAfterNetworkData(); }); + } +} + +void QuicTransportBaseLite::setConnectionCallbackFromCtor( + folly::MaybeManagedPtr callback) { + connCallback_ = callback; } Optional QuicTransportBaseLite::setControlStream(StreamId id) { @@ -1501,6 +1509,10 @@ void QuicTransportBaseLite::processCallbacksAfterNetworkData() { if (closeState_ != CloseState::OPEN) { return; } + if (!connCallback_ || !conn_->streamManager) { + return; + } + // We reuse this storage for storing streams which need callbacks. std::vector tempStorage; diff --git a/quic/api/QuicTransportBaseLite.h b/quic/api/QuicTransportBaseLite.h index a0679b64f..61e4a9545 100644 --- a/quic/api/QuicTransportBaseLite.h +++ b/quic/api/QuicTransportBaseLite.h @@ -471,6 +471,9 @@ class QuicTransportBaseLite : virtual public QuicSocketLite, virtual void createBufAccessor(size_t /* capacity */) {} protected: + void setConnectionCallbackFromCtor( + folly::MaybeManagedPtr callback); + /** * A wrapper around writeSocketData * diff --git a/quic/api/test/QuicTransportBaseTest.cpp b/quic/api/test/QuicTransportBaseTest.cpp index f13df3951..265b84ce7 100644 --- a/quic/api/test/QuicTransportBaseTest.cpp +++ b/quic/api/test/QuicTransportBaseTest.cpp @@ -235,8 +235,6 @@ class TestQuicTransport ConnectionCallback* connCb) : QuicTransportBase(std::move(evb), std::move(socket)), observerContainer_(std::make_shared(this)) { - setConnectionSetupCallback(connSetupCb); - setConnectionCallback(connCb); auto conn = std::make_unique( FizzServerQuicHandshakeContext::Builder().build()); conn->clientConnectionId = ConnectionId({10, 9, 8, 7}); @@ -247,6 +245,8 @@ class TestQuicTransport aead = test::createNoOpAead(); headerCipher = test::createNoOpHeaderCipher(); connIdAlgo_ = std::make_unique(); + setConnectionSetupCallback(connSetupCb); + setConnectionCallbackFromCtor(connCb); } ~TestQuicTransport() override { @@ -711,6 +711,23 @@ TEST_P(QuicTransportImplTestBase, IdleTimeoutExpiredDestroysTransport) { transport->invokeIdleTimeout(); } +TEST_P(QuicTransportImplTestBase, DelayConnCallback) { + transport->transportConn->streamManager->setMaxLocalBidirectionalStreams( + 0, /*force=*/true); + transport->setConnectionCallback(nullptr); + + transport->addMaxStreamsFrame( + MaxStreamsFrame(10, /*isBidirectionalIn=*/true)); + + transport->setConnectionCallback(&connCallback); + EXPECT_CALL(connCallback, onBidirectionalStreamsAvailable(_)) + .WillOnce(Invoke([](uint64_t numAvailableStreams) { + EXPECT_EQ(numAvailableStreams, 10); + })); + transport->getEventBase()->loopOnce(); + transport.reset(); +} + TEST_P(QuicTransportImplTestBase, IdleTimeoutStreamMaessage) { auto stream1 = transport->createBidirectionalStream().value(); auto stream2 = transport->createBidirectionalStream().value(); diff --git a/quic/api/test/TestQuicTransport.h b/quic/api/test/TestQuicTransport.h index abdfb3feb..fba9d16fd 100644 --- a/quic/api/test/TestQuicTransport.h +++ b/quic/api/test/TestQuicTransport.h @@ -26,8 +26,6 @@ class TestQuicTransport ConnectionCallback* connCb) : QuicTransportBase(std::move(evb), std::move(socket)), observerContainer_(std::make_shared(this)) { - setConnectionSetupCallback(connSetupCb); - setConnectionCallback(connCb); conn_.reset(new QuicServerConnectionState( FizzServerQuicHandshakeContext::Builder().build())); conn_->clientConnectionId = ConnectionId({9, 8, 7, 6}); @@ -36,6 +34,8 @@ class TestQuicTransport conn_->observerContainer = observerContainer_; aead = test::createNoOpAead(); headerCipher = test::createNoOpHeaderCipher(); + setConnectionSetupCallback(connSetupCb); + setConnectionCallbackFromCtor(connCb); } ~TestQuicTransport() override { diff --git a/quic/server/QuicServerTransport.cpp b/quic/server/QuicServerTransport.cpp index e92894427..fdcd49b85 100644 --- a/quic/server/QuicServerTransport.cpp +++ b/quic/server/QuicServerTransport.cpp @@ -67,7 +67,7 @@ QuicServerTransport::QuicServerTransport( conn_.reset(tempConn.release()); conn_->observerContainer = wrappedObserverContainer_.getWeakPtr(); setConnectionSetupCallback(connSetupCb); - setConnectionCallback(connStreamsCb); + setConnectionCallbackFromCtor(connStreamsCb); registerAllTransportKnobParamHandlers(); }