/* * Copyright (c) Meta Platforms, Inc. and affiliates. * * This source code is licensed under the MIT license found in the * LICENSE file in the root directory of this source tree. */ #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include using namespace testing; using namespace folly; using OnDataAvailableParams = folly::AsyncUDPSocket::ReadCallback::OnDataAvailableParams; const folly::SocketAddress kClientAddr("1.2.3.4", 1234); const folly::SocketAddress kClientAddr2("1.2.3.5", 1235); const folly::SocketAddress kClientAddr3("1.2.3.6", 1236); namespace quic { namespace { using PacketDropReason = QuicTransportStatsCallback::PacketDropReason; } // namespace namespace test { MATCHER_P(NetworkDataMatches, networkData, "") { for (size_t i = 0; i < arg.packets.size(); ++i) { folly::IOBufEqualTo eq; bool equals = eq(*arg.packets[i], networkData); if (equals) { return true; } } return false; } /** * QuicServerWorker test without a connection to drive any real behavior. Use * QuicServerWorkerTest for most cases. */ class SimpleQuicServerWorkerTest : public Test { protected: folly::EventBase eventbase_; std::unique_ptr worker_; std::shared_ptr workerCb_; folly::test::MockAsyncUDPSocket* rawSocket_{nullptr}; }; TEST_F(SimpleQuicServerWorkerTest, RejectCid) { folly::SocketAddress addr("::1", 0); auto mockSock = std::make_unique(&eventbase_); EXPECT_CALL(*mockSock, address()).WillRepeatedly(ReturnRef(addr)); MockConnectionSetupCallback mockConnectionSetupCallback; MockConnectionCallback mockConnectionCallback; MockQuicTransport::Ptr transportPtr = std::make_shared( &eventbase_, std::move(mockSock), &mockConnectionSetupCallback, &mockConnectionCallback, nullptr); workerCb_ = std::make_shared>(); worker_ = std::make_unique(workerCb_); auto includeCid = getTestConnectionId(0); auto excludeCid = getTestConnectionId(1); EXPECT_FALSE(worker_->rejectConnectionId(includeCid)); EXPECT_FALSE(worker_->rejectConnectionId(excludeCid)); worker_->onConnectionIdAvailable(transportPtr, includeCid); EXPECT_TRUE(worker_->rejectConnectionId(includeCid)); EXPECT_FALSE(worker_->rejectConnectionId(excludeCid)); QuicServerTransport::SourceIdentity sourceId(addr, includeCid); std::vector cidDataVec; cidDataVec.emplace_back(includeCid, 0); EXPECT_CALL(*transportPtr, setRoutingCallback(nullptr)).Times(1); worker_->onConnectionUnbound(transportPtr.get(), sourceId, cidDataVec); EXPECT_FALSE(worker_->rejectConnectionId(includeCid)); EXPECT_FALSE(worker_->rejectConnectionId(excludeCid)); } TEST_F(SimpleQuicServerWorkerTest, TurnOffPMTU) { auto sock = std::make_unique>(&eventbase_); rawSocket_ = sock.get(); DCHECK(sock->getEventBase()); EXPECT_CALL(*sock, getNetworkSocket()) .WillRepeatedly(Return(folly::NetworkSocket())); workerCb_ = std::make_shared>(); worker_ = std::make_unique(workerCb_); worker_->setSocket(std::move(sock)); folly::SocketAddress addr("::1", 0); // We check versions in bind() worker_->setSupportedVersions({QuicVersion::MVFST}); EXPECT_CALL(*rawSocket_, setDFAndTurnOffPMTU()).Times(1); worker_->bind(addr); } std::unique_ptr createData(size_t size) { std::string data; data.resize(size); return folly::IOBuf::copyBuffer(data); } class QuicServerWorkerTest : public Test { public: void SetUp() override { fakeAddress_ = folly::SocketAddress("111.111.111.111", 44444); auto sock = std::make_unique>( &eventbase_); DCHECK(sock->getEventBase()); socketPtr_ = sock.get(); workerCb_ = std::make_shared>(); worker_ = std::make_unique(workerCb_); auto quicStats = std::make_unique>(); TransportSettings settings; settings.statelessResetTokenSecret = getRandSecret(); tokenSecret_ = getRandSecret(); settings.retryTokenSecret = tokenSecret_; worker_->setSupportedVersions({QuicVersion::MVFST}); worker_->setTransportStatsCallback(std::move(quicStats)); worker_->setTransportSettings(settings); worker_->setSocket(std::move(sock)); worker_->setWorkerId(42); worker_->setProcessId(ProcessId::ONE); worker_->setHostId(hostId_); worker_->setConnectionIdAlgo(std::make_unique()); worker_->setCongestionControllerFactory( std::make_shared()); quicStats_ = (MockQuicStats*)worker_->getTransportStatsCallback(); auto cb = [&](const folly::SocketAddress& addr, std::unique_ptr& routingData, std::unique_ptr& networkData, folly::Optional quicVersion, bool isForwardedData) { worker_->dispatchPacketData( addr, std::move(*routingData.get()), std::move(*networkData.get()), std::move(quicVersion), isForwardedData); }; EXPECT_CALL(*workerCb_, routeDataToWorkerLong(_, _, _, _, _)) .WillRepeatedly(Invoke(cb)); socketFactory_ = std::make_unique(); EXPECT_CALL(*socketFactory_, _make(_, _)).WillRepeatedly(Return(nullptr)); worker_->setNewConnectionSocketFactory(socketFactory_.get()); NiceMock connSetupCb; NiceMock connCb; std::unique_ptr mockSock = std::make_unique>( &eventbase_); EXPECT_CALL(*mockSock, address()).WillRepeatedly(ReturnRef(fakeAddress_)); transport_.reset(new MockQuicTransport( worker_->getEventBase(), std::move(mockSock), &connSetupCb, &connCb, nullptr)); factory_ = std::make_unique(); EXPECT_CALL(*transport_, getEventBase()) .WillRepeatedly(Return(&eventbase_)); EXPECT_CALL(*transport_, getOriginalPeerAddress()) .WillRepeatedly(ReturnRef(kClientAddr)); EXPECT_CALL(*transport_, hasShutdown()) .WillRepeatedly(ReturnPointee(&hasShutdown_)); worker_->setTransportFactory(factory_.get()); EXPECT_CALL(*quicStats_, onTokenDecryptFailure()).Times(0); } void createQuicConnection( const folly::SocketAddress& addr, ConnectionId connId, MockQuicTransport::Ptr transportOverride = nullptr); void expectConnectionCreation( const folly::SocketAddress& addr, MockQuicTransport::Ptr transportOverride = nullptr); void testSendReset( Buf packet, ConnectionId connId, ShortHeader shortHeader, QuicTransportStatsCallback::PacketDropReason dropReason); std::string testSendRetry( ConnectionId& srcConnId, ConnectionId& dstConnId, const folly::SocketAddress& clientAddr); std::string testSendRetryUnfinished( ConnectionId& srcConnId, ConnectionId& dstConnId, const folly::SocketAddress& clientAddr); void testSendInitialWithRetryToken( const std::string& retryToken, ConnectionId& srcConnId, ConnectionId& dstConnId, const folly::SocketAddress& clientAddr); void expectConnCreateRefused(); void createQuicConnectionDuringShedding( const folly::SocketAddress& addr, ConnectionId connId); void expectServerRetryPacketWrite( std::string& encryptedRetryToken, ConnectionId& dstConnId, const folly::SocketAddress& clientAddr); protected: folly::SocketAddress fakeAddress_; folly::EventBase eventbase_; std::unique_ptr worker_; MockQuicTransport::Ptr transport_; std::shared_ptr workerCb_; std::unique_ptr factory_; std::unique_ptr listenerSocketFactory_; std::unique_ptr socketFactory_; MockQuicStats* quicStats_{nullptr}; folly::test::MockAsyncUDPSocket* socketPtr_{nullptr}; uint16_t hostId_{49}; bool hasShutdown_{false}; TokenSecret tokenSecret_; }; void QuicServerWorkerTest::expectConnectionCreation( const folly::SocketAddress& addr, MockQuicTransport::Ptr transportOverride) { MockQuicTransport::Ptr transport = transport_; if (transportOverride) { transport = transportOverride; } EXPECT_CALL(*factory_, _make(_, _, _, _)).WillOnce(Return(transport)); EXPECT_CALL(*transport, setSupportedVersions(_)); EXPECT_CALL(*transport, setOriginalPeerAddress(addr)); EXPECT_CALL(*transport, setRoutingCallback(worker_.get())); EXPECT_CALL(*transport, setConnectionIdAlgo(_)); EXPECT_CALL(*transport, setServerConnectionIdParams(_)) .WillOnce(Invoke([](ServerConnectionIdParams params) { EXPECT_EQ(params.processId, 1); EXPECT_EQ(params.workerId, 42); })); EXPECT_CALL(*transport, setTransportSettings(_)); EXPECT_CALL(*transport, accept()); EXPECT_CALL(*transport, setTransportStatsCallback(quicStats_)); } void QuicServerWorkerTest::expectConnCreateRefused() { MockQuicTransport::Ptr transport = transport_; EXPECT_CALL(*factory_, _make(_, _, _, _)).WillOnce(Return(nullptr)); EXPECT_CALL(*transport, setSupportedVersions(_)).Times(0); EXPECT_CALL(*transport, setOriginalPeerAddress(_)).Times(0); EXPECT_CALL(*transport, setRoutingCallback(worker_.get())).Times(0); EXPECT_CALL(*transport, setConnectionIdAlgo(_)).Times(0); EXPECT_CALL(*transport, setServerConnectionIdParams(_)).Times(0); EXPECT_CALL(*transport, setTransportSettings(_)).Times(0); EXPECT_CALL(*transport, accept()).Times(0); EXPECT_CALL(*transport, setTransportStatsCallback(quicStats_)).Times(0); EXPECT_CALL(*transport, onNetworkData(_, _)).Times(0); } void QuicServerWorkerTest::createQuicConnectionDuringShedding( const folly::SocketAddress& addr, ConnectionId connId) { PacketNum num = 1; QuicVersion version = QuicVersion::MVFST; LongHeader header(LongHeader::Types::Initial, connId, connId, num, version); RoutingData routingData(HeaderForm::Long, true, false, true, connId, connId); auto data = createData(kMinInitialPacketSize + 10); expectConnCreateRefused(); worker_->dispatchPacketData( addr, std::move(routingData), NetworkData(data->clone(), Clock::now()), version); const auto& addrMap = worker_->getSrcToTransportMap(); EXPECT_EQ(0, addrMap.count(std::make_pair(addr, connId))); eventbase_.loopIgnoreKeepAlive(); } void QuicServerWorkerTest::createQuicConnection( const folly::SocketAddress& addr, ConnectionId connId, MockQuicTransport::Ptr transportOverride) { PacketNum num = 1; QuicVersion version = QuicVersion::MVFST; LongHeader header(LongHeader::Types::Initial, connId, connId, num, version); RoutingData routingData(HeaderForm::Long, true, false, true, connId, connId); auto data = createData(kMinInitialPacketSize + 10); MockQuicTransport::Ptr transport = transport_; if (transportOverride) { transport = transportOverride; } expectConnectionCreation(addr, transport); EXPECT_CALL(*transport, onNetworkData(addr, NetworkDataMatches(*data))); worker_->dispatchPacketData( addr, std::move(routingData), NetworkData(data->clone(), Clock::now()), version); const auto& addrMap = worker_->getSrcToTransportMap(); EXPECT_EQ(addrMap.count(std::make_pair(addr, connId)), 1); eventbase_.loopIgnoreKeepAlive(); } void QuicServerWorkerTest::testSendReset( Buf packet, ConnectionId, ShortHeader shortHeader, QuicTransportStatsCallback::PacketDropReason dropReason) { EXPECT_CALL(*quicStats_, onPacketDropped(dropReason)).Times(1); // should write reset packet EXPECT_CALL(*quicStats_, onWrite(_)).Times(1); EXPECT_CALL(*quicStats_, onPacketSent()).Times(1); EXPECT_CALL(*quicStats_, onStatelessReset()).Times(1); // verify that the packet that gets written is stateless reset packet EXPECT_CALL(*socketPtr_, write(_, _)) .WillOnce(Invoke([&](const folly::SocketAddress&, const std::unique_ptr& buf) { QuicReadCodec codec(QuicNodeType::Client); auto aead = createNoOpAead(); // Make the decrypt fail EXPECT_CALL(*aead, _tryDecrypt(_, _, _)) .WillRepeatedly( Invoke([&](auto&, auto, auto) { return folly::none; })); codec.setOneRttReadCipher(std::move(aead)); codec.setOneRttHeaderCipher(test::createNoOpHeaderCipher()); StatelessResetToken token = generateStatelessResetToken(); codec.setStatelessResetToken(token); overridePacketWithToken(*buf, token); AckStates ackStates; auto packetQueue = bufToQueue(buf->clone()); auto res = codec.parsePacket(packetQueue, ackStates); EXPECT_NE(res.statelessReset(), nullptr); return buf->computeChainDataLength(); })); RoutingData routingData( HeaderForm::Short, false, false, false, shortHeader.getConnectionId(), folly::none); worker_->dispatchPacketData( kClientAddr, std::move(routingData), NetworkData(packet->clone(), Clock::now()), folly::none); eventbase_.loopIgnoreKeepAlive(); } /** * Helper method that expects a socketPtr_->write() contains a retry packet, and * retrieves the token value from its header. Sets the retry token value to the * encryptedRetryToken parameter. */ void QuicServerWorkerTest::expectServerRetryPacketWrite( std::string& encryptedRetryToken, ConnectionId& dstConnId, const folly::SocketAddress& clientAddr) { EXPECT_CALL(*quicStats_, onConnectionRateLimited()).Times(1); EXPECT_CALL(*quicStats_, onWrite(_)).Times(1); EXPECT_CALL(*quicStats_, onPacketSent()).Times(1); EXPECT_CALL(*socketPtr_, write(_, _)) .WillOnce(Invoke([&](const folly::SocketAddress&, const std::unique_ptr& buf) { // Validate that the retry packet has been created QuicReadCodec codec(QuicNodeType::Client); auto packetQueue = bufToQueue(buf->clone()); AckStates ackStates; auto parsedPacket = codec.parsePacket(packetQueue, ackStates); EXPECT_NE(parsedPacket.retryPacket(), nullptr); // Validate the generated retry token RetryPacket* retryPacket = parsedPacket.retryPacket(); EXPECT_TRUE(retryPacket->header.hasToken()); encryptedRetryToken = retryPacket->header.getToken(); auto tokenBuf = folly::IOBuf::copyBuffer(encryptedRetryToken); TokenGenerator generator(tokenSecret_); RetryToken token(dstConnId, clientAddr.getIPAddress(), 0); auto maybeDecryptedTokenMs = generator.decryptToken( std::move(tokenBuf), token.genAeadAssocData()); EXPECT_TRUE(maybeDecryptedTokenMs); return buf->computeChainDataLength(); })); } // Helper method to send a client initial to the server // Will return the encrypted retry token std::string QuicServerWorkerTest::testSendRetry( ConnectionId& srcConnId, ConnectionId& dstConnId, const folly::SocketAddress& clientAddr) { // Retry packet will only be sent if rate-limiting is configured worker_->setRateLimiter( std::make_unique([]() { return 0; }, 60s)); // Send a client inital to the server - the server will respond with retry // packet RoutingData routingData( HeaderForm::Long, true, false, true, dstConnId, srcConnId); auto data = createData(kMinInitialPacketSize); std::string encryptedRetryToken; expectServerRetryPacketWrite(encryptedRetryToken, dstConnId, clientAddr); worker_->dispatchPacketData( clientAddr, std::move(routingData), NetworkData(data->clone(), Clock::now()), QuicVersion::MVFST); eventbase_.loopIgnoreKeepAlive(); return encryptedRetryToken; } std::string QuicServerWorkerTest::testSendRetryUnfinished( ConnectionId& srcConnId, ConnectionId& dstConnId, const folly::SocketAddress& clientAddr) { worker_->setUnfinishedHandshakeLimit([]() { return 0; }); // Send a client inital to the server - the server will respond with retry // packet RoutingData routingData( HeaderForm::Long, true, false, true, dstConnId, srcConnId); auto data = createData(kMinInitialPacketSize); std::string encryptedRetryToken; expectServerRetryPacketWrite(encryptedRetryToken, dstConnId, clientAddr); worker_->dispatchPacketData( clientAddr, std::move(routingData), NetworkData(data->clone(), Clock::now()), QuicVersion::MVFST); eventbase_.loopIgnoreKeepAlive(); return encryptedRetryToken; } // Helper method to send a client initial that contains the retry token to the // server in response to the retry packet void QuicServerWorkerTest::testSendInitialWithRetryToken( const std::string& retryToken, ConnectionId& srcConnId, ConnectionId& dstConnId, const folly::SocketAddress& clientAddr) { QuicVersion version = QuicVersion::MVFST; PacketNum num = 1; LongHeader initialHeader( LongHeader::Types::Initial, srcConnId, dstConnId, num, version, retryToken); RegularQuicPacketBuilder initialBuilder( kDefaultUDPSendPacketLen, std::move(initialHeader), 0); initialBuilder.encodePacketHeader(); while (initialBuilder.remainingSpaceInPkt() > 0) { writeFrame(PaddingFrame(), initialBuilder); } auto initialPacket = packetToBuf(std::move(initialBuilder).buildPacket()); RoutingData routingData( HeaderForm::Long, true, false, true, dstConnId, srcConnId); worker_->dispatchPacketData( clientAddr, std::move(routingData), NetworkData(initialPacket->clone(), Clock::now()), version); } TEST_F(QuicServerWorkerTest, HostIdMismatchTestReset) { auto data = folly::IOBuf::copyBuffer("data"); EXPECT_CALL(*socketPtr_, address()).WillRepeatedly(ReturnRef(fakeAddress_)); PacketNum num = 2; // create packet with connId with different hostId encoded ShortHeader shortHeaderConnId( ProtectionType::KeyPhaseZero, getTestConnectionId(hostId_ + 1), num); testSendReset( std::move(data), getTestConnectionId(hostId_ + 1), std::move(shortHeaderConnId), QuicTransportStatsCallback::PacketDropReason::ROUTING_ERROR_WRONG_HOST); } TEST_F(QuicServerWorkerTest, NoConnFoundTestReset) { EXPECT_CALL(*socketPtr_, address()).WillRepeatedly(ReturnRef(fakeAddress_)); auto data = folly::IOBuf::copyBuffer("data"); PacketNum num = 2; // create packet with connId with different hostId encoded worker_->stopPacketForwarding(); ShortHeader shortHeaderConnId( ProtectionType::KeyPhaseZero, getTestConnectionId(hostId_), num); testSendReset( std::move(data), getTestConnectionId(hostId_), std::move(shortHeaderConnId), QuicTransportStatsCallback::PacketDropReason::CANNOT_FORWARD_DATA); } TEST_F(QuicServerWorkerTest, RateLimit) { worker_->setRateLimiter( std::make_unique([]() { return 2; }, 60s)); EXPECT_CALL(*quicStats_, onConnectionRateLimited()).Times(1); NiceMock connSetupCb1; NiceMock connCb1; auto mockSock1 = std::make_unique>(&eventbase_); EXPECT_CALL(*mockSock1, address()).WillRepeatedly(ReturnRef(fakeAddress_)); MockQuicTransport::Ptr testTransport1 = std::make_shared( worker_->getEventBase(), std::move(mockSock1), &connSetupCb1, &connCb1, nullptr); EXPECT_CALL(*testTransport1, getEventBase()) .WillRepeatedly(Return(&eventbase_)); EXPECT_CALL(*testTransport1, getOriginalPeerAddress()) .WillRepeatedly(ReturnRef(kClientAddr)); auto connId1 = getTestConnectionId(hostId_); PacketNum num = 1; QuicVersion version = QuicVersion::MVFST; RoutingData routingData( HeaderForm::Long, true, false, true, connId1, connId1); auto data = createData(kMinInitialPacketSize + 10); EXPECT_CALL( *testTransport1, onNetworkData(kClientAddr, NetworkDataMatches(*data))); EXPECT_CALL(*factory_, _make(_, _, _, _)).WillOnce(Return(testTransport1)); worker_->dispatchPacketData( kClientAddr, std::move(routingData), NetworkData(data->clone(), Clock::now()), version); const auto& addrMap = worker_->getSrcToTransportMap(); EXPECT_EQ(addrMap.count(std::make_pair(kClientAddr, connId1)), 1); eventbase_.loopIgnoreKeepAlive(); auto caddr2 = folly::SocketAddress("2.3.4.5", 1234); NiceMock connSetupCb2; NiceMock connCb2; auto mockSock2 = std::make_unique>(&eventbase_); EXPECT_CALL(*mockSock2, address()).WillRepeatedly(ReturnRef(caddr2)); MockQuicTransport::Ptr testTransport2 = std::make_shared( worker_->getEventBase(), std::move(mockSock2), &connSetupCb2, &connCb2, nullptr); EXPECT_CALL(*testTransport2, getEventBase()) .WillRepeatedly(Return(&eventbase_)); EXPECT_CALL(*testTransport2, getOriginalPeerAddress()) .WillRepeatedly(ReturnRef(caddr2)); ConnectionId connId2({2, 4, 5, 6, 7, 8, 9, 10}); num = 1; version = QuicVersion::MVFST; RoutingData routingData2( HeaderForm::Long, true, false, true, connId2, connId2); auto data2 = createData(kMinInitialPacketSize + 10); EXPECT_CALL( *testTransport2, onNetworkData(caddr2, NetworkDataMatches(*data2))); EXPECT_CALL(*factory_, _make(_, _, _, _)).WillOnce(Return(testTransport2)); worker_->dispatchPacketData( caddr2, std::move(routingData2), NetworkData(data2->clone(), Clock::now()), version); EXPECT_EQ(addrMap.count(std::make_pair(caddr2, connId2)), 1); eventbase_.loopIgnoreKeepAlive(); auto caddr3 = folly::SocketAddress("3.3.4.5", 1234); auto mockSock3 = std::make_unique>(&eventbase_); ConnectionId connId3({8, 4, 5, 6, 7, 8, 9, 10}); num = 1; version = QuicVersion::MVFST; RoutingData routingData3( HeaderForm::Long, true, false, true, connId3, connId3); auto data3 = createData(kMinInitialPacketSize + 10); EXPECT_CALL(*factory_, _make(_, _, _, _)).Times(0); worker_->dispatchPacketData( caddr3, std::move(routingData3), NetworkData(data2->clone(), Clock::now()), version); EXPECT_EQ(addrMap.count(std::make_pair(caddr3, connId3)), 0); EXPECT_EQ(addrMap.size(), 2); eventbase_.loopIgnoreKeepAlive(); } TEST_F(QuicServerWorkerTest, UnfinishedHandshakeLimit) { worker_->setUnfinishedHandshakeLimit([]() { return 2; }); EXPECT_CALL(*quicStats_, onConnectionRateLimited()).Times(1); NiceMock connSetupCb1; NiceMock connCb1; auto mockSock1 = std::make_unique>(&eventbase_); EXPECT_CALL(*mockSock1, address()).WillRepeatedly(ReturnRef(fakeAddress_)); MockQuicTransport::Ptr testTransport1 = std::make_shared( worker_->getEventBase(), std::move(mockSock1), &connSetupCb1, &connCb1, nullptr); EXPECT_CALL(*testTransport1, getEventBase()) .WillRepeatedly(Return(&eventbase_)); EXPECT_CALL(*testTransport1, getOriginalPeerAddress()) .WillRepeatedly(ReturnRef(kClientAddr)); auto connId1 = getTestConnectionId(hostId_); QuicVersion version = QuicVersion::MVFST; RoutingData routingData( HeaderForm::Long, true, false, true, connId1, connId1); auto data = createData(kMinInitialPacketSize + 10); EXPECT_CALL( *testTransport1, onNetworkData(kClientAddr, NetworkDataMatches(*data))); EXPECT_CALL(*factory_, _make(_, _, _, _)).WillOnce(Return(testTransport1)); worker_->dispatchPacketData( kClientAddr, std::move(routingData), NetworkData(data->clone(), Clock::now()), version); const auto& addrMap = worker_->getSrcToTransportMap(); EXPECT_EQ(addrMap.count(std::make_pair(kClientAddr, connId1)), 1); eventbase_.loopIgnoreKeepAlive(); auto caddr2 = folly::SocketAddress("2.3.4.5", 1234); NiceMock connSetupCb2; NiceMock connCb2; auto mockSock2 = std::make_unique>(&eventbase_); EXPECT_CALL(*mockSock2, address()).WillRepeatedly(ReturnRef(caddr2)); MockQuicTransport::Ptr testTransport2 = std::make_shared( worker_->getEventBase(), std::move(mockSock2), &connSetupCb2, &connCb2, nullptr); EXPECT_CALL(*testTransport2, getEventBase()) .WillRepeatedly(Return(&eventbase_)); EXPECT_CALL(*testTransport2, getOriginalPeerAddress()) .WillRepeatedly(ReturnRef(caddr2)); ConnectionId connId2({2, 4, 5, 6, 7, 8, 9, 10}); version = QuicVersion::MVFST; RoutingData routingData2( HeaderForm::Long, true, false, true, connId2, connId2); auto data2 = createData(kMinInitialPacketSize + 10); EXPECT_CALL( *testTransport2, onNetworkData(caddr2, NetworkDataMatches(*data2))); EXPECT_CALL(*factory_, _make(_, _, _, _)).WillOnce(Return(testTransport2)); worker_->dispatchPacketData( caddr2, std::move(routingData2), NetworkData(data2->clone(), Clock::now()), version); EXPECT_EQ(addrMap.count(std::make_pair(caddr2, connId2)), 1); eventbase_.loopIgnoreKeepAlive(); auto caddr3 = folly::SocketAddress("3.3.4.5", 1234); auto mockSock3 = std::make_unique>(&eventbase_); ConnectionId connId3({3, 4, 5, 6, 7, 8, 9, 10}); version = QuicVersion::MVFST; RoutingData routingData3( HeaderForm::Long, true, false, true, connId3, connId3); auto data3 = createData(kMinInitialPacketSize + 10); EXPECT_CALL(*factory_, _make(_, _, _, _)).Times(0); worker_->dispatchPacketData( caddr3, std::move(routingData3), NetworkData(data3->clone(), Clock::now()), version); EXPECT_EQ(addrMap.count(std::make_pair(caddr3, connId3)), 0); EXPECT_EQ(addrMap.size(), 2); eventbase_.loopIgnoreKeepAlive(); // Finish a handshake. worker_->onHandshakeFinished(); auto caddr4 = folly::SocketAddress("4.3.4.5", 1234); NiceMock connSetupCb4; NiceMock connCb4; auto mockSock4 = std::make_unique>(&eventbase_); EXPECT_CALL(*mockSock4, address()).WillRepeatedly(ReturnRef(caddr4)); MockQuicTransport::Ptr testTransport4 = std::make_shared( worker_->getEventBase(), std::move(mockSock4), &connSetupCb4, &connCb4, nullptr); EXPECT_CALL(*testTransport4, getEventBase()) .WillRepeatedly(Return(&eventbase_)); EXPECT_CALL(*testTransport4, getOriginalPeerAddress()) .WillRepeatedly(ReturnRef(caddr4)); ConnectionId connId4({4, 4, 5, 6, 7, 8, 9, 10}); version = QuicVersion::MVFST; RoutingData routingData4( HeaderForm::Long, true, false, true, connId4, connId4); auto data4 = createData(kMinInitialPacketSize + 10); EXPECT_CALL( *testTransport4, onNetworkData(caddr4, NetworkDataMatches(*data4))); EXPECT_CALL(*factory_, _make(_, _, _, _)).WillOnce(Return(testTransport4)); worker_->dispatchPacketData( caddr4, std::move(routingData4), NetworkData(data4->clone(), Clock::now()), version); EXPECT_EQ(addrMap.count(std::make_pair(caddr4, connId4)), 1); eventbase_.loopIgnoreKeepAlive(); } // Validate that when the worker receives a valid NewToken, it invokes // transportInfo->onNewTokenReceived(). TEST_F(QuicServerWorkerTest, TestNewTokenStatsCallback) { EXPECT_CALL(*quicStats_, onNewTokenReceived()); NewToken newToken(kClientAddr.getIPAddress()); // Create the encrypted retry token TokenGenerator generator(tokenSecret_); auto encryptedToken = generator.encryptToken(newToken); CHECK(encryptedToken.has_value()); std::string encryptedTokenStr = encryptedToken.value()->moveToFbString().toStdString(); auto dstConnId = getTestConnectionId(hostId_); auto srcConnId = getTestConnectionId(0); // we piggyback the retrytoken logic with a new token testSendInitialWithRetryToken( encryptedTokenStr, srcConnId, dstConnId, kClientAddr); } TEST_F(QuicServerWorkerTest, TestRetryValidInitial) { // The second client initial packet with the retry token is valid // as the client IP is the same as the one stored in the retry token auto dstConnId = getTestConnectionId(hostId_); auto srcConnId = getTestConnectionId(0); expectConnectionCreation(kClientAddr, transport_); EXPECT_CALL(*transport_, setRoutingCallback(nullptr)); EXPECT_CALL(*transport_, setTransportStatsCallback(nullptr)); auto retryToken = testSendRetry(srcConnId, dstConnId, kClientAddr); testSendInitialWithRetryToken(retryToken, srcConnId, dstConnId, kClientAddr); } TEST_F(QuicServerWorkerTest, TestRetryUnfinishedValidInitial) { // The second client initial packet with the retry token is valid // as the client IP is the same as the one stored in the retry token auto dstConnId = getTestConnectionId(hostId_); auto srcConnId = getTestConnectionId(0); expectConnectionCreation(kClientAddr, transport_); EXPECT_CALL(*transport_, setRoutingCallback(nullptr)); EXPECT_CALL(*transport_, setTransportStatsCallback(nullptr)); auto retryToken = testSendRetryUnfinished(srcConnId, dstConnId, kClientAddr); testSendInitialWithRetryToken(retryToken, srcConnId, dstConnId, kClientAddr); } TEST_F(QuicServerWorkerTest, TestRetryInvalidInitialClientIp) { // The second client initial packet with the retry token is invalid // as the client IP is different from the one stored in the retry token auto dstConnId = getTestConnectionId(hostId_); auto srcConnId = getTestConnectionId(0); auto retryToken = testSendRetry(srcConnId, dstConnId, kClientAddr); // Rate limited servers will issue an retry packet in response to invalid // retry tokens std::string encryptedRetryToken; expectServerRetryPacketWrite(encryptedRetryToken, dstConnId, kClientAddr2); EXPECT_CALL(*quicStats_, onTokenDecryptFailure()).Times(1); testSendInitialWithRetryToken(retryToken, srcConnId, dstConnId, kClientAddr2); } TEST_F(QuicServerWorkerTest, TestRetryUnfinishedInvalidInitialClientIp) { // The second client initial packet with the retry token is invalid // as the client IP is different from the one stored in the retry token EXPECT_CALL(*quicStats_, onTokenDecryptFailure()).Times(1); auto dstConnId = getTestConnectionId(hostId_); auto srcConnId = getTestConnectionId(0); auto retryToken = testSendRetryUnfinished(srcConnId, dstConnId, kClientAddr); std::string _; expectServerRetryPacketWrite(_, dstConnId, kClientAddr2); testSendInitialWithRetryToken(retryToken, srcConnId, dstConnId, kClientAddr2); } TEST_F(QuicServerWorkerTest, TestRetryInvalidInitialDstConnId) { // Dest conn ID is invalid as it is different from the original dst conn ID EXPECT_CALL(*quicStats_, onTokenDecryptFailure()).Times(1); auto dstConnId = getTestConnectionId(hostId_); auto srcConnId = getTestConnectionId(0); auto retryToken = testSendRetry(srcConnId, dstConnId, kClientAddr); auto invalidDstConnId = getTestConnectionId(1); // Rate limited servers will issue an retry packet in response to invalid // retry tokens std::string encryptedRetryToken; expectServerRetryPacketWrite( encryptedRetryToken, invalidDstConnId, kClientAddr); testSendInitialWithRetryToken( retryToken, srcConnId, invalidDstConnId, kClientAddr); } TEST_F(QuicServerWorkerTest, QuicServerWorkerUnbindBeforeCidAvailable) { NiceMock connSetupCb; NiceMock connCb; auto mockSock = std::make_unique>(&eventbase_); EXPECT_CALL(*mockSock, address()).WillRepeatedly(ReturnRef(fakeAddress_)); MockQuicTransport::Ptr testTransport = std::make_shared( worker_->getEventBase(), std::move(mockSock), &connSetupCb, &connCb, nullptr); EXPECT_CALL(*testTransport, getEventBase()) .WillRepeatedly(Return(&eventbase_)); EXPECT_CALL(*testTransport, getOriginalPeerAddress()) .WillRepeatedly(ReturnRef(kClientAddr)); auto connId = getTestConnectionId(hostId_); createQuicConnection(kClientAddr, connId, testTransport); // Otherwise the mock of _make will hold on to a shared_ptr to the transport Mock::VerifyAndClearExpectations(factory_.get()); auto& srcAddrMap = worker_->getSrcToTransportMap(); EXPECT_EQ(1, srcAddrMap.size()); EXPECT_EQ(srcAddrMap.begin()->second.get(), testTransport.get()); auto& srcIdentity = srcAddrMap.begin()->first; auto& connIdMap = worker_->getConnectionIdMap(); EXPECT_EQ(0, connIdMap.size()); auto* rawTransport = testTransport.get(); // This is fine, server worker still has at one shared_ptr in its map. testTransport.reset(); EXPECT_CALL(*rawTransport, setRoutingCallback(nullptr)).Times(1); // Now remove it from the maps. Nothing should crash. auto cidDataOnHeap = std::make_shared>(); const auto& cidRef = *cidDataOnHeap; EXPECT_CALL(*rawTransport, customDestructor()) .Times(1) .WillOnce( Invoke([cid = std::move(cidDataOnHeap)]() mutable { cid.reset(); })); worker_->onConnectionUnbound(rawTransport, srcIdentity, cidRef); EXPECT_EQ(0, srcAddrMap.size()); } TEST_F(QuicServerWorkerTest, QuicServerMultipleConnIdsRouting) { EXPECT_CALL(*socketPtr_, address()).WillRepeatedly(ReturnRef(fakeAddress_)); auto connId = getTestConnectionId(hostId_); createQuicConnection(kClientAddr, connId); auto data = folly::IOBuf::copyBuffer("data"); PacketNum num = 2; ShortHeader shortHeaderConnId(ProtectionType::KeyPhaseZero, connId, num); EXPECT_CALL(*quicStats_, onNewConnection()); transport_->QuicServerTransport::setRoutingCallback(worker_.get()); worker_.get()->onConnectionIdAvailable(transport_, connId); const auto& connIdMap = worker_->getConnectionIdMap(); EXPECT_EQ(connIdMap.count(connId), 1); EXPECT_CALL(*transport_, getClientChosenDestConnectionId()) .WillRepeatedly(Return(connId)); worker_->onConnectionIdBound(transport_); const auto& addrMap = worker_->getSrcToTransportMap(); EXPECT_EQ(addrMap.count(std::make_pair(kClientAddr, connId)), 0); // routing by connid after connid available. EXPECT_CALL( *transport_, onNetworkData(kClientAddr, NetworkDataMatches(*data))) .Times(1); RoutingData routingData2( HeaderForm::Short, false, false, false, connId, folly::none); worker_->dispatchPacketData( kClientAddr, std::move(routingData2), NetworkData(data->clone(), Clock::now()), folly::none); eventbase_.loopIgnoreKeepAlive(); auto connId2 = connId; connId2.data()[7] ^= 0x1; worker_.get()->onConnectionIdAvailable(transport_, connId2); EXPECT_EQ(connIdMap.size(), 2); EXPECT_CALL( *transport_, onNetworkData(kClientAddr, NetworkDataMatches(*data))) .Times(1); RoutingData routingData3( HeaderForm::Short, false, false, false, connId2, folly::none); worker_->dispatchPacketData( kClientAddr, std::move(routingData3), NetworkData(data->clone(), Clock::now()), folly::none); eventbase_.loopIgnoreKeepAlive(); EXPECT_CALL(*transport_, setRoutingCallback(nullptr)); worker_->onConnectionUnbound( transport_.get(), std::make_pair(kClientAddr, connId), std::vector{ ConnectionIdData{connId, 0}, ConnectionIdData{connId2, 1}}); EXPECT_EQ(connIdMap.count(connId), 0); EXPECT_EQ(addrMap.count(std::make_pair(kClientAddr, connId)), 0); // transport_ dtor is run at the end of the test, which causes // onConnectionUnbound to be called if the routingCallback_ is // still set. transport_->QuicServerTransport::setRoutingCallback(nullptr); } TEST_F(QuicServerWorkerTest, RetireConnIds) { EXPECT_CALL(*socketPtr_, address()).WillRepeatedly(ReturnRef(fakeAddress_)); auto connId = getTestConnectionId(hostId_); createQuicConnection(kClientAddr, connId); const auto& connIdMap = worker_->getConnectionIdMap(); const auto& addrMap = worker_->getSrcToTransportMap(); EXPECT_CALL(*quicStats_, onNewConnection()); transport_->QuicServerTransport::setRoutingCallback(worker_.get()); worker_->onConnectionIdAvailable(transport_, connId); EXPECT_EQ(connIdMap.count(connId), 1); EXPECT_CALL(*transport_, getClientChosenDestConnectionId()) .WillRepeatedly(Return(connId)); worker_->onConnectionIdBound(transport_); EXPECT_EQ(addrMap.count(std::make_pair(kClientAddr, connId)), 0); // create two additional conn ids auto connId2 = connId; connId2.data()[7] ^= 0x1; worker_->onConnectionIdAvailable(transport_, connId2); auto connId3 = connId; connId3.data()[7] ^= 0x2; worker_->onConnectionIdAvailable(transport_, connId3); eventbase_.loopIgnoreKeepAlive(); EXPECT_EQ(connIdMap.size(), 3); // retiring invalid conn id should not change map size auto invalidConnId = connId; invalidConnId.data()[0] ^= 0x1; worker_->onConnectionIdRetired(*transport_, invalidConnId); EXPECT_EQ(connIdMap.size(), 3); // retiring valid conn ids should decrease map size by one each time worker_->onConnectionIdRetired(*transport_, connId2); EXPECT_EQ(connIdMap.size(), 2); worker_->onConnectionIdRetired(*transport_, connId3); EXPECT_EQ(connIdMap.size(), 1); EXPECT_CALL(*transport_, setRoutingCallback(nullptr)); worker_->onConnectionUnbound( transport_.get(), std::make_pair(kClientAddr, connId), std::vector{ConnectionIdData{connId, 0}}); EXPECT_EQ(connIdMap.count(connId), 0); EXPECT_EQ(addrMap.count(std::make_pair(kClientAddr, connId)), 0); // transport_ dtor is run at the end of the test, which causes // onConnectionUnbound to be called if the routingCallback_ is // still set. transport_->QuicServerTransport::setRoutingCallback(nullptr); } TEST_F(QuicServerWorkerTest, QuicServerNewConnection) { EXPECT_CALL(*socketPtr_, address()).WillRepeatedly(ReturnRef(fakeAddress_)); auto connId = getTestConnectionId(hostId_); LOG(ERROR) << "from test CID: " << int(connId.size()); createQuicConnection(kClientAddr, connId); auto data = folly::IOBuf::copyBuffer("data"); PacketNum num = 2; ShortHeader shortHeaderConnId( ProtectionType::KeyPhaseZero, getTestConnectionId(hostId_), num); // Routing by connid before conn id available on a short packet. EXPECT_CALL(*quicStats_, onPacketDropped(_)).Times(1); RoutingData routingData( HeaderForm::Short, false, false, false, shortHeaderConnId.getConnectionId(), folly::none); worker_->dispatchPacketData( kClientAddr, std::move(routingData), NetworkData(data->clone(), Clock::now()), folly::none); eventbase_.loopIgnoreKeepAlive(); EXPECT_CALL(*quicStats_, onNewConnection()); ConnectionId newConnId = getTestConnectionId(hostId_); transport_->QuicServerTransport::setRoutingCallback(worker_.get()); worker_.get()->onConnectionIdAvailable(transport_, newConnId); const auto& connIdMap = worker_->getConnectionIdMap(); EXPECT_EQ(connIdMap.count(getTestConnectionId(hostId_)), 1); EXPECT_CALL(*transport_, getClientChosenDestConnectionId()) .WillRepeatedly(Return(connId)); worker_->onConnectionIdBound(transport_); const auto& addrMap = worker_->getSrcToTransportMap(); EXPECT_EQ( addrMap.count(std::make_pair(kClientAddr, getTestConnectionId(hostId_))), 0); // routing by connid after connid available. EXPECT_CALL( *transport_, onNetworkData(kClientAddr, NetworkDataMatches(*data))); RoutingData routingData2( HeaderForm::Short, false, false, false, shortHeaderConnId.getConnectionId(), folly::none); worker_->dispatchPacketData( kClientAddr, std::move(routingData2), NetworkData(data->clone(), Clock::now()), folly::none); eventbase_.loopIgnoreKeepAlive(); // routing by address after transport_'s connid available, but before // transport2's connid available. ConnectionId connId2({2, 4, 5, 6, 7, 8, 9, 10}); folly::SocketAddress clientAddr2("2.3.4.5", 2345); NiceMock connSetupCb; NiceMock connCb; auto mockSock = std::make_unique>(&eventbase_); EXPECT_CALL(*mockSock, address()).WillRepeatedly(ReturnRef(fakeAddress_)); MockQuicTransport::Ptr transport2 = std::make_shared( worker_->getEventBase(), std::move(mockSock), &connSetupCb, &connCb, nullptr); EXPECT_CALL(*transport2, getEventBase()).WillRepeatedly(Return(&eventbase_)); EXPECT_CALL(*transport2, getOriginalPeerAddress()) .WillRepeatedly(ReturnRef(kClientAddr)); createQuicConnection(clientAddr2, connId2, transport2); ShortHeader shortHeaderConnId2(ProtectionType::KeyPhaseZero, connId2, num); // Will be dropped EXPECT_CALL(*quicStats_, onPacketDropped(_)).Times(1); RoutingData routingData3( HeaderForm::Short, false, false, false, shortHeaderConnId2.getConnectionId(), folly::none); worker_->dispatchPacketData( clientAddr2, std::move(routingData3), NetworkData(data->clone(), Clock::now()), folly::none); eventbase_.loopIgnoreKeepAlive(); EXPECT_CALL(*transport_, setRoutingCallback(nullptr)).Times(2); worker_->onConnectionUnbound( transport_.get(), std::make_pair(kClientAddr, getTestConnectionId(hostId_)), std::vector{ConnectionIdData{connId, 0}}); worker_->onConnectionUnbound( transport_.get(), std::make_pair(clientAddr2, connId2), std::vector{ConnectionIdData{connId2, 0}}); EXPECT_EQ(connIdMap.count(getTestConnectionId(hostId_)), 0); EXPECT_EQ( addrMap.count(std::make_pair(kClientAddr, getTestConnectionId(hostId_))), 0); // transport_ dtor is run at the end of the test, which causes // onConnectionUnbound to be called if the routingCallback_ is // still set. transport_->QuicServerTransport::setRoutingCallback(nullptr); } TEST_F(QuicServerWorkerTest, InitialPacketTooSmall) { auto data = createData(kMinInitialPacketSize - 100); auto connId = getTestConnectionId(hostId_); PacketNum num = 1; QuicVersion version = QuicVersion::MVFST; LongHeader header( LongHeader::Types::Initial, getTestConnectionId(hostId_ + 1), connId, num, version); EXPECT_CALL(*factory_, _make(_, _, _, _)).Times(0); EXPECT_CALL(*quicStats_, onPacketDropped(_)); RoutingData routingData( HeaderForm::Long, true, false, true, header.getDestinationConnId(), header.getSourceConnId()); worker_->dispatchPacketData( kClientAddr, std::move(routingData), NetworkData(data->clone(), Clock::now()), version); eventbase_.loopIgnoreKeepAlive(); } TEST_F(QuicServerWorkerTest, QuicShedTest) { auto connId = getTestConnectionId(hostId_); EXPECT_CALL( *quicStats_, onPacketDropped(PacketDropReason::CANNOT_MAKE_TRANSPORT)); folly::Optional versionPacket; EXPECT_CALL(*socketPtr_, write(_, _)) .WillOnce(Invoke([&](const SocketAddress&, const std::unique_ptr& writtenData) { auto codec = std::make_unique(QuicNodeType::Server); auto packetQueue = bufToQueue(writtenData->clone()); versionPacket = codec->tryParsingVersionNegotiation(packetQueue); return packetQueue.chainLength(); })); createQuicConnectionDuringShedding(kClientAddr, connId); EXPECT_TRUE(versionPacket.has_value()); EXPECT_EQ(versionPacket->destinationConnectionId, connId); } TEST_F(QuicServerWorkerTest, BlockedSourcePort) { folly::SocketAddress blockedSrcPort("1.2.3.4", 443); worker_->setIsBlockListedSrcPort([](uint16_t port) { return port == 443; }); auto data = createData(kDefaultUDPSendPacketLen); auto connId = ConnectionId(std::vector()); PacketNum num = 1; QuicVersion version = QuicVersion::MVFST; LongHeader header(LongHeader::Types::Initial, connId, connId, num, version); EXPECT_CALL(*quicStats_, onPacketDropped(PacketDropReason::INVALID_SRC_PORT)); RegularQuicPacketBuilder builder( kDefaultUDPSendPacketLen, std::move(header), 0 /* largestAcked */); builder.encodePacketHeader(); while (builder.remainingSpaceInPkt() > 0) { writeFrame(PaddingFrame(), builder); } auto packet = packetToBuf(std::move(builder).buildPacket()); worker_->handleNetworkData(blockedSrcPort, std::move(packet), Clock::now()); eventbase_.loopIgnoreKeepAlive(); } TEST_F(QuicServerWorkerTest, ZeroLengthConnectionId) { auto data = createData(kDefaultUDPSendPacketLen); auto connId = ConnectionId(std::vector()); PacketNum num = 1; QuicVersion version = QuicVersion::MVFST; LongHeader header(LongHeader::Types::Initial, connId, connId, num, version); EXPECT_CALL(*quicStats_, onPacketDropped(_)).Times(1); RegularQuicPacketBuilder builder( kDefaultUDPSendPacketLen, std::move(header), 0 /* largestAcked */); builder.encodePacketHeader(); while (builder.remainingSpaceInPkt() > 0) { writeFrame(PaddingFrame(), builder); } auto packet = packetToBuf(std::move(builder).buildPacket()); worker_->handleNetworkData(kClientAddr, std::move(packet), Clock::now()); eventbase_.loopIgnoreKeepAlive(); } TEST_F(QuicServerWorkerTest, ClientInitialCounting) { auto srcConnId = getTestConnectionId(0); auto destConnId = getTestConnectionId(1); QuicVersion version = QuicVersion::MVFST; PacketNum num = 1; LongHeader initialHeader( LongHeader::Types::Initial, srcConnId, destConnId, num, version); RegularQuicPacketBuilder initialBuilder( kDefaultUDPSendPacketLen, std::move(initialHeader), 0); initialBuilder.encodePacketHeader(); auto initialPacket = packetToBuf(std::move(initialBuilder).buildPacket()); EXPECT_CALL(*quicStats_, onClientInitialReceived(QuicVersion::MVFST)) .Times(1); worker_->handleNetworkData( kClientAddr, std::move(initialPacket), Clock::now()); eventbase_.loopIgnoreKeepAlive(); // Initial with any packet number should also increate the counting PacketNum bignum = 200; LongHeader initialHeaderBigNum( LongHeader::Types::Initial, srcConnId, destConnId, bignum, version); RegularQuicPacketBuilder initialBuilderBigNum( kDefaultUDPSendPacketLen, std::move(initialHeaderBigNum), 0); initialBuilderBigNum.encodePacketHeader(); auto initialPacketBigNum = packetToBuf(std::move(initialBuilderBigNum).buildPacket()); EXPECT_CALL(*quicStats_, onClientInitialReceived(QuicVersion::MVFST)) .Times(1); worker_->handleNetworkData( kClientAddr, std::move(initialPacketBigNum), Clock::now()); eventbase_.loopIgnoreKeepAlive(); LongHeader handshakeHeader( LongHeader::Types::Handshake, srcConnId, destConnId, num, version); RegularQuicPacketBuilder handshakeBuilder( kDefaultUDPSendPacketLen, std::move(handshakeHeader), 0); handshakeBuilder.encodePacketHeader(); auto handshakePacket = packetToBuf(std::move(handshakeBuilder).buildPacket()); EXPECT_CALL(*quicStats_, onClientInitialReceived(_)).Times(0); worker_->handleNetworkData( kClientAddr, std::move(handshakePacket), Clock::now()); eventbase_.loopIgnoreKeepAlive(); } TEST_F(QuicServerWorkerTest, ConnectionIdTooShort) { auto data = createData(kDefaultUDPSendPacketLen); auto connId = ConnectionId::createWithoutChecks({1}); PacketNum num = 1; QuicVersion version = QuicVersion::MVFST; LongHeader header(LongHeader::Types::Initial, connId, connId, num, version); EXPECT_CALL(*quicStats_, onPacketDropped(_)).Times(1); EXPECT_CALL(*quicStats_, onPacketProcessed()).Times(0); EXPECT_CALL(*quicStats_, onPacketSent()).Times(0); EXPECT_CALL(*quicStats_, onWrite(_)).Times(0); RegularQuicPacketBuilder builder( kDefaultUDPSendPacketLen, std::move(header), 0 /* largestAcked */); builder.encodePacketHeader(); while (builder.remainingSpaceInPkt() > 0) { writeFrame(PaddingFrame(), builder); } auto packet = packetToBuf(std::move(builder).buildPacket()); worker_->handleNetworkData(kClientAddr, std::move(packet), Clock::now()); eventbase_.loopIgnoreKeepAlive(); } TEST_F(QuicServerWorkerTest, FailToParseConnectionId) { auto data = createData(kDefaultUDPSendPacketLen); auto srcConnId = getTestConnectionId(0); auto dstConnId = getTestConnectionId(1); auto mockConnIdAlgo = std::make_unique(); auto rawConnIdAlgo = mockConnIdAlgo.get(); worker_->setConnectionIdAlgo(std::move(mockConnIdAlgo)); PacketNum num = 1; QuicVersion version = QuicVersion::MVFST; LongHeader header( LongHeader::Types::Initial, srcConnId, dstConnId, num, version); RegularQuicPacketBuilder builder( kDefaultUDPSendPacketLen, std::move(header), 0 /* largestAcked */); builder.encodePacketHeader(); while (builder.remainingSpaceInPkt() > 0) { writeFrame(PaddingFrame(), builder); } auto packet = packetToBuf(std::move(builder).buildPacket()); // To force dropping path, set initial to false RoutingData routingData( HeaderForm::Long, false /* isInitial */, false, true /* isUsingClientCid */, dstConnId, srcConnId); NetworkData networkData(std::move(packet), Clock::now()); EXPECT_CALL(*rawConnIdAlgo, canParseNonConst(_)).WillOnce(Return(true)); EXPECT_CALL(*rawConnIdAlgo, parseConnectionId(dstConnId)) .WillOnce(Return(folly::makeUnexpected(QuicInternalException( "This CID has COVID-19", LocalErrorCode::INTERNAL_ERROR)))); EXPECT_CALL(*quicStats_, onPacketDropped(_)).Times(1); EXPECT_CALL(*quicStats_, onPacketProcessed()).Times(0); EXPECT_CALL(*quicStats_, onPacketSent()).Times(0); EXPECT_CALL(*quicStats_, onWrite(_)).Times(0); worker_->dispatchPacketData( kClientAddr, std::move(routingData), std::move(networkData), version); eventbase_.loopIgnoreKeepAlive(); } TEST_F(QuicServerWorkerTest, ConnectionIdTooShortDispatch) { auto data = createData(kDefaultUDPSendPacketLen); auto dstConnId = ConnectionId::createWithoutChecks({3}); auto srcConnId = ConnectionId::createWithoutChecks({3}); PacketNum num = 1; QuicVersion version = QuicVersion::MVFST; LongHeader header( LongHeader::Types::Initial, srcConnId, dstConnId, num, version); EXPECT_CALL(*factory_, _make(_, _, _, _)).Times(0); EXPECT_CALL(*quicStats_, onPacketDropped(_)).Times(1); EXPECT_CALL(*quicStats_, onPacketProcessed()).Times(0); EXPECT_CALL(*quicStats_, onPacketSent()).Times(0); EXPECT_CALL(*quicStats_, onWrite(_)).Times(0); RegularQuicPacketBuilder builder( kDefaultUDPSendPacketLen, std::move(header), 0 /* largestAcked */); builder.encodePacketHeader(); while (builder.remainingSpaceInPkt() > 0) { writeFrame(PaddingFrame(), builder); } auto packet = packetToBuf(std::move(builder).buildPacket()); RoutingData routingData( HeaderForm::Long, true, false, true, dstConnId, srcConnId); NetworkData networkData(std::move(packet), Clock::now()); worker_->dispatchPacketData( kClientAddr, std::move(routingData), std::move(networkData), version); eventbase_.loopIgnoreKeepAlive(); } TEST_F(QuicServerWorkerTest, ConnectionIdTooLargeDispatch) { auto data = createData(kDefaultUDPSendPacketLen); auto dstConnId = ConnectionId::createWithoutChecks({21}); auto srcConnId = ConnectionId::createWithoutChecks({3}); PacketNum num = 1; QuicVersion version = QuicVersion::MVFST; LongHeader header( LongHeader::Types::Initial, srcConnId, dstConnId, num, version); EXPECT_CALL(*factory_, _make(_, _, _, _)).Times(0); EXPECT_CALL(*quicStats_, onPacketDropped(_)).Times(1); EXPECT_CALL(*quicStats_, onPacketProcessed()).Times(0); EXPECT_CALL(*quicStats_, onPacketSent()).Times(0); EXPECT_CALL(*quicStats_, onWrite(_)).Times(0); RegularQuicPacketBuilder builder( kDefaultUDPSendPacketLen, std::move(header), 0 /* largestAcked */); builder.encodePacketHeader(); while (builder.remainingSpaceInPkt() > 0) { writeFrame(PaddingFrame(), builder); } auto packet = packetToBuf(std::move(builder).buildPacket()); RoutingData routingData( HeaderForm::Long, true, false, true, dstConnId, srcConnId); NetworkData networkData(std::move(packet), Clock::now()); worker_->dispatchPacketData( kClientAddr, std::move(routingData), std::move(networkData), version); eventbase_.loopIgnoreKeepAlive(); } TEST_F(QuicServerWorkerTest, ShutdownQuicServer) { auto connId = getTestConnectionId(hostId_); createQuicConnection(kClientAddr, connId); EXPECT_CALL(*quicStats_, onNewConnection()); worker_->onConnectionIdAvailable(transport_, getTestConnectionId(hostId_)); const auto& connIdMap = worker_->getConnectionIdMap(); EXPECT_EQ(connIdMap.count(getTestConnectionId(hostId_)), 1); EXPECT_CALL(*transport_, setRoutingCallback(nullptr)).Times(2); EXPECT_CALL(*transport_, setTransportStatsCallback(nullptr)).Times(2); EXPECT_CALL(*transport_, close(_)).WillRepeatedly(Invoke([this](auto) { hasShutdown_ = true; })); std::thread t([&] { eventbase_.loopForever(); }); worker_->shutdownAllConnections(LocalErrorCode::SHUTTING_DOWN); eventbase_.terminateLoopSoon(); t.join(); } TEST_F(QuicServerWorkerTest, PacketAfterShutdown) { std::thread t([&] { eventbase_.loopForever(); }); worker_->shutdownAllConnections(LocalErrorCode::SHUTTING_DOWN); auto connId = getTestConnectionId(hostId_); PacketNum packetNum = 1; QuicVersion version = QuicVersion::MVFST; LongHeader header( LongHeader::Types::Initial, connId, connId, packetNum, version); EXPECT_CALL(*factory_, _make(_, _, _, _)).Times(0); RegularQuicPacketBuilder builder( kDefaultUDPSendPacketLen, std::move(header), 0 /* largestAcked */); builder.encodePacketHeader(); auto packet = packetToBuf(std::move(builder).buildPacket()); worker_->handleNetworkData(kClientAddr, std::move(packet), Clock::now()); eventbase_.terminateLoopSoon(); t.join(); } TEST_F(QuicServerWorkerTest, DestroyQuicServer) { auto connId = getTestConnectionId(hostId_); createQuicConnection(kClientAddr, connId); EXPECT_CALL(*quicStats_, onNewConnection()); worker_->onConnectionIdAvailable(transport_, getTestConnectionId(hostId_)); const auto& connIdMap = worker_->getConnectionIdMap(); EXPECT_EQ(connIdMap.count(getTestConnectionId(hostId_)), 1); EXPECT_CALL(*transport_, setRoutingCallback(nullptr)).Times(2); EXPECT_CALL(*transport_, setTransportStatsCallback(nullptr)).Times(2); EXPECT_CALL(*transport_, setHandshakeFinishedCallback(nullptr)).Times(2); EXPECT_CALL(*transport_, close(_)).WillRepeatedly(Invoke([this](auto) { hasShutdown_ = true; })); std::thread t([&] { eventbase_.loopForever(); }); worker_.reset(); eventbase_.terminateLoopSoon(); t.join(); } TEST_F(QuicServerWorkerTest, AssignBufAccessor) { EXPECT_CALL(*socketPtr_, address()).WillRepeatedly(ReturnRef(fakeAddress_)); auto connId = getTestConnectionId(hostId_); TransportSettings transportSettings; transportSettings.dataPathType = DataPathType::ContinuousMemory; transportSettings.batchingMode = QuicBatchingMode::BATCHING_MODE_GSO; worker_->setTransportSettings(transportSettings); EXPECT_CALL(*transport_, setBufAccessor(_)) .Times(1) .WillOnce(Invoke( [](BufAccessor* accessor) { EXPECT_TRUE(accessor != nullptr); })); createQuicConnection(kClientAddr, connId); // From shutdownAllConnections: EXPECT_CALL(*transport_, setRoutingCallback(nullptr)).Times(1); EXPECT_CALL(*transport_, setTransportStatsCallback(nullptr)).Times(1); } class MockAcceptObserver : public AcceptObserver { public: MOCK_METHOD(void, accept, (QuicTransportBase* const), (noexcept)); MOCK_METHOD(void, acceptorDestroy, (QuicServerWorker*), (noexcept)); MOCK_METHOD(void, observerAttach, (QuicServerWorker*), (noexcept)); MOCK_METHOD(void, observerDetach, (QuicServerWorker*), (noexcept)); }; TEST_F(QuicServerWorkerTest, AcceptObserver) { auto cb = std::make_unique>(); EXPECT_CALL(*cb, observerAttach(worker_.get())); worker_->addAcceptObserver(cb.get()); auto initTestSocketAndTransport = [this]() { NiceMock connSetupCb; NiceMock connCb; auto mockSock = std::make_unique>( &eventbase_); EXPECT_CALL(*mockSock, address()).WillRepeatedly(ReturnRef(fakeAddress_)); MockQuicTransport::Ptr mockTransport = std::make_shared( worker_->getEventBase(), std::move(mockSock), &connSetupCb, &connCb, nullptr); EXPECT_CALL(*mockTransport, setRoutingCallback(nullptr)); EXPECT_CALL(*mockTransport, setTransportStatsCallback(nullptr)); EXPECT_CALL(*mockTransport, getEventBase()) .WillRepeatedly(Return(&eventbase_)); EXPECT_CALL(*mockTransport, getOriginalPeerAddress()) .WillRepeatedly(ReturnRef(kClientAddr)); return std::make_pair(std::move(mockSock), std::move(mockTransport)); }; // add first connection, expect events const auto [sock1, transport1] = initTestSocketAndTransport(); EXPECT_CALL(*cb, accept(transport1.get())); createQuicConnection(kClientAddr, getTestConnectionId(hostId_), transport1); Mock::VerifyAndClearExpectations(cb.get()); // add second connection, expect events const auto [sock2, transport2] = initTestSocketAndTransport(); EXPECT_CALL(*cb, accept(transport2.get())); createQuicConnection(kClientAddr2, getTestConnectionId(hostId_), transport2); Mock::VerifyAndClearExpectations(cb.get()); // remove AcceptObserver EXPECT_CALL(*cb, observerDetach(worker_.get())); EXPECT_TRUE(worker_->removeAcceptObserver(cb.get())); Mock::VerifyAndClearExpectations(cb.get()); // add third connection, no events const auto [sock3, transport3] = initTestSocketAndTransport(); createQuicConnection(kClientAddr3, getTestConnectionId(hostId_), transport3); } TEST_F(QuicServerWorkerTest, AcceptObserverRemove) { auto cb = std::make_unique>(); EXPECT_CALL(*cb, observerAttach(worker_.get())); worker_->addAcceptObserver(cb.get()); Mock::VerifyAndClearExpectations(cb.get()); EXPECT_CALL(*cb, observerDetach(worker_.get())); EXPECT_TRUE(worker_->removeAcceptObserver(cb.get())); Mock::VerifyAndClearExpectations(cb.get()); } TEST_F(QuicServerWorkerTest, AcceptObserverRemoveMissing) { auto cb = std::make_unique>(); EXPECT_FALSE(worker_->removeAcceptObserver(cb.get())); } TEST_F(QuicServerWorkerTest, AcceptObserverDestroyWorker) { auto cb = std::make_unique>(); EXPECT_CALL(*cb, observerAttach(worker_.get())); worker_->addAcceptObserver(cb.get()); Mock::VerifyAndClearExpectations(cb.get()); // destroy the acceptor while the AcceptObserver is installed EXPECT_CALL(*cb, acceptorDestroy(worker_.get())); worker_ = nullptr; Mock::VerifyAndClearExpectations(cb.get()); } TEST_F(QuicServerWorkerTest, AcceptObserverMultipleRemove) { auto cb1 = std::make_unique>(); EXPECT_CALL(*cb1, observerAttach(worker_.get())); worker_->addAcceptObserver(cb1.get()); Mock::VerifyAndClearExpectations(cb1.get()); auto cb2 = std::make_unique>(); EXPECT_CALL(*cb2, observerAttach(worker_.get())); worker_->addAcceptObserver(cb2.get()); Mock::VerifyAndClearExpectations(cb2.get()); EXPECT_CALL(*cb1, observerDetach(worker_.get())); EXPECT_TRUE(worker_->removeAcceptObserver(cb1.get())); Mock::VerifyAndClearExpectations(cb1.get()); EXPECT_CALL(*cb2, observerDetach(worker_.get())); EXPECT_TRUE(worker_->removeAcceptObserver(cb2.get())); Mock::VerifyAndClearExpectations(cb2.get()); } TEST_F(QuicServerWorkerTest, AcceptObserverMultipleRemoveReverse) { auto cb1 = std::make_unique>(); EXPECT_CALL(*cb1, observerAttach(worker_.get())); worker_->addAcceptObserver(cb1.get()); Mock::VerifyAndClearExpectations(cb1.get()); auto cb2 = std::make_unique>(); EXPECT_CALL(*cb2, observerAttach(worker_.get())); worker_->addAcceptObserver(cb2.get()); Mock::VerifyAndClearExpectations(cb2.get()); EXPECT_CALL(*cb2, observerDetach(worker_.get())); EXPECT_TRUE(worker_->removeAcceptObserver(cb2.get())); Mock::VerifyAndClearExpectations(cb2.get()); EXPECT_CALL(*cb1, observerDetach(worker_.get())); EXPECT_TRUE(worker_->removeAcceptObserver(cb1.get())); Mock::VerifyAndClearExpectations(cb1.get()); } TEST_F(QuicServerWorkerTest, AcceptObserverMultipleAcceptorDestroyed) { auto cb1 = std::make_unique>(); EXPECT_CALL(*cb1, observerAttach(worker_.get())); worker_->addAcceptObserver(cb1.get()); Mock::VerifyAndClearExpectations(cb1.get()); auto cb2 = std::make_unique>(); EXPECT_CALL(*cb2, observerAttach(worker_.get())); worker_->addAcceptObserver(cb2.get()); Mock::VerifyAndClearExpectations(cb2.get()); // destroy the acceptor while the AcceptObserver is installed EXPECT_CALL(*cb1, acceptorDestroy(worker_.get())); EXPECT_CALL(*cb2, acceptorDestroy(worker_.get())); worker_ = nullptr; Mock::VerifyAndClearExpectations(cb1.get()); Mock::VerifyAndClearExpectations(cb2.get()); } TEST_F(QuicServerWorkerTest, AcceptObserverRemoveCallbackThenDestroyWorker) { auto cb = std::make_unique>(); EXPECT_CALL(*cb, observerAttach(worker_.get())); worker_->addAcceptObserver(cb.get()); Mock::VerifyAndClearExpectations(cb.get()); EXPECT_CALL(*cb, observerDetach(worker_.get())); EXPECT_TRUE(worker_->removeAcceptObserver(cb.get())); Mock::VerifyAndClearExpectations(cb.get()); worker_ = nullptr; } auto createInitialStream( ConnectionId srcConnId, ConnectionId destConnId, StreamId streamId, folly::IOBuf& data, QuicVersion version, LongHeader::Types pktHeaderType = LongHeader::Types::Initial) { PacketNum packetNum = 1; LongHeader header(pktHeaderType, srcConnId, destConnId, packetNum, version); LongHeader headerRetry( pktHeaderType, srcConnId, destConnId, packetNum, version, std::string("this is a retry token :)")); RegularQuicPacketBuilder builder( kDefaultUDPSendPacketLen, pktHeaderType == LongHeader::Types::Retry ? std::move(headerRetry) : std::move(header), 0 /* largestAcked */); builder.encodePacketHeader(); auto streamData = data.clone(); auto dataLen = writeStreamFrameHeader( builder, streamId, 0, streamData->computeChainDataLength(), streamData->computeChainDataLength(), true, folly::none /* skipLenHint */); EXPECT_TRUE(dataLen); writeStreamFrameData(builder, std::move(streamData), *dataLen); return packetToBuf(std::move(builder).buildPacket()); } auto createInitialStream( StreamId streamId, folly::IOBuf& data, QuicVersion version, LongHeader::Types pktHeaderType = LongHeader::Types::Initial) { return createInitialStream( getTestConnectionId(), getTestConnectionId(1), streamId, data, version, pktHeaderType); } std::unique_ptr writeTestDataOnWorkersBuf( ConnectionId srcConnId, ConnectionId destConnId, size_t& lenOut, QuicServerWorker* worker, LongHeader::Types pktHeaderType = LongHeader::Types::Initial) { StreamId id = 1; auto buf = pktHeaderType == LongHeader::Types::Initial ? createData(kMinInitialPacketSize) : folly::IOBuf::copyBuffer("hello, world!"); auto packet = createInitialStream( srcConnId, destConnId, id, *buf, MVFST1, pktHeaderType); packet->coalesce(); auto data = std::move(packet); uint8_t* workerBuf = nullptr; size_t workerBufLen = 0; worker->getReadBuffer((void**)&workerBuf, &workerBufLen); lenOut = std::min(workerBufLen, data->computeChainDataLength()); memcpy(workerBuf, data->buffer(), lenOut); return data; } ConnectionId createConnIdForServer(ProcessId server) { auto connIdAlgo = std::make_unique(); uint8_t processId = (server == ProcessId::ONE) ? 1 : 0; ServerConnectionIdParams params(0, processId, 0); return *connIdAlgo->encodeConnectionId(params); } class QuicServerWorkerTakeoverTest : public Test { public: void SetUp() override { auto sock = std::make_unique>(&evb_); DCHECK(sock->getEventBase()); EXPECT_CALL(*sock, getNetworkSocket()) .WillRepeatedly(Return(folly::NetworkSocket())); EXPECT_CALL(*sock, pauseRead()); takeoverWorkerCb_ = std::make_shared>(); takeoverWorker_ = std::make_unique(takeoverWorkerCb_); takeoverWorker_->setSupportedVersions(supportedVersions); takeoverWorker_->setSocket(std::move(sock)); takeoverSocketFactory_ = std::make_unique(); takeoverWorker_->setNewConnectionSocketFactory( takeoverSocketFactory_.get()); factory_ = std::make_unique(); takeoverWorker_->setTransportFactory(factory_.get()); auto quicStats = std::make_unique>(); takeoverWorker_->setConnectionIdAlgo( std::make_unique()); takeoverWorker_->setTransportStatsCallback(std::move(quicStats)); quicStats_ = (MockQuicStats*)takeoverWorker_->getTransportStatsCallback(); auto takeoverSock = std::make_unique>(&evb_); takeoverSocket_ = takeoverSock.get(); folly::SocketAddress takeoverAddr; EXPECT_CALL(*takeoverSocket_, bind(_, _)); EXPECT_CALL(*takeoverSocket_, resumeRead(_)); takeoverWorker_->allowBeingTakenOver(std::move(takeoverSock), takeoverAddr); } void testPacketForwarding(Buf data, size_t len, ConnectionId connId); void testNoPacketForwarding(Buf data, size_t len, ConnectionId connId); protected: std::shared_ptr takeoverWorkerCb_; folly::test::MockAsyncUDPSocket* takeoverSocket_; folly::EventBase evb_; std::unique_ptr takeoverWorker_; folly::IOBufEqualTo eq; folly::SocketAddress clientAddr{kClientAddr}; std::unique_ptr takeoverSocketFactory_; std::unique_ptr factory_; MockQuicStats* quicStats_{nullptr}; std::vector supportedVersions{QuicVersion::MVFST, MVFST1}; uint16_t clientHostId_{25}; }; TEST_F(QuicServerWorkerTakeoverTest, QuicServerTakeoverReInitHandler) { auto takeoverSock = std::make_unique>(&evb_); folly::SocketAddress takeoverAddr; EXPECT_CALL(*takeoverSocket_, pauseRead()); EXPECT_CALL(*takeoverSock, bind(_, _)); EXPECT_CALL(*takeoverSock, resumeRead(_)); EXPECT_CALL(*takeoverSock, address()).WillOnce(Invoke([&]() { return takeoverAddr; })); takeoverWorker_->overrideTakeoverHandlerAddress( std::move(takeoverSock), takeoverAddr); takeoverSocket_ = takeoverSock.get(); } TEST_F(QuicServerWorkerTakeoverTest, QuicServerTakeoverNoForwarding) { ConnectionId connId = createConnIdForServer(ProcessId::ONE), clientConnId = getTestConnectionId(clientHostId_); takeoverWorker_->setProcessId(ProcessId::ONE); size_t len{0}; auto data = writeTestDataOnWorkersBuf( clientConnId, connId, len, takeoverWorker_.get()); // enable packet forwarding takeoverWorker_->startPacketForwarding(kClientAddr); // this packet belongs to this server, so it should write unaltered packet // to the actual client. Also test different variations in header type Also // verify that the packet is not forwarded for all packet types testNoPacketForwarding( writeTestDataOnWorkersBuf( clientConnId, connId, len, takeoverWorker_.get(), LongHeader::Types::Initial), len, connId); testNoPacketForwarding( writeTestDataOnWorkersBuf( clientConnId, connId, len, takeoverWorker_.get(), LongHeader::Types::Retry), len, connId); testNoPacketForwarding( writeTestDataOnWorkersBuf( clientConnId, connId, len, takeoverWorker_.get(), LongHeader::Types::Handshake), len, connId); testNoPacketForwarding( writeTestDataOnWorkersBuf( clientConnId, connId, len, takeoverWorker_.get(), LongHeader::Types::ZeroRtt), len, connId); } void QuicServerWorkerTakeoverTest::testNoPacketForwarding( Buf /* data */, size_t len, ConnectionId /* connId */) { auto cb = [&](const folly::SocketAddress& addr, std::unique_ptr& /* routingData */, std::unique_ptr& /* networkData */, folly::Optional /* quicVersion */, bool isForwardedData) { EXPECT_EQ(addr.getIPAddress(), clientAddr.getIPAddress()); EXPECT_EQ(addr.getPort(), clientAddr.getPort()); EXPECT_FALSE(isForwardedData); }; EXPECT_CALL(*takeoverWorkerCb_, routeDataToWorkerLong(_, _, _, _, _)) .WillOnce(Invoke(cb)); EXPECT_CALL(*quicStats_, onPacketReceived()); EXPECT_CALL(*quicStats_, onRead(len)); EXPECT_CALL(*quicStats_, onPacketForwarded()).Times(0); takeoverWorker_->onDataAvailable( clientAddr, len, false, OnDataAvailableParams()); } TEST_F(QuicServerWorkerTakeoverTest, QuicServerTakeoverForwarding) { // now try for packets that belongs to different server ConnectionId connId = createConnIdForServer(ProcessId::ZERO), clientConnId = getTestConnectionId(clientHostId_); takeoverWorker_->setProcessId(ProcessId::ONE); size_t len{0}; auto pkt = writeTestDataOnWorkersBuf( clientConnId, connId, len, takeoverWorker_.get(), LongHeader::Types::Handshake); testPacketForwarding(std::move(pkt), len, connId); pkt = writeTestDataOnWorkersBuf( clientConnId, connId, len, takeoverWorker_.get(), LongHeader::Types::ZeroRtt); testNoPacketForwarding(std::move(pkt), len, connId); // verify that the Initial packet type is not forwarded even if the // server-bit is different pkt = writeTestDataOnWorkersBuf( clientConnId, connId, len, takeoverWorker_.get(), LongHeader::Types::Initial); testNoPacketForwarding(std::move(pkt), len, connId); } void QuicServerWorkerTakeoverTest::testPacketForwarding( Buf data, size_t len, ConnectionId connId) { auto writeSock = std::make_unique>(&evb_); EXPECT_CALL(*takeoverSocketFactory_, _make(_, _)) .WillOnce(Return(writeSock.get())); EXPECT_CALL(*writeSock, bind(_, _)); EXPECT_CALL(*writeSock, write(_, _)) .WillOnce(Invoke([&](const SocketAddress& /* unused */, const std::unique_ptr& writtenData) { // the writtenData contains actual client address + time of ack + data EXPECT_FALSE(eq(*data, *writtenData)); // extract and verify the encoded client address folly::io::Cursor cursor(writtenData.get()); uint32_t protocotVersion = (cursor.readBE()); EXPECT_EQ(protocotVersion, 0x0000001); uint16_t addrLen = (cursor.readBE()); struct sockaddr* sockaddr = nullptr; std::pair addrData = cursor.peek(); EXPECT_GE(addrData.second, addrLen); sockaddr = (struct sockaddr*)addrData.first; cursor.skip(addrLen); folly::SocketAddress actualClient; actualClient.setFromSockaddr(sockaddr, addrLen); EXPECT_EQ(actualClient.getIPAddress(), clientAddr.getIPAddress()); EXPECT_EQ(actualClient.getPort(), clientAddr.getPort()); auto pktReceiveEpoch = cursor.readBE(); // the encoded time should be strictly less than 'now' EXPECT_LT(pktReceiveEpoch, Clock::now().time_since_epoch().count()); // skip to the start of the packet writtenData->trimStart( sizeof(uint32_t) + sizeof(uint16_t) + addrLen + sizeof(uint64_t)); EXPECT_TRUE(eq(*data, *writtenData)); // parse header and check connId to verify the integrity of the packet auto parsedHeader = parseHeader(*writtenData); auto& header = parsedHeader->parsedHeader; LongHeader* longHeader = header->asLong(); if (longHeader) { EXPECT_EQ(connId, longHeader->getDestinationConnId()); } else { EXPECT_EQ(connId, header->asShort()->getConnectionId()); } return data->computeChainDataLength(); })); takeoverWorker_->startPacketForwarding(kClientAddr); auto cb = [&](const folly::SocketAddress& client, std::unique_ptr& routingData, std::unique_ptr& networkData, folly::Optional quicVersion, bool isForwardedData) { takeoverWorker_->dispatchPacketData( client, std::move(*routingData.get()), std::move(*networkData.get()), std::move(quicVersion), isForwardedData); }; EXPECT_CALL(*takeoverWorkerCb_, routeDataToWorkerLong(_, _, _, _, _)) .WillOnce(Invoke(cb)); EXPECT_CALL(*quicStats_, onPacketReceived()); EXPECT_CALL(*quicStats_, onRead(len)); EXPECT_CALL(*quicStats_, onPacketForwarded()).Times(1); takeoverWorker_->onDataAvailable( clientAddr, len, false, OnDataAvailableParams()); takeoverWorker_->stopPacketForwarding(); // release this resource since MockQuicUDPSocketFactory::_make() hands its // ownership to it's caller (i.e. QuicServerWorker) writeSock.release(); } TEST_F(QuicServerWorkerTakeoverTest, QuicServerTakeoverProcessForwardedPkt) { // packet belongs to different server ConnectionId connId = createConnIdForServer(ProcessId::ZERO), clientConnId = getTestConnectionId(clientHostId_); takeoverWorker_->setProcessId(ProcessId::ONE); size_t len{0}; auto data = writeTestDataOnWorkersBuf( clientConnId, connId, len, takeoverWorker_.get(), LongHeader::Types::Handshake); takeoverWorker_->startPacketForwarding(kClientAddr); // the packet will be forwarded auto writeSock = std::make_unique>(&evb_); EXPECT_CALL(*takeoverSocketFactory_, _make(_, _)) .WillOnce(Return(writeSock.get())); EXPECT_CALL(*writeSock, bind(_, _)); EXPECT_CALL(*writeSock, write(_, _)) .WillOnce(Invoke([&](const SocketAddress& client, const std::unique_ptr& writtenData) { // the writtenData contains actual client address + time of ack + data EXPECT_FALSE(eq(*data, *writtenData)); // flip the server id to 'own' the packet (else it'll keep forwarding) takeoverWorker_->setProcessId(ProcessId::ZERO); // now invoke the Takeover Handler callback folly::AsyncUDPSocket::ReadCallback* takeoverCb = takeoverWorker_->getTakeoverHandlerCallback(); uint8_t* workerBuf = nullptr; size_t workerBufLen = 0; takeoverCb->getReadBuffer((void**)&workerBuf, &workerBufLen); writtenData->coalesce(); size_t bufLen = std::min(workerBufLen, writtenData->computeChainDataLength()); memcpy(workerBuf, writtenData->buffer(), bufLen); // test processing of the forwarded packet auto cb = [&](const folly::SocketAddress& addr, std::unique_ptr& /* routingData */, std::unique_ptr& networkData, folly::Optional /* quicVersion */, bool isForwardedData) { // verify that it is the original client address EXPECT_EQ(addr.getIPAddress(), clientAddr.getIPAddress()); EXPECT_EQ(addr.getPort(), clientAddr.getPort()); // the original data should be extracted after processing takeover // protocol related information EXPECT_EQ(networkData->packets.size(), 1); EXPECT_TRUE(eq(*data, *(networkData->packets[0]))); EXPECT_TRUE(isForwardedData); }; EXPECT_CALL(*takeoverWorkerCb_, routeDataToWorkerLong(_, _, _, _, _)) .WillOnce(Invoke(cb)); takeoverCb->onDataAvailable( client, bufLen, false, OnDataAvailableParams()); return bufLen; })); auto workerCb = [&](const folly::SocketAddress& client, std::unique_ptr& routingData, std::unique_ptr& networkData, folly::Optional quicVersion, bool isForwardedData) { takeoverWorker_->dispatchPacketData( client, std::move(*routingData.get()), std::move(*networkData.get()), std::move(quicVersion), isForwardedData); }; EXPECT_CALL(*takeoverWorkerCb_, routeDataToWorkerLong(_, _, _, _, _)) .WillOnce(Invoke(workerCb)); EXPECT_CALL(*quicStats_, onPacketReceived()); EXPECT_CALL(*quicStats_, onRead(len)); EXPECT_CALL(*quicStats_, onPacketForwarded()).Times(1); EXPECT_CALL(*quicStats_, onForwardedPacketReceived()).Times(1); EXPECT_CALL(*quicStats_, onForwardedPacketProcessed()).Times(1); takeoverWorker_->onDataAvailable( clientAddr, len, false, OnDataAvailableParams()); // release this resource since MockQuicUDPSocketFactory::_make() hands its // ownership to it's caller (i.e. QuicServerWorker) writeSock.release(); } TEST_F(QuicServerWorkerTakeoverTest, QuicServerTakeoverCbReadClose) { folly::AsyncUDPSocket::ReadCallback* takeoverCb = takeoverWorker_->getTakeoverHandlerCallback(); takeoverCb->onReadClosed(); } TEST_F(QuicServerWorkerTakeoverTest, QuicServerTakeoverCbReadError) { folly::AsyncUDPSocket::ReadCallback* takeoverCb = takeoverWorker_->getTakeoverHandlerCallback(); EXPECT_CALL(*takeoverSocket_, pauseRead()); folly::AsyncSocketException ex( folly::AsyncSocketException::AsyncSocketExceptionType::UNKNOWN, ""); takeoverCb->onReadError(ex); evb_.loopIgnoreKeepAlive(); } class QuicServerTest : public Test { public: void SetUp() override { auto factory = std::make_unique(); factory_ = factory.get(); server_ = QuicServer::createQuicServer(); server_->setQuicServerTransportFactory(std::move(factory)); server_->setFizzContext(quic::test::createServerCtx()); server_->setHostId(serverHostId_); server_->setConnectionIdVersion(quic::ConnectionIdVersion::V2); transportSettings_.advertisedInitialConnectionWindowSize = kDefaultConnectionWindowSize * 2; transportSettings_.advertisedInitialBidiLocalStreamWindowSize = kDefaultStreamWindowSize * 2; transportSettings_.advertisedInitialBidiRemoteStreamWindowSize = kDefaultStreamWindowSize * 2; transportSettings_.advertisedInitialUniStreamWindowSize = kDefaultStreamWindowSize * 2; transportSettings_.statelessResetTokenSecret = getRandSecret(); server_->setTransportSettings(transportSettings_); server_->setConnectionIdAlgoFactory( std::make_unique()); } void TearDown() override { server_->shutdown(); } void setUpTransportFactoryForWorkers(std::vector evbs) { for (auto ev : evbs) { EXPECT_TRUE(server_->isInitialized()); server_->addTransportFactory(ev, factory_); } } folly::SocketAddress initializeServer( std::vector evbs, MockQuicStats* stats = nullptr) { folly::SocketAddress addr("::1", 0); // test that the transportStatsFactory works as expected auto transportStatsFactory = std::make_unique(); transportStatsFactory_ = transportStatsFactory.get(); server_->setTransportStatsCallbackFactory(std::move(transportStatsFactory)); if (stats) { CHECK_EQ(evbs.size(), 1); EXPECT_CALL(*transportStatsFactory_, make()) .WillRepeatedly(Invoke( [stats]() { return std::unique_ptr(stats); })); } else { EXPECT_CALL(*transportStatsFactory_, make()).WillRepeatedly(Invoke([&]() { auto mockInfoCb = std::make_unique>(); return mockInfoCb; })); } if (evbs.empty()) { server_->start(addr, 2); } else { server_->initialize(addr, evbs); server_->start(); setUpTransportFactoryForWorkers(evbs); } server_->waitUntilInitialized(); return server_->getAddress(); } std::shared_ptr createNewTransport( folly::EventBase* eventBase, folly::AsyncUDPSocket& client, folly::SocketAddress serverAddr) { // create payload StreamId id = 1; auto clientConnId = getTestConnectionId(clientHostId_); auto serverConnId = getTestConnectionId(serverHostId_, quic::ConnectionIdVersion::V2); auto buf = createData(kMinInitialPacketSize); auto packet = createInitialStream( clientConnId, serverConnId, id, *buf, QuicVersion::MVFST); auto data = std::move(packet); std::mutex m; std::condition_variable cv; bool calledOnNetworkData = false; // create mock transport std::shared_ptr transport; eventBase->runInEventBaseThreadAndWait([&] { NiceMock connSetupcb; NiceMock connCb; std::unique_ptr mockSock = std::make_unique>( eventBase); EXPECT_CALL(*mockSock, address()).WillRepeatedly(ReturnRef(serverAddr)); transport = std::make_shared( eventBase, std::move(mockSock), &connSetupcb, &connCb, quic::test::createServerCtx()); }); auto makeTransport = [&](folly::EventBase* evb, std::unique_ptr& /* socket */, const folly::SocketAddress&, std::shared_ptr) noexcept { // set proper expectations for the transport after its creation EXPECT_CALL(*transport, getEventBase()).WillRepeatedly(Return(evb)); EXPECT_CALL(*transport, setTransportStatsCallback(_)) .WillOnce(Invoke([&](QuicTransportStatsCallback* statsCallback) { CHECK(statsCallback); })); EXPECT_CALL(*transport, setTransportSettings(_)) .WillRepeatedly(Invoke([&](auto transportSettings) { EXPECT_EQ( transportSettings_ .advertisedInitialBidiLocalStreamWindowSize, transportSettings .advertisedInitialBidiLocalStreamWindowSize); EXPECT_EQ( transportSettings_ .advertisedInitialBidiRemoteStreamWindowSize, transportSettings .advertisedInitialBidiRemoteStreamWindowSize); EXPECT_EQ( transportSettings_.advertisedInitialUniStreamWindowSize, transportSettings.advertisedInitialUniStreamWindowSize); EXPECT_EQ( transportSettings_.advertisedInitialConnectionWindowSize, transportSettings.advertisedInitialConnectionWindowSize); })); ON_CALL(*transport, onNetworkData(_, _)) .WillByDefault(Invoke( [&, expected = std::shared_ptr(data->clone())]( auto, const auto& networkData) mutable { EXPECT_GT(networkData.packets.size(), 0); EXPECT_TRUE(folly::IOBufEqualTo()( *networkData.packets[0], *expected)); std::unique_lock lg(m); calledOnNetworkData = true; cv.notify_one(); })); return transport; }; EXPECT_CALL(*factory_, _make(_, _, _, _)).WillOnce(Invoke(makeTransport)); // send packets to the server std::unique_lock lg(m); size_t tries = 0; if (!calledOnNetworkData && tries < 3) { tries++; auto ret = client.write(serverAddr, data->clone()); CHECK_EQ(ret, data->computeChainDataLength()); cv.wait_until(lg, std::chrono::system_clock::now() + 1s, [&] { return calledOnNetworkData; }); } CHECK(calledOnNetworkData); return transport; } std::unique_ptr makeUdpClient() { folly::SocketAddress addr2("::1", 0); std::unique_ptr client; evbThread_.getEventBase()->runInEventBaseThreadAndWait([&] { client = std::make_unique(evbThread_.getEventBase()); client->bind(addr2); }); return client; } void closeUdpClient(std::unique_ptr client) { evbThread_.getEventBase()->runInEventBaseThreadAndWait( [&] { client->close(); }); } void runTest(std::vector evbs) { auto serverAddr = initializeServer(evbs); auto client = makeUdpClient(); folly::EventBase* evb = server_->getWorkerEvbs().back(); auto transport = createNewTransport(evb, *client, serverAddr); EXPECT_CALL(*transport, setTransportStatsCallback(nullptr)); EXPECT_CALL(*transport, setRoutingCallback(nullptr)); EXPECT_CALL(*transport, closeNow(_)); server_->shutdown(); closeUdpClient(std::move(client)); // cleanup transport transport->getEventBase()->runInEventBaseThreadAndWait( [&] { transport.reset(); }); } void testReset(Buf packet); protected: folly::ScopedEventBaseThread evbThread_; std::shared_ptr server_; MockQuicServerTransportFactory* factory_; TransportSettings transportSettings_; MockQuicStatsFactory* transportStatsFactory_; uint32_t clientHostId_{0}, serverHostId_{0xAABBCC}; }; // namespace test TEST_F(QuicServerTest, NetworkTest) { runTest(std::vector()); } TEST_F(QuicServerTest, OtherEvbs) { folly::ScopedEventBaseThread evbThread; auto evb = evbThread.getEventBase(); std::vector evbs{evb}; runTest(evbs); } TEST_F(QuicServerTest, DontRouteDataAfterShutdown) { folly::ScopedEventBaseThread evbThread; std::vector evbs; evbs.emplace_back(evbThread.getEventBase()); MockQuicStats* stats = new MockQuicStats(); auto serverAddr = initializeServer(evbs, stats); auto client = makeUdpClient(); auto transport = createNewTransport(evbThread.getEventBase(), *client, serverAddr); EXPECT_CALL(*transport, setTransportStatsCallback(nullptr)); EXPECT_CALL(*transport, closeNow(_)).WillOnce(InvokeWithoutArgs([&] { PacketNum packetNum = 1; QuicVersion version = QuicVersion::MVFST; ConnectionId connId = getTestConnectionId(); LongHeader header( LongHeader::Types::Initial, getTestConnectionId(1), connId, packetNum, version); // Simulate receiving a packet before the worker shutdown. EXPECT_CALL(*stats, onPacketDropped(PacketDropReason::SERVER_SHUTDOWN)); NetworkData networkData(folly::IOBuf::copyBuffer("wat"), Clock::now()); RoutingData routingData( HeaderForm::Long, true, false, true, header.getDestinationConnId(), header.getSourceConnId()); server_->routeDataToWorker( kClientAddr, std::move(routingData), std::move(networkData), version); })); std::thread t([&] { server_->shutdown(); }); t.join(); closeUdpClient(std::move(client)); // cleanup transport transport->getEventBase()->runInEventBaseThreadAndWait( [&] { transport.reset(); }); } TEST_F(QuicServerTest, RouteDataFromDifferentThread) { folly::ScopedEventBaseThread evbThread; std::vector evbs; evbs.emplace_back(evbThread.getEventBase()); MockQuicStats* stats = new MockQuicStats(); auto serverAddr = initializeServer(evbs, stats); auto client = makeUdpClient(); auto transport = createNewTransport(evbThread.getEventBase(), *client, serverAddr); EXPECT_CALL(*transport, setTransportStatsCallback(nullptr)); EXPECT_CALL(*stats, onPacketDropped(PacketDropReason::SERVER_SHUTDOWN)) .Times(0); auto clientConnId = getTestConnectionId(clientHostId_); auto serverConnId = getTestConnectionId(serverHostId_, quic::ConnectionIdVersion::V2); PacketNum packetNum = 1; QuicVersion version = QuicVersion::MVFST; LongHeader header( LongHeader::Types::Initial, clientConnId, serverConnId, packetNum, version); auto initialData = std::shared_ptr( folly::IOBuf::create(kMinInitialPacketSize)); initialData->append(kMinInitialPacketSize); memset(initialData->writableData(), 'd', kMinInitialPacketSize); NetworkData networkData(initialData->clone(), Clock::now()); RoutingData routingData( HeaderForm::Long, true, false, true, header.getDestinationConnId(), header.getSourceConnId()); EXPECT_CALL(*transport, onNetworkData(_, _)) .WillOnce(Invoke([&](auto, const auto& networkData) { EXPECT_GT(networkData.packets.size(), 0); EXPECT_TRUE( folly::IOBufEqualTo()(*networkData.packets[0], *initialData)); })); server_->routeDataToWorker( client->address(), std::move(routingData), std::move(networkData), version); // cleanup transport transport->getEventBase()->runInEventBaseThreadAndWait( [&] { transport.reset(); }); closeUdpClient(std::move(client)); std::thread t([&] { server_->shutdown(); }); t.join(); } TEST_F(QuicServerTest, OverrideTakeoverAddressTest) { folly::ScopedEventBaseThread evbThread; std::vector evbs; evbs.emplace_back(evbThread.getEventBase()); auto serverAddr = initializeServer(evbs); folly::SocketAddress takeoverAddr("::1", 0); server_->allowBeingTakenOver(takeoverAddr); uint8_t bindAttempt = 0; folly::SocketAddress boundAddr; while (bindAttempt < 5) { boundAddr = server_->overrideTakeoverHandlerAddress(takeoverAddr); bindAttempt++; } EXPECT_TRUE(boundAddr.isInitialized()); std::thread t([&] { server_->shutdown(); }); t.join(); } class QuicServerTakeoverTest : public Test { public: void SetUp() override { transportSettings_.advertisedInitialConnectionWindowSize = kDefaultConnectionWindowSize * 2; transportSettings_.advertisedInitialBidiLocalStreamWindowSize = kDefaultStreamWindowSize * 2; transportSettings_.advertisedInitialBidiRemoteStreamWindowSize = kDefaultStreamWindowSize * 2; transportSettings_.advertisedInitialUniStreamWindowSize = kDefaultStreamWindowSize * 2; setUpServer(oldServer_, ProcessId::ZERO); setUpServer(newServer_, ProcessId::ONE); } void setUpServer(std::shared_ptr& server, ProcessId id) { auto factory = std::make_unique(); if (id == ProcessId::ZERO) { oldFactory_ = factory.get(); } else { newFactory_ = factory.get(); } server = QuicServer::createQuicServer(); server->setQuicServerTransportFactory(std::move(factory)); server->setFizzContext(quic::test::createServerCtx()); server->setTransportSettings(transportSettings_); server->setProcessId(id); } std::shared_ptr initTransport( MockQuicServerTransportFactory* factory, ConnectionId& clientConnId, Buf& data, folly::Baton<>& baton) { std::shared_ptr transport; NiceMock connSetupCb; NiceMock connCb; auto makeTransport = [&](folly::EventBase* eventBase, std::unique_ptr& socket, const folly::SocketAddress&, std::shared_ptr ctx) noexcept { transport = std::make_shared( eventBase, std::move(socket), &connSetupCb, &connCb, ctx); transport->setClientConnectionId(clientConnId); // setup expectations EXPECT_CALL(*transport, getEventBase()) .WillRepeatedly(Return(eventBase)); EXPECT_CALL(*transport, setTransportSettings(_)); EXPECT_CALL(*transport, accept()); EXPECT_CALL(*transport, setSupportedVersions(_)); EXPECT_CALL(*transport, setRoutingCallback(_)); EXPECT_CALL(*transport, setOriginalPeerAddress(_)); EXPECT_CALL(*transport, setTransportStatsCallback(_)); EXPECT_CALL(*transport, setServerConnectionIdParams(_)) .WillOnce(Invoke([&](ServerConnectionIdParams params) { EXPECT_EQ(params.processId, 0); EXPECT_EQ(params.workerId, 0); })); EXPECT_CALL(*transport, onNetworkData(_, _)) .WillOnce(Invoke( [&, expected = data.get()](auto, const auto& networkData) { EXPECT_GT(networkData.packets.size(), 0); EXPECT_TRUE(folly::IOBufEqualTo()( *networkData.packets[0], *expected)); baton.post(); })); return transport; }; EXPECT_CALL(*factory, _make(_, _, _, _)).WillOnce(Invoke(makeTransport)); return transport; } void runTest( std::vector evbs1, std::vector evbs2) { folly::Baton<> b; ConnectionId clientConnId = getTestConnectionId(clientHostId_); // create a packet to send to the old server and verify that it accepts it StreamId id = 1; auto buf = createData(kMinInitialPacketSize); ConnectionId connId = createConnIdForServer(ProcessId::ZERO); auto packet = createInitialStream(clientConnId, connId, id, *buf, QuicVersion::MVFST); auto data = std::move(packet); auto transportCbForOldServer = initTransport(oldFactory_, clientConnId, data, b); folly::SocketAddress addr("::1", 0); // setup mock transport stats factory auto transportStatsFactory = std::make_unique(); auto makeCbForOldServer = [&]() { oldTransInfoCb_ = new NiceMock(); std::unique_ptr quicStats; quicStats.reset(oldTransInfoCb_); return quicStats; }; EXPECT_CALL(*transportStatsFactory, make()) .WillOnce(Invoke(makeCbForOldServer)); oldServer_->setTransportStatsCallbackFactory( std::move(transportStatsFactory)); oldServer_->initialize(addr, evbs1); oldServer_->start(); oldServer_->waitUntilInitialized(); for (auto ev : evbs1) { oldServer_->addTransportFactory(ev, oldFactory_); } auto serverAddr = oldServer_->getAddress(); folly::SocketAddress takeoverAddr("::1", 0); oldServer_->allowBeingTakenOver(takeoverAddr); folly::SocketAddress clientAddr("::1", 0); std::unique_ptr client; evbThread_.getEventBase()->runInEventBaseThreadAndWait([&] { client = std::make_unique(evbThread_.getEventBase()); client->bind(clientAddr); }); // send packet to the server and wait EXPECT_CALL(*oldTransInfoCb_, onPacketReceived()); EXPECT_CALL(*oldTransInfoCb_, onRead(_)); client->write(serverAddr, data->clone()); b.wait(); // spin another server and verify that the old server gets the packet // that is routed to the new server int takeoverListeningFd = oldServer_->getTakeoverHandlerSocketFD(); newServer_->setListeningFDs(oldServer_->getAllListeningSocketFDs()); folly::SocketAddress newAddr("::1", 0); // setup mock transport stats factory transportStatsFactory = std::make_unique(); auto makeCbForNewServer = [&]() { newTransInfoCb_ = new NiceMock(); std::unique_ptr quicStats; quicStats.reset(newTransInfoCb_); return quicStats; }; EXPECT_CALL(*transportStatsFactory, make()) .WillOnce(Invoke(makeCbForNewServer)); newServer_->setTransportStatsCallbackFactory( std::move(transportStatsFactory)); newServer_->initialize(newAddr, evbs2); newServer_->start(); newServer_->waitUntilInitialized(); for (auto ev : evbs2) { newServer_->addTransportFactory(ev, newFactory_); } folly::SocketAddress destAddr; destAddr.setFromLocalAddress( folly::NetworkSocket::fromFd(takeoverListeningFd)); newServer_->startPacketForwarding(destAddr); auto newServerAddr = newServer_->getAddress(); CHECK(newServerAddr != destAddr); packet = createInitialStream( clientConnId, connId, id, *buf, QuicVersion::MVFST, LongHeader::Types::Retry); data = std::move(packet); folly::Baton<> b1; // onNetworkData(_, _) shouldn't be called on the newServer_ transport, // but should be routed to oldServer_ EXPECT_CALL(*transportCbForOldServer, onNetworkData(_, _)) .WillOnce( Invoke([&, expected = data.get()](auto, const auto& networkData) { EXPECT_GT(networkData.packets.size(), 0); EXPECT_TRUE( folly::IOBufEqualTo()(*networkData.packets[0], *expected)); b1.post(); })); // new quic server receives the packet and forwards it EXPECT_CALL(*newTransInfoCb_, onPacketReceived()); EXPECT_CALL(*newTransInfoCb_, onRead(_)); EXPECT_CALL(*newTransInfoCb_, onPacketForwarded()); EXPECT_CALL(*newTransInfoCb_, onPacketProcessed()).Times(0); // verify takeover related counters on the old quic server EXPECT_CALL(*oldTransInfoCb_, onForwardedPacketReceived()); EXPECT_CALL(*oldTransInfoCb_, onForwardedPacketProcessed()); // the old server should then handle it as usual EXPECT_CALL(*oldTransInfoCb_, onPacketDropped(_)).Times(0); // pause the old server so that we can deterministically route to the new // server oldServer_->pauseRead(); client->write(newServerAddr, data->clone()); b1.wait(); b1.reset(); EXPECT_CALL(*transportCbForOldServer, setRoutingCallback(nullptr)); EXPECT_CALL(*transportCbForOldServer, closeNow(_)); // Disable packet forwarding on the new server and send packet // This packet should be dropped since it's not an initial packet newServer_->stopPacketForwarding(0ms); std::atomic posted{false}; EXPECT_CALL(*newTransInfoCb_, onPacketReceived()) .WillRepeatedly(Invoke([&] { if (posted) { return; } posted = true; b1.post(); })); EXPECT_CALL(*newTransInfoCb_, onRead(_)).Times(AtLeast(1)); EXPECT_CALL(*newTransInfoCb_, onPacketForwarded()).Times(0); EXPECT_CALL(*newTransInfoCb_, onPacketDropped(_)).Times(AtLeast(1)); client->write(newServerAddr, data->clone()); client->write(newServerAddr, data->clone()); b1.wait(); EXPECT_CALL(*transportCbForOldServer, setTransportStatsCallback(nullptr)); oldServer_->shutdown(); // 'transport' never gets created for the newServer_ // so no callback on closeNow() newServer_->shutdown(); evbThread_.getEventBase()->runInEventBaseThreadAndWait( [&] { client->close(); }); // cleanup transport transportCbForOldServer->getEventBase()->runInEventBaseThreadAndWait( [&] { transportCbForOldServer.reset(); }); } protected: folly::ScopedEventBaseThread evbThread_; std::shared_ptr oldServer_; std::shared_ptr newServer_; MockQuicServerTransportFactory* oldFactory_; MockQuicServerTransportFactory* newFactory_; MockQuicStats* oldTransInfoCb_; MockQuicStats* newTransInfoCb_; TransportSettings transportSettings_; MockQuicStatsFactory* transportStatsFactory_; uint16_t clientHostId_{25}; }; TEST_F(QuicServerTakeoverTest, TakeoverTest) { folly::ScopedEventBaseThread evbThread1; auto evb1 = evbThread1.getEventBase(); std::vector evbs1{evb1}; folly::ScopedEventBaseThread evbThread2; auto evb2 = evbThread2.getEventBase(); std::vector evbs2{evb2}; runTest(evbs1, evbs2); } struct UDPReader : public folly::AsyncUDPSocket::ReadCallback { UDPReader() { bufPromise_ = std::make_unique>>(); } ~UDPReader() override = default; void start(EventBase* evb, SocketAddress addr) { evb_ = evb; evb_->runInEventBaseThreadAndWait([&] { client = std::make_unique(evb_); client->bind(addr); client->resumeRead(this); }); } AsyncUDPSocket& getSocket() { return *client; } void getReadBuffer(void** buf, size_t* len) noexcept override { if (!buf_) { buf_ = IOBuf::create(kDefaultUDPReadBufferSize); } *buf = buf_->writableData(); *len = kDefaultUDPReadBufferSize; } void onDataAvailable( const folly::SocketAddress&, size_t len, bool truncated, OnDataAvailableParams /*params*/) noexcept override { std::lock_guard guard(bufLock_); if (truncated) { bufPromise_->setException(std::runtime_error("truncated buf")); return; } if (bufPromise_) { buf_->append(len); bufPromise_->setValue(std::move(buf_)); } } void onReadError(const AsyncSocketException& ex) noexcept override { std::lock_guard guard(bufLock_); if (bufPromise_) { bufPromise_->setException(ex); } } void onReadClosed() noexcept override { if (bufPromise_) { bufPromise_->setException(std::runtime_error("closed")); } } folly::Future> readOne() { std::lock_guard guard(bufLock_); if (!bufPromise_) { bufPromise_ = std::make_unique>>(); } return bufPromise_->getFuture().ensure([&]() { bufPromise_ = nullptr; }); } private: std::unique_ptr buf_; std::mutex bufLock_; std::unique_ptr>> bufPromise_; std::unique_ptr client; EventBase* evb_; }; TEST_F(QuicServerTest, NetworkTestVersionNegotiation) { folly::SocketAddress addr("::1", 0); server_->start(addr, 2); server_->waitUntilInitialized(); auto serverAddr = server_->getAddress(); folly::SocketAddress addr2("::1", 0); std::unique_ptr reader = std::make_unique(); reader->start(evbThread_.getEventBase(), addr2); SCOPE_EXIT { server_->shutdown(); evbThread_.getEventBase()->runInEventBaseThreadAndWait( [&] { reader->getSocket().close(); }); }; StreamId id = 1; auto clientConnId = getTestConnectionId(clientHostId_); auto serverConnId = getTestConnectionId(serverHostId_, quic::ConnectionIdVersion::V2); auto buf = createData(kMinInitialPacketSize + 10); auto packet = createInitialStream(clientConnId, serverConnId, id, *buf, MVFST1); auto data = std::move(packet); reader->getSocket().write(serverAddr, data->clone()); auto serverData = reader->readOne().get(200ms); auto codec = std::make_unique(QuicNodeType::Server); auto packetQueue = bufToQueue(std::move(serverData)); auto versionPacket = codec->tryParsingVersionNegotiation(packetQueue); ASSERT_TRUE(versionPacket.has_value()); EXPECT_EQ(versionPacket->destinationConnectionId, clientConnId); } TEST_F(QuicServerTest, NetworkTestVersionNegotiationMinLength) { folly::SocketAddress addr("::1", 0); server_->start(addr, 2); server_->waitUntilInitialized(); auto serverAddr = server_->getAddress(); folly::SocketAddress addr2("::1", 0); std::unique_ptr reader = std::make_unique(); reader->start(evbThread_.getEventBase(), addr2); SCOPE_EXIT { server_->shutdown(); evbThread_.getEventBase()->runInEventBaseThreadAndWait( [&] { reader->getSocket().close(); }); }; StreamId id = 1; auto clientConnId = getTestConnectionId(clientHostId_); auto serverConnId = getTestConnectionId(serverHostId_, quic::ConnectionIdVersion::V2); auto buf = createData(kMinInitialPacketSize - 100); auto packet = createInitialStream(clientConnId, serverConnId, id, *buf, MVFST1); auto data = std::move(packet); reader->getSocket().write(serverAddr, data->clone()); EXPECT_THROW(reader->readOne().get(200ms), folly::FutureTimeout); } TEST_F(QuicServerTest, NetworkTestNoVersionNegotiation) { folly::SocketAddress addr("::1", 0); server_->start(addr, 2); server_->waitUntilInitialized(); auto serverAddr = server_->getAddress(); folly::SocketAddress addr2("::1", 0); std::unique_ptr reader = std::make_unique(); reader->start(evbThread_.getEventBase(), addr2); SCOPE_EXIT { server_->shutdown(); evbThread_.getEventBase()->runInEventBaseThreadAndWait( [&] { reader->getSocket().close(); }); }; StreamId id = 1; auto clientConnId = getTestConnectionId(clientHostId_); auto serverConnId = getTestConnectionId(serverHostId_, quic::ConnectionIdVersion::V2); auto buf = folly::IOBuf::copyBuffer("hello"); auto packet = createInitialStream( clientConnId, serverConnId, id, *buf, QuicVersion::VERSION_NEGOTIATION); auto data = std::move(packet); reader->getSocket().write(serverAddr, data->clone()); EXPECT_THROW(reader->readOne().get(200ms), folly::FutureTimeout); } TEST_F(QuicServerTest, TestRejectNewConnections) { // test that Version Negotiation fails if the server is rejecting all // new connections folly::SocketAddress addr("::1", 0); server_->start(addr, 2); server_->rejectNewConnections([]() { return true; }); server_->waitUntilInitialized(); auto serverAddr = server_->getAddress(); folly::SocketAddress addr2("::1", 0); std::unique_ptr reader = std::make_unique(); reader->start(evbThread_.getEventBase(), addr2); SCOPE_EXIT { server_->shutdown(); evbThread_.getEventBase()->runInEventBaseThreadAndWait( [&] { reader->getSocket().close(); }); }; StreamId id = 1; auto clientConnId = getTestConnectionId(clientHostId_); auto serverConnId = getTestConnectionId(serverHostId_, quic::ConnectionIdVersion::V2); auto buf = createData(kMinInitialPacketSize); auto packet = createInitialStream(clientConnId, serverConnId, id, *buf, MVFST1); auto data = std::move(packet); reader->getSocket().write(serverAddr, data->clone()); auto serverData = reader->readOne().get(200ms); auto codec = std::make_unique(QuicNodeType::Server); auto packetQueue = bufToQueue(std::move(serverData)); auto versionPacket = codec->tryParsingVersionNegotiation(packetQueue); ASSERT_TRUE(versionPacket.has_value()); EXPECT_EQ(versionPacket->destinationConnectionId, clientConnId); EXPECT_EQ(versionPacket->sourceConnectionId, serverConnId); EXPECT_EQ(versionPacket->versions.size(), 1); EXPECT_EQ(versionPacket->versions.at(0), QuicVersion::MVFST_INVALID); } TEST_F(QuicServerTest, NetworkTestHealthCheck) { folly::SocketAddress addr("::1", 0); std::string healthCheckToken = "health"; std::string notHealthCheckToken = "health2"; server_->setHealthCheckToken(healthCheckToken); server_->start(addr, 2); server_->waitUntilInitialized(); auto serverAddr = server_->getAddress(); folly::SocketAddress addr2("::1", 0); std::unique_ptr reader = std::make_unique(); reader->start(evbThread_.getEventBase(), addr2); SCOPE_EXIT { server_->shutdown(); evbThread_.getEventBase()->runInEventBaseThreadAndWait( [&] { reader->getSocket().close(); }); }; reader->getSocket().write(serverAddr, IOBuf::copyBuffer(healthCheckToken)); auto serverData = reader->readOne().get(); EXPECT_EQ(serverData->moveToFbString().toStdString(), std::string("OK")); reader->getSocket().write(serverAddr, IOBuf::copyBuffer(notHealthCheckToken)); EXPECT_THROW(reader->readOne().get(200ms), folly::FutureTimeout); } void QuicServerTest::testReset(Buf packet) { folly::SocketAddress addr("::1", 0); server_->start(addr, 2); server_->waitUntilInitialized(); auto serverAddr = server_->getAddress(); folly::SocketAddress addr2("::1", 0); std::unique_ptr reader = std::make_unique(); reader->start(evbThread_.getEventBase(), addr2); SCOPE_EXIT { server_->shutdown(); evbThread_.getEventBase()->runInEventBaseThreadAndWait( [&] { reader->getSocket().close(); }); }; reader->getSocket().write(serverAddr, packet->clone()); auto serverData = reader->readOne().get(1000ms); EXPECT_LE(serverData->computeChainDataLength(), kDefaultUDPSendPacketLen); QuicReadCodec codec(QuicNodeType::Client); auto aead = createNoOpAead(); // Make the decrypt fail EXPECT_CALL(*aead, _tryDecrypt(_, _, _)) .WillRepeatedly(Invoke([&](auto&, auto, auto) { return folly::none; })); codec.setOneRttReadCipher(std::move(aead)); codec.setOneRttHeaderCipher(test::createNoOpHeaderCipher()); StatelessResetToken token = generateStatelessResetToken(); codec.setStatelessResetToken(token); AckStates ackStates; auto packetBuf = serverData->clone(); overridePacketWithToken(*packetBuf, token); auto packetQueue = bufToQueue(std::move(packetBuf)); auto res = codec.parsePacket(packetQueue, ackStates); EXPECT_NE(res.statelessReset(), nullptr); } TEST_F(QuicServerTest, NetworkTestReset) { StreamId id = 1; auto clientConnId = getTestConnectionId(clientHostId_); auto serverConnId = getTestConnectionId(serverHostId_, quic::ConnectionIdVersion::V2); PacketNum packetNum = 20; auto buf = folly::IOBuf::copyBuffer("hello"); auto packet = packetToBuf(createStreamPacket( clientConnId, serverConnId, packetNum, id, *buf, 0 /* cipherOverhead */, 0 /* largestAcked */)); auto data = std::move(packet); testReset(data->clone()); } TEST_F(QuicServerTest, NetworkTestResetLargePacket) { StreamId id = 1; auto clientConnId = getTestConnectionId(clientHostId_); auto serverConnId = getTestConnectionId(serverHostId_, quic::ConnectionIdVersion::V2); PacketNum packetNum = 20; auto buf = folly::IOBuf::create(kDefaultUDPSendPacketLen + 3); buf->append(kDefaultUDPSendPacketLen + 3); auto packet = packetToBuf(createStreamPacket( clientConnId, serverConnId, packetNum, id, *buf, 0 /* cipherOverhead */, 0 /* largestAcked */)); ASSERT_LE(packet->computeChainDataLength(), kDefaultUDPSendPacketLen); testReset(std::move(packet)); } TEST_F(QuicServerTest, NetworkTestResetLongHeader) { StreamId id = 1; auto clientConnId = getTestConnectionId(clientHostId_); auto serverConnId = getTestConnectionId(serverHostId_, quic::ConnectionIdVersion::V2); PacketNum packetNum = 20; auto buf = folly::IOBuf::copyBuffer("hello"); auto packet = packetToBuf(createStreamPacket( clientConnId, serverConnId, packetNum, id, *buf, 0 /* cipherOverhead */, 0 /* largestAcked */, std::make_pair(LongHeader::Types::ZeroRtt, QuicVersion::MVFST))); EXPECT_THROW(testReset(std::move(packet)), folly::FutureTimeout); } TEST_F(QuicServerTest, ZeroRttPacketRoute) { folly::ScopedEventBaseThread evbThread; auto evb = evbThread.getEventBase(); std::vector evbs{evb}; folly::SocketAddress addr("::1", 0); server_->start(addr, 1); server_->waitUntilInitialized(); setUpTransportFactoryForWorkers(evbs); std::shared_ptr transport; NiceMock connSetupCb; NiceMock connCb; folly::Baton<> b; // create payload StreamId id = 1; auto clientConnId = getTestConnectionId(clientHostId_); auto serverConnId = getTestConnectionId(serverHostId_, quic::ConnectionIdVersion::V2); auto buf = createData(kMinInitialPacketSize + 10); auto packet = createInitialStream( clientConnId, serverConnId, id, *buf, QuicVersion::MVFST); auto data = std::move(packet); auto makeTransport = [&](folly::EventBase* eventBase, std::unique_ptr& socket, const folly::SocketAddress&, std::shared_ptr ctx) noexcept { transport = std::make_shared( eventBase, std::move(socket), &connSetupCb, &connCb, ctx); EXPECT_CALL(*transport, getEventBase()) .WillRepeatedly(Return(eventBase)); EXPECT_CALL(*transport, setSupportedVersions(_)); EXPECT_CALL(*transport, setOriginalPeerAddress(_)); EXPECT_CALL(*transport, setTransportSettings(_)); EXPECT_CALL(*transport, setServerConnectionIdParams(_)); EXPECT_CALL(*transport, accept()); // post baton upon receiving the data EXPECT_CALL(*transport, onNetworkData(_, _)) .WillOnce(Invoke( [&, expected = data.get()](auto, const auto& networkData) { EXPECT_GT(networkData.packets.size(), 0); EXPECT_TRUE(folly::IOBufEqualTo()( *networkData.packets[0], *expected)); b.post(); })); return transport; }; EXPECT_CALL(*factory_, _make(_, _, _, _)).WillOnce(Invoke(makeTransport)); auto serverAddr = server_->getAddress(); folly::SocketAddress addr2("::1", 0); std::unique_ptr reader = std::make_unique(); reader->start(evbThread_.getEventBase(), addr2); SCOPE_EXIT { server_->shutdown(); evbThread_.getEventBase()->runInEventBaseThreadAndWait( [&] { reader->getSocket().close(); }); transport->getEventBase()->runInEventBaseThreadAndWait( [&] { transport.reset(); }); }; // send an initial packet - that should create a new 'connection' reader->getSocket().write(serverAddr, data->clone()); b.wait(); // now send 0-rtt packet, and verify that it gets routed properly PacketNum packetNum = 20; packet = packetToBuf(createStreamPacket( clientConnId, serverConnId, packetNum, id, *buf, 0 /* cipherOverhead */, 0 /* largestAcked */, std::make_pair(LongHeader::Types::ZeroRtt, QuicVersion::MVFST))); data = std::move(packet); folly::Baton<> b1; auto verifyZeroRtt = [&](const folly::SocketAddress& peer, const NetworkData& networkData) noexcept { EXPECT_GT(networkData.packets.size(), 0); EXPECT_EQ(peer, reader->getSocket().address()); EXPECT_TRUE(folly::IOBufEqualTo()(*data, *networkData.packets[0])); b1.post(); }; EXPECT_CALL(*transport, onNetworkData(_, _)).WillOnce(Invoke(verifyZeroRtt)); reader->getSocket().write(serverAddr, data->clone()); b1.wait(); } TEST_F(QuicServerTest, ZeroRttBeforeInitial) { folly::ScopedEventBaseThread evbThread; auto evb = evbThread.getEventBase(); std::vector evbs{evb}; folly::SocketAddress addr("::1", 0); server_->start(addr, 1); server_->waitUntilInitialized(); setUpTransportFactoryForWorkers(evbs); std::shared_ptr transport; NiceMock connSetupCb; NiceMock connCb; folly::Baton<> b; // create payload StreamId id = 1; auto clientConnId = getTestConnectionId(clientHostId_); auto serverConnId = getTestConnectionId(serverHostId_, quic::ConnectionIdVersion::V2); auto initialBuf = createData(kMinInitialPacketSize + 10); auto initialPacket = createInitialStream( clientConnId, serverConnId, id, *initialBuf, QuicVersion::MVFST); auto initialData = std::move(initialPacket); std::vector receivedData; auto makeTransport = [&](folly::EventBase* eventBase, std::unique_ptr& socket, const folly::SocketAddress&, std::shared_ptr ctx) noexcept { transport = std::make_shared( eventBase, std::move(socket), &connSetupCb, &connCb, ctx); EXPECT_CALL(*transport, getEventBase()) .WillRepeatedly(Return(eventBase)); EXPECT_CALL(*transport, setSupportedVersions(_)); EXPECT_CALL(*transport, setOriginalPeerAddress(_)); EXPECT_CALL(*transport, setTransportSettings(_)); EXPECT_CALL(*transport, setServerConnectionIdParams(_)); EXPECT_CALL(*transport, accept()); // post baton upon receiving the data EXPECT_CALL(*transport, onNetworkData(_, _)) .Times(2) .WillRepeatedly(Invoke([&](auto, auto& networkData) { for (auto& buf : networkData.packets) { receivedData.emplace_back(buf->clone()); } if (receivedData.size() == 2) { b.post(); } })); return transport; }; EXPECT_CALL(*factory_, _make(_, _, _, _)).WillOnce(Invoke(makeTransport)); auto serverAddr = server_->getAddress(); folly::SocketAddress addr2("::1", 0); std::unique_ptr reader = std::make_unique(); reader->start(evbThread_.getEventBase(), addr2); SCOPE_EXIT { server_->shutdown(); evbThread_.getEventBase()->runInEventBaseThreadAndWait( [&] { reader->getSocket().close(); }); transport->getEventBase()->runInEventBaseThreadAndWait( [&] { transport.reset(); }); }; // send 0-rtt packet, it should be buffered PacketNum packetNum = 20; auto zeroRttBuf = folly::IOBuf::copyBuffer("blah blah blah blah"); auto zeroRttPacket = packetToBuf(createStreamPacket( clientConnId, serverConnId, packetNum, id, *zeroRttBuf, 0 /* cipherOverhead */, 0 /* largestAcked */, std::make_pair(LongHeader::Types::ZeroRtt, QuicVersion::MVFST))); auto zeroRttData = std::move(zeroRttPacket); reader->getSocket().write(serverAddr, zeroRttData->clone()); // send an initial packet - that should create a new 'connection' reader->getSocket().write(serverAddr, initialData->clone()); b.wait(); ASSERT_EQ(receivedData.size(), 2); ASSERT_EQ( initialData->computeChainDataLength(), receivedData[0]->computeChainDataLength()); ASSERT_EQ( zeroRttData->computeChainDataLength(), receivedData[1]->computeChainDataLength()); EXPECT_TRUE(folly::IOBufEqualTo()(*initialData, *receivedData[0])); EXPECT_TRUE(folly::IOBufEqualTo()(*zeroRttData, *receivedData[1])); } TEST_F(QuicServerTest, OneEVB) { folly::EventBase evb; folly::SocketAddress addr("::1", 0); server_->initialize(addr, {&evb}, true); server_->start(); server_->waitUntilInitialized(); evb.runAfterDelay([this] { server_->shutdown(); }, 100); evb.loop(); } } // namespace test } // namespace quic