1
0
mirror of https://github.com/facebookincubator/mvfst.git synced 2025-11-10 21:22:20 +03:00

Introduce new connection rate limits.

Summary:
This introduces a rate limit to new connections created by a worker.

Right now it will simply send a VN, but eventually this will only issue a RETRY for unverified initials.

Reviewed By: udippant

Differential Revision: D21614905

fbshipit-source-id: 1832fbdad525c53fb1cb810aa9d7bae868c267d6
This commit is contained in:
Matt Joras
2020-05-18 16:38:41 -07:00
committed by Facebook GitHub Bot
parent 99e5aedc21
commit 30bff94e85
7 changed files with 129 additions and 0 deletions

View File

@@ -15,6 +15,7 @@
#include <quic/server/QuicReusePortUDPSocketFactory.h>
#include <quic/server/QuicServerTransport.h>
#include <quic/server/QuicSharedUDPSocketFactory.h>
#include <quic/server/SlidingWindowRateLimiter.h>
namespace quic {
@@ -60,6 +61,10 @@ void QuicServer::setCongestionControllerFactory(
ccFactory_ = std::move(ccFactory);
}
void QuicServer::setRateLimit(uint64_t count, std::chrono::seconds window) {
rateLimit_ = folly::make_optional<RateLimit>(count, window);
}
void QuicServer::setSupportedVersion(const std::vector<QuicVersion>& versions) {
supportedVersions_ = versions;
}
@@ -164,6 +169,10 @@ void QuicServer::initializeWorkers(
}
worker->setConnectionIdAlgo(connIdAlgoFactory_->make());
worker->setCongestionControllerFactory(ccFactory_);
if (rateLimit_) {
worker->setRateLimiter(std::make_unique<SlidingWindowRateLimiter>(
rateLimit_->count, rateLimit_->window));
}
worker->setWorkerId(i);
worker->setTransportSettingsOverrideFn(transportSettingsOverrideFn_);
workers_.push_back(std::move(worker));

View File

@@ -99,6 +99,8 @@ class QuicServer : public QuicServerWorker::WorkerCallback,
void setCongestionControllerFactory(
std::shared_ptr<CongestionControllerFactory> ccFactory);
void setRateLimit(uint64_t count, std::chrono::seconds window);
/**
* Set list of supported QUICVersion for this server. These versions will be
* used during the 'Version-Negotiation' phase with the client.
@@ -380,6 +382,13 @@ class QuicServer : public QuicServerWorker::WorkerCallback,
// address that the server is bound to
folly::SocketAddress boundAddress_;
folly::SocketOptionMap socketOptions_;
// Rate limits
struct RateLimit {
RateLimit(uint64_t c, std::chrono::seconds w) : count(c), window(w) {}
uint64_t count;
std::chrono::seconds window;
};
folly::Optional<RateLimit> rateLimit_;
};
} // namespace quic

View File

@@ -107,6 +107,11 @@ void QuicServerWorker::setCongestionControllerFactory(
ccFactory_ = ccFactory;
}
void QuicServerWorker::setRateLimiter(
std::unique_ptr<RateLimiter> rateLimiter) {
newConnRateLimiter_ = std::move(rateLimiter);
}
void QuicServerWorker::start() {
CHECK(socket_);
if (!pacingTimer_) {
@@ -490,6 +495,19 @@ void QuicServerWorker::dispatchPacketData(
PacketDropReason::INVALID_PACKET);
return;
}
if (newConnRateLimiter_ &&
newConnRateLimiter_->check(networkData.receiveTimePoint)) {
// TODO RETRY
VersionNegotiationPacketBuilder builder(
routingData.destinationConnId,
routingData.sourceConnId.value_or(
ConnectionId(std::vector<uint8_t>())),
std::vector<QuicVersion>{QuicVersion::MVFST_INVALID});
auto versionNegotiationPacket = std::move(builder).buildPacket();
socket_->write(client, versionNegotiationPacket.second);
QUIC_STATS(statsCallback_, onConnectionRateLimited);
return;
}
// create 'accepting' transport
auto sock = makeSocket(getEventBase());
auto trans = transportFactory_->make(

View File

@@ -20,6 +20,7 @@
#include <quic/server/QuicServerPacketRouter.h>
#include <quic/server/QuicServerTransportFactory.h>
#include <quic/server/QuicUDPSocketFactory.h>
#include <quic/server/RateLimiter.h>
#include <quic/server/state/ServerConnectionIdRejector.h>
#include <quic/state/QuicTransportStatsCallback.h>
@@ -226,6 +227,11 @@ class QuicServerWorker : public folly::AsyncUDPSocket::ReadCallback,
void setCongestionControllerFactory(
std::shared_ptr<CongestionControllerFactory> factory);
/**
* Set the rate limiter which will be used to rate limit new connections.
*/
void setRateLimiter(std::unique_ptr<RateLimiter> rateLimiter);
// Read callback
void getReadBuffer(void** buf, size_t* len) noexcept override;
@@ -414,6 +420,9 @@ class QuicServerWorker : public folly::AsyncUDPSocket::ReadCallback,
// Output buffer to be used for continuous memory GSO write
std::unique_ptr<BufAccessor> bufAccessor_;
// Rate limits the creation of new connections for this worker.
std::unique_ptr<RateLimiter> newConnRateLimiter_;
};
} // namespace quic

View File

@@ -18,6 +18,7 @@
#include <quic/codec/QuicHeaderCodec.h>
#include <quic/codec/test/Mocks.h>
#include <quic/common/test/TestUtils.h>
#include <quic/server/SlidingWindowRateLimiter.h>
#include <quic/server/handshake/StatelessResetGenerator.h>
#include <quic/server/test/Mocks.h>
#include <quic/state/test/MockQuicStats.h>
@@ -378,6 +379,86 @@ TEST_F(QuicServerWorkerTest, NoConnFoundTestReset) {
QuicTransportStatsCallback::PacketDropReason::CONNECTION_NOT_FOUND);
}
TEST_F(QuicServerWorkerTest, RateLimit) {
worker_->setRateLimiter(std::make_unique<SlidingWindowRateLimiter>(2, 60s));
EXPECT_CALL(*transportInfoCb_, onConnectionRateLimited()).Times(1);
NiceMock<MockConnectionCallback> connCb1;
auto mockSock1 =
std::make_unique<NiceMock<folly::test::MockAsyncUDPSocket>>(&eventbase_);
EXPECT_CALL(*mockSock1, address()).WillRepeatedly(ReturnRef(fakeAddress_));
MockQuicTransport::Ptr testTransport1 = std::make_shared<MockQuicTransport>(
worker_->getEventBase(), std::move(mockSock1), connCb1, nullptr);
EXPECT_CALL(*testTransport1, getEventBase())
.WillRepeatedly(Return(&eventbase_));
EXPECT_CALL(*testTransport1, getOriginalPeerAddress())
.WillRepeatedly(ReturnRef(kClientAddr));
auto connId1 = getTestConnectionId(hostId_);
PacketNum num = 1;
QuicVersion version = QuicVersion::MVFST;
RoutingData routingData(HeaderForm::Long, true, true, connId1, connId1);
auto data = createData(kMinInitialPacketSize + 10);
EXPECT_CALL(
*testTransport1, onNetworkData(kClientAddr, NetworkDataMatches(*data)));
EXPECT_CALL(*factory_, _make(_, _, _, _)).WillOnce(Return(testTransport1));
worker_->dispatchPacketData(
kClientAddr,
std::move(routingData),
NetworkData(data->clone(), Clock::now()));
const auto& addrMap = worker_->getSrcToTransportMap();
EXPECT_EQ(addrMap.count(std::make_pair(kClientAddr, connId1)), 1);
eventbase_.loop();
auto caddr2 = folly::SocketAddress("2.3.4.5", 1234);
NiceMock<MockConnectionCallback> connCb2;
auto mockSock2 =
std::make_unique<NiceMock<folly::test::MockAsyncUDPSocket>>(&eventbase_);
EXPECT_CALL(*mockSock2, address()).WillRepeatedly(ReturnRef(caddr2));
MockQuicTransport::Ptr testTransport2 = std::make_shared<MockQuicTransport>(
worker_->getEventBase(), std::move(mockSock2), connCb2, nullptr);
EXPECT_CALL(*testTransport2, getEventBase())
.WillRepeatedly(Return(&eventbase_));
EXPECT_CALL(*testTransport2, getOriginalPeerAddress())
.WillRepeatedly(ReturnRef(caddr2));
ConnectionId connId2({2, 4, 5, 6});
num = 1;
version = QuicVersion::MVFST;
RoutingData routingData2(HeaderForm::Long, true, true, connId2, connId2);
auto data2 = createData(kMinInitialPacketSize + 10);
EXPECT_CALL(
*testTransport2, onNetworkData(caddr2, NetworkDataMatches(*data2)));
EXPECT_CALL(*factory_, _make(_, _, _, _)).WillOnce(Return(testTransport2));
worker_->dispatchPacketData(
caddr2,
std::move(routingData2),
NetworkData(data2->clone(), Clock::now()));
EXPECT_EQ(addrMap.count(std::make_pair(caddr2, connId2)), 1);
eventbase_.loop();
auto caddr3 = folly::SocketAddress("3.3.4.5", 1234);
auto mockSock3 =
std::make_unique<NiceMock<folly::test::MockAsyncUDPSocket>>(&eventbase_);
ConnectionId connId3({8, 4, 5, 6});
num = 1;
version = QuicVersion::MVFST;
RoutingData routingData3(HeaderForm::Long, true, true, connId3, connId3);
auto data3 = createData(kMinInitialPacketSize + 10);
EXPECT_CALL(*factory_, _make(_, _, _, _)).Times(0);
worker_->dispatchPacketData(
caddr3,
std::move(routingData3),
NetworkData(data2->clone(), Clock::now()));
EXPECT_EQ(addrMap.count(std::make_pair(caddr3, connId3)), 0);
EXPECT_EQ(addrMap.size(), 2);
eventbase_.loop();
}
TEST_F(QuicServerWorkerTest, QuicServerWorkerUnbindBeforeCidAvailable) {
NiceMock<MockConnectionCallback> connCb;
auto mockSock =

View File

@@ -87,6 +87,8 @@ class QuicTransportStatsCallback {
virtual void onClientInitialReceived() = 0;
virtual void onConnectionRateLimited() = 0;
// connection level metrics:
virtual void onNewConnection() = 0;

View File

@@ -28,6 +28,7 @@ class MockQuicStats : public QuicTransportStatsCallback {
MOCK_METHOD0(onForwardedPacketReceived, void());
MOCK_METHOD0(onForwardedPacketProcessed, void());
MOCK_METHOD0(onClientInitialReceived, void());
MOCK_METHOD0(onConnectionRateLimited, void());
MOCK_METHOD0(onNewConnection, void());
MOCK_METHOD1(onConnectionClose, void(folly::Optional<ConnectionCloseReason>));
MOCK_METHOD0(onNewQuicStream, void());