1
0
mirror of https://github.com/facebookincubator/mvfst.git synced 2025-08-08 09:42:06 +03:00

Exception-free Quic ConnIdAlgo

Summary: no more surprises in upper layer

Reviewed By: mjoras

Differential Revision: D19976510

fbshipit-source-id: 3487e9aa2cb28d7bc748f13bc2bbc393216b4a8a
This commit is contained in:
Yang Chi
2020-02-19 15:52:04 -08:00
committed by Facebook Github Bot
parent 6ce36df8e3
commit 5f51f4436f
14 changed files with 122 additions and 82 deletions

View File

@@ -165,7 +165,7 @@ class TestQuicTransport
while (!cursor.isAtEnd()) {
// create server chosen connId with processId = 0 and workerId = 0
ServerConnectionIdParams params(0, 0, 0);
conn_->serverConnectionId = connIdAlgo_->encodeConnectionId(params);
conn_->serverConnectionId = *connIdAlgo_->encodeConnectionId(params);
auto type = static_cast<TestFrameType>(cursor.readBE<uint8_t>());
if (type == TestFrameType::CRYPTO) {
auto cryptoBuffer = decodeCryptoBuffer(cursor);
@@ -338,7 +338,7 @@ class TestQuicTransport
void setServerConnectionId() {
// create server chosen connId with processId = 0 and workerId = 0
ServerConnectionIdParams params(0, 0, 0);
conn_->serverConnectionId = connIdAlgo_->encodeConnectionId(params);
conn_->serverConnectionId = *connIdAlgo_->encodeConnectionId(params);
}
void driveReadCallbacks() {

View File

@@ -1418,7 +1418,7 @@ class QuicClientTransportTest : public Test {
void setConnectionIds() {
originalConnId = client->getConn().clientConnectionId;
ServerConnectionIdParams params(0, 0, 0);
serverChosenConnId = connIdAlgo_->encodeConnectionId(params);
serverChosenConnId = *connIdAlgo_->encodeConnectionId(params);
}
void recvServerHello(const folly::SocketAddress& addr) {
@@ -1712,7 +1712,7 @@ TEST_F(QuicClientTransportTest, FirstPacketProcessedCallback) {
originalConnId = client->getConn().clientConnectionId;
ServerConnectionIdParams params(0, 0, 0);
client->getNonConstConn().serverConnectionId =
connIdAlgo_->encodeConnectionId(params);
*connIdAlgo_->encodeConnectionId(params);
AckBlocks acks;
acks.insert(0);
@@ -2755,7 +2755,7 @@ class QuicClientTransportVersionAndRetryTest
originalConnId = client->getConn().clientConnectionId;
// create server chosen connId with processId = 0 and workerId = 0
ServerConnectionIdParams params(0, 0, 0);
serverChosenConnId = connIdAlgo_->encodeConnectionId(params);
serverChosenConnId = *connIdAlgo_->encodeConnectionId(params);
// The tests that we do here create streams before crypto is finished,
// so we initialize the peer streams, to allow for this behavior. TODO: when
// 0-rtt support exists, remove this.

View File

@@ -8,7 +8,8 @@
#pragma once
#include <folly/Optional.h>
#include <folly/Expected.h>
#include <quic/QuicException.h>
#include <quic/codec/QuicConnectionId.h>
namespace quic {
@@ -28,19 +29,19 @@ class ConnectionIdAlgo {
/**
* Check if this implementation of algorithm can parse the given ConnectionId
*/
virtual bool canParse(const ConnectionId& id) const = 0;
virtual bool canParse(const ConnectionId& id) const noexcept = 0;
/**
* Parses ServerConnectionIdParams from the given connection id.
*/
virtual ServerConnectionIdParams parseConnectionId(
const ConnectionId& id) = 0;
virtual folly::Expected<ServerConnectionIdParams, QuicInternalException>
parseConnectionId(const ConnectionId& id) noexcept = 0;
/**
* Encodes the given ServerConnectionIdParams into connection id
*/
virtual ConnectionId encodeConnectionId(
const ServerConnectionIdParams& params) = 0;
virtual folly::Expected<ConnectionId, QuicInternalException>
encodeConnectionId(const ServerConnectionIdParams& params) noexcept = 0;
};
/**

View File

@@ -30,25 +30,28 @@ constexpr uint8_t kShortVersionBitsMask = 0xc0;
/**
* Sets the short version id bits (0 - 1) into the given ConnectionId
*/
void setVersionBitsInConnId(quic::ConnectionId& connId, uint8_t version) {
if (connId.size() < quic::kMinSelfConnectionIdSize) {
throw quic::QuicInternalException(
folly::Expected<folly::Unit, quic::QuicInternalException>
setVersionBitsInConnId(quic::ConnectionId& connId, uint8_t version) noexcept {
if (UNLIKELY(connId.size() < quic::kMinSelfConnectionIdSize)) {
return folly::makeUnexpected(quic::QuicInternalException(
"ConnectionId is too small for version",
quic::LocalErrorCode::INTERNAL_ERROR);
quic::LocalErrorCode::INTERNAL_ERROR));
}
// clear 0-1 bits
connId.data()[0] &= (~kShortVersionBitsMask);
connId.data()[0] |= (kShortVersionBitsMask & (version << 6));
return folly::unit;
}
/**
* Extract the version id bits (0 - 1) from the given ConnectionId
*/
uint8_t getVersionBitsFromConnId(const quic::ConnectionId& connId) {
if (connId.size() < quic::kMinSelfConnectionIdSize) {
throw quic::QuicInternalException(
folly::Expected<uint8_t, quic::QuicInternalException> getVersionBitsFromConnId(
const quic::ConnectionId& connId) noexcept {
if (UNLIKELY(connId.size() < quic::kMinSelfConnectionIdSize)) {
return folly::makeUnexpected(quic::QuicInternalException(
"ConnectionId is too small for version",
quic::LocalErrorCode::INTERNAL_ERROR);
quic::LocalErrorCode::INTERNAL_ERROR));
}
uint8_t version = 0;
version = (kShortVersionBitsMask & connId.data()[0]) >> 6;
@@ -58,11 +61,13 @@ uint8_t getVersionBitsFromConnId(const quic::ConnectionId& connId) {
/**
* Sets the host id bits [2 - 17] bits into the given ConnectionId
*/
void setHostIdBitsInConnId(quic::ConnectionId& connId, uint16_t hostId) {
if (connId.size() < quic::kMinSelfConnectionIdSize) {
throw quic::QuicInternalException(
folly::Expected<folly::Unit, quic::QuicInternalException> setHostIdBitsInConnId(
quic::ConnectionId& connId,
uint16_t hostId) noexcept {
if (UNLIKELY(connId.size() < quic::kMinSelfConnectionIdSize)) {
return folly::makeUnexpected(quic::QuicInternalException(
"ConnectionId is too small for hostid",
quic::LocalErrorCode::INTERNAL_ERROR);
quic::LocalErrorCode::INTERNAL_ERROR));
}
// clear 2 - 7 bits
connId.data()[0] &= ~kHostIdFirstByteMask;
@@ -77,16 +82,18 @@ void setHostIdBitsInConnId(quic::ConnectionId& connId, uint16_t hostId) {
connId.data()[1] |= (kHostIdSecondByteMask & (hostId >> 2));
// set 16 - 17 bits in the connId with the last 2 bits of the worker id
connId.data()[2] |= (kHostIdThirdByteMask & (hostId << 6));
return folly::unit;
}
/**
* Extract the host id bits [2 - 17] bits from the given ConnectionId
*/
uint16_t getHostIdBitsInConnId(const quic::ConnectionId& connId) {
if (connId.size() < quic::kMinSelfConnectionIdSize) {
throw quic::QuicInternalException(
folly::Expected<uint16_t, quic::QuicInternalException> getHostIdBitsInConnId(
const quic::ConnectionId& connId) noexcept {
if (UNLIKELY(connId.size() < quic::kMinSelfConnectionIdSize)) {
return folly::makeUnexpected(quic::QuicInternalException(
"ConnectionId is too small for hostid",
quic::LocalErrorCode::INTERNAL_ERROR);
quic::LocalErrorCode::INTERNAL_ERROR));
}
uint16_t hostId = 0;
// get 2 - 7 bits from the connId and set first 6 bits of the host id
@@ -103,11 +110,12 @@ uint16_t getHostIdBitsInConnId(const quic::ConnectionId& connId) {
/**
* Sets the given 8-bit workerId into the given connectionId's 18-25 bits
*/
void setWorkerIdBitsInConnId(quic::ConnectionId& connId, uint8_t workerId) {
if (connId.size() < quic::kMinSelfConnectionIdSize) {
throw quic::QuicInternalException(
folly::Expected<folly::Unit, quic::QuicInternalException>
setWorkerIdBitsInConnId(quic::ConnectionId& connId, uint8_t workerId) noexcept {
if (UNLIKELY(connId.size() < quic::kMinSelfConnectionIdSize)) {
return folly::makeUnexpected(quic::QuicInternalException(
"ConnectionId is too small for workerid",
quic::LocalErrorCode::INTERNAL_ERROR);
quic::LocalErrorCode::INTERNAL_ERROR));
}
// clear 18-23 bits
connId.data()[2] &= 0xc0;
@@ -117,16 +125,18 @@ void setWorkerIdBitsInConnId(quic::ConnectionId& connId, uint8_t workerId) {
connId.data()[2] |= (kWorkerIdFirstByteMask & workerId) >> 2;
// set 24 - 25 bits in the connId with the last 2 bits of the worker id
connId.data()[3] |= (kWorkerIdSecondByteMask & workerId) << 6;
return folly::unit;
}
/**
* Extracts the 'workerId' bits from the given ConnectionId
*/
uint8_t getWorkerIdFromConnId(const quic::ConnectionId& connId) {
if (connId.size() < quic::kMinSelfConnectionIdSize) {
throw quic::QuicInternalException(
folly::Expected<uint8_t, quic::QuicInternalException> getWorkerIdFromConnId(
const quic::ConnectionId& connId) noexcept {
if (UNLIKELY(connId.size() < quic::kMinSelfConnectionIdSize)) {
return folly::makeUnexpected(quic::QuicInternalException(
"ConnectionId is too small for workerid",
quic::LocalErrorCode::INTERNAL_ERROR);
quic::LocalErrorCode::INTERNAL_ERROR));
}
// get 18 - 23 bits from the connId
uint8_t workerId = connId.data()[2] << 2;
@@ -138,25 +148,30 @@ uint8_t getWorkerIdFromConnId(const quic::ConnectionId& connId) {
/**
* Sets the server id bit (at 26th bit) into the given ConnectionId
*/
void setProcessIdBitsInConnId(quic::ConnectionId& connId, uint8_t processId) {
if (connId.size() < quic::kMinSelfConnectionIdSize) {
throw quic::QuicInternalException(
folly::Expected<folly::Unit, quic::QuicInternalException>
setProcessIdBitsInConnId(
quic::ConnectionId& connId,
uint8_t processId) noexcept {
if (UNLIKELY(connId.size() < quic::kMinSelfConnectionIdSize)) {
return folly::makeUnexpected(quic::QuicInternalException(
"ConnectionId is too small for processid",
quic::LocalErrorCode::INTERNAL_ERROR);
quic::LocalErrorCode::INTERNAL_ERROR));
}
// clear the 26th bit
connId.data()[3] &= (~kProcessIdBitMask);
connId.data()[3] |= (kProcessIdBitMask & (processId << 5));
return folly::unit;
}
/**
* Extract the server id bit (at 26th bit) from the given ConnectionId
*/
uint8_t getProcessIdBitsFromConnId(const quic::ConnectionId& connId) {
folly::Expected<uint8_t, quic::QuicInternalException>
getProcessIdBitsFromConnId(const quic::ConnectionId& connId) noexcept {
if (connId.size() < quic::kMinSelfConnectionIdSize) {
throw quic::QuicInternalException(
return folly::makeUnexpected(quic::QuicInternalException(
"ConnectionId is too small for processid",
quic::LocalErrorCode::INTERNAL_ERROR);
quic::LocalErrorCode::INTERNAL_ERROR));
}
uint8_t processId = 0;
processId = (kProcessIdBitMask & connId.data()[3]) >> 5;
@@ -166,33 +181,53 @@ uint8_t getProcessIdBitsFromConnId(const quic::ConnectionId& connId) {
namespace quic {
bool DefaultConnectionIdAlgo::canParse(const ConnectionId& id) const {
bool DefaultConnectionIdAlgo::canParse(const ConnectionId& id) const noexcept {
if (id.size() < kMinSelfConnectionIdSize) {
return false;
}
return getVersionBitsFromConnId(id) == kShortVersionId;
return *getVersionBitsFromConnId(id) == kShortVersionId;
}
ServerConnectionIdParams DefaultConnectionIdAlgo::parseConnectionId(
const ConnectionId& id) {
folly::Expected<ServerConnectionIdParams, QuicInternalException>
DefaultConnectionIdAlgo::parseConnectionId(const ConnectionId& id) noexcept {
auto expectingVersion = getVersionBitsFromConnId(id);
if (UNLIKELY(!expectingVersion)) {
return folly::makeUnexpected(expectingVersion.error());
}
auto expectingHost = getHostIdBitsInConnId(id);
if (UNLIKELY(!expectingHost)) {
return folly::makeUnexpected(expectingHost.error());
}
auto expectingProcess = getProcessIdBitsFromConnId(id);
if (UNLIKELY(!expectingProcess)) {
return folly::makeUnexpected(expectingProcess.error());
}
auto expectingWorker = getWorkerIdFromConnId(id);
if (UNLIKELY(!expectingWorker)) {
return folly::makeUnexpected(expectingWorker.error());
}
ServerConnectionIdParams serverConnIdParams(
getVersionBitsFromConnId(id),
getHostIdBitsInConnId(id),
getProcessIdBitsFromConnId(id),
getWorkerIdFromConnId(id));
*expectingVersion, *expectingHost, *expectingProcess, *expectingWorker);
return serverConnIdParams;
}
ConnectionId DefaultConnectionIdAlgo::encodeConnectionId(
const ServerConnectionIdParams& params) {
folly::Expected<ConnectionId, QuicInternalException>
DefaultConnectionIdAlgo::encodeConnectionId(
const ServerConnectionIdParams& params) noexcept {
// In case there is no client cid, create a random connection id.
std::vector<uint8_t> connIdData(kDefaultConnectionIdSize);
folly::Random::secureRandom(connIdData.data(), connIdData.size());
ConnectionId connId = ConnectionId(std::move(connIdData));
setVersionBitsInConnId(connId, params.version);
setHostIdBitsInConnId(connId, params.hostId);
setProcessIdBitsInConnId(connId, params.processId);
setWorkerIdBitsInConnId(connId, params.workerId);
auto expected =
setVersionBitsInConnId(connId, params.version)
.then([&](auto) { setHostIdBitsInConnId(connId, params.hostId); })
.then(
[&](auto) { setProcessIdBitsInConnId(connId, params.processId); })
.then(
[&](auto) { setWorkerIdBitsInConnId(connId, params.workerId); });
if (UNLIKELY(expected.hasError())) {
return folly::makeUnexpected(expected.error());
}
return connId;
}

View File

@@ -8,7 +8,8 @@
#pragma once
#include <folly/Optional.h>
#include <folly/Expected.h>
#include <quic/QuicException.h>
#include <quic/codec/ConnectionIdAlgo.h>
#include <quic/codec/QuicConnectionId.h>
@@ -37,18 +38,19 @@ class DefaultConnectionIdAlgo : public ConnectionIdAlgo {
/**
* Check if this implementation of algorithm can parse the given ConnectionId
*/
bool canParse(const ConnectionId& id) const override;
bool canParse(const ConnectionId& id) const noexcept override;
/**
* Parses ServerConnectionIdParams from the given connection id.
*/
ServerConnectionIdParams parseConnectionId(const ConnectionId& id) override;
folly::Expected<ServerConnectionIdParams, QuicInternalException>
parseConnectionId(const ConnectionId& id) noexcept override;
/**
* Encodes the given ServerConnectionIdParams into connection id
*/
ConnectionId encodeConnectionId(
const ServerConnectionIdParams& params) override;
folly::Expected<ConnectionId, QuicInternalException> encodeConnectionId(
const ServerConnectionIdParams& params) noexcept override;
};
/**

View File

@@ -201,14 +201,14 @@ TEST_F(TypesTest, TestConnIdWorkerId) {
uint16_t hostId = folly::Random::rand32() % 4095;
ServerConnectionIdParams params(hostId, processId, i);
auto paramsAfterEncode =
connIdAlgo->parseConnectionId(connIdAlgo->encodeConnectionId(params));
EXPECT_TRUE(connIdAlgo->canParse(connIdAlgo->encodeConnectionId(params)));
EXPECT_EQ(paramsAfterEncode.hostId, hostId);
EXPECT_EQ(paramsAfterEncode.workerId, i);
EXPECT_EQ(paramsAfterEncode.processId, processId);
connIdAlgo->parseConnectionId(*connIdAlgo->encodeConnectionId(params));
EXPECT_TRUE(connIdAlgo->canParse(*connIdAlgo->encodeConnectionId(params)));
EXPECT_EQ(paramsAfterEncode->hostId, hostId);
EXPECT_EQ(paramsAfterEncode->workerId, i);
EXPECT_EQ(paramsAfterEncode->processId, processId);
}
ServerConnectionIdParams vParam(0x2, 7, 7, 7);
EXPECT_FALSE(connIdAlgo->canParse(connIdAlgo->encodeConnectionId(vParam)));
EXPECT_FALSE(connIdAlgo->canParse(*connIdAlgo->encodeConnectionId(vParam)));
}
TEST_F(TypesTest, ShortHeaderPacketNumberSpace) {

View File

@@ -448,7 +448,7 @@ uint64_t computeExpectedDelay(
ConnectionId getTestConnectionId(uint16_t hostId) {
ServerConnectionIdParams params(hostId, 0, 0);
DefaultConnectionIdAlgo connIdAlgo;
auto connId = connIdAlgo.encodeConnectionId(params);
auto connId = *connIdAlgo.encodeConnectionId(params);
connId.data()[3] = 3;
connId.data()[4] = 4;
connId.data()[5] = 5;

View File

@@ -82,7 +82,7 @@ class QuicLossFunctionsTest : public TestWithParam<PacketNumberSpace> {
// with bits for processId and workerId set to 0
ServerConnectionIdParams params(0, 0, 0);
conn->connIdAlgo = connIdAlgo_.get();
conn->serverConnectionId = connIdAlgo_->encodeConnectionId(params);
conn->serverConnectionId = *connIdAlgo_->encodeConnectionId(params);
// for canSetLossTimerForAppData()
conn->oneRttWriteCipher = createNoOpAead();
return conn;
@@ -108,7 +108,7 @@ class QuicLossFunctionsTest : public TestWithParam<PacketNumberSpace> {
// create a serverConnectionId that is different from the client connId
// with bits for processId and workerId set to 0
ServerConnectionIdParams params(0, 0, 0);
conn->serverConnectionId = connIdAlgo_.get()->encodeConnectionId(params);
conn->serverConnectionId = *connIdAlgo_.get()->encodeConnectionId(params);
return conn;
}

View File

@@ -26,7 +26,8 @@ size_t getWorkerToRouteTo(
const RoutingData& routingData,
size_t numWorkers,
ConnectionIdAlgo* connIdAlgo) {
return connIdAlgo->parseConnectionId(routingData.destinationConnId).workerId %
return connIdAlgo->parseConnectionId(routingData.destinationConnId)
->workerId %
numWorkers;
}
} // namespace

View File

@@ -450,7 +450,7 @@ void QuicServerWorker::dispatchPacketData(
return;
}
ServerConnectionIdParams connIdParam =
connIdAlgo_->parseConnectionId(routingData.destinationConnId);
*connIdAlgo_->parseConnectionId(routingData.destinationConnId);
if (connIdParam.hostId != hostId_) {
VLOG(3) << "Dropping packet routed to wrong host, CID="
<< routingData.destinationConnId.hex()

View File

@@ -1193,9 +1193,12 @@ QuicServerConnectionState::createAndAddNewSelfConnId() {
// TODO Possibly change this mechanism later
// The default connectionId algo has 36 bits of randomness.
auto encodedCid = connIdAlgo->encodeConnectionId(*serverConnIdParams);
if (encodedCid.hasError()) {
return folly::none;
}
auto newConnIdData =
ConnectionIdData{connIdAlgo->encodeConnectionId(*serverConnIdParams),
nextSelfConnectionIdSequence++};
ConnectionIdData{std::move(*encodedCid), nextSelfConnectionIdSequence++};
newConnIdData.token = generator.generateToken(newConnIdData.connId);
selfConnectionIds.push_back(newConnIdData);
return newConnIdData;

View File

@@ -810,7 +810,7 @@ ConnectionId createConnIdForServer(ProcessId server) {
auto connIdAlgo = std::make_unique<DefaultConnectionIdAlgo>();
uint8_t processId = (server == ProcessId::ONE) ? 1 : 0;
ServerConnectionIdParams params(0, processId, 0);
return connIdAlgo->encodeConnectionId(params);
return *connIdAlgo->encodeConnectionId(params);
}
class QuicServerWorkerTakeoverTest : public Test {

View File

@@ -51,9 +51,9 @@ TEST(ServerStateMachineTest, TestAddConnId) {
EXPECT_EQ(newConnId2->token->size(), kStatelessResetTokenLength);
EXPECT_EQ(newConnId3->token->size(), kStatelessResetTokenLength);
auto params1 = serverState.connIdAlgo->parseConnectionId(newConnId1->connId);
auto params2 = serverState.connIdAlgo->parseConnectionId(newConnId2->connId);
auto params3 = serverState.connIdAlgo->parseConnectionId(newConnId3->connId);
auto params1 = *serverState.connIdAlgo->parseConnectionId(newConnId1->connId);
auto params2 = *serverState.connIdAlgo->parseConnectionId(newConnId2->connId);
auto params3 = *serverState.connIdAlgo->parseConnectionId(newConnId3->connId);
// Server connection id params are correctly encoded/decoded.
assertServerConnIdParamsEq(originalParams, params1);

View File

@@ -526,8 +526,7 @@ TEST_F(QuicHalfClosedRemoteStateTest, AckStream) {
// create server chosen connId with processId = 0 and workerId = 0
ServerConnectionIdParams params(0, 0, 0);
auto connIdAlgo = std::make_unique<DefaultConnectionIdAlgo>();
folly::Optional<ConnectionId> serverChosenConnId =
connIdAlgo->encodeConnectionId(params);
auto serverChosenConnId = connIdAlgo->encodeConnectionId(params);
auto stream = conn->streamManager->createNextBidirectionalStream().value();
stream->sendState = StreamSendState::Open_E;
stream->recvState = StreamRecvState::Closed_E;
@@ -565,8 +564,7 @@ TEST_F(QuicHalfClosedRemoteStateTest, AckStreamAfterSkip) {
// create server chosen connId with processId = 0 and workerId = 0
ServerConnectionIdParams params(0, 0, 0);
auto connIdAlgo = std::make_unique<DefaultConnectionIdAlgo>();
folly::Optional<ConnectionId> serverChosenConnId =
connIdAlgo->encodeConnectionId(params);
auto serverChosenConnId = connIdAlgo->encodeConnectionId(params);
auto stream = conn->streamManager->createNextBidirectionalStream().value();
stream->sendState = StreamSendState::Open_E;
stream->recvState = StreamRecvState::Closed_E;