1
0
mirror of https://github.com/facebookincubator/mvfst.git synced 2025-11-10 21:22:20 +03:00

add more checks for internally forwarded data

Summary:
This adds checks for forwarded data on all the possible branches that lead to
forwarding of packets to another process

Reviewed By: mjoras

Differential Revision: D18416971

fbshipit-source-id: 22dc3fd63de615904a411f90164a138bf0ef56e0
This commit is contained in:
Udip Pant
2019-11-18 11:03:25 -08:00
committed by Facebook Github Bot
parent 024bbbba29
commit 80b3a9f393
6 changed files with 60 additions and 31 deletions

View File

@@ -311,7 +311,8 @@ void QuicServer::pauseRead() {
void QuicServer::routeDataToWorker( void QuicServer::routeDataToWorker(
const folly::SocketAddress& client, const folly::SocketAddress& client,
RoutingData&& routingData, RoutingData&& routingData,
NetworkData&& networkData) { NetworkData&& networkData,
bool isForwardedData) {
// figure out worker idx // figure out worker idx
if (!initialized_) { if (!initialized_) {
// drop the packet if we are not initialized. This is a janky memory // drop the packet if we are not initialized. This is a janky memory
@@ -344,7 +345,10 @@ void QuicServer::routeDataToWorker(
if (routingData.isUsingClientConnId && workerPtr_) { if (routingData.isUsingClientConnId && workerPtr_) {
CHECK(workerPtr_->getEventBase()->isInEventBaseThread()); CHECK(workerPtr_->getEventBase()->isInEventBaseThread());
workerPtr_->dispatchPacketData( workerPtr_->dispatchPacketData(
client, std::move(routingData), std::move(networkData)); client,
std::move(routingData),
std::move(networkData),
isForwardedData);
return; return;
} }
@@ -358,11 +362,13 @@ void QuicServer::routeDataToWorker(
cl = client, cl = client,
routingData = std::move(routingData), routingData = std::move(routingData),
w = worker.get(), w = worker.get(),
buf = std::move(networkData)]() mutable { buf = std::move(networkData),
isForwarded = isForwardedData]() mutable {
if (server->shutdown_) { if (server->shutdown_) {
return; return;
} }
w->dispatchPacketData(cl, std::move(routingData), std::move(buf)); w->dispatchPacketData(
cl, std::move(routingData), std::move(buf), isForwarded);
}); });
} }

View File

@@ -185,7 +185,7 @@ class QuicServer : public QuicServerWorker::WorkerCallback,
*/ */
void waitUntilInitialized(); void waitUntilInitialized();
void handleWorkerError(LocalErrorCode error); void handleWorkerError(LocalErrorCode error) override;
/** /**
* Routes the given data for the given client to the correct worker that may * Routes the given data for the given client to the correct worker that may
@@ -194,7 +194,8 @@ class QuicServer : public QuicServerWorker::WorkerCallback,
void routeDataToWorker( void routeDataToWorker(
const folly::SocketAddress& client, const folly::SocketAddress& client,
RoutingData&& routingData, RoutingData&& routingData,
NetworkData&& networkData); NetworkData&& networkData,
bool isForwardedData = false) override;
/** /**
* Set an EventBaseObserver for server and all its workers. This only works * Set an EventBaseObserver for server and all its workers. This only works

View File

@@ -313,7 +313,7 @@ void QuicServerWorker::forwardNetworkData(
return; return;
} }
callback_->routeDataToWorker( callback_->routeDataToWorker(
client, std::move(routingData), std::move(networkData)); client, std::move(routingData), std::move(networkData), isForwardedData);
} }
void QuicServerWorker::setPacingTimer( void QuicServerWorker::setPacingTimer(
@@ -324,7 +324,8 @@ void QuicServerWorker::setPacingTimer(
void QuicServerWorker::dispatchPacketData( void QuicServerWorker::dispatchPacketData(
const folly::SocketAddress& client, const folly::SocketAddress& client,
RoutingData&& routingData, RoutingData&& routingData,
NetworkData&& networkData) noexcept { NetworkData&& networkData,
bool isForwardedData) noexcept {
DCHECK(socket_); DCHECK(socket_);
QuicServerTransport::Ptr transport; QuicServerTransport::Ptr transport;
bool dropPacket = false; bool dropPacket = false;
@@ -447,7 +448,7 @@ void QuicServerWorker::dispatchPacketData(
routingData.destinationConnId); routingData.destinationConnId);
} }
if (!packetForwardingEnabled_) { if (!packetForwardingEnabled_ || isForwardedData) {
QUIC_STATS( QUIC_STATS(
infoCallback_, onPacketDropped, PacketDropReason::CONNECTION_NOT_FOUND); infoCallback_, onPacketDropped, PacketDropReason::CONNECTION_NOT_FOUND);
return sendResetPacket( return sendResetPacket(

View File

@@ -38,7 +38,8 @@ class QuicServerWorker : public folly::AsyncUDPSocket::ReadCallback,
virtual void routeDataToWorker( virtual void routeDataToWorker(
const folly::SocketAddress& client, const folly::SocketAddress& client,
RoutingData&& routingData, RoutingData&& routingData,
NetworkData&& networkData) = 0; NetworkData&& networkData,
bool isForwardedData) = 0;
}; };
explicit QuicServerWorker(std::shared_ptr<WorkerCallback> callback); explicit QuicServerWorker(std::shared_ptr<WorkerCallback> callback);
@@ -246,7 +247,8 @@ class QuicServerWorker : public folly::AsyncUDPSocket::ReadCallback,
void dispatchPacketData( void dispatchPacketData(
const folly::SocketAddress& client, const folly::SocketAddress& client,
RoutingData&& routingData, RoutingData&& routingData,
NetworkData&& networkData) noexcept; NetworkData&& networkData,
bool isForwardedData = false) noexcept;
using ConnIdToTransportMap = std:: using ConnIdToTransportMap = std::
unordered_map<ConnectionId, QuicServerTransport::Ptr, ConnectionIdHash>; unordered_map<ConnectionId, QuicServerTransport::Ptr, ConnectionIdHash>;

View File

@@ -47,30 +47,33 @@ class MockWorkerCallback : public QuicServerWorker::WorkerCallback {
~MockWorkerCallback() = default; ~MockWorkerCallback() = default;
MOCK_METHOD1(handleWorkerError, void(LocalErrorCode)); MOCK_METHOD1(handleWorkerError, void(LocalErrorCode));
MOCK_METHOD3( MOCK_METHOD4(
routeDataToWorkerLong, routeDataToWorkerLong,
void( void(
const folly::SocketAddress&, const folly::SocketAddress&,
std::unique_ptr<RoutingData>&, std::unique_ptr<RoutingData>&,
std::unique_ptr<NetworkData>&)); std::unique_ptr<NetworkData>&,
bool isForwardedData));
MOCK_METHOD3( MOCK_METHOD4(
routeDataToWorkerShort, routeDataToWorkerShort,
void( void(
const folly::SocketAddress&, const folly::SocketAddress&,
std::unique_ptr<RoutingData>&, std::unique_ptr<RoutingData>&,
std::unique_ptr<NetworkData>&)); std::unique_ptr<NetworkData>&,
bool isForwardedData));
void routeDataToWorker( void routeDataToWorker(
const folly::SocketAddress& client, const folly::SocketAddress& client,
RoutingData&& routingDataIn, RoutingData&& routingDataIn,
NetworkData&& networkDataIn) { NetworkData&& networkDataIn,
bool isForwardedData = false) {
auto routingData = std::make_unique<RoutingData>(std::move(routingDataIn)); auto routingData = std::make_unique<RoutingData>(std::move(routingDataIn));
auto networkData = std::make_unique<NetworkData>(std::move(networkDataIn)); auto networkData = std::make_unique<NetworkData>(std::move(networkDataIn));
if (routingData->headerForm == HeaderForm::Long) { if (routingData->headerForm == HeaderForm::Long) {
routeDataToWorkerLong(client, routingData, networkData); routeDataToWorkerLong(client, routingData, networkData, isForwardedData);
} else { } else {
routeDataToWorkerShort(client, routingData, networkData); routeDataToWorkerShort(client, routingData, networkData, isForwardedData);
} }
} }
}; };

View File

@@ -114,12 +114,16 @@ class QuicServerWorkerTest : public Test {
auto cb = [&](const folly::SocketAddress& addr, auto cb = [&](const folly::SocketAddress& addr,
std::unique_ptr<RoutingData>& routingData, std::unique_ptr<RoutingData>& routingData,
std::unique_ptr<NetworkData>& networkData) { std::unique_ptr<NetworkData>& networkData,
bool isForwardedData) {
worker_->dispatchPacketData( worker_->dispatchPacketData(
addr, std::move(*routingData.get()), std::move(*networkData.get())); addr,
std::move(*routingData.get()),
std::move(*networkData.get()),
isForwardedData);
}; };
EXPECT_CALL(*workerCb_, routeDataToWorkerLong(_, _, _)) EXPECT_CALL(*workerCb_, routeDataToWorkerLong(_, _, _, _))
.WillRepeatedly(Invoke(cb)); .WillRepeatedly(Invoke(cb));
socketFactory_ = std::make_unique<MockQuicUDPSocketFactory>(); socketFactory_ = std::make_unique<MockQuicUDPSocketFactory>();
@@ -832,11 +836,13 @@ void QuicServerWorkerTakeoverTest::testNoPacketForwarding(
ConnectionId /* connId */) { ConnectionId /* connId */) {
auto cb = [&](const folly::SocketAddress& addr, auto cb = [&](const folly::SocketAddress& addr,
std::unique_ptr<RoutingData>& /* routingData */, std::unique_ptr<RoutingData>& /* routingData */,
std::unique_ptr<NetworkData>& /* networkData */) { std::unique_ptr<NetworkData>& /* networkData */,
bool isForwardedData) {
EXPECT_EQ(addr.getIPAddress(), clientAddr.getIPAddress()); EXPECT_EQ(addr.getIPAddress(), clientAddr.getIPAddress());
EXPECT_EQ(addr.getPort(), clientAddr.getPort()); EXPECT_EQ(addr.getPort(), clientAddr.getPort());
EXPECT_FALSE(isForwardedData);
}; };
EXPECT_CALL(*takeoverWorkerCb_, routeDataToWorkerLong(_, _, _)) EXPECT_CALL(*takeoverWorkerCb_, routeDataToWorkerLong(_, _, _, _))
.WillOnce(Invoke(cb)); .WillOnce(Invoke(cb));
EXPECT_CALL(*transportInfoCb_, onPacketReceived()); EXPECT_CALL(*transportInfoCb_, onPacketReceived());
EXPECT_CALL(*transportInfoCb_, onRead(len)); EXPECT_CALL(*transportInfoCb_, onRead(len));
@@ -937,11 +943,15 @@ void QuicServerWorkerTakeoverTest::testPacketForwarding(
auto cb = [&](const folly::SocketAddress& client, auto cb = [&](const folly::SocketAddress& client,
std::unique_ptr<RoutingData>& routingData, std::unique_ptr<RoutingData>& routingData,
std::unique_ptr<NetworkData>& networkData) { std::unique_ptr<NetworkData>& networkData,
bool isForwardedData) {
takeoverWorker_->dispatchPacketData( takeoverWorker_->dispatchPacketData(
client, std::move(*routingData.get()), std::move(*networkData.get())); client,
std::move(*routingData.get()),
std::move(*networkData.get()),
isForwardedData);
}; };
EXPECT_CALL(*takeoverWorkerCb_, routeDataToWorkerLong(_, _, _)) EXPECT_CALL(*takeoverWorkerCb_, routeDataToWorkerLong(_, _, _, _))
.WillOnce(Invoke(cb)); .WillOnce(Invoke(cb));
EXPECT_CALL(*transportInfoCb_, onPacketReceived()); EXPECT_CALL(*transportInfoCb_, onPacketReceived());
EXPECT_CALL(*transportInfoCb_, onRead(len)); EXPECT_CALL(*transportInfoCb_, onRead(len));
@@ -995,15 +1005,17 @@ TEST_F(QuicServerWorkerTakeoverTest, QuicServerTakeoverProcessForwardedPkt) {
// test processing of the forwarded packet // test processing of the forwarded packet
auto cb = [&](const folly::SocketAddress& addr, auto cb = [&](const folly::SocketAddress& addr,
std::unique_ptr<RoutingData>& /* routingData */, std::unique_ptr<RoutingData>& /* routingData */,
std::unique_ptr<NetworkData>& networkData) { std::unique_ptr<NetworkData>& networkData,
bool isForwardedData) {
// verify that it is the original client address // verify that it is the original client address
EXPECT_EQ(addr.getIPAddress(), clientAddr.getIPAddress()); EXPECT_EQ(addr.getIPAddress(), clientAddr.getIPAddress());
EXPECT_EQ(addr.getPort(), clientAddr.getPort()); EXPECT_EQ(addr.getPort(), clientAddr.getPort());
// the original data should be extracted after processing takeover // the original data should be extracted after processing takeover
// protocol related information // protocol related information
EXPECT_TRUE(eq(*data, *(networkData->data))); EXPECT_TRUE(eq(*data, *(networkData->data)));
EXPECT_TRUE(isForwardedData);
}; };
EXPECT_CALL(*takeoverWorkerCb_, routeDataToWorkerLong(_, _, _)) EXPECT_CALL(*takeoverWorkerCb_, routeDataToWorkerLong(_, _, _, _))
.WillOnce(Invoke(cb)); .WillOnce(Invoke(cb));
takeoverCb->onDataAvailable(client, bufLen, false); takeoverCb->onDataAvailable(client, bufLen, false);
@@ -1011,11 +1023,15 @@ TEST_F(QuicServerWorkerTakeoverTest, QuicServerTakeoverProcessForwardedPkt) {
})); }));
auto workerCb = [&](const folly::SocketAddress& client, auto workerCb = [&](const folly::SocketAddress& client,
std::unique_ptr<RoutingData>& routingData, std::unique_ptr<RoutingData>& routingData,
std::unique_ptr<NetworkData>& networkData) { std::unique_ptr<NetworkData>& networkData,
bool isForwardedData) {
takeoverWorker_->dispatchPacketData( takeoverWorker_->dispatchPacketData(
client, std::move(*routingData.get()), std::move(*networkData.get())); client,
std::move(*routingData.get()),
std::move(*networkData.get()),
isForwardedData);
}; };
EXPECT_CALL(*takeoverWorkerCb_, routeDataToWorkerLong(_, _, _)) EXPECT_CALL(*takeoverWorkerCb_, routeDataToWorkerLong(_, _, _, _))
.WillOnce(Invoke(workerCb)); .WillOnce(Invoke(workerCb));
EXPECT_CALL(*transportInfoCb_, onPacketReceived()); EXPECT_CALL(*transportInfoCb_, onPacketReceived());
EXPECT_CALL(*transportInfoCb_, onRead(len)); EXPECT_CALL(*transportInfoCb_, onRead(len));