From 30bff94e8510daa7e9e4feacadba4b4228f1beba Mon Sep 17 00:00:00 2001 From: Matt Joras Date: Mon, 18 May 2020 16:38:41 -0700 Subject: [PATCH] Introduce new connection rate limits. Summary: This introduces a rate limit to new connections created by a worker. Right now it will simply send a VN, but eventually this will only issue a RETRY for unverified initials. Reviewed By: udippant Differential Revision: D21614905 fbshipit-source-id: 1832fbdad525c53fb1cb810aa9d7bae868c267d6 --- quic/server/QuicServer.cpp | 9 +++ quic/server/QuicServer.h | 9 +++ quic/server/QuicServerWorker.cpp | 18 ++++++ quic/server/QuicServerWorker.h | 9 +++ quic/server/test/QuicServerTest.cpp | 81 +++++++++++++++++++++++++ quic/state/QuicTransportStatsCallback.h | 2 + quic/state/test/MockQuicStats.h | 1 + 7 files changed, 129 insertions(+) diff --git a/quic/server/QuicServer.cpp b/quic/server/QuicServer.cpp index d437fcb64..7ec0fae28 100644 --- a/quic/server/QuicServer.cpp +++ b/quic/server/QuicServer.cpp @@ -15,6 +15,7 @@ #include #include #include +#include namespace quic { @@ -60,6 +61,10 @@ void QuicServer::setCongestionControllerFactory( ccFactory_ = std::move(ccFactory); } +void QuicServer::setRateLimit(uint64_t count, std::chrono::seconds window) { + rateLimit_ = folly::make_optional(count, window); +} + void QuicServer::setSupportedVersion(const std::vector& versions) { supportedVersions_ = versions; } @@ -164,6 +169,10 @@ void QuicServer::initializeWorkers( } worker->setConnectionIdAlgo(connIdAlgoFactory_->make()); worker->setCongestionControllerFactory(ccFactory_); + if (rateLimit_) { + worker->setRateLimiter(std::make_unique( + rateLimit_->count, rateLimit_->window)); + } worker->setWorkerId(i); worker->setTransportSettingsOverrideFn(transportSettingsOverrideFn_); workers_.push_back(std::move(worker)); diff --git a/quic/server/QuicServer.h b/quic/server/QuicServer.h index 25a9ce17f..409437f48 100644 --- a/quic/server/QuicServer.h +++ b/quic/server/QuicServer.h @@ -99,6 +99,8 @@ class QuicServer : public QuicServerWorker::WorkerCallback, void setCongestionControllerFactory( std::shared_ptr ccFactory); + void setRateLimit(uint64_t count, std::chrono::seconds window); + /** * Set list of supported QUICVersion for this server. These versions will be * used during the 'Version-Negotiation' phase with the client. @@ -380,6 +382,13 @@ class QuicServer : public QuicServerWorker::WorkerCallback, // address that the server is bound to folly::SocketAddress boundAddress_; folly::SocketOptionMap socketOptions_; + // Rate limits + struct RateLimit { + RateLimit(uint64_t c, std::chrono::seconds w) : count(c), window(w) {} + uint64_t count; + std::chrono::seconds window; + }; + folly::Optional rateLimit_; }; } // namespace quic diff --git a/quic/server/QuicServerWorker.cpp b/quic/server/QuicServerWorker.cpp index 447bcaf66..97115bb5b 100644 --- a/quic/server/QuicServerWorker.cpp +++ b/quic/server/QuicServerWorker.cpp @@ -107,6 +107,11 @@ void QuicServerWorker::setCongestionControllerFactory( ccFactory_ = ccFactory; } +void QuicServerWorker::setRateLimiter( + std::unique_ptr rateLimiter) { + newConnRateLimiter_ = std::move(rateLimiter); +} + void QuicServerWorker::start() { CHECK(socket_); if (!pacingTimer_) { @@ -490,6 +495,19 @@ void QuicServerWorker::dispatchPacketData( PacketDropReason::INVALID_PACKET); return; } + if (newConnRateLimiter_ && + newConnRateLimiter_->check(networkData.receiveTimePoint)) { + // TODO RETRY + VersionNegotiationPacketBuilder builder( + routingData.destinationConnId, + routingData.sourceConnId.value_or( + ConnectionId(std::vector())), + std::vector{QuicVersion::MVFST_INVALID}); + auto versionNegotiationPacket = std::move(builder).buildPacket(); + socket_->write(client, versionNegotiationPacket.second); + QUIC_STATS(statsCallback_, onConnectionRateLimited); + return; + } // create 'accepting' transport auto sock = makeSocket(getEventBase()); auto trans = transportFactory_->make( diff --git a/quic/server/QuicServerWorker.h b/quic/server/QuicServerWorker.h index 5c954c48b..24f3ee0eb 100644 --- a/quic/server/QuicServerWorker.h +++ b/quic/server/QuicServerWorker.h @@ -20,6 +20,7 @@ #include #include #include +#include #include #include @@ -226,6 +227,11 @@ class QuicServerWorker : public folly::AsyncUDPSocket::ReadCallback, void setCongestionControllerFactory( std::shared_ptr factory); + /** + * Set the rate limiter which will be used to rate limit new connections. + */ + void setRateLimiter(std::unique_ptr rateLimiter); + // Read callback void getReadBuffer(void** buf, size_t* len) noexcept override; @@ -414,6 +420,9 @@ class QuicServerWorker : public folly::AsyncUDPSocket::ReadCallback, // Output buffer to be used for continuous memory GSO write std::unique_ptr bufAccessor_; + + // Rate limits the creation of new connections for this worker. + std::unique_ptr newConnRateLimiter_; }; } // namespace quic diff --git a/quic/server/test/QuicServerTest.cpp b/quic/server/test/QuicServerTest.cpp index 3755c65c9..6d869777a 100644 --- a/quic/server/test/QuicServerTest.cpp +++ b/quic/server/test/QuicServerTest.cpp @@ -18,6 +18,7 @@ #include #include #include +#include #include #include #include @@ -378,6 +379,86 @@ TEST_F(QuicServerWorkerTest, NoConnFoundTestReset) { QuicTransportStatsCallback::PacketDropReason::CONNECTION_NOT_FOUND); } +TEST_F(QuicServerWorkerTest, RateLimit) { + worker_->setRateLimiter(std::make_unique(2, 60s)); + EXPECT_CALL(*transportInfoCb_, onConnectionRateLimited()).Times(1); + + NiceMock connCb1; + auto mockSock1 = + std::make_unique>(&eventbase_); + EXPECT_CALL(*mockSock1, address()).WillRepeatedly(ReturnRef(fakeAddress_)); + MockQuicTransport::Ptr testTransport1 = std::make_shared( + worker_->getEventBase(), std::move(mockSock1), connCb1, nullptr); + EXPECT_CALL(*testTransport1, getEventBase()) + .WillRepeatedly(Return(&eventbase_)); + EXPECT_CALL(*testTransport1, getOriginalPeerAddress()) + .WillRepeatedly(ReturnRef(kClientAddr)); + auto connId1 = getTestConnectionId(hostId_); + PacketNum num = 1; + QuicVersion version = QuicVersion::MVFST; + RoutingData routingData(HeaderForm::Long, true, true, connId1, connId1); + + auto data = createData(kMinInitialPacketSize + 10); + EXPECT_CALL( + *testTransport1, onNetworkData(kClientAddr, NetworkDataMatches(*data))); + EXPECT_CALL(*factory_, _make(_, _, _, _)).WillOnce(Return(testTransport1)); + + worker_->dispatchPacketData( + kClientAddr, + std::move(routingData), + NetworkData(data->clone(), Clock::now())); + + const auto& addrMap = worker_->getSrcToTransportMap(); + EXPECT_EQ(addrMap.count(std::make_pair(kClientAddr, connId1)), 1); + eventbase_.loop(); + + auto caddr2 = folly::SocketAddress("2.3.4.5", 1234); + NiceMock connCb2; + auto mockSock2 = + std::make_unique>(&eventbase_); + EXPECT_CALL(*mockSock2, address()).WillRepeatedly(ReturnRef(caddr2)); + MockQuicTransport::Ptr testTransport2 = std::make_shared( + worker_->getEventBase(), std::move(mockSock2), connCb2, nullptr); + EXPECT_CALL(*testTransport2, getEventBase()) + .WillRepeatedly(Return(&eventbase_)); + EXPECT_CALL(*testTransport2, getOriginalPeerAddress()) + .WillRepeatedly(ReturnRef(caddr2)); + ConnectionId connId2({2, 4, 5, 6}); + num = 1; + version = QuicVersion::MVFST; + RoutingData routingData2(HeaderForm::Long, true, true, connId2, connId2); + + auto data2 = createData(kMinInitialPacketSize + 10); + EXPECT_CALL( + *testTransport2, onNetworkData(caddr2, NetworkDataMatches(*data2))); + EXPECT_CALL(*factory_, _make(_, _, _, _)).WillOnce(Return(testTransport2)); + worker_->dispatchPacketData( + caddr2, + std::move(routingData2), + NetworkData(data2->clone(), Clock::now())); + + EXPECT_EQ(addrMap.count(std::make_pair(caddr2, connId2)), 1); + eventbase_.loop(); + + auto caddr3 = folly::SocketAddress("3.3.4.5", 1234); + auto mockSock3 = + std::make_unique>(&eventbase_); + ConnectionId connId3({8, 4, 5, 6}); + num = 1; + version = QuicVersion::MVFST; + RoutingData routingData3(HeaderForm::Long, true, true, connId3, connId3); + auto data3 = createData(kMinInitialPacketSize + 10); + EXPECT_CALL(*factory_, _make(_, _, _, _)).Times(0); + worker_->dispatchPacketData( + caddr3, + std::move(routingData3), + NetworkData(data2->clone(), Clock::now())); + + EXPECT_EQ(addrMap.count(std::make_pair(caddr3, connId3)), 0); + EXPECT_EQ(addrMap.size(), 2); + eventbase_.loop(); +} + TEST_F(QuicServerWorkerTest, QuicServerWorkerUnbindBeforeCidAvailable) { NiceMock connCb; auto mockSock = diff --git a/quic/state/QuicTransportStatsCallback.h b/quic/state/QuicTransportStatsCallback.h index f0b36879b..61a64393f 100644 --- a/quic/state/QuicTransportStatsCallback.h +++ b/quic/state/QuicTransportStatsCallback.h @@ -87,6 +87,8 @@ class QuicTransportStatsCallback { virtual void onClientInitialReceived() = 0; + virtual void onConnectionRateLimited() = 0; + // connection level metrics: virtual void onNewConnection() = 0; diff --git a/quic/state/test/MockQuicStats.h b/quic/state/test/MockQuicStats.h index eaa536753..62a53ec2d 100644 --- a/quic/state/test/MockQuicStats.h +++ b/quic/state/test/MockQuicStats.h @@ -28,6 +28,7 @@ class MockQuicStats : public QuicTransportStatsCallback { MOCK_METHOD0(onForwardedPacketReceived, void()); MOCK_METHOD0(onForwardedPacketProcessed, void()); MOCK_METHOD0(onClientInitialReceived, void()); + MOCK_METHOD0(onConnectionRateLimited, void()); MOCK_METHOD0(onNewConnection, void()); MOCK_METHOD1(onConnectionClose, void(folly::Optional)); MOCK_METHOD0(onNewQuicStream, void());