1
0
mirror of https://github.com/facebookincubator/mvfst.git synced 2025-08-08 09:42:06 +03:00

Defer stream available callbacks until connCallback_ is set

Summary: If the peer sends higher stream limits in the settings, these callbacks may be invoked before the app has a chance to set the connection callback.

Reviewed By: sharmafb

Differential Revision: D64214648

fbshipit-source-id: 6a8a9b8d4d9e02a2baad672d69a43ba61daba918
This commit is contained in:
Alan Frindell
2024-10-25 15:02:38 -07:00
committed by Facebook GitHub Bot
parent 02a98533da
commit 0c061201af
5 changed files with 37 additions and 5 deletions

View File

@@ -646,6 +646,14 @@ void QuicTransportBaseLite::setConnectionSetupCallback(
void QuicTransportBaseLite::setConnectionCallback( void QuicTransportBaseLite::setConnectionCallback(
folly::MaybeManagedPtr<ConnectionCallback> callback) { folly::MaybeManagedPtr<ConnectionCallback> callback) {
connCallback_ = callback; connCallback_ = callback;
if (connCallback_) {
runOnEvbAsync([](auto self) { self->processCallbacksAfterNetworkData(); });
}
}
void QuicTransportBaseLite::setConnectionCallbackFromCtor(
folly::MaybeManagedPtr<ConnectionCallback> callback) {
connCallback_ = callback;
} }
Optional<LocalErrorCode> QuicTransportBaseLite::setControlStream(StreamId id) { Optional<LocalErrorCode> QuicTransportBaseLite::setControlStream(StreamId id) {
@@ -1501,6 +1509,10 @@ void QuicTransportBaseLite::processCallbacksAfterNetworkData() {
if (closeState_ != CloseState::OPEN) { if (closeState_ != CloseState::OPEN) {
return; return;
} }
if (!connCallback_ || !conn_->streamManager) {
return;
}
// We reuse this storage for storing streams which need callbacks. // We reuse this storage for storing streams which need callbacks.
std::vector<StreamId> tempStorage; std::vector<StreamId> tempStorage;

View File

@@ -471,6 +471,9 @@ class QuicTransportBaseLite : virtual public QuicSocketLite,
virtual void createBufAccessor(size_t /* capacity */) {} virtual void createBufAccessor(size_t /* capacity */) {}
protected: protected:
void setConnectionCallbackFromCtor(
folly::MaybeManagedPtr<ConnectionCallback> callback);
/** /**
* A wrapper around writeSocketData * A wrapper around writeSocketData
* *

View File

@@ -235,8 +235,6 @@ class TestQuicTransport
ConnectionCallback* connCb) ConnectionCallback* connCb)
: QuicTransportBase(std::move(evb), std::move(socket)), : QuicTransportBase(std::move(evb), std::move(socket)),
observerContainer_(std::make_shared<SocketObserverContainer>(this)) { observerContainer_(std::make_shared<SocketObserverContainer>(this)) {
setConnectionSetupCallback(connSetupCb);
setConnectionCallback(connCb);
auto conn = std::make_unique<QuicServerConnectionState>( auto conn = std::make_unique<QuicServerConnectionState>(
FizzServerQuicHandshakeContext::Builder().build()); FizzServerQuicHandshakeContext::Builder().build());
conn->clientConnectionId = ConnectionId({10, 9, 8, 7}); conn->clientConnectionId = ConnectionId({10, 9, 8, 7});
@@ -247,6 +245,8 @@ class TestQuicTransport
aead = test::createNoOpAead(); aead = test::createNoOpAead();
headerCipher = test::createNoOpHeaderCipher(); headerCipher = test::createNoOpHeaderCipher();
connIdAlgo_ = std::make_unique<DefaultConnectionIdAlgo>(); connIdAlgo_ = std::make_unique<DefaultConnectionIdAlgo>();
setConnectionSetupCallback(connSetupCb);
setConnectionCallbackFromCtor(connCb);
} }
~TestQuicTransport() override { ~TestQuicTransport() override {
@@ -711,6 +711,23 @@ TEST_P(QuicTransportImplTestBase, IdleTimeoutExpiredDestroysTransport) {
transport->invokeIdleTimeout(); transport->invokeIdleTimeout();
} }
TEST_P(QuicTransportImplTestBase, DelayConnCallback) {
transport->transportConn->streamManager->setMaxLocalBidirectionalStreams(
0, /*force=*/true);
transport->setConnectionCallback(nullptr);
transport->addMaxStreamsFrame(
MaxStreamsFrame(10, /*isBidirectionalIn=*/true));
transport->setConnectionCallback(&connCallback);
EXPECT_CALL(connCallback, onBidirectionalStreamsAvailable(_))
.WillOnce(Invoke([](uint64_t numAvailableStreams) {
EXPECT_EQ(numAvailableStreams, 10);
}));
transport->getEventBase()->loopOnce();
transport.reset();
}
TEST_P(QuicTransportImplTestBase, IdleTimeoutStreamMaessage) { TEST_P(QuicTransportImplTestBase, IdleTimeoutStreamMaessage) {
auto stream1 = transport->createBidirectionalStream().value(); auto stream1 = transport->createBidirectionalStream().value();
auto stream2 = transport->createBidirectionalStream().value(); auto stream2 = transport->createBidirectionalStream().value();

View File

@@ -26,8 +26,6 @@ class TestQuicTransport
ConnectionCallback* connCb) ConnectionCallback* connCb)
: QuicTransportBase(std::move(evb), std::move(socket)), : QuicTransportBase(std::move(evb), std::move(socket)),
observerContainer_(std::make_shared<SocketObserverContainer>(this)) { observerContainer_(std::make_shared<SocketObserverContainer>(this)) {
setConnectionSetupCallback(connSetupCb);
setConnectionCallback(connCb);
conn_.reset(new QuicServerConnectionState( conn_.reset(new QuicServerConnectionState(
FizzServerQuicHandshakeContext::Builder().build())); FizzServerQuicHandshakeContext::Builder().build()));
conn_->clientConnectionId = ConnectionId({9, 8, 7, 6}); conn_->clientConnectionId = ConnectionId({9, 8, 7, 6});
@@ -36,6 +34,8 @@ class TestQuicTransport
conn_->observerContainer = observerContainer_; conn_->observerContainer = observerContainer_;
aead = test::createNoOpAead(); aead = test::createNoOpAead();
headerCipher = test::createNoOpHeaderCipher(); headerCipher = test::createNoOpHeaderCipher();
setConnectionSetupCallback(connSetupCb);
setConnectionCallbackFromCtor(connCb);
} }
~TestQuicTransport() override { ~TestQuicTransport() override {

View File

@@ -67,7 +67,7 @@ QuicServerTransport::QuicServerTransport(
conn_.reset(tempConn.release()); conn_.reset(tempConn.release());
conn_->observerContainer = wrappedObserverContainer_.getWeakPtr(); conn_->observerContainer = wrappedObserverContainer_.getWeakPtr();
setConnectionSetupCallback(connSetupCb); setConnectionSetupCallback(connSetupCb);
setConnectionCallback(connStreamsCb); setConnectionCallbackFromCtor(connStreamsCb);
registerAllTransportKnobParamHandlers(); registerAllTransportKnobParamHandlers();
} }