diff --git a/quic/api/test/MockQuicSocket.h b/quic/api/test/MockQuicSocket.h index 8ce95d975..2f565b361 100644 --- a/quic/api/test/MockQuicSocket.h +++ b/quic/api/test/MockQuicSocket.h @@ -17,14 +17,11 @@ class MockQuicSocket : public QuicSocket { public: using SharedBuf = std::shared_ptr; - MockQuicSocket(folly::EventBase* /*eventBase*/, ConnectionCallback& connCb) - : setupCb_(&connCb), connCb_(&connCb) {} - MockQuicSocket( folly::EventBase* /*eventBase*/, - ConnectionSetupCallback& setupCb, + ConnectionSetupCallback* setupCb, ConnectionCallbackNew* connCb) - : setupCb_(&setupCb), connCb_(connCb) {} + : setupCb_(setupCb), connCb_(connCb) {} MOCK_CONST_METHOD0(good, bool()); MOCK_CONST_METHOD0(replaySafe, bool()); diff --git a/quic/api/test/Mocks.h b/quic/api/test/Mocks.h index 44ff426fd..bf37e2cfe 100644 --- a/quic/api/test/Mocks.h +++ b/quic/api/test/Mocks.h @@ -101,9 +101,23 @@ class MockWriteCallback : public QuicSocket::WriteCallback { void(std::pair>)); }; -class MockConnectionCallback : public QuicSocket::ConnectionCallback { +class MockConnectionSetupCallback : public QuicSocket::ConnectionSetupCallback { public: - ~MockConnectionCallback() override = default; + ~MockConnectionSetupCallback() override = default; + GMOCK_METHOD1_( + , + noexcept, + , + onConnectionSetupError, + void(std::pair)); + GMOCK_METHOD0_(, noexcept, , onReplaySafe, void()); + GMOCK_METHOD0_(, noexcept, , onTransportReady, void()); + GMOCK_METHOD0_(, noexcept, , onFirstPeerPacketProcessed, void()); +}; + +class MockConnectionCallbackNew : public QuicSocket::ConnectionCallbackNew { + public: + ~MockConnectionCallbackNew() override = default; GMOCK_METHOD1_(, noexcept, , onFlowControlUpdate, void(StreamId)); GMOCK_METHOD1_(, noexcept, , onNewBidirectionalStream, void(StreamId)); @@ -115,19 +129,12 @@ class MockConnectionCallback : public QuicSocket::ConnectionCallback { onStopSending, void(StreamId, ApplicationErrorCode)); GMOCK_METHOD0_(, noexcept, , onConnectionEnd, void()); - void onConnectionSetupError( - std::pair code) noexcept override { - onConnectionError(std::move(code)); - } GMOCK_METHOD1_( , noexcept, , onConnectionError, void(std::pair)); - GMOCK_METHOD0_(, noexcept, , onReplaySafe, void()); - GMOCK_METHOD0_(, noexcept, , onTransportReady, void()); - GMOCK_METHOD0_(, noexcept, , onFirstPeerPacketProcessed, void()); GMOCK_METHOD1_(, noexcept, , onBidirectionalStreamsAvailable, void(uint64_t)); GMOCK_METHOD1_( , @@ -211,9 +218,10 @@ class MockQuicTransport : public QuicServerTransport { MockQuicTransport( folly::EventBase* evb, std::unique_ptr sock, - ConnectionCallback& cb, + ConnectionSetupCallback* connSetupCb, + ConnectionCallbackNew* connCb, std::shared_ptr ctx) - : QuicServerTransport(evb, std::move(sock), cb, ctx) {} + : QuicServerTransport(evb, std::move(sock), connSetupCb, connCb, ctx) {} virtual ~MockQuicTransport() { customDestructor(); diff --git a/quic/api/test/QuicStreamAsyncTransportTest.cpp b/quic/api/test/QuicStreamAsyncTransportTest.cpp index ecee871cf..b7f6c52b8 100644 --- a/quic/api/test/QuicStreamAsyncTransportTest.cpp +++ b/quic/api/test/QuicStreamAsyncTransportTest.cpp @@ -43,7 +43,11 @@ class QuicStreamAsyncTransportTest : public Test { const folly::SocketAddress& /*addr*/, std::shared_ptr ctx) { auto transport = quic::QuicServerTransport::make( - evb, std::move(socket), serverConnectionCB_, std::move(ctx)); + evb, + std::move(socket), + &serverConnectionSetupCB_, + &serverConnectionCB_, + std::move(ctx)); CHECK(serverSocket_.get() == nullptr); serverSocket_ = transport; return transport; @@ -91,13 +95,14 @@ class QuicStreamAsyncTransportTest : public Test { void createClient() { clientEvbThread_ = std::thread([&]() { clientEvb_.loopForever(); }); - EXPECT_CALL(clientConnectionCB_, onTransportReady()).WillOnce(Invoke([&]() { - clientAsyncWrapper_ = - QuicStreamAsyncTransport::createWithNewStream(client_); - ASSERT_TRUE(clientAsyncWrapper_); - clientAsyncWrapper_->setReadCB(&clientReadCB_); - startPromise_.setValue(); - })); + EXPECT_CALL(clientConnectionSetupCB_, onTransportReady()) + .WillOnce(Invoke([&]() { + clientAsyncWrapper_ = + QuicStreamAsyncTransport::createWithNewStream(client_); + ASSERT_TRUE(clientAsyncWrapper_); + clientAsyncWrapper_->setReadCB(&clientReadCB_); + startPromise_.setValue(); + })); EXPECT_CALL(clientReadCB_, isBufferMovable_()) .WillRepeatedly(Return(false)); @@ -128,7 +133,7 @@ class QuicStreamAsyncTransportTest : public Test { &clientEvb_, std::move(sock), std::move(fizzClientContext)); client_->setHostname("echo.com"); client_->addNewPeerAddress(serverAddr_); - client_->start(&clientConnectionCB_); + client_->start(&clientConnectionSetupCB_, &clientConnectionCB_); }); std::move(future).get(1s); @@ -152,7 +157,8 @@ class QuicStreamAsyncTransportTest : public Test { protected: std::shared_ptr server_; folly::SocketAddress serverAddr_; - NiceMock serverConnectionCB_; + NiceMock serverConnectionSetupCB_; + NiceMock serverConnectionCB_; std::shared_ptr serverSocket_; QuicStreamAsyncTransport::UniquePtr serverAsyncWrapper_; folly::test::MockWriteCallback serverWriteCB_; @@ -162,7 +168,8 @@ class QuicStreamAsyncTransportTest : public Test { std::shared_ptr client_; folly::EventBase clientEvb_; std::thread clientEvbThread_; - NiceMock clientConnectionCB_; + NiceMock clientConnectionSetupCB_; + NiceMock clientConnectionCB_; QuicStreamAsyncTransport::UniquePtr clientAsyncWrapper_; folly::Promise startPromise_; folly::test::MockWriteCallback clientWriteCB_; diff --git a/quic/api/test/QuicTransportBaseTest.cpp b/quic/api/test/QuicTransportBaseTest.cpp index 35eee6cf0..0d6ce091d 100644 --- a/quic/api/test/QuicTransportBaseTest.cpp +++ b/quic/api/test/QuicTransportBaseTest.cpp @@ -198,9 +198,11 @@ class TestQuicTransport TestQuicTransport( folly::EventBase* evb, std::unique_ptr socket, - ConnectionCallback& cb) + ConnectionSetupCallback* connSetupCb, + ConnectionCallbackNew* connCb) : QuicTransportBase(evb, std::move(socket)) { - setConnectionCallback(&cb); + setConnectionSetupCallback(connSetupCb); + setConnectionCallbackNew(connCb); auto conn = std::make_unique( FizzServerQuicHandshakeContext::Builder().build()); conn->clientConnectionId = ConnectionId({10, 9, 8, 7}); @@ -504,7 +506,7 @@ class QuicTransportImplTest : public Test { std::make_unique>(evb.get()); socketPtr = socket.get(); transport = std::make_shared( - evb.get(), std::move(socket), connCallback); + evb.get(), std::move(socket), &connSetupCallback, &connCallback); auto& conn = *transport->transportConn; conn.flowControlState.peerAdvertisedInitialMaxStreamOffsetBidiLocal = kDefaultStreamWindowSize; @@ -530,7 +532,8 @@ class QuicTransportImplTest : public Test { protected: std::unique_ptr evb; - NiceMock connCallback; + NiceMock connSetupCallback; + NiceMock connCallback; TestByteEventCallback byteEventCallback; std::shared_ptr transport; folly::test::MockAsyncUDPSocket* socketPtr; @@ -2812,8 +2815,8 @@ TEST_F(QuicTransportImplTest, GetLocalAddressUnboundSocket) { } TEST_F(QuicTransportImplTest, GetLocalAddressBadSocket) { - auto badTransport = - std::make_shared(evb.get(), nullptr, connCallback); + auto badTransport = std::make_shared( + evb.get(), nullptr, &connSetupCallback, &connCallback); badTransport->closeWithoutWrite(); SocketAddress localAddr = badTransport->getLocalAddress(); EXPECT_FALSE(localAddr.isInitialized()); diff --git a/quic/api/test/QuicTransportTest.cpp b/quic/api/test/QuicTransportTest.cpp index a68128d04..52930749f 100644 --- a/quic/api/test/QuicTransportTest.cpp +++ b/quic/api/test/QuicTransportTest.cpp @@ -69,8 +69,8 @@ class QuicTransportTest : public Test { std::unique_ptr sock = std::make_unique>(&evb_); socket_ = sock.get(); - transport_.reset( - new TestQuicTransport(&evb_, std::move(sock), connCallback_)); + transport_.reset(new TestQuicTransport( + &evb_, std::move(sock), &connSetupCallback_, &connCallback_)); // Set the write handshake state to tell the client that the handshake has // a cipher. auto aead = std::make_unique>(); @@ -114,7 +114,8 @@ class QuicTransportTest : public Test { protected: folly::EventBase evb_; MockAsyncUDPSocket* socket_; - NiceMock connCallback_; + NiceMock connSetupCallback_; + NiceMock connCallback_; NiceMock writeCallback_; MockAead* aead_; std::unique_ptr headerCipher_; diff --git a/quic/api/test/TestQuicTransport.h b/quic/api/test/TestQuicTransport.h index 355f4180d..d697c176f 100644 --- a/quic/api/test/TestQuicTransport.h +++ b/quic/api/test/TestQuicTransport.h @@ -21,9 +21,11 @@ class TestQuicTransport TestQuicTransport( folly::EventBase* evb, std::unique_ptr socket, - ConnectionCallback& cb) + ConnectionSetupCallback* connSetupCb, + ConnectionCallbackNew* connCb) : QuicTransportBase(evb, std::move(socket)) { - setConnectionCallback(&cb); + setConnectionSetupCallback(connSetupCb); + setConnectionCallbackNew(connCb); conn_.reset(new QuicServerConnectionState( FizzServerQuicHandshakeContext::Builder().build())); conn_->clientConnectionId = ConnectionId({9, 8, 7, 6}); diff --git a/quic/fizz/client/test/QuicClientTransportTest.cpp b/quic/fizz/client/test/QuicClientTransportTest.cpp index 7c4a97c01..15608fb65 100644 --- a/quic/fizz/client/test/QuicClientTransportTest.cpp +++ b/quic/fizz/client/test/QuicClientTransportTest.cpp @@ -67,9 +67,8 @@ class QuicClientTransportIntegrationTest : public TestWithParam { serverCtx->setSupportedAlpns({"h1q-fb", "hq"}); server_ = createServer(ProcessId::ZERO); serverAddr = server_->getAddress(); - ON_CALL(clientConnCallback, onTransportReady()).WillByDefault(Invoke([&] { - connected_ = true; - })); + ON_CALL(clientConnSetupCallback, onTransportReady()) + .WillByDefault(Invoke([&] { connected_ = true; })); clientCtx = createClientContext(); verifier = createTestCertificateVerifier(); @@ -165,7 +164,7 @@ class QuicClientTransportIntegrationTest : public TestWithParam { } void expectTransportCallbacks() { - EXPECT_CALL(clientConnCallback, onReplaySafe()); + EXPECT_CALL(clientConnSetupCallback, onReplaySafe()); } void expectStatsCallbacks() { @@ -244,7 +243,8 @@ class QuicClientTransportIntegrationTest : public TestWithParam { std::string hostname; folly::EventBase eventbase_; folly::SocketAddress serverAddr; - NiceMock clientConnCallback; + NiceMock clientConnSetupCallback; + NiceMock clientConnCallback; NiceMock readCb; std::shared_ptr client; std::shared_ptr serverCtx; @@ -328,9 +328,9 @@ void QuicClientTransportIntegrationTest::sendRequestAndResponseAndWait( TEST_P(QuicClientTransportIntegrationTest, NetworkTest) { expectTransportCallbacks(); expectStatsCallbacks(); - client->start(&clientConnCallback); + client->start(&clientConnSetupCallback, &clientConnCallback); - EXPECT_CALL(clientConnCallback, onTransportReady()).WillOnce(Invoke([&] { + EXPECT_CALL(clientConnSetupCallback, onTransportReady()).WillOnce(Invoke([&] { CHECK(client->getConn().oneRttWriteCipher); EXPECT_EQ(client->getConn().peerConnectionIds.size(), 1); EXPECT_EQ( @@ -349,9 +349,9 @@ TEST_P(QuicClientTransportIntegrationTest, NetworkTest) { TEST_P(QuicClientTransportIntegrationTest, FlowControlLimitedTest) { expectTransportCallbacks(); - client->start(&clientConnCallback); + client->start(&clientConnSetupCallback, &clientConnCallback); - EXPECT_CALL(clientConnCallback, onTransportReady()).WillOnce(Invoke([&] { + EXPECT_CALL(clientConnSetupCallback, onTransportReady()).WillOnce(Invoke([&] { CHECK(client->getConn().oneRttWriteCipher); eventbase_.terminateLoopSoon(); })); @@ -369,13 +369,13 @@ TEST_P(QuicClientTransportIntegrationTest, FlowControlLimitedTest) { } TEST_P(QuicClientTransportIntegrationTest, ALPNTest) { - EXPECT_CALL(clientConnCallback, onTransportReady()).WillOnce(Invoke([&] { + EXPECT_CALL(clientConnSetupCallback, onTransportReady()).WillOnce(Invoke([&] { ASSERT_EQ(client->getAppProtocol(), "h1q-fb"); client->close(folly::none); eventbase_.terminateLoopSoon(); })); ASSERT_EQ(client->getAppProtocol(), folly::none); - client->start(&clientConnCallback); + client->start(&clientConnSetupCallback, &clientConnCallback); eventbase_.loopForever(); } @@ -399,7 +399,7 @@ TEST_P(QuicClientTransportIntegrationTest, TLSAlert) { ASSERT_EQ(client->getAppProtocol(), folly::none); - client->start(&clientConnCallback); + client->start(&clientConnSetupCallback, &clientConnCallback); eventbase_.loopForever(); } @@ -418,7 +418,7 @@ TEST_P(QuicClientTransportIntegrationTest, BadServerTest) { EXPECT_NE(localError, nullptr); this->checkTransportSummaryEvent(qLogger); })); - client->start(&clientConnCallback); + client->start(&clientConnSetupCallback, &clientConnCallback); eventbase_.loop(); } @@ -429,9 +429,9 @@ TEST_P(QuicClientTransportIntegrationTest, NetworkTestConnected) { TransportSettings settings; settings.connectUDP = true; client->setTransportSettings(settings); - client->start(&clientConnCallback); + client->start(&clientConnSetupCallback, &clientConnCallback); - EXPECT_CALL(clientConnCallback, onTransportReady()).WillOnce(Invoke([&] { + EXPECT_CALL(clientConnSetupCallback, onTransportReady()).WillOnce(Invoke([&] { CHECK(client->getConn().oneRttWriteCipher); eventbase_.terminateLoopSoon(); })); @@ -451,9 +451,9 @@ TEST_P(QuicClientTransportIntegrationTest, SetTransportSettingsAfterStart) { TransportSettings settings; settings.connectUDP = true; client->setTransportSettings(settings); - client->start(&clientConnCallback); + client->start(&clientConnSetupCallback, &clientConnCallback); - EXPECT_CALL(clientConnCallback, onTransportReady()).WillOnce(Invoke([&] { + EXPECT_CALL(clientConnSetupCallback, onTransportReady()).WillOnce(Invoke([&] { CHECK(client->getConn().oneRttWriteCipher); eventbase_.terminateLoopSoon(); })); @@ -484,7 +484,7 @@ TEST_P(QuicClientTransportIntegrationTest, TestZeroRttSuccess) { return true; }, []() -> Buf { return nullptr; }); - client->start(&clientConnCallback); + client->start(&clientConnSetupCallback, &clientConnCallback); EXPECT_TRUE(performedValidation); CHECK(client->getConn().zeroRttWriteCipher); EXPECT_TRUE(client->serverInitialParamsSet()); @@ -499,7 +499,7 @@ TEST_P(QuicClientTransportIntegrationTest, TestZeroRttSuccess) { EXPECT_EQ( client->peerAdvertisedInitialMaxStreamDataUni(), kDefaultStreamWindowSize); - EXPECT_CALL(clientConnCallback, onTransportReady()).WillOnce(Invoke([&] { + EXPECT_CALL(clientConnSetupCallback, onTransportReady()).WillOnce(Invoke([&] { ASSERT_EQ(client->getAppProtocol(), "h1q-fb"); CHECK(client->getConn().zeroRttWriteCipher); eventbase_.terminateLoopSoon(); @@ -514,7 +514,7 @@ TEST_P(QuicClientTransportIntegrationTest, TestZeroRttSuccess) { auto data = IOBuf::copyBuffer("hello"); auto expected = std::shared_ptr(IOBuf::copyBuffer("echo ")); expected->prependChain(data->clone()); - EXPECT_CALL(clientConnCallback, onReplaySafe()); + EXPECT_CALL(clientConnSetupCallback, onReplaySafe()); sendRequestAndResponseAndWait(*expected, data->clone(), streamId, &readCb); EXPECT_FALSE(client->getConn().zeroRttWriteCipher); EXPECT_TRUE(client->getConn().statelessResetToken.has_value()); @@ -561,7 +561,7 @@ TEST_P(QuicClientTransportIntegrationTest, ZeroRttRetryPacketTest) { return true; }, []() -> Buf { return nullptr; }); - client->start(&clientConnCallback); + client->start(&clientConnSetupCallback, &clientConnCallback); EXPECT_TRUE(performedValidation); CHECK(client->getConn().zeroRttWriteCipher); EXPECT_TRUE(client->serverInitialParamsSet()); @@ -576,7 +576,7 @@ TEST_P(QuicClientTransportIntegrationTest, ZeroRttRetryPacketTest) { EXPECT_EQ( client->peerAdvertisedInitialMaxStreamDataUni(), kDefaultStreamWindowSize); - EXPECT_CALL(clientConnCallback, onTransportReady()).WillOnce(Invoke([&] { + EXPECT_CALL(clientConnSetupCallback, onTransportReady()).WillOnce(Invoke([&] { ASSERT_EQ(client->getAppProtocol(), "h1q-fb"); CHECK(client->getConn().zeroRttWriteCipher); eventbase_.terminateLoopSoon(); @@ -592,7 +592,7 @@ TEST_P(QuicClientTransportIntegrationTest, ZeroRttRetryPacketTest) { auto expected = std::shared_ptr(IOBuf::copyBuffer("echo ")); expected->prependChain(data->clone()); - EXPECT_CALL(clientConnCallback, onReplaySafe()).WillOnce(Invoke([&] { + EXPECT_CALL(clientConnSetupCallback, onReplaySafe()).WillOnce(Invoke([&] { EXPECT_TRUE(!client->getConn().retryToken.empty()); })); sendRequestAndResponseAndWait(*expected, data->clone(), streamId, &readCb); @@ -622,7 +622,7 @@ TEST_P(QuicClientTransportIntegrationTest, TestZeroRttRejection) { return true; }, []() -> Buf { return nullptr; }); - client->start(&clientConnCallback); + client->start(&clientConnSetupCallback, &clientConnCallback); EXPECT_TRUE(performedValidation); CHECK(client->getConn().zeroRttWriteCipher); EXPECT_TRUE(client->serverInitialParamsSet()); @@ -639,7 +639,7 @@ TEST_P(QuicClientTransportIntegrationTest, TestZeroRttRejection) { kDefaultStreamWindowSize); client->serverInitialParamsSet() = false; - EXPECT_CALL(clientConnCallback, onTransportReady()).WillOnce(Invoke([&] { + EXPECT_CALL(clientConnSetupCallback, onTransportReady()).WillOnce(Invoke([&] { ASSERT_EQ(client->getAppProtocol(), "h1q-fb"); CHECK(client->getConn().zeroRttWriteCipher); eventbase_.terminateLoopSoon(); @@ -685,9 +685,9 @@ TEST_P(QuicClientTransportIntegrationTest, TestZeroRttNotAttempted) { return true; }, []() -> Buf { return nullptr; }); - client->start(&clientConnCallback); + client->start(&clientConnSetupCallback, &clientConnCallback); - EXPECT_CALL(clientConnCallback, onTransportReady()).WillOnce(Invoke([&] { + EXPECT_CALL(clientConnSetupCallback, onTransportReady()).WillOnce(Invoke([&] { EXPECT_FALSE(client->getConn().zeroRttWriteCipher); CHECK(client->getConn().oneRttWriteCipher); eventbase_.terminateLoopSoon(); @@ -726,10 +726,10 @@ TEST_P(QuicClientTransportIntegrationTest, TestZeroRttInvalidAppParams) { return false; }, []() -> Buf { return nullptr; }); - client->start(&clientConnCallback); + client->start(&clientConnSetupCallback, &clientConnCallback); EXPECT_TRUE(performedValidation); - EXPECT_CALL(clientConnCallback, onTransportReady()).WillOnce(Invoke([&] { + EXPECT_CALL(clientConnSetupCallback, onTransportReady()).WillOnce(Invoke([&] { EXPECT_FALSE(client->getConn().zeroRttWriteCipher); CHECK(client->getConn().oneRttWriteCipher); eventbase_.terminateLoopSoon(); @@ -759,9 +759,9 @@ TEST_P(QuicClientTransportIntegrationTest, ChangeEventBase) { NiceMock readCb2; folly::ScopedEventBaseThread newEvb; expectTransportCallbacks(); - client->start(&clientConnCallback); + client->start(&clientConnSetupCallback, &clientConnCallback); - EXPECT_CALL(clientConnCallback, onTransportReady()).WillOnce(Invoke([&] { + EXPECT_CALL(clientConnSetupCallback, onTransportReady()).WillOnce(Invoke([&] { CHECK(client->getConn().oneRttWriteCipher); eventbase_.terminateLoopSoon(); })); @@ -799,9 +799,9 @@ TEST_P(QuicClientTransportIntegrationTest, ResetClient) { server2 = nullptr; }; - client->start(&clientConnCallback); + client->start(&clientConnSetupCallback, &clientConnCallback); - EXPECT_CALL(clientConnCallback, onTransportReady()).WillOnce(Invoke([&] { + EXPECT_CALL(clientConnSetupCallback, onTransportReady()).WillOnce(Invoke([&] { CHECK(client->getConn().oneRttWriteCipher); eventbase_.terminateLoopSoon(); })); @@ -844,9 +844,9 @@ TEST_P(QuicClientTransportIntegrationTest, TestStatelessResetToken) { server2 = nullptr; }; - client->start(&clientConnCallback); + client->start(&clientConnSetupCallback, &clientConnCallback); - EXPECT_CALL(clientConnCallback, onTransportReady()).WillOnce(Invoke([&] { + EXPECT_CALL(clientConnSetupCallback, onTransportReady()).WillOnce(Invoke([&] { token1 = client->getConn().statelessResetToken; eventbase_.terminateLoopSoon(); })); @@ -895,7 +895,7 @@ TEST_P(QuicClientTransportIntegrationTest, D6DEnabledTest) { server_->setTransportSettings(serverSettings); // we only use 1 worker in test - client->start(&clientConnCallback); + client->start(&clientConnSetupCallback, &clientConnCallback); EXPECT_EQ(1, statsCallbacks_.size()); EXPECT_CALL(*statsCallbacks_[0], onConnectionD6DStarted()) .WillOnce(Invoke([&] { eventbase_.terminateLoopSoon(); })); @@ -933,7 +933,7 @@ TEST_F(QuicClientTransportTest, ReadErrorCloseTransprot) { TEST_F(QuicClientTransportTest, FirstPacketProcessedCallback) { client->addNewPeerAddress(serverAddr); - client->start(&clientConnCallback); + client->start(&clientConnSetupCallback, &clientConnCallback); originalConnId = client->getConn().clientConnectionId; ServerConnectionIdParams params(0, 0, 0); @@ -955,7 +955,7 @@ TEST_F(QuicClientTransportTest, FirstPacketProcessedCallback) { headerCipher, initialPacketNum); initialPacketNum++; - EXPECT_CALL(clientConnCallback, onFirstPeerPacketProcessed()).Times(1); + EXPECT_CALL(clientConnSetupCallback, onFirstPeerPacketProcessed()).Times(1); deliverData(serverAddr, ackPacket->coalesce()); EXPECT_FALSE(client->hasWriteCipher()); @@ -971,7 +971,7 @@ TEST_F(QuicClientTransportTest, FirstPacketProcessedCallback) { headerCipher, initialPacketNum); initialPacketNum++; - EXPECT_CALL(clientConnCallback, onFirstPeerPacketProcessed()).Times(0); + EXPECT_CALL(clientConnSetupCallback, onFirstPeerPacketProcessed()).Times(0); deliverData(serverAddr, oneMoreAckPacket->coalesce()); EXPECT_FALSE(client->hasWriteCipher()); @@ -988,7 +988,7 @@ TEST_F(QuicClientTransportTest, CustomTransportParam) { TEST_F(QuicClientTransportTest, CloseSocketOnWriteError) { client->addNewPeerAddress(serverAddr); EXPECT_CALL(*sock, write(_, _)).WillOnce(SetErrnoAndReturn(EBADF, -1)); - client->start(&clientConnCallback); + client->start(&clientConnSetupCallback, &clientConnCallback); EXPECT_FALSE(client->isClosed()); EXPECT_CALL(clientConnCallback, onConnectionError(_)); @@ -1115,7 +1115,7 @@ TEST_F(QuicClientTransportTest, NetworkUnreachableIsFatalToConn) { setupCryptoLayer(); EXPECT_CALL(clientConnCallback, onConnectionError(_)); EXPECT_CALL(*sock, write(_, _)).WillOnce(SetErrnoAndReturn(ENETUNREACH, -1)); - client->start(&clientConnCallback); + client->start(&clientConnSetupCallback, &clientConnCallback); loopForWrites(); } @@ -1131,7 +1131,7 @@ TEST_F(QuicClientTransportTest, HappyEyeballsWithSingleV4Address) { EXPECT_FALSE(conn.happyEyeballsState.finished); EXPECT_FALSE(conn.peerAddress.isInitialized()); - client->start(&clientConnCallback); + client->start(&clientConnSetupCallback, &clientConnCallback); EXPECT_FALSE(client->happyEyeballsConnAttemptDelayTimeout().isScheduled()); EXPECT_TRUE(conn.happyEyeballsState.finished); EXPECT_EQ(conn.peerAddress, serverAddr); @@ -1150,7 +1150,7 @@ TEST_F(QuicClientTransportTest, HappyEyeballsWithSingleV6Address) { EXPECT_FALSE(conn.happyEyeballsState.finished); EXPECT_FALSE(conn.peerAddress.isInitialized()); - client->start(&clientConnCallback); + client->start(&clientConnSetupCallback, &clientConnCallback); EXPECT_FALSE(client->happyEyeballsConnAttemptDelayTimeout().isScheduled()); EXPECT_TRUE(conn.happyEyeballsState.finished); EXPECT_EQ(conn.peerAddress, serverAddrV6); @@ -1159,7 +1159,7 @@ TEST_F(QuicClientTransportTest, HappyEyeballsWithSingleV6Address) { TEST_F(QuicClientTransportTest, IdleTimerResetOnWritingFirstData) { client->addNewPeerAddress(serverAddr); setupCryptoLayer(); - client->start(&clientConnCallback); + client->start(&clientConnSetupCallback, &clientConnCallback); loopForWrites(); ASSERT_FALSE(client->getConn().receivedNewPacketBeforeWrite); ASSERT_TRUE(client->idleTimeout().isScheduled()); @@ -1320,7 +1320,7 @@ class QuicClientTransportHappyEyeballsTest return buf->computeChainDataLength(); })); EXPECT_CALL(*secondSock, write(_, _)).Times(0); - client->start(&clientConnCallback); + client->start(&clientConnSetupCallback, &clientConnCallback); setConnectionIds(); EXPECT_EQ(conn.peerAddress, firstAddress); @@ -1334,8 +1334,8 @@ class QuicClientTransportHappyEyeballsTest EXPECT_FALSE(conn.happyEyeballsState.finished); if (firstPacketType == ServerFirstPacketType::ServerHello) { - EXPECT_CALL(clientConnCallback, onTransportReady()); - EXPECT_CALL(clientConnCallback, onReplaySafe()); + EXPECT_CALL(clientConnSetupCallback, onTransportReady()); + EXPECT_CALL(clientConnSetupCallback, onReplaySafe()); } EXPECT_CALL(*secondSock, write(_, _)).Times(0); EXPECT_CALL(*secondSock, pauseRead()); @@ -1365,7 +1365,7 @@ class QuicClientTransportHappyEyeballsTest return buf->computeChainDataLength(); })); EXPECT_CALL(*secondSock, write(_, _)).Times(0); - client->start(&clientConnCallback); + client->start(&clientConnSetupCallback, &clientConnCallback); setConnectionIds(); EXPECT_EQ(conn.peerAddress, firstAddress); @@ -1402,8 +1402,8 @@ class QuicClientTransportHappyEyeballsTest socketWrites.clear(); EXPECT_FALSE(conn.happyEyeballsState.finished); if (firstPacketType == ServerFirstPacketType::ServerHello) { - EXPECT_CALL(clientConnCallback, onTransportReady()); - EXPECT_CALL(clientConnCallback, onReplaySafe()); + EXPECT_CALL(clientConnSetupCallback, onTransportReady()); + EXPECT_CALL(clientConnSetupCallback, onReplaySafe()); } EXPECT_CALL(*sock, write(firstAddress, _)) .Times(AtLeast(1)) @@ -1441,7 +1441,7 @@ class QuicClientTransportHappyEyeballsTest return buf->computeChainDataLength(); })); EXPECT_CALL(*secondSock, write(_, _)).Times(0); - client->start(&clientConnCallback); + client->start(&clientConnSetupCallback, &clientConnCallback); setConnectionIds(); EXPECT_EQ(conn.peerAddress, firstAddress); EXPECT_EQ(conn.happyEyeballsState.secondPeerAddress, secondAddress); @@ -1479,8 +1479,8 @@ class QuicClientTransportHappyEyeballsTest EXPECT_FALSE(conn.happyEyeballsState.finished); if (firstPacketType == ServerFirstPacketType::ServerHello) { - EXPECT_CALL(clientConnCallback, onTransportReady()); - EXPECT_CALL(clientConnCallback, onReplaySafe()); + EXPECT_CALL(clientConnSetupCallback, onTransportReady()); + EXPECT_CALL(clientConnSetupCallback, onReplaySafe()); } EXPECT_CALL(*sock, write(_, _)).Times(0); EXPECT_CALL(*sock, pauseRead()); @@ -1514,7 +1514,7 @@ class QuicClientTransportHappyEyeballsTest EXPECT_CALL(*secondSock, bind(_, _)) .WillOnce(Invoke( [](const folly::SocketAddress&, auto) { throw std::exception(); })); - client->start(&clientConnCallback); + client->start(&clientConnSetupCallback, &clientConnCallback); EXPECT_EQ(conn.peerAddress, firstAddress); EXPECT_EQ(conn.happyEyeballsState.secondPeerAddress, secondAddress); EXPECT_FALSE(client->happyEyeballsConnAttemptDelayTimeout().isScheduled()); @@ -1531,7 +1531,7 @@ class QuicClientTransportHappyEyeballsTest EXPECT_CALL(*sock, write(firstAddress, _)) .WillOnce(SetErrnoAndReturn(EAGAIN, -1)); EXPECT_CALL(*secondSock, write(_, _)); - client->start(&clientConnCallback); + client->start(&clientConnSetupCallback, &clientConnCallback); EXPECT_EQ(conn.peerAddress, firstAddress); EXPECT_EQ(conn.happyEyeballsState.secondPeerAddress, secondAddress); // Continue trying first socket @@ -1556,7 +1556,7 @@ class QuicClientTransportHappyEyeballsTest EXPECT_CALL(*sock, pauseRead()).Times(2); EXPECT_CALL(*sock, close()).Times(1); EXPECT_CALL(*secondSock, write(_, _)); - client->start(&clientConnCallback); + client->start(&clientConnSetupCallback, &clientConnCallback); EXPECT_EQ(conn.peerAddress, firstAddress); EXPECT_EQ(conn.happyEyeballsState.secondPeerAddress, secondAddress); // Give up first socket @@ -1580,7 +1580,7 @@ class QuicClientTransportHappyEyeballsTest // Socket is paused read for the second time when QuicClientTransport dies EXPECT_CALL(*sock, pauseRead()).Times(2); EXPECT_CALL(*sock, close()).Times(1); - client->start(&clientConnCallback); + client->start(&clientConnSetupCallback, &clientConnCallback); EXPECT_EQ(conn.peerAddress, firstAddress); EXPECT_EQ(conn.happyEyeballsState.secondPeerAddress, secondAddress); // Give up first socket @@ -1613,7 +1613,7 @@ class QuicClientTransportHappyEyeballsTest EXPECT_CALL(*sock, write(firstAddress, _)); EXPECT_CALL(*secondSock, write(_, _)).Times(0); - client->start(&clientConnCallback); + client->start(&clientConnSetupCallback, &clientConnCallback); EXPECT_EQ(conn.peerAddress, firstAddress); EXPECT_EQ(conn.happyEyeballsState.secondPeerAddress, secondAddress); EXPECT_TRUE(client->happyEyeballsConnAttemptDelayTimeout().isScheduled()); @@ -1649,7 +1649,7 @@ class QuicClientTransportHappyEyeballsTest EXPECT_CALL(*sock, write(firstAddress, _)); EXPECT_CALL(*secondSock, write(_, _)).Times(0); - client->start(&clientConnCallback); + client->start(&clientConnSetupCallback, &clientConnCallback); EXPECT_EQ(conn.peerAddress, firstAddress); EXPECT_EQ(conn.happyEyeballsState.secondPeerAddress, secondAddress); EXPECT_TRUE(client->happyEyeballsConnAttemptDelayTimeout().isScheduled()); @@ -1690,7 +1690,7 @@ class QuicClientTransportHappyEyeballsTest EXPECT_CALL(*sock, write(firstAddress, _)); EXPECT_CALL(*secondSock, write(_, _)).Times(0); - client->start(&clientConnCallback); + client->start(&clientConnSetupCallback, &clientConnCallback); EXPECT_EQ(conn.peerAddress, firstAddress); EXPECT_EQ(conn.happyEyeballsState.secondPeerAddress, secondAddress); EXPECT_TRUE(client->happyEyeballsConnAttemptDelayTimeout().isScheduled()); @@ -1734,7 +1734,7 @@ class QuicClientTransportHappyEyeballsTest EXPECT_CALL(*sock, write(firstAddress, _)); EXPECT_CALL(*secondSock, write(_, _)).Times(0); - client->start(&clientConnCallback); + client->start(&clientConnSetupCallback, &clientConnCallback); EXPECT_EQ(conn.peerAddress, firstAddress); EXPECT_EQ(conn.happyEyeballsState.secondPeerAddress, secondAddress); EXPECT_TRUE(client->happyEyeballsConnAttemptDelayTimeout().isScheduled()); @@ -1770,7 +1770,7 @@ class QuicClientTransportHappyEyeballsTest EXPECT_CALL(*sock, write(firstAddress, _)); EXPECT_CALL(*secondSock, write(_, _)).Times(0); - client->start(&clientConnCallback); + client->start(&clientConnSetupCallback, &clientConnCallback); EXPECT_EQ(conn.peerAddress, firstAddress); EXPECT_EQ(conn.happyEyeballsState.secondPeerAddress, secondAddress); EXPECT_TRUE(client->happyEyeballsConnAttemptDelayTimeout().isScheduled()); @@ -1811,7 +1811,7 @@ class QuicClientTransportHappyEyeballsTest EXPECT_CALL(*sock, write(firstAddress, _)); EXPECT_CALL(*secondSock, write(_, _)).Times(0); - client->start(&clientConnCallback); + client->start(&clientConnSetupCallback, &clientConnCallback); EXPECT_EQ(conn.peerAddress, firstAddress); EXPECT_EQ(conn.happyEyeballsState.secondPeerAddress, secondAddress); EXPECT_TRUE(client->happyEyeballsConnAttemptDelayTimeout().isScheduled()); @@ -1860,7 +1860,7 @@ class QuicClientTransportHappyEyeballsTest EXPECT_CALL(*sock, write(firstAddress, _)); EXPECT_CALL(*secondSock, write(_, _)).Times(0); - client->start(&clientConnCallback); + client->start(&clientConnSetupCallback, &clientConnCallback); EXPECT_EQ(conn.peerAddress, firstAddress); EXPECT_EQ(conn.happyEyeballsState.secondPeerAddress, secondAddress); EXPECT_TRUE(client->happyEyeballsConnAttemptDelayTimeout().isScheduled()); @@ -1897,7 +1897,7 @@ class QuicClientTransportHappyEyeballsTest EXPECT_CALL(*sock, write(firstAddress, _)); EXPECT_CALL(*secondSock, write(_, _)).Times(0); - client->start(&clientConnCallback); + client->start(&clientConnSetupCallback, &clientConnCallback); EXPECT_EQ(conn.peerAddress, firstAddress); EXPECT_EQ(conn.happyEyeballsState.secondPeerAddress, secondAddress); EXPECT_TRUE(client->happyEyeballsConnAttemptDelayTimeout().isScheduled()); @@ -1931,7 +1931,7 @@ class QuicClientTransportHappyEyeballsTest EXPECT_CALL(*sock, write(firstAddress, _)); EXPECT_CALL(*secondSock, write(_, _)).Times(0); - client->start(&clientConnCallback); + client->start(&clientConnSetupCallback, &clientConnCallback); EXPECT_EQ(conn.peerAddress, firstAddress); EXPECT_EQ(conn.happyEyeballsState.secondPeerAddress, secondAddress); EXPECT_TRUE(client->happyEyeballsConnAttemptDelayTimeout().isScheduled()); @@ -2155,7 +2155,7 @@ class QuicClientTransportVersionAndRetryTest ~QuicClientTransportVersionAndRetryTest() override = default; void start() override { - client->start(&clientConnCallback); + client->start(&clientConnSetupCallback, &clientConnCallback); originalConnId = client->getConn().clientConnectionId; // create server chosen connId with processId = 0 and workerId = 0 ServerConnectionIdParams params(0, 0, 0); @@ -2186,7 +2186,7 @@ class QuicClientVersionParamInvalidTest // force the server to declare that the version negotiated was invalid.; mockClientHandshake->negotiatedVersion = MVFST2; - client->start(&clientConnCallback); + client->start(&clientConnSetupCallback, &clientConnCallback); originalConnId = client->getConn().clientConnectionId; } }; @@ -3934,8 +3934,8 @@ TEST_F( EXPECT_THROW(deliverData(packet.second->coalesce()), std::runtime_error); EXPECT_EQ(client->getConn().oneRttWriteCipher.get(), nullptr); - EXPECT_CALL(clientConnCallback, onTransportReady()).Times(0); - EXPECT_CALL(clientConnCallback, onReplaySafe()).Times(0); + EXPECT_CALL(clientConnSetupCallback, onTransportReady()).Times(0); + EXPECT_CALL(clientConnSetupCallback, onReplaySafe()).Times(0); client->close(folly::none); } @@ -4886,8 +4886,8 @@ class QuicZeroRttClientTest : public QuicClientTransportAfterStartTestBase { } void startClient() { - EXPECT_CALL(clientConnCallback, onTransportReady()); - client->start(&clientConnCallback); + EXPECT_CALL(clientConnSetupCallback, onTransportReady()); + client->start(&clientConnSetupCallback, &clientConnCallback); setConnectionIds(); EXPECT_EQ(socketWrites.size(), 1); EXPECT_TRUE( @@ -4947,7 +4947,7 @@ TEST_F(QuicZeroRttClientTest, TestReplaySafeCallback) { loopForWrites(); EXPECT_TRUE(zeroRttPacketsOutstanding()); assertWritten(false, LongHeader::Types::ZeroRtt); - EXPECT_CALL(clientConnCallback, onReplaySafe()); + EXPECT_CALL(clientConnSetupCallback, onReplaySafe()); mockClientHandshake->setZeroRttRejected(false); recvServerHello(); @@ -5021,7 +5021,7 @@ TEST_F(QuicZeroRttClientTest, TestEarlyRetransmit0Rtt) { loopForWrites(); EXPECT_TRUE(zeroRttPacketsOutstanding()); assertWritten(false, LongHeader::Types::ZeroRtt); - EXPECT_CALL(clientConnCallback, onReplaySafe()); + EXPECT_CALL(clientConnSetupCallback, onReplaySafe()); mockClientHandshake->setZeroRttRejected(false); recvServerHello(); @@ -5091,7 +5091,7 @@ TEST_F(QuicZeroRttClientTest, TestZeroRttRejection) { client->writeChain(streamId, IOBuf::copyBuffer("hello"), true); loopForWrites(); EXPECT_TRUE(zeroRttPacketsOutstanding()); - EXPECT_CALL(clientConnCallback, onReplaySafe()); + EXPECT_CALL(clientConnSetupCallback, onReplaySafe()); mockClientHandshake->setZeroRttRejected(true); EXPECT_CALL(*mockQuicPskCache_, removePsk(hostname_)); recvServerHello(); @@ -5353,7 +5353,7 @@ class QuicProcessDataTest : public QuicClientTransportAfterStartTestBase, // force the server to declare that the version negotiated was invalid.; mockClientHandshake->negotiatedVersion = QuicVersion::QUIC_V1; client->setSupportedVersions({QuicVersion::QUIC_V1}); - client->start(&clientConnCallback); + client->start(&clientConnSetupCallback, &clientConnCallback); setConnectionIds(); } }; diff --git a/quic/fizz/client/test/QuicClientTransportTestUtil.h b/quic/fizz/client/test/QuicClientTransportTestUtil.h index 9052f7e48..4be1de1a3 100644 --- a/quic/fizz/client/test/QuicClientTransportTestUtil.h +++ b/quic/fizz/client/test/QuicClientTransportTestUtil.h @@ -498,10 +498,10 @@ class QuicClientTransportTestBase : public virtual testing::Test { } virtual void start() { - EXPECT_CALL(clientConnCallback, onTransportReady()); - EXPECT_CALL(clientConnCallback, onReplaySafe()); + EXPECT_CALL(clientConnSetupCallback, onTransportReady()); + EXPECT_CALL(clientConnSetupCallback, onReplaySafe()); setUpSocketExpectations(); - client->start(&clientConnCallback); + client->start(&clientConnSetupCallback, &clientConnCallback); setConnectionIds(); EXPECT_TRUE(client->idleTimeout().isScheduled()); @@ -536,7 +536,11 @@ class QuicClientTransportTestBase : public virtual testing::Test { return client->getNonConstConn(); } - MockConnectionCallback& getConnCallback() { + MockConnectionSetupCallback& getConnSetupCallback() { + return clientConnSetupCallback; + } + + MockConnectionCallbackNew& getConnCallback() { return clientConnCallback; } @@ -877,7 +881,8 @@ class QuicClientTransportTestBase : public virtual testing::Test { std::deque socketReads; testing::NiceMock deliveryCallback; testing::NiceMock readCb; - testing::NiceMock clientConnCallback; + testing::NiceMock clientConnSetupCallback; + testing::NiceMock clientConnCallback; folly::test::MockAsyncUDPSocket* sock; std::shared_ptr destructionCallback; diff --git a/quic/samples/echo/EchoHandler.h b/quic/samples/echo/EchoHandler.h index f8e982b61..d871bdde7 100644 --- a/quic/samples/echo/EchoHandler.h +++ b/quic/samples/echo/EchoHandler.h @@ -14,7 +14,8 @@ namespace quic { namespace samples { -class EchoHandler : public quic::QuicSocket::ConnectionCallback, +class EchoHandler : public quic::QuicSocket::ConnectionSetupCallback, + public quic::QuicSocket::ConnectionCallbackNew, public quic::QuicSocket::ReadCallback, public quic::QuicSocket::WriteCallback { public: diff --git a/quic/samples/echo/EchoServer.h b/quic/samples/echo/EchoServer.h index cc8e6628e..7b583e981 100644 --- a/quic/samples/echo/EchoServer.h +++ b/quic/samples/echo/EchoServer.h @@ -46,7 +46,7 @@ class EchoServerTransportFactory : public quic::QuicServerTransportFactory { CHECK_EQ(evb, sock->getEventBase()); auto echoHandler = std::make_unique(evb); auto transport = quic::QuicServerTransport::make( - evb, std::move(sock), *echoHandler, ctx); + evb, std::move(sock), echoHandler.get(), echoHandler.get(), ctx); echoHandler->setQuicSocket(transport); echoHandlers_.push_back(std::move(echoHandler)); return transport; diff --git a/quic/server/QuicServerTransport.cpp b/quic/server/QuicServerTransport.cpp index 7e0df3527..42a11b931 100644 --- a/quic/server/QuicServerTransport.cpp +++ b/quic/server/QuicServerTransport.cpp @@ -36,6 +36,24 @@ QuicServerTransport::QuicServerTransport( conn_->ackStates = AckStates(startingPacketNum); } +QuicServerTransport::QuicServerTransport( + folly::EventBase* evb, + std::unique_ptr sock, + ConnectionSetupCallback* connSetupCb, + ConnectionCallbackNew* connStreamsCb, + std::shared_ptr ctx, + std::unique_ptr cryptoFactory, + PacketNum startingPacketNum) + : QuicServerTransport( + evb, + std::move(sock), + connSetupCb, + connStreamsCb, + std::move(ctx), + std::move(cryptoFactory)) { + conn_->ackStates = AckStates(startingPacketNum); +} + QuicServerTransport::QuicServerTransport( folly::EventBase* evb, std::unique_ptr sock, diff --git a/quic/server/QuicServerTransport.h b/quic/server/QuicServerTransport.h index ae710155e..d57901211 100644 --- a/quic/server/QuicServerTransport.h +++ b/quic/server/QuicServerTransport.h @@ -95,6 +95,15 @@ class QuicServerTransport std::unique_ptr cryptoFactory, PacketNum startingPacketNum); + QuicServerTransport( + folly::EventBase* evb, + std::unique_ptr sock, + ConnectionSetupCallback* connSetupCb, + ConnectionCallbackNew* connStreamsCb, + std::shared_ptr ctx, + std::unique_ptr cryptoFactory, + PacketNum startingPacketNum); + ~QuicServerTransport() override; virtual void setRoutingCallback(RoutingCallback* callback) noexcept; diff --git a/quic/server/handshake/test/DefaultAppTokenValidatorTest.cpp b/quic/server/handshake/test/DefaultAppTokenValidatorTest.cpp index dbd617099..96b4b9f7d 100644 --- a/quic/server/handshake/test/DefaultAppTokenValidatorTest.cpp +++ b/quic/server/handshake/test/DefaultAppTokenValidatorTest.cpp @@ -313,7 +313,8 @@ TEST(DefaultAppTokenValidatorTest, TestInvalidAppParams) { auto quicStats = std::make_shared(); conn.statsCallback = quicStats.get(); - MockConnectionCallback connCallback; + MockConnectionSetupCallback connSetupCallback; + MockConnectionCallbackNew connCallback; AppToken appToken; appToken.transportParams = createTicketTransportParameters( diff --git a/quic/server/test/QuicServerTest.cpp b/quic/server/test/QuicServerTest.cpp index 162828dfa..41ffc3454 100644 --- a/quic/server/test/QuicServerTest.cpp +++ b/quic/server/test/QuicServerTest.cpp @@ -71,9 +71,14 @@ TEST_F(SimpleQuicServerWorkerTest, RejectCid) { auto mockSock = std::make_unique(&eventbase_); EXPECT_CALL(*mockSock, address()).WillRepeatedly(ReturnRef(addr)); - MockConnectionCallback mockConnectionCallback; + MockConnectionSetupCallback mockConnectionSetupCallback; + MockConnectionCallbackNew mockConnectionCallback; MockQuicTransport::Ptr transportPtr = std::make_shared( - &eventbase_, std::move(mockSock), mockConnectionCallback, nullptr); + &eventbase_, + std::move(mockSock), + &mockConnectionSetupCallback, + &mockConnectionCallback, + nullptr); workerCb_ = std::make_shared>(); worker_ = std::make_unique(workerCb_); auto includeCid = getTestConnectionId(0); @@ -164,13 +169,18 @@ class QuicServerWorkerTest : public Test { socketFactory_ = std::make_unique(); EXPECT_CALL(*socketFactory_, _make(_, _)).WillRepeatedly(Return(nullptr)); worker_->setNewConnectionSocketFactory(socketFactory_.get()); - NiceMock connCb; + NiceMock connSetupCb; + NiceMock connCb; std::unique_ptr mockSock = std::make_unique>( &eventbase_); EXPECT_CALL(*mockSock, address()).WillRepeatedly(ReturnRef(fakeAddress_)); transport_.reset(new MockQuicTransport( - worker_->getEventBase(), std::move(mockSock), connCb, nullptr)); + worker_->getEventBase(), + std::move(mockSock), + &connSetupCb, + &connCb, + nullptr)); factory_ = std::make_unique(); EXPECT_CALL(*transport_, getEventBase()) .WillRepeatedly(Return(&eventbase_)); @@ -536,12 +546,17 @@ TEST_F(QuicServerWorkerTest, RateLimit) { std::make_unique([]() { return 2; }, 60s)); EXPECT_CALL(*quicStats_, onConnectionRateLimited()).Times(1); - NiceMock connCb1; + NiceMock connSetupCb1; + NiceMock connCb1; auto mockSock1 = std::make_unique>(&eventbase_); EXPECT_CALL(*mockSock1, address()).WillRepeatedly(ReturnRef(fakeAddress_)); MockQuicTransport::Ptr testTransport1 = std::make_shared( - worker_->getEventBase(), std::move(mockSock1), connCb1, nullptr); + worker_->getEventBase(), + std::move(mockSock1), + &connSetupCb1, + &connCb1, + nullptr); EXPECT_CALL(*testTransport1, getEventBase()) .WillRepeatedly(Return(&eventbase_)); EXPECT_CALL(*testTransport1, getOriginalPeerAddress()) @@ -568,12 +583,17 @@ TEST_F(QuicServerWorkerTest, RateLimit) { eventbase_.loop(); auto caddr2 = folly::SocketAddress("2.3.4.5", 1234); - NiceMock connCb2; + NiceMock connSetupCb2; + NiceMock connCb2; auto mockSock2 = std::make_unique>(&eventbase_); EXPECT_CALL(*mockSock2, address()).WillRepeatedly(ReturnRef(caddr2)); MockQuicTransport::Ptr testTransport2 = std::make_shared( - worker_->getEventBase(), std::move(mockSock2), connCb2, nullptr); + worker_->getEventBase(), + std::move(mockSock2), + &connSetupCb2, + &connCb2, + nullptr); EXPECT_CALL(*testTransport2, getEventBase()) .WillRepeatedly(Return(&eventbase_)); EXPECT_CALL(*testTransport2, getOriginalPeerAddress()) @@ -622,12 +642,17 @@ TEST_F(QuicServerWorkerTest, UnfinishedHandshakeLimit) { worker_->setUnfinishedHandshakeLimit([]() { return 2; }); EXPECT_CALL(*quicStats_, onConnectionRateLimited()).Times(1); - NiceMock connCb1; + NiceMock connSetupCb1; + NiceMock connCb1; auto mockSock1 = std::make_unique>(&eventbase_); EXPECT_CALL(*mockSock1, address()).WillRepeatedly(ReturnRef(fakeAddress_)); MockQuicTransport::Ptr testTransport1 = std::make_shared( - worker_->getEventBase(), std::move(mockSock1), connCb1, nullptr); + worker_->getEventBase(), + std::move(mockSock1), + &connSetupCb1, + &connCb1, + nullptr); EXPECT_CALL(*testTransport1, getEventBase()) .WillRepeatedly(Return(&eventbase_)); EXPECT_CALL(*testTransport1, getOriginalPeerAddress()) @@ -653,12 +678,17 @@ TEST_F(QuicServerWorkerTest, UnfinishedHandshakeLimit) { eventbase_.loop(); auto caddr2 = folly::SocketAddress("2.3.4.5", 1234); - NiceMock connCb2; + NiceMock connSetupCb2; + NiceMock connCb2; auto mockSock2 = std::make_unique>(&eventbase_); EXPECT_CALL(*mockSock2, address()).WillRepeatedly(ReturnRef(caddr2)); MockQuicTransport::Ptr testTransport2 = std::make_shared( - worker_->getEventBase(), std::move(mockSock2), connCb2, nullptr); + worker_->getEventBase(), + std::move(mockSock2), + &connSetupCb2, + &connCb2, + nullptr); EXPECT_CALL(*testTransport2, getEventBase()) .WillRepeatedly(Return(&eventbase_)); EXPECT_CALL(*testTransport2, getOriginalPeerAddress()) @@ -703,12 +733,17 @@ TEST_F(QuicServerWorkerTest, UnfinishedHandshakeLimit) { // Finish a handshake. worker_->onHandshakeFinished(); auto caddr4 = folly::SocketAddress("4.3.4.5", 1234); - NiceMock connCb4; + NiceMock connSetupCb4; + NiceMock connCb4; auto mockSock4 = std::make_unique>(&eventbase_); EXPECT_CALL(*mockSock4, address()).WillRepeatedly(ReturnRef(caddr4)); MockQuicTransport::Ptr testTransport4 = std::make_shared( - worker_->getEventBase(), std::move(mockSock4), connCb4, nullptr); + worker_->getEventBase(), + std::move(mockSock4), + &connSetupCb4, + &connCb4, + nullptr); EXPECT_CALL(*testTransport4, getEventBase()) .WillRepeatedly(Return(&eventbase_)); EXPECT_CALL(*testTransport4, getOriginalPeerAddress()) @@ -824,12 +859,17 @@ TEST_F(QuicServerWorkerTest, TestRetryInvalidInitialDstConnId) { } TEST_F(QuicServerWorkerTest, QuicServerWorkerUnbindBeforeCidAvailable) { - NiceMock connCb; + NiceMock connSetupCb; + NiceMock connCb; auto mockSock = std::make_unique>(&eventbase_); EXPECT_CALL(*mockSock, address()).WillRepeatedly(ReturnRef(fakeAddress_)); MockQuicTransport::Ptr testTransport = std::make_shared( - worker_->getEventBase(), std::move(mockSock), connCb, nullptr); + worker_->getEventBase(), + std::move(mockSock), + &connSetupCb, + &connCb, + nullptr); EXPECT_CALL(*testTransport, getEventBase()) .WillRepeatedly(Return(&eventbase_)); @@ -1000,13 +1040,18 @@ TEST_F(QuicServerWorkerTest, QuicServerNewConnection) { // transport2's connid available. ConnectionId connId2({2, 4, 5, 6}); folly::SocketAddress clientAddr2("2.3.4.5", 2345); - NiceMock connCb; + NiceMock connSetupCb; + NiceMock connCb; auto mockSock = std::make_unique>(&eventbase_); EXPECT_CALL(*mockSock, address()).WillRepeatedly(ReturnRef(fakeAddress_)); MockQuicTransport::Ptr transport2 = std::make_shared( - worker_->getEventBase(), std::move(mockSock), connCb, nullptr); + worker_->getEventBase(), + std::move(mockSock), + &connSetupCb, + &connCb, + nullptr); EXPECT_CALL(*transport2, getEventBase()).WillRepeatedly(Return(&eventbase_)); EXPECT_CALL(*transport2, getOriginalPeerAddress()) .WillRepeatedly(ReturnRef(kClientAddr)); @@ -1384,12 +1429,17 @@ TEST_F(QuicServerWorkerTest, AcceptObserver) { worker_->addAcceptObserver(cb.get()); auto initTestSocketAndTransport = [this]() { - NiceMock connCb; + NiceMock connSetupCb; + NiceMock connCb; auto mockSock = std::make_unique>( &eventbase_); EXPECT_CALL(*mockSock, address()).WillRepeatedly(ReturnRef(fakeAddress_)); MockQuicTransport::Ptr mockTransport = std::make_shared( - worker_->getEventBase(), std::move(mockSock), connCb, nullptr); + worker_->getEventBase(), + std::move(mockSock), + &connSetupCb, + &connCb, + nullptr); EXPECT_CALL(*mockTransport, setRoutingCallback(nullptr)); EXPECT_CALL(*mockTransport, setTransportStatsCallback(nullptr)); EXPECT_CALL(*mockTransport, getEventBase()) @@ -2041,13 +2091,18 @@ class QuicServerTest : public Test { // create mock transport std::shared_ptr transport; eventBase->runInEventBaseThreadAndWait([&] { - NiceMock cb; + NiceMock connSetupcb; + NiceMock connCb; std::unique_ptr mockSock = std::make_unique>( eventBase); EXPECT_CALL(*mockSock, address()).WillRepeatedly(ReturnRef(serverAddr)); transport = std::make_shared( - eventBase, std::move(mockSock), cb, quic::test::createServerCtx()); + eventBase, + std::move(mockSock), + &connSetupcb, + &connCb, + quic::test::createServerCtx()); }); auto makeTransport = @@ -2315,7 +2370,8 @@ class QuicServerTakeoverTest : public Test { Buf& data, folly::Baton<>& baton) { std::shared_ptr transport; - NiceMock cb; + NiceMock connSetupCb; + NiceMock connCb; auto makeTransport = [&](folly::EventBase* eventBase, std::unique_ptr& socket, @@ -2323,7 +2379,7 @@ class QuicServerTakeoverTest : public Test { std::shared_ptr ctx) noexcept { transport = std::make_shared( - eventBase, std::move(socket), cb, ctx); + eventBase, std::move(socket), &connSetupCb, &connCb, ctx); transport->setClientConnectionId(clientConnId); // setup expectations EXPECT_CALL(*transport, getEventBase()) @@ -2882,7 +2938,8 @@ TEST_F(QuicServerTest, ZeroRttPacketRoute) { setUpTransportFactoryForWorkers(evbs); std::shared_ptr transport; - NiceMock cb; + NiceMock connSetupCb; + NiceMock connCb; folly::Baton<> b; // create payload StreamId id = 1; @@ -2900,7 +2957,7 @@ TEST_F(QuicServerTest, ZeroRttPacketRoute) { const folly::SocketAddress&, std::shared_ptr ctx) noexcept { transport = std::make_shared( - eventBase, std::move(socket), cb, ctx); + eventBase, std::move(socket), &connSetupCb, &connCb, ctx); EXPECT_CALL(*transport, getEventBase()) .WillRepeatedly(Return(eventBase)); EXPECT_CALL(*transport, setSupportedVersions(_)); @@ -2976,7 +3033,8 @@ TEST_F(QuicServerTest, ZeroRttBeforeInitial) { setUpTransportFactoryForWorkers(evbs); std::shared_ptr transport; - NiceMock cb; + NiceMock connSetupCb; + NiceMock connCb; folly::Baton<> b; // create payload StreamId id = 1; @@ -2995,7 +3053,7 @@ TEST_F(QuicServerTest, ZeroRttBeforeInitial) { const folly::SocketAddress&, std::shared_ptr ctx) noexcept { transport = std::make_shared( - eventBase, std::move(socket), cb, ctx); + eventBase, std::move(socket), &connSetupCb, &connCb, ctx); EXPECT_CALL(*transport, getEventBase()) .WillRepeatedly(Return(eventBase)); EXPECT_CALL(*transport, setSupportedVersions(_)); diff --git a/quic/server/test/QuicServerTransportTest.cpp b/quic/server/test/QuicServerTransportTest.cpp index 5b5123eeb..d41f23f0d 100644 --- a/quic/server/test/QuicServerTransportTest.cpp +++ b/quic/server/test/QuicServerTransportTest.cpp @@ -2896,14 +2896,14 @@ class QuicUnencryptedServerTransportTest : public QuicServerTransportTest { TEST_F(QuicUnencryptedServerTransportTest, FirstPacketProcessedCallback) { getFakeHandshakeLayer()->allowZeroRttKeys(); - EXPECT_CALL(connCallback, onFirstPeerPacketProcessed()).Times(1); + EXPECT_CALL(connSetupCallback, onFirstPeerPacketProcessed()).Times(1); recvClientHello(); loopForWrites(); AckBlocks acks; acks.insert(0); auto aead = getInitialCipher(); auto headerCipher = getInitialHeaderCipher(); - EXPECT_CALL(connCallback, onFirstPeerPacketProcessed()).Times(0); + EXPECT_CALL(connSetupCallback, onFirstPeerPacketProcessed()).Times(0); deliverData(packetToBufCleartext( createAckPacket( server->getNonConstConn(), @@ -3857,14 +3857,14 @@ class QuicServerTransportHandshakeTest // If 0-rtt is accepted, one rtt write cipher will be available after CHLO // is processed if (GetParam().acceptZeroRtt) { - EXPECT_CALL(connCallback, onTransportReady()); + EXPECT_CALL(connSetupCallback, onTransportReady()); } recvClientHello(); // If 0-rtt is disabled, one rtt write cipher will be available after CFIN // is processed if (!GetParam().acceptZeroRtt) { - EXPECT_CALL(connCallback, onTransportReady()); + EXPECT_CALL(connSetupCallback, onTransportReady()); } // onConnectionIdBound is always invoked after CFIN is processed EXPECT_CALL(routingCallback, onConnectionIdBound(_)); diff --git a/quic/server/test/QuicServerTransportTestUtil.h b/quic/server/test/QuicServerTransportTestUtil.h index ed45a289e..54522b0c2 100644 --- a/quic/server/test/QuicServerTransportTestUtil.h +++ b/quic/server/test/QuicServerTransportTestUtil.h @@ -31,9 +31,15 @@ class TestingQuicServerTransport : public QuicServerTransport { TestingQuicServerTransport( folly::EventBase* evb, std::unique_ptr sock, - ConnectionCallback& cb, + ConnectionSetupCallback* connSetupCb, + ConnectionCallbackNew* connCb, std::shared_ptr ctx) - : QuicServerTransport(evb, std::move(sock), cb, std::move(ctx)) {} + : QuicServerTransport( + evb, + std::move(sock), + connSetupCb, + connCb, + std::move(ctx)) {} QuicTransportBase* getTransport() { return this; @@ -135,7 +141,7 @@ class QuicServerTransportTestBase : public virtual testing::Test { connIdAlgo_ = std::make_unique(); ccFactory_ = std::make_shared(); server = std::make_shared( - &evb, std::move(sock), connCallback, serverCtx); + &evb, std::move(sock), &connSetupCallback, &connCallback, serverCtx); server->setCongestionControllerFactory(ccFactory_); server->setCongestionControl(CongestionControlType::Cubic); server->setRoutingCallback(&routingCallback); @@ -192,7 +198,11 @@ class QuicServerTransportTestBase : public virtual testing::Test { return server->getNonConstConn(); } - MockConnectionCallback& getConnCallback() { + MockConnectionSetupCallback& getConnSetupCallback() { + return connSetupCallback; + } + + MockConnectionCallbackNew& getConnCallback() { return connCallback; } @@ -548,7 +558,8 @@ class QuicServerTransportTestBase : public virtual testing::Test { folly::EventBase evb; folly::SocketAddress serverAddr; folly::SocketAddress clientAddr; - testing::NiceMock connCallback; + testing::NiceMock connSetupCallback; + testing::NiceMock connCallback; testing::NiceMock routingCallback; testing::NiceMock handshakeFinishedCallback; folly::Optional clientConnectionId; diff --git a/quic/server/test/QuicSocketTest.cpp b/quic/server/test/QuicSocketTest.cpp index 747479d76..fca7ae3d4 100644 --- a/quic/server/test/QuicSocketTest.cpp +++ b/quic/server/test/QuicSocketTest.cpp @@ -21,7 +21,7 @@ using folly::IOBuf; class QuicSocketTest : public Test { public: void SetUp() override { - socket_ = std::make_shared(&evb_, handler_); + socket_ = std::make_shared(&evb_, &handler_, &handler_); handler_.setQuicSocket(socket_); }