1
0
mirror of https://github.com/facebookincubator/mvfst.git synced 2025-08-08 09:42:06 +03:00

Remove throws from CryptoFactory

Summary: Continuing the theme, removing them from the CryptoFactory and translating to Expected.

Reviewed By: kvtsoy, jbeshay

Differential Revision: D74676120

fbshipit-source-id: 715b497e68a4e3004811038cba479c443d5398fd
This commit is contained in:
Matt Joras
2025-05-16 14:19:45 -07:00
committed by Facebook GitHub Bot
parent c088421ecf
commit bf71d17f2c
36 changed files with 617 additions and 300 deletions

View File

@@ -242,7 +242,7 @@ class TestQuicTransport
transportConn = conn.get();
conn_.reset(conn.release());
aead = test::createNoOpAead();
headerCipher = test::createNoOpHeaderCipher();
headerCipher = test::createNoOpHeaderCipher().value();
connIdAlgo_ = std::make_unique<DefaultConnectionIdAlgo>();
setConnectionSetupCallback(connSetupCb);
setConnectionCallbackFromCtor(connCb);

View File

@@ -171,7 +171,7 @@ class QuicTransportFunctionsTest : public Test {
public:
void SetUp() override {
aead = test::createNoOpAead();
headerCipher = test::createNoOpHeaderCipher();
headerCipher = test::createNoOpHeaderCipher().value();
quicStats_ = std::make_unique<NiceMock<MockQuicStats>>();
}
@@ -191,7 +191,7 @@ class QuicTransportFunctionsTest : public Test {
kDefaultConnectionFlowControlWindow * 1000;
conn->statsCallback = quicStats_.get();
conn->initialWriteCipher = createNoOpAead();
conn->initialHeaderCipher = createNoOpHeaderCipher();
conn->initialHeaderCipher = createNoOpHeaderCipher().value();
CHECK(
!conn->streamManager
->setMaxLocalBidirectionalStreams(kDefaultMaxStreamsBidirectional)
@@ -3665,7 +3665,7 @@ TEST_F(QuicTransportFunctionsTest, ResetNumProbePackets) {
EXPECT_EQ(0, writeRes1->bytesWritten);
conn->handshakeWriteCipher = createNoOpAead();
conn->handshakeWriteHeaderCipher = createNoOpHeaderCipher();
conn->handshakeWriteHeaderCipher = createNoOpHeaderCipher().value();
conn->pendingEvents.numProbePackets[PacketNumberSpace::Handshake] = 2;
auto writeRes2 = writeCryptoAndAckDataToSocket(
*rawSocket,
@@ -3682,7 +3682,7 @@ TEST_F(QuicTransportFunctionsTest, ResetNumProbePackets) {
EXPECT_EQ(0, writeRes2->bytesWritten);
conn->oneRttWriteCipher = createNoOpAead();
conn->oneRttWriteHeaderCipher = createNoOpHeaderCipher();
conn->oneRttWriteHeaderCipher = createNoOpHeaderCipher().value();
conn->pendingEvents.numProbePackets[PacketNumberSpace::AppData] = 2;
auto writeRes3 = writeQuicDataToSocket(
*rawSocket,
@@ -4474,13 +4474,13 @@ TEST_F(QuicTransportFunctionsTest, HandshakeConfirmedDropCipher) {
ASSERT_NE(nullptr, conn->initialWriteCipher);
conn->handshakeWriteCipher = createNoOpAead();
conn->readCodec->setInitialReadCipher(createNoOpAead());
conn->readCodec->setInitialHeaderCipher(createNoOpHeaderCipher());
conn->readCodec->setInitialHeaderCipher(createNoOpHeaderCipher().value());
conn->readCodec->setHandshakeReadCipher(createNoOpAead());
conn->readCodec->setHandshakeHeaderCipher(createNoOpHeaderCipher());
conn->readCodec->setHandshakeHeaderCipher(createNoOpHeaderCipher().value());
conn->oneRttWriteCipher = createNoOpAead();
conn->oneRttWriteHeaderCipher = createNoOpHeaderCipher();
conn->oneRttWriteHeaderCipher = createNoOpHeaderCipher().value();
conn->readCodec->setOneRttReadCipher(createNoOpAead());
conn->readCodec->setOneRttHeaderCipher(createNoOpHeaderCipher());
conn->readCodec->setOneRttHeaderCipher(createNoOpHeaderCipher().value());
writeCryptoDataProbesToSocketForTest(
*socket,
*conn,

View File

@@ -105,7 +105,7 @@ class QuicTransportTest : public Test {
.WillRepeatedly(
Invoke([&](auto& buf, auto, auto) { return buf->clone(); }));
transport_->getConnectionState().oneRttWriteCipher = std::move(aead);
auto tempHeaderCipher = test::createNoOpHeaderCipher();
auto tempHeaderCipher = test::createNoOpHeaderCipher().value();
tempHeaderCipher->setDefaultKey();
transport_->getConnectionState().oneRttWriteHeaderCipher =
std::move(tempHeaderCipher);

View File

@@ -32,7 +32,7 @@ class TestQuicTransport
conn_->version = QuicVersion::MVFST;
conn_->observerContainer = observerContainer_;
aead = test::createNoOpAead();
headerCipher = test::createNoOpHeaderCipher();
headerCipher = test::createNoOpHeaderCipher().value();
setConnectionSetupCallback(connSetupCb);
setConnectionCallbackFromCtor(connCb);
}

View File

@@ -1164,15 +1164,35 @@ QuicClientTransportLite::startCryptoHandshake() {
auto& cryptoFactory = handshakeLayer->getCryptoFactory();
auto version = conn_->originalVersion.value();
conn_->initialWriteCipher = cryptoFactory.getClientInitialCipher(
auto initialWriteCipherResult = cryptoFactory.getClientInitialCipher(
*clientConn_->initialDestinationConnectionId, version);
conn_->readCodec->setInitialReadCipher(cryptoFactory.getServerInitialCipher(
*clientConn_->initialDestinationConnectionId, version));
if (initialWriteCipherResult.hasError()) {
return folly::makeUnexpected(initialWriteCipherResult.error());
}
conn_->initialWriteCipher = std::move(initialWriteCipherResult.value());
auto serverInitialCipherResult = cryptoFactory.getServerInitialCipher(
*clientConn_->initialDestinationConnectionId, version);
if (serverInitialCipherResult.hasError()) {
return folly::makeUnexpected(serverInitialCipherResult.error());
}
conn_->readCodec->setInitialReadCipher(
std::move(serverInitialCipherResult.value()));
auto serverHeaderCipherResult = cryptoFactory.makeServerInitialHeaderCipher(
*clientConn_->initialDestinationConnectionId, version);
if (serverHeaderCipherResult.hasError()) {
return folly::makeUnexpected(serverHeaderCipherResult.error());
}
conn_->readCodec->setInitialHeaderCipher(
cryptoFactory.makeServerInitialHeaderCipher(
*clientConn_->initialDestinationConnectionId, version));
conn_->initialHeaderCipher = cryptoFactory.makeClientInitialHeaderCipher(
std::move(serverHeaderCipherResult.value()));
auto clientHeaderCipherResult = cryptoFactory.makeClientInitialHeaderCipher(
*clientConn_->initialDestinationConnectionId, version);
if (clientHeaderCipherResult.hasError()) {
return folly::makeUnexpected(clientHeaderCipherResult.error());
}
conn_->initialHeaderCipher = std::move(clientHeaderCipherResult.value());
customTransportParameters_ = getSupportedExtTransportParams(*conn_);

View File

@@ -159,9 +159,21 @@ bool ClientHandshake::waitingForData() const {
}
void ClientHandshake::computeCiphers(CipherKind kind, ByteRange secret) {
std::unique_ptr<Aead> aead = buildAead(kind, secret);
auto aeadResult = buildAead(kind, secret);
if (aeadResult.hasError()) {
error_ = folly::makeUnexpected(std::move(aeadResult.error()));
return;
}
auto packetNumberCipherResult = buildHeaderCipher(secret);
if (packetNumberCipherResult.hasError()) {
error_ = folly::makeUnexpected(std::move(packetNumberCipherResult.error()));
return;
}
std::unique_ptr<Aead> aead = std::move(aeadResult.value());
std::unique_ptr<PacketNumberCipher> packetNumberCipher =
buildHeaderCipher(secret);
std::move(packetNumberCipherResult.value());
switch (kind) {
case CipherKind::HandshakeWrite:
conn_->handshakeWriteCipher = std::move(aead);
@@ -208,11 +220,15 @@ ClientHandshake::getNextOneRttWriteCipher() {
CHECK(writeTrafficSecret_);
LOG_IF(WARNING, trafficSecretSync_ > 1 || trafficSecretSync_ < -1)
<< "Client read and write secrets are out of sync";
writeTrafficSecret_ = getNextTrafficSecret(writeTrafficSecret_->coalesce());
auto nextSecretResult = getNextTrafficSecret(writeTrafficSecret_->coalesce());
if (nextSecretResult.hasError()) {
return folly::makeUnexpected(std::move(nextSecretResult.error()));
}
writeTrafficSecret_ = std::move(nextSecretResult.value());
trafficSecretSync_--;
auto cipher =
buildAead(CipherKind::OneRttWrite, writeTrafficSecret_->coalesce());
return cipher;
return buildAead(CipherKind::OneRttWrite, writeTrafficSecret_->coalesce());
}
folly::Expected<std::unique_ptr<Aead>, QuicError>
@@ -224,11 +240,15 @@ ClientHandshake::getNextOneRttReadCipher() {
CHECK(readTrafficSecret_);
LOG_IF(WARNING, trafficSecretSync_ > 1 || trafficSecretSync_ < -1)
<< "Client read and write secrets are out of sync";
readTrafficSecret_ = getNextTrafficSecret(readTrafficSecret_->coalesce());
auto nextSecretResult = getNextTrafficSecret(readTrafficSecret_->coalesce());
if (nextSecretResult.hasError()) {
return folly::makeUnexpected(std::move(nextSecretResult.error()));
}
readTrafficSecret_ = std::move(nextSecretResult.value());
trafficSecretSync_++;
auto cipher =
buildAead(CipherKind::OneRttRead, readTrafficSecret_->coalesce());
return cipher;
return buildAead(CipherKind::OneRttRead, readTrafficSecret_->coalesce());
}
void ClientHandshake::waitForData() {

View File

@@ -102,7 +102,8 @@ class ClientHandshake : public Handshake {
* API used to verify that the integrity token present in the retry packet
* matches what we would expect
*/
virtual bool verifyRetryIntegrityTag(
[[nodiscard]] virtual folly::Expected<bool, QuicError>
verifyRetryIntegrityTag(
const ConnectionId& originalDstConnId,
const RetryPacket& retryPacket) = 0;
@@ -176,7 +177,8 @@ class ClientHandshake : public Handshake {
* Given secret_n, returns secret_n+1 to be used for generating the next Aead
* on key updates.
*/
virtual BufPtr getNextTrafficSecret(ByteRange secret) const = 0;
[[nodiscard]] virtual folly::Expected<BufPtr, QuicError> getNextTrafficSecret(
ByteRange secret) const = 0;
BufPtr readTrafficSecret_;
BufPtr writeTrafficSecret_;
@@ -185,16 +187,17 @@ class ClientHandshake : public Handshake {
Optional<bool> canResendZeroRtt_;
private:
virtual folly::Expected<Optional<CachedServerTransportParameters>, QuicError>
[[nodiscard]] virtual folly::
Expected<Optional<CachedServerTransportParameters>, QuicError>
connectImpl(Optional<std::string> hostname) = 0;
virtual void processSocketData(folly::IOBufQueue& queue) = 0;
virtual bool matchEarlyParameters() = 0;
virtual std::unique_ptr<Aead> buildAead(
CipherKind kind,
ByteRange secret) = 0;
virtual std::unique_ptr<PacketNumberCipher> buildHeaderCipher(
ByteRange secret) = 0;
[[nodiscard]] virtual folly::Expected<std::unique_ptr<Aead>, QuicError>
buildAead(CipherKind kind, ByteRange secret) = 0;
[[nodiscard]] virtual folly::
Expected<std::unique_ptr<PacketNumberCipher>, QuicError>
buildHeaderCipher(ByteRange secret) = 0;
// Represents the packet type that should be used to write the data currently
// in the stream.

View File

@@ -57,7 +57,7 @@ class ClientStateMachineTest : public Test {
public:
void SetUp() override {
mockFactory_ = std::make_shared<MockClientHandshakeFactory>();
EXPECT_CALL(*mockFactory_, _makeClientHandshake(_))
EXPECT_CALL(*mockFactory_, makeClientHandshakeImpl(_))
.WillRepeatedly(Invoke(
[&](QuicClientConnectionState* conn)
-> std::unique_ptr<quic::ClientHandshake> {

View File

@@ -22,29 +22,29 @@ namespace quic::test {
class MockClientHandshakeFactory : public ClientHandshakeFactory {
public:
MOCK_METHOD(
std::unique_ptr<ClientHandshake>,
_makeClientHandshake,
(QuicClientConnectionState*));
std::unique_ptr<ClientHandshake>
makeClientHandshake(QuicClientConnectionState* conn) && override {
return _makeClientHandshake(conn);
return std::move(*this).makeClientHandshakeImpl(conn);
}
MOCK_METHOD(
std::unique_ptr<ClientHandshake>,
makeClientHandshakeImpl,
(QuicClientConnectionState*));
};
class MockClientHandshake : public ClientHandshake {
class MockClientHandshakeBase : public ClientHandshake {
public:
MockClientHandshake(QuicClientConnectionState* conn)
MockClientHandshakeBase(QuicClientConnectionState* conn)
: ClientHandshake(conn) {}
~MockClientHandshake() override {
~MockClientHandshakeBase() override {
destroy();
}
// Legacy workaround for move-only types
folly::Expected<folly::Unit, QuicError> doHandshake(
std::unique_ptr<folly::IOBuf> data,
BufPtr data,
EncryptionLevel encryptionLevel) override {
doHandshakeImpl(data.get(), encryptionLevel);
return folly::unit;
@@ -52,53 +52,133 @@ class MockClientHandshake : public ClientHandshake {
MOCK_METHOD(void, doHandshakeImpl, (folly::IOBuf*, EncryptionLevel));
MOCK_METHOD(
bool,
(folly::Expected<bool, QuicError>),
verifyRetryIntegrityTag,
(const ConnectionId&, const RetryPacket&));
(const ConnectionId&, const RetryPacket&),
(override));
MOCK_METHOD(void, removePsk, (const Optional<std::string>&));
MOCK_METHOD(const CryptoFactory&, getCryptoFactory, (), (const));
MOCK_METHOD(bool, isTLSResumed, (), (const));
MOCK_METHOD(const CryptoFactory&, getCryptoFactory, (), (const, override));
MOCK_METHOD(bool, isTLSResumed, (), (const, override));
MOCK_METHOD(
Optional<std::vector<uint8_t>>,
getExportedKeyingMaterial,
(const std::string& label,
const Optional<ByteRange>& context,
uint16_t keyLength),
());
(override));
MOCK_METHOD(Optional<bool>, getZeroRttRejected, ());
MOCK_METHOD(Optional<bool>, getCanResendZeroRtt, (), (const));
MOCK_METHOD(
const Optional<ServerTransportParameters>&,
getServerTransportParams,
());
(),
(override));
MOCK_METHOD(void, destroy, ());
MOCK_METHOD(
(folly::Expected<std::unique_ptr<Aead>, QuicError>),
getNextOneRttWriteCipher,
(),
(override));
MOCK_METHOD(
(folly::Expected<std::unique_ptr<Aead>, QuicError>),
getNextOneRttReadCipher,
(),
(override));
void handshakeConfirmed() override {
handshakeConfirmedImpl();
}
MOCK_METHOD(void, handshakeConfirmedImpl, ());
Handshake::TLSSummary getTLSSummary() const override {
return getTLSSummaryImpl();
}
MOCK_METHOD(Handshake::TLSSummary, getTLSSummaryImpl, (), (const));
// Mock the public connect method
folly::Expected<folly::Unit, QuicError> connect(
Optional<std::string> hostname,
std::shared_ptr<ClientTransportParametersExtension> transportParams) {
return mockConnect(std::move(hostname), std::move(transportParams));
}
MOCK_METHOD(
(folly::Expected<Optional<CachedServerTransportParameters>, QuicError>),
connectImpl,
(Optional<std::string>));
MOCK_METHOD(EncryptionLevel, getReadRecordLayerEncryptionLevel, ());
(folly::Expected<folly::Unit, QuicError>),
mockConnect,
(Optional<std::string>,
std::shared_ptr<ClientTransportParametersExtension>));
MOCK_METHOD(
EncryptionLevel,
getReadRecordLayerEncryptionLevel,
(),
(override));
MOCK_METHOD(void, processSocketData, (folly::IOBufQueue & queue));
MOCK_METHOD(bool, matchEarlyParameters, ());
MOCK_METHOD(
std::unique_ptr<Aead>,
(folly::Expected<std::unique_ptr<Aead>, QuicError>),
buildAead,
(ClientHandshake::CipherKind kind, ByteRange secret));
MOCK_METHOD(
std::unique_ptr<PacketNumberCipher>,
(folly::Expected<std::unique_ptr<PacketNumberCipher>, QuicError>),
buildHeaderCipher,
(ByteRange secret));
MOCK_METHOD(BufPtr, getNextTrafficSecret, (ByteRange secret), (const));
MOCK_METHOD(
(folly::Expected<BufPtr, QuicError>),
getNextTrafficSecret,
(ByteRange secret),
(const));
MOCK_METHOD(
const Optional<std::string>&,
getApplicationProtocol,
(),
(const));
(const, override));
MOCK_METHOD(
const std::shared_ptr<const folly::AsyncTransportCertificate>,
getPeerCertificate,
(),
(const));
MOCK_METHOD(Handshake::TLSSummary, getTLSSummary, (), (const));
(const, override));
MOCK_METHOD(Phase, getPhase, (), (const));
MOCK_METHOD(bool, waitingForData, (), (const));
};
class MockClientHandshake : public MockClientHandshakeBase {
public:
MockClientHandshake(QuicClientConnectionState* conn)
: MockClientHandshakeBase(conn) {}
private:
// Implement the private pure virtual methods from ClientHandshake
folly::Expected<Optional<CachedServerTransportParameters>, QuicError>
connectImpl(Optional<std::string> /* hostname */) override {
return Optional<CachedServerTransportParameters>(std::nullopt);
}
void processSocketData(folly::IOBufQueue& /* queue */) override {}
bool matchEarlyParameters() override {
return false;
}
folly::Expected<std::unique_ptr<Aead>, QuicError> buildAead(
CipherKind /* kind */,
ByteRange /* secret */) override {
return folly::makeUnexpected(
QuicError(TransportErrorCode::INTERNAL_ERROR, "Not implemented"));
}
folly::Expected<std::unique_ptr<PacketNumberCipher>, QuicError>
buildHeaderCipher(ByteRange /* secret */) override {
return folly::makeUnexpected(
QuicError(TransportErrorCode::INTERNAL_ERROR, "Not implemented"));
}
folly::Expected<BufPtr, QuicError> getNextTrafficSecret(
ByteRange /* secret */) const override {
return folly::makeUnexpected(
QuicError(TransportErrorCode::INTERNAL_ERROR, "Not implemented"));
}
};
class MockQuicConnectorCallback : public quic::QuicConnector::Callback {

View File

@@ -51,7 +51,7 @@ class QuicClientTransportLiteTest : public Test {
ON_CALL(*socket, setCmsgs(_)).WillByDefault(Return(folly::unit));
ON_CALL(*socket, appendCmsgs(_)).WillByDefault(Return(folly::unit));
auto mockFactory = std::make_shared<MockClientHandshakeFactory>();
EXPECT_CALL(*mockFactory, _makeClientHandshake(_))
EXPECT_CALL(*mockFactory, makeClientHandshakeImpl(_))
.WillRepeatedly(Invoke(
[&](QuicClientConnectionState* conn)
-> std::unique_ptr<quic::ClientHandshake> {
@@ -61,7 +61,7 @@ class QuicClientTransportLiteTest : public Test {
qEvb_, std::move(socket), mockFactory);
quicClient_->getConn()->oneRttWriteCipher = test::createNoOpAead();
quicClient_->getConn()->oneRttWriteHeaderCipher =
test::createNoOpHeaderCipher();
test::createNoOpHeaderCipher().value();
ASSERT_FALSE(quicClient_->getState()
->streamManager->setMaxLocalBidirectionalStreams(128)
.hasError());

View File

@@ -109,7 +109,7 @@ class QuicClientTransportTest : public Test {
.WillByDefault(testing::Return(folly::unit));
mockFactory_ = std::make_shared<MockClientHandshakeFactory>();
EXPECT_CALL(*mockFactory_, _makeClientHandshake(_))
EXPECT_CALL(*mockFactory_, makeClientHandshakeImpl(_))
.WillRepeatedly(Invoke(
[&](QuicClientConnectionState* conn)
-> std::unique_ptr<quic::ClientHandshake> {

View File

@@ -64,23 +64,25 @@ std::unique_ptr<QuicReadCodec> makeCodec(
auto codec = std::make_unique<QuicReadCodec>(nodeType);
if (nodeType != QuicNodeType::Client) {
codec->setZeroRttReadCipher(std::move(zeroRttCipher));
codec->setZeroRttHeaderCipher(test::createNoOpHeaderCipher());
codec->setZeroRttHeaderCipher(test::createNoOpHeaderCipher().value());
}
codec->setOneRttReadCipher(std::move(oneRttCipher));
codec->setOneRttHeaderCipher(test::createNoOpHeaderCipher());
codec->setOneRttHeaderCipher(test::createNoOpHeaderCipher().value());
codec->setHandshakeReadCipher(test::createNoOpAead());
codec->setHandshakeHeaderCipher(test::createNoOpHeaderCipher());
codec->setHandshakeHeaderCipher(test::createNoOpHeaderCipher().value());
codec->setClientConnectionId(clientConnId);
if (nodeType == QuicNodeType::Client) {
codec->setInitialReadCipher(
cryptoFactory.getServerInitialCipher(clientConnId, version));
cryptoFactory.getServerInitialCipher(clientConnId, version).value());
codec->setInitialHeaderCipher(
cryptoFactory.makeServerInitialHeaderCipher(clientConnId, version));
cryptoFactory.makeServerInitialHeaderCipher(clientConnId, version)
.value());
} else {
codec->setInitialReadCipher(
cryptoFactory.getClientInitialCipher(clientConnId, version));
cryptoFactory.getClientInitialCipher(clientConnId, version).value());
codec->setInitialHeaderCipher(
cryptoFactory.makeClientInitialHeaderCipher(clientConnId, version));
cryptoFactory.makeClientInitialHeaderCipher(clientConnId, version)
.value());
}
return codec;
}
@@ -181,9 +183,10 @@ TEST_P(QuicPacketBuilderTest, LongHeaderRegularPacket) {
QuicVersion ver = QuicVersion::MVFST;
// create a server cleartext write codec.
FizzCryptoFactory cryptoFactory;
auto cleartextAead = cryptoFactory.getClientInitialCipher(serverConnId, ver);
auto cleartextAead =
cryptoFactory.getClientInitialCipher(serverConnId, ver).value();
auto headerCipher =
cryptoFactory.makeClientInitialHeaderCipher(serverConnId, ver);
cryptoFactory.makeClientInitialHeaderCipher(serverConnId, ver).value();
std::unique_ptr<PacketBuilderInterface> builderOwner;
auto builderProvider = [&](PacketHeader header, PacketNum largestAcked) {

View File

@@ -44,15 +44,18 @@ std::unique_ptr<QuicReadCodec> makeEncryptedCodec(
auto codec = std::make_unique<QuicReadCodec>(nodeType);
codec->setClientConnectionId(clientConnId);
codec->setInitialReadCipher(
cryptoFactory.getClientInitialCipher(clientConnId, QuicVersion::MVFST));
codec->setInitialHeaderCipher(cryptoFactory.makeClientInitialHeaderCipher(
clientConnId, QuicVersion::MVFST));
cryptoFactory.getClientInitialCipher(clientConnId, QuicVersion::MVFST)
.value());
codec->setInitialHeaderCipher(
cryptoFactory
.makeClientInitialHeaderCipher(clientConnId, QuicVersion::MVFST)
.value());
if (zeroRttAead) {
codec->setZeroRttReadCipher(std::move(zeroRttAead));
}
codec->setZeroRttHeaderCipher(test::createNoOpHeaderCipher());
codec->setZeroRttHeaderCipher(test::createNoOpHeaderCipher().value());
codec->setOneRttReadCipher(std::move(oneRttAead));
codec->setOneRttHeaderCipher(test::createNoOpHeaderCipher());
codec->setOneRttHeaderCipher(test::createNoOpHeaderCipher().value());
if (sourceToken) {
codec->setStatelessResetToken(*sourceToken);
}
@@ -217,7 +220,7 @@ TEST_F(QuicReadCodecTest, LongHeaderPacketLenMismatch) {
AckStates ackStates;
auto codec = makeUnencryptedCodec();
codec->setInitialReadCipher(createNoOpAead());
codec->setInitialHeaderCipher(test::createNoOpHeaderCipher());
codec->setInitialHeaderCipher(test::createNoOpHeaderCipher().value());
auto result = codec->parsePacket(packetQueue, ackStates);
auto nothing = result.nothing();
EXPECT_NE(nothing, nullptr);
@@ -667,9 +670,11 @@ TEST_F(QuicReadCodecTest, TestInitialPacket) {
FizzCryptoFactory cryptoFactory;
PacketNum packetNum = 1;
uint64_t offset = 0;
auto aead = cryptoFactory.getClientInitialCipher(connId, QuicVersion::MVFST);
auto aead =
cryptoFactory.getClientInitialCipher(connId, QuicVersion::MVFST).value();
auto headerCipher =
cryptoFactory.makeClientInitialHeaderCipher(connId, QuicVersion::MVFST);
cryptoFactory.makeClientInitialHeaderCipher(connId, QuicVersion::MVFST)
.value();
auto initialCryptoPacketBuf = folly::IOBuf::copyBuffer("CHLO");
ChainedByteRangeHead initialCryptoPacketRch(initialCryptoPacketBuf);
auto packet = createInitialCryptoPacket(
@@ -682,7 +687,8 @@ TEST_F(QuicReadCodecTest, TestInitialPacket) {
offset);
auto codec = makeEncryptedCodec(connId, std::move(aead), nullptr);
aead = cryptoFactory.getClientInitialCipher(connId, QuicVersion::MVFST);
aead =
cryptoFactory.getClientInitialCipher(connId, QuicVersion::MVFST).value();
AckStates ackStates;
auto packetQueue =
bufToQueue(packetToBufCleartext(packet, *aead, *headerCipher, packetNum));
@@ -703,9 +709,11 @@ TEST_F(QuicReadCodecTest, TestInitialPacketExtractToken) {
FizzCryptoFactory cryptoFactory;
PacketNum packetNum = 1;
uint64_t offset = 0;
auto aead = cryptoFactory.getClientInitialCipher(connId, QuicVersion::MVFST);
auto aead =
cryptoFactory.getClientInitialCipher(connId, QuicVersion::MVFST).value();
auto headerCipher =
cryptoFactory.makeClientInitialHeaderCipher(connId, QuicVersion::MVFST);
cryptoFactory.makeClientInitialHeaderCipher(connId, QuicVersion::MVFST)
.value();
std::string token = "aswerdfewdewrgetg";
auto initialCryptoPacketBuf = folly::IOBuf::copyBuffer("CHLO");
ChainedByteRangeHead initialCryptoPacketRch(initialCryptoPacketBuf);
@@ -721,7 +729,8 @@ TEST_F(QuicReadCodecTest, TestInitialPacketExtractToken) {
token);
auto codec = makeEncryptedCodec(connId, std::move(aead), nullptr);
aead = cryptoFactory.getClientInitialCipher(connId, QuicVersion::MVFST);
aead =
cryptoFactory.getClientInitialCipher(connId, QuicVersion::MVFST).value();
auto packetQueue =
bufToQueue(packetToBufCleartext(packet, *aead, *headerCipher, packetNum));
@@ -740,9 +749,11 @@ TEST_F(QuicReadCodecTest, TestHandshakeDone) {
FizzCryptoFactory cryptoFactory;
PacketNum packetNum = 1;
uint64_t offset = 0;
auto aead = cryptoFactory.getClientInitialCipher(connId, QuicVersion::MVFST);
auto aead =
cryptoFactory.getClientInitialCipher(connId, QuicVersion::MVFST).value();
auto headerCipher =
cryptoFactory.makeClientInitialHeaderCipher(connId, QuicVersion::MVFST);
cryptoFactory.makeClientInitialHeaderCipher(connId, QuicVersion::MVFST)
.value();
auto initialCryptoPacketBuf = folly::IOBuf::copyBuffer("CHLO");
ChainedByteRangeHead initialCryptoPacketRch(initialCryptoPacketBuf);
auto packet = createInitialCryptoPacket(
@@ -755,7 +766,8 @@ TEST_F(QuicReadCodecTest, TestHandshakeDone) {
offset);
auto codec = makeEncryptedCodec(connId, std::move(aead), nullptr);
aead = cryptoFactory.getClientInitialCipher(connId, QuicVersion::MVFST);
aead =
cryptoFactory.getClientInitialCipher(connId, QuicVersion::MVFST).value();
AckStates ackStates;
auto packetQueue =
bufToQueue(packetToBufCleartext(packet, *aead, *headerCipher, packetNum));

View File

@@ -39,9 +39,11 @@ const RegularQuicWritePacket& writeQuicPacket(
bool eof) {
auto version = conn.version.value_or(*conn.originalVersion);
auto aead = createNoOpAead();
auto headerCipher = createNoOpHeaderCipher();
auto headerCipherResult = createNoOpHeaderCipher();
CHECK(!headerCipherResult.hasError()) << "Failed to create header cipher";
auto headerCipher = std::move(headerCipherResult.value());
CHECK(!writeDataToQuicStream(stream, data.clone(), eof).hasError());
CHECK(!writeQuicDataToSocket(
auto result = writeQuicDataToSocket(
sock,
conn,
srcConnId,
@@ -49,8 +51,8 @@ const RegularQuicWritePacket& writeQuicPacket(
*aead,
*headerCipher,
version,
conn.transportSettings.writeConnectionDataPacketsLimit)
.hasError());
conn.transportSettings.writeConnectionDataPacketsLimit);
CHECK(!result.hasError());
CHECK(
conn.outstandings.packets.rend() !=
getLastOutstandingPacket(conn, PacketNumberSpace::AppData));
@@ -63,10 +65,12 @@ PacketNum rstStreamAndSendPacket(
QuicStreamState& stream,
ApplicationErrorCode errorCode) {
auto aead = createNoOpAead();
auto headerCipher = createNoOpHeaderCipher();
auto headerCipherResult = createNoOpHeaderCipher();
CHECK(!headerCipherResult.hasError()) << "Failed to create header cipher";
auto headerCipher = std::move(headerCipherResult.value());
auto version = conn.version.value_or(*conn.originalVersion);
CHECK(!sendRstSMHandler(stream, errorCode).hasError());
CHECK(!writeQuicDataToSocket(
auto result = writeQuicDataToSocket(
sock,
conn,
*conn.clientConnectionId,
@@ -74,8 +78,9 @@ PacketNum rstStreamAndSendPacket(
*aead,
*headerCipher,
version,
conn.transportSettings.writeConnectionDataPacketsLimit)
.hasError());
conn.transportSettings.writeConnectionDataPacketsLimit);
CHECK(!result.hasError());
CHECK(!result.hasError());
for (const auto& packet : conn.outstandings.packets) {
for (const auto& frame : packet.packet.frames) {
@@ -269,7 +274,8 @@ std::unique_ptr<MockAead> createNoOpAead(uint64_t cipherOverhead) {
return createNoOpAeadImpl<MockAead>(cipherOverhead);
}
std::unique_ptr<MockPacketNumberCipher> createNoOpHeaderCipher() {
folly::Expected<std::unique_ptr<MockPacketNumberCipher>, QuicError>
createNoOpHeaderCipher() {
auto headerCipher = std::make_unique<NiceMock<MockPacketNumberCipher>>();
ON_CALL(*headerCipher, mask(_)).WillByDefault(Return(HeaderProtectionMask{}));
ON_CALL(*headerCipher, keyLength()).WillByDefault(Return(16));
@@ -776,12 +782,12 @@ bool writableContains(QuicStreamManager& streamManager, StreamId streamId) {
streamManager.controlWriteQueue().count(streamId) > 0;
}
std::unique_ptr<PacketNumberCipher>
folly::Expected<std::unique_ptr<PacketNumberCipher>, QuicError>
FizzCryptoTestFactory::makePacketNumberCipher(fizz::CipherSuite) const {
return std::move(packetNumberCipher_);
}
std::unique_ptr<PacketNumberCipher>
folly::Expected<std::unique_ptr<PacketNumberCipher>, QuicError>
FizzCryptoTestFactory::makePacketNumberCipher(ByteRange secret) const {
return _makePacketNumberCipher(secret);
}
@@ -793,7 +799,10 @@ void FizzCryptoTestFactory::setMockPacketNumberCipher(
void FizzCryptoTestFactory::setDefault() {
ON_CALL(*this, _makePacketNumberCipher(_))
.WillByDefault(Invoke([&](ByteRange secret) {
.WillByDefault(Invoke(
[&](ByteRange secret) -> folly::Expected<
std::unique_ptr<PacketNumberCipher>,
QuicError> {
return FizzCryptoFactory::makePacketNumberCipher(secret);
}));
}
@@ -831,15 +840,18 @@ TrafficKey getQuicTestKey() {
std::unique_ptr<folly::IOBuf> getProtectionKey() {
FizzCryptoFactory factory;
auto secret = getRandSecret();
auto pnCipher =
auto pnCipherResult =
factory.makePacketNumberCipher(fizz::CipherSuite::TLS_AES_128_GCM_SHA256);
CHECK(!pnCipherResult.hasError()) << "Failed to make packet number cipher";
auto& pnCipher = pnCipherResult.value();
auto deriver = factory.getFizzFactory()->makeKeyDeriver(
fizz::CipherSuite::TLS_AES_128_GCM_SHA256);
return deriver->expandLabel(
auto pnKey = deriver->expandLabel(
folly::range(secret),
kQuicPNLabel,
folly::IOBuf::create(0),
pnCipher->keyLength());
(*pnCipher).keyLength());
return pnKey;
}
size_t getTotalIovecLen(const struct iovec* vec, size_t iovec_len) {

View File

@@ -189,7 +189,15 @@ std::unique_ptr<T> createNoOpAeadImpl(uint64_t cipherOverhead = 0) {
std::unique_ptr<MockAead> createNoOpAead(uint64_t cipherOverhead = 0);
std::unique_ptr<MockPacketNumberCipher> createNoOpHeaderCipher();
folly::Expected<std::unique_ptr<MockPacketNumberCipher>, QuicError>
createNoOpHeaderCipher();
// For backward compatibility with existing code
inline std::unique_ptr<MockPacketNumberCipher> createNoOpHeaderCipherNoThrow() {
auto result = createNoOpHeaderCipher();
CHECK(!result.hasError()) << "Failed to create header cipher";
return std::move(result.value());
}
uint64_t computeExpectedDelay(
std::chrono::microseconds ackDelay,
@@ -348,17 +356,17 @@ class FizzCryptoTestFactory : public FizzCryptoFactory {
~FizzCryptoTestFactory() override = default;
using FizzCryptoFactory::makePacketNumberCipher;
std::unique_ptr<PacketNumberCipher> makePacketNumberCipher(
fizz::CipherSuite) const override;
folly::Expected<std::unique_ptr<PacketNumberCipher>, QuicError>
makePacketNumberCipher(fizz::CipherSuite) const override;
MOCK_METHOD(
std::unique_ptr<PacketNumberCipher>,
(folly::Expected<std::unique_ptr<PacketNumberCipher>, QuicError>),
_makePacketNumberCipher,
(ByteRange),
(const));
std::unique_ptr<PacketNumberCipher> makePacketNumberCipher(
ByteRange secret) const override;
folly::Expected<std::unique_ptr<PacketNumberCipher>, QuicError>
makePacketNumberCipher(ByteRange secret) const override;
void setMockPacketNumberCipher(
std::unique_ptr<PacketNumberCipher> packetNumberCipher);
@@ -564,9 +572,16 @@ class FakeServerHandshake : public FizzServerHandshake {
void setEarlyKeys() {
oneRttWriteCipher_ = createNoOpAead();
oneRttWriteHeaderCipher_ = createNoOpHeaderCipher();
auto oneRttWriteHeaderCipherResult = createNoOpHeaderCipher();
CHECK(!oneRttWriteHeaderCipherResult.hasError())
<< "Failed to create header cipher";
oneRttWriteHeaderCipher_ = std::move(oneRttWriteHeaderCipherResult.value());
zeroRttReadCipher_ = createNoOpAead();
zeroRttReadHeaderCipher_ = createNoOpHeaderCipher();
auto zeroRttReadHeaderCipherResult = createNoOpHeaderCipher();
CHECK(!zeroRttReadHeaderCipherResult.hasError())
<< "Failed to create header cipher";
zeroRttReadHeaderCipher_ = std::move(zeroRttReadHeaderCipherResult.value());
}
void setOneRttKeys() {
@@ -577,12 +592,19 @@ class FakeServerHandshake : public FizzServerHandshake {
ON_CALL(*mockOneRttWriteCipher, getKey())
.WillByDefault(testing::Invoke([]() { return getQuicTestKey(); }));
oneRttWriteCipher_ = std::move(mockOneRttWriteCipher);
auto mockOneRttWriteHeaderCipher = createNoOpHeaderCipher();
auto mockOneRttWriteHeaderCipherResult = createNoOpHeaderCipher();
CHECK(!mockOneRttWriteHeaderCipherResult.hasError())
<< "Failed to create header cipher";
auto& mockOneRttWriteHeaderCipher =
mockOneRttWriteHeaderCipherResult.value();
mockOneRttWriteHeaderCipher->setDefaultKey();
oneRttWriteHeaderCipher_ = std::move(mockOneRttWriteHeaderCipher);
}
oneRttReadCipher_ = createNoOpAead();
oneRttReadHeaderCipher_ = createNoOpHeaderCipher();
auto oneRttReadHeaderCipherResult = createNoOpHeaderCipher();
CHECK(!oneRttReadHeaderCipherResult.hasError())
<< "Failed to create header cipher";
oneRttReadHeaderCipher_ = std::move(oneRttReadHeaderCipherResult.value());
readTrafficSecret_ = folly::IOBuf::copyBuffer(getRandSecret());
writeTrafficSecret_ = folly::IOBuf::copyBuffer(getRandSecret());
}
@@ -597,9 +619,17 @@ class FakeServerHandshake : public FizzServerHandshake {
void setHandshakeKeys() {
conn_.handshakeWriteCipher = createNoOpAead();
conn_.handshakeWriteHeaderCipher = createNoOpHeaderCipher();
auto handshakeWriteHeaderCipherResult = createNoOpHeaderCipher();
CHECK(!handshakeWriteHeaderCipherResult.hasError())
<< "Failed to create header cipher";
conn_.handshakeWriteHeaderCipher =
std::move(handshakeWriteHeaderCipherResult.value());
handshakeReadCipher_ = createNoOpAead();
handshakeReadHeaderCipher_ = createNoOpHeaderCipher();
auto handshakeReadHeaderCipherResult = createNoOpHeaderCipher();
CHECK(!handshakeReadHeaderCipherResult.hasError())
<< "Failed to create header cipher";
handshakeReadHeaderCipher_ =
std::move(handshakeReadHeaderCipherResult.value());
}
void setHandshakeDone(bool done) {

View File

@@ -62,8 +62,12 @@ class CipherBuilder {
*quicFizzCryptoFactory_.getFizzFactory(),
std::move(trafficKey),
cipherSuite));
auto headerCipher = quicFizzCryptoFactory_.makePacketNumberCipher(
auto headerCipherResult = quicFizzCryptoFactory_.makePacketNumberCipher(
fizz::CipherSuite::TLS_AES_128_GCM_SHA256);
if (headerCipherResult.hasError()) {
throw std::runtime_error("Failed to create header cipher");
}
auto headerCipher = std::move(headerCipherResult.value());
headerCipher->setKey(packetProtectionKey->coalesce());
return {std::move(aead), std::move(headerCipher)};

View File

@@ -45,7 +45,7 @@ class DSRPacketizerSingleWriteTest : public Test {
protected:
void SetUp() override {
aead = test::createNoOpAead();
headerCipher = test::createNoOpHeaderCipher();
headerCipher = test::createNoOpHeaderCipher().value();
qEvb_ = std::make_shared<FollyQuicEventBase>(&evb);
}

View File

@@ -149,9 +149,10 @@ const Optional<std::string>& FizzClientHandshake::getApplicationProtocol()
}
}
bool FizzClientHandshake::verifyRetryIntegrityTag(
folly::Expected<bool, QuicError> FizzClientHandshake::verifyRetryIntegrityTag(
const ConnectionId& originalDstConnId,
const RetryPacket& retryPacket) {
try {
PseudoRetryPacketBuilder pseudoRetryPacketBuilder(
retryPacket.initialByte,
retryPacket.header.getSourceConnId(),
@@ -160,7 +161,8 @@ bool FizzClientHandshake::verifyRetryIntegrityTag(
retryPacket.header.getVersion(),
BufHelpers::copyBuffer(retryPacket.header.getToken()));
BufPtr pseudoRetryPacket = std::move(pseudoRetryPacketBuilder).buildPacket();
BufPtr pseudoRetryPacket =
std::move(pseudoRetryPacketBuilder).buildPacket();
FizzRetryIntegrityTagGenerator retryIntegrityTagGenerator;
auto expectedIntegrityTag = retryIntegrityTagGenerator.getRetryIntegrityTag(
@@ -169,6 +171,10 @@ bool FizzClientHandshake::verifyRetryIntegrityTag(
folly::IOBuf integrityTagWrapper = BufHelpers::wrapBufferAsValue(
retryPacket.integrityTag.data(), retryPacket.integrityTag.size());
return BufEq()(*expectedIntegrityTag, integrityTagWrapper);
} catch (const std::exception& ex) {
return folly::makeUnexpected(
QuicError(TransportErrorCode::INTERNAL_ERROR, ex.what()));
}
}
bool FizzClientHandshake::isTLSResumed() const {
@@ -212,9 +218,9 @@ bool FizzClientHandshake::matchEarlyParameters() {
return fizz::client::earlyParametersMatch(state_);
}
std::unique_ptr<Aead> FizzClientHandshake::buildAead(
CipherKind kind,
ByteRange secret) {
folly::Expected<std::unique_ptr<Aead>, QuicError>
FizzClientHandshake::buildAead(CipherKind kind, ByteRange secret) {
try {
bool isEarlyTraffic = kind == CipherKind::ZeroRttWrite;
fizz::CipherSuite cipher =
isEarlyTraffic ? state_.earlyDataParams()->cipher : *state_.cipher();
@@ -233,19 +239,29 @@ std::unique_ptr<Aead> FizzClientHandshake::buildAead(
kQuicIVLabel));
return aead;
} catch (const std::exception& ex) {
return folly::makeUnexpected(
QuicError(TransportErrorCode::INTERNAL_ERROR, ex.what()));
}
}
std::unique_ptr<PacketNumberCipher> FizzClientHandshake::buildHeaderCipher(
ByteRange secret) {
folly::Expected<std::unique_ptr<PacketNumberCipher>, QuicError>
FizzClientHandshake::buildHeaderCipher(ByteRange secret) {
return cryptoFactory_->makePacketNumberCipher(secret);
}
BufPtr FizzClientHandshake::getNextTrafficSecret(ByteRange secret) const {
folly::Expected<BufPtr, QuicError> FizzClientHandshake::getNextTrafficSecret(
ByteRange secret) const {
try {
auto deriver =
state_.context()->getFactory()->makeKeyDeriver(*state_.cipher());
auto nextSecret = deriver->expandLabel(
secret, kQuicKULabel, BufHelpers::create(0), secret.size());
return nextSecret;
} catch (const std::exception& ex) {
return folly::makeUnexpected(
QuicError(TransportErrorCode::INTERNAL_ERROR, ex.what()));
}
}
void FizzClientHandshake::onNewCachedPsk(

View File

@@ -33,7 +33,7 @@ class FizzClientHandshake : public ClientHandshake {
const Optional<std::string>& getApplicationProtocol() const override;
bool verifyRetryIntegrityTag(
[[nodiscard]] folly::Expected<bool, QuicError> verifyRetryIntegrityTag(
const ConnectionId& originalDstConnId,
const RetryPacket& retryPacket) override;
@@ -70,16 +70,20 @@ class FizzClientHandshake : public ClientHandshake {
void echRetryAvailable(fizz::client::ECHRetryAvailable& retry);
private:
folly::Expected<Optional<CachedServerTransportParameters>, QuicError>
[[nodiscard]] folly::
Expected<Optional<CachedServerTransportParameters>, QuicError>
connectImpl(Optional<std::string> hostname) override;
EncryptionLevel getReadRecordLayerEncryptionLevel() override;
void processSocketData(folly::IOBufQueue& queue) override;
bool matchEarlyParameters() override;
std::unique_ptr<Aead> buildAead(CipherKind kind, ByteRange secret) override;
std::unique_ptr<PacketNumberCipher> buildHeaderCipher(
[[nodiscard]] folly::Expected<std::unique_ptr<Aead>, QuicError> buildAead(
CipherKind kind,
ByteRange secret) override;
BufPtr getNextTrafficSecret(ByteRange secret) const override;
[[nodiscard]] folly::Expected<std::unique_ptr<PacketNumberCipher>, QuicError>
buildHeaderCipher(ByteRange secret) override;
[[nodiscard]] folly::Expected<BufPtr, QuicError> getNextTrafficSecret(
ByteRange secret) const override;
class ActionMoveVisitor;
void processActions(fizz::client::Actions actions);

View File

@@ -2912,8 +2912,10 @@ TEST_P(QuicClientTransportAfterStartTest, ReadStreamCoalesced) {
FizzCryptoFactory cryptoFactory;
auto garbage = IOBuf::copyBuffer("garbage");
auto initialCipher = cryptoFactory.getServerInitialCipher(
auto initialCipherResult = cryptoFactory.getServerInitialCipher(
*serverChosenConnId, QuicVersion::MVFST);
ASSERT_FALSE(initialCipherResult.hasError());
auto& initialCipher = initialCipherResult.value();
auto firstPacketNum = appDataPacketNum++;
auto packet1 = packetToBufCleartext(
createStreamPacket(
@@ -2963,8 +2965,11 @@ TEST_F(QuicClientTransportAfterStartTest, ReadStreamCoalescedMany) {
BufQueue packets;
for (int i = 0; i < kMaxNumCoalescedPackets; i++) {
auto garbage = IOBuf::copyBuffer("garbage");
auto initialCipher = cryptoFactory.getServerInitialCipher(
auto initialCipherResult = cryptoFactory.getServerInitialCipher(
*serverChosenConnId, QuicVersion::MVFST);
ASSERT_FALSE(initialCipherResult.hasError());
auto& initialCipher = initialCipherResult.value();
auto packetNum = appDataPacketNum++;
auto packet1 = packetToBufCleartext(
createStreamPacket(
@@ -2973,7 +2978,7 @@ TEST_F(QuicClientTransportAfterStartTest, ReadStreamCoalescedMany) {
packetNum,
streamId,
*garbage,
initialCipher->getCipherOverhead(),
initialCipher.get()->getCipherOverhead(),
0 /* largestAcked */,
std::make_pair(LongHeader::Types::Initial, QuicVersion::MVFST)),
*initialCipher,
@@ -3651,8 +3656,10 @@ TEST_F(QuicClientTransportAfterStartTest, WrongCleartextCipher) {
// throws on getting unencrypted stream data.
PacketNum nextPacketNum = appDataPacketNum++;
auto initialCipher = cryptoFactory.getServerInitialCipher(
auto initialCipherResult = cryptoFactory.getServerInitialCipher(
*serverChosenConnId, QuicVersion::MVFST);
ASSERT_FALSE(initialCipherResult.hasError());
auto& initialCipher = initialCipherResult.value();
auto packet = packetToBufCleartext(
createStreamPacket(
*serverChosenConnId /* src */,
@@ -3660,7 +3667,7 @@ TEST_F(QuicClientTransportAfterStartTest, WrongCleartextCipher) {
nextPacketNum,
streamId,
*expected,
initialCipher->getCipherOverhead(),
initialCipher.get()->getCipherOverhead(),
0 /* largestAcked */,
std::make_pair(LongHeader::Types::Initial, QuicVersion::MVFST)),
*initialCipher,
@@ -5148,15 +5155,15 @@ class QuicZeroRttClientTest : public QuicClientTransportAfterStartTestBase {
mockClientHandshake->setHandshakeWriteCipher(std::move(handshakeWriteAead));
mockClientHandshake->setHandshakeReadHeaderCipher(
test::createNoOpHeaderCipher());
test::createNoOpHeaderCipherNoThrow());
mockClientHandshake->setHandshakeWriteHeaderCipher(
test::createNoOpHeaderCipher());
test::createNoOpHeaderCipherNoThrow());
mockClientHandshake->setOneRttWriteHeaderCipher(
test::createNoOpHeaderCipher());
test::createNoOpHeaderCipherNoThrow());
mockClientHandshake->setOneRttReadHeaderCipher(
test::createNoOpHeaderCipher());
test::createNoOpHeaderCipherNoThrow());
mockClientHandshake->setZeroRttWriteHeaderCipher(
test::createNoOpHeaderCipher());
test::createNoOpHeaderCipherNoThrow());
}
std::shared_ptr<QuicPskCache> getPskCache() override {

View File

@@ -336,9 +336,10 @@ class FakeOneRttHandshakeLayer : public FizzClientHandshake {
}
if (getPhase() == Phase::Initial) {
conn->handshakeWriteCipher = test::createNoOpAead();
conn->handshakeWriteHeaderCipher = test::createNoOpHeaderCipher();
conn->handshakeWriteHeaderCipher = test::createNoOpHeaderCipherNoThrow();
conn->readCodec->setHandshakeReadCipher(test::createNoOpAead());
conn->readCodec->setHandshakeHeaderCipher(test::createNoOpHeaderCipher());
conn->readCodec->setHandshakeHeaderCipher(
test::createNoOpHeaderCipherNoThrow());
writeDataToQuicStream(
conn->cryptoState->handshakeStream,
folly::IOBuf::copyBuffer("ClientFinished"));
@@ -357,7 +358,8 @@ class FakeOneRttHandshakeLayer : public FizzClientHandshake {
return params_;
}
BufPtr getNextTrafficSecret(ByteRange /*secret*/) const override {
folly::Expected<BufPtr, QuicError> getNextTrafficSecret(
ByteRange /*secret*/) const override {
return folly::IOBuf::copyBuffer(getRandSecret());
}
@@ -410,12 +412,15 @@ class FakeOneRttHandshakeLayer : public FizzClientHandshake {
throw std::runtime_error("matchEarlyParameters not implemented");
}
std::unique_ptr<Aead> buildAead(CipherKind, ByteRange) override {
folly::Expected<std::unique_ptr<Aead>, QuicError> buildAead(
CipherKind,
ByteRange) override {
return createNoOpAead();
}
std::unique_ptr<PacketNumberCipher> buildHeaderCipher(ByteRange) override {
throw std::runtime_error("buildHeaderCipher not implemented");
folly::Expected<std::unique_ptr<PacketNumberCipher>, QuicError>
buildHeaderCipher(ByteRange) override {
return createNoOpHeaderCipher();
}
};
@@ -650,13 +655,13 @@ class QuicClientTransportTestBase : public virtual testing::Test {
mockClientHandshake->setOneRttWriteCipher(std::move(writeAead));
mockClientHandshake->setHandshakeReadHeaderCipher(
test::createNoOpHeaderCipher());
test::createNoOpHeaderCipherNoThrow());
mockClientHandshake->setHandshakeWriteHeaderCipher(
test::createNoOpHeaderCipher());
test::createNoOpHeaderCipherNoThrow());
mockClientHandshake->setOneRttWriteHeaderCipher(
test::createNoOpHeaderCipher());
test::createNoOpHeaderCipherNoThrow());
mockClientHandshake->setOneRttReadHeaderCipher(
test::createNoOpHeaderCipher());
test::createNoOpHeaderCipherNoThrow());
}
virtual void setUpSocketExpectations() {
@@ -1002,12 +1007,20 @@ class QuicClientTransportTestBase : public virtual testing::Test {
FizzCryptoFactory cryptoFactory;
auto codec = std::make_unique<QuicReadCodec>(QuicNodeType::Server);
codec->setClientConnectionId(*originalConnId);
codec->setInitialReadCipher(cryptoFactory.getClientInitialCipher(
*client->getConn().initialDestinationConnectionId, QuicVersion::MVFST));
codec->setInitialHeaderCipher(cryptoFactory.makeClientInitialHeaderCipher(
*client->getConn().initialDestinationConnectionId, QuicVersion::MVFST));
codec->setInitialReadCipher(
cryptoFactory
.getClientInitialCipher(
*client->getConn().initialDestinationConnectionId,
QuicVersion::MVFST)
.value());
codec->setInitialHeaderCipher(
cryptoFactory
.makeClientInitialHeaderCipher(
*client->getConn().initialDestinationConnectionId,
QuicVersion::MVFST)
.value());
codec->setHandshakeReadCipher(test::createNoOpAead());
codec->setHandshakeHeaderCipher(test::createNoOpHeaderCipher());
codec->setHandshakeHeaderCipher(test::createNoOpHeaderCipher().value());
return codec;
}
@@ -1018,18 +1031,24 @@ class QuicClientTransportTestBase : public virtual testing::Test {
std::unique_ptr<Aead> handshakeReadCipher;
codec->setClientConnectionId(*originalConnId);
codec->setOneRttReadCipher(test::createNoOpAead());
codec->setOneRttHeaderCipher(test::createNoOpHeaderCipher());
codec->setOneRttHeaderCipher(test::createNoOpHeaderCipher().value());
codec->setZeroRttReadCipher(test::createNoOpAead());
codec->setZeroRttHeaderCipher(test::createNoOpHeaderCipher());
codec->setZeroRttHeaderCipher(test::createNoOpHeaderCipher().value());
if (handshakeCipher) {
codec->setInitialReadCipher(cryptoFactory.getClientInitialCipher(
auto initialReadCipher = cryptoFactory.getClientInitialCipher(
*client->getConn().initialDestinationConnectionId,
QuicVersion::MVFST));
codec->setInitialHeaderCipher(cryptoFactory.makeClientInitialHeaderCipher(
QuicVersion::MVFST);
CHECK(initialReadCipher.hasValue());
codec->setInitialReadCipher(std::move(initialReadCipher.value()));
auto initialHeaderCipher = cryptoFactory.makeClientInitialHeaderCipher(
*client->getConn().initialDestinationConnectionId,
QuicVersion::MVFST));
QuicVersion::MVFST);
CHECK(initialHeaderCipher.hasValue());
codec->setInitialHeaderCipher(std::move(initialHeaderCipher.value()));
codec->setHandshakeReadCipher(test::createNoOpAead());
codec->setHandshakeHeaderCipher(test::createNoOpHeaderCipher());
codec->setHandshakeHeaderCipher(test::createNoOpHeaderCipher().value());
}
return codec;
}

View File

@@ -14,7 +14,7 @@
namespace quic {
BufPtr FizzCryptoFactory::makeInitialTrafficSecret(
folly::Expected<BufPtr, QuicError> FizzCryptoFactory::makeInitialTrafficSecret(
folly::StringPiece label,
const ConnectionId& clientDestinationConnId,
QuicVersion version) const {
@@ -31,12 +31,18 @@ BufPtr FizzCryptoFactory::makeInitialTrafficSecret(
return trafficSecret;
}
std::unique_ptr<Aead> FizzCryptoFactory::makeInitialAead(
folly::Expected<std::unique_ptr<Aead>, QuicError>
FizzCryptoFactory::makeInitialAead(
folly::StringPiece label,
const ConnectionId& clientDestinationConnId,
QuicVersion version) const {
auto trafficSecret =
auto trafficSecretResult =
makeInitialTrafficSecret(label, clientDestinationConnId, version);
if (trafficSecretResult.hasError()) {
return folly::makeUnexpected(trafficSecretResult.error());
}
auto& trafficSecret = trafficSecretResult.value();
auto deriver =
fizzFactory_->makeKeyDeriver(fizz::CipherSuite::TLS_AES_128_GCM_SHA256);
auto aead = fizzFactory_->makeAead(fizz::CipherSuite::TLS_AES_128_GCM_SHA256);
@@ -51,15 +57,22 @@ std::unique_ptr<Aead> FizzCryptoFactory::makeInitialAead(
BufHelpers::create(0),
aead->ivLength());
fizz::TrafficKey trafficKey = {std::move(key), std::move(iv)};
fizz::TrafficKey trafficKey;
trafficKey.key = std::move(key);
trafficKey.iv = std::move(iv);
aead->setKey(std::move(trafficKey));
return FizzAead::wrap(std::move(aead));
}
std::unique_ptr<PacketNumberCipher> FizzCryptoFactory::makePacketNumberCipher(
ByteRange baseSecret) const {
auto pnCipher =
folly::Expected<std::unique_ptr<PacketNumberCipher>, QuicError>
FizzCryptoFactory::makePacketNumberCipher(ByteRange baseSecret) const {
auto pnCipherResult =
makePacketNumberCipher(fizz::CipherSuite::TLS_AES_128_GCM_SHA256);
if (pnCipherResult.hasError()) {
return folly::makeUnexpected(pnCipherResult.error());
}
auto pnCipher = std::move(pnCipherResult.value());
auto deriver =
fizzFactory_->makeKeyDeriver(fizz::CipherSuite::TLS_AES_128_GCM_SHA256);
auto pnKey = deriver->expandLabel(
@@ -68,15 +81,17 @@ std::unique_ptr<PacketNumberCipher> FizzCryptoFactory::makePacketNumberCipher(
return pnCipher;
}
std::unique_ptr<PacketNumberCipher> FizzCryptoFactory::makePacketNumberCipher(
fizz::CipherSuite cipher) const {
folly::Expected<std::unique_ptr<PacketNumberCipher>, QuicError>
FizzCryptoFactory::makePacketNumberCipher(fizz::CipherSuite cipher) const {
switch (cipher) {
case fizz::CipherSuite::TLS_AES_128_GCM_SHA256:
return std::make_unique<Aes128PacketNumberCipher>();
case fizz::CipherSuite::TLS_AES_256_GCM_SHA384:
return std::make_unique<Aes256PacketNumberCipher>();
default:
throw std::runtime_error("Packet number cipher not implemented");
return folly::makeUnexpected(QuicError(
TransportErrorCode::INTERNAL_ERROR,
"Packet number cipher not implemented"));
}
}

View File

@@ -16,21 +16,23 @@ class FizzCryptoFactory : public CryptoFactory {
public:
FizzCryptoFactory() : fizzFactory_{std::make_shared<QuicFizzFactory>()} {}
BufPtr makeInitialTrafficSecret(
[[nodiscard]] folly::Expected<BufPtr, QuicError> makeInitialTrafficSecret(
folly::StringPiece label,
const ConnectionId& clientDestinationConnId,
QuicVersion version) const override;
std::unique_ptr<Aead> makeInitialAead(
[[nodiscard]] folly::Expected<std::unique_ptr<Aead>, QuicError>
makeInitialAead(
folly::StringPiece label,
const ConnectionId& clientDestinationConnId,
QuicVersion version) const override;
std::unique_ptr<PacketNumberCipher> makePacketNumberCipher(
ByteRange baseSecret) const override;
[[nodiscard]] folly::Expected<std::unique_ptr<PacketNumberCipher>, QuicError>
makePacketNumberCipher(ByteRange baseSecret) const override;
virtual std::unique_ptr<PacketNumberCipher> makePacketNumberCipher(
fizz::CipherSuite cipher) const;
[[nodiscard]] virtual folly::
Expected<std::unique_ptr<PacketNumberCipher>, QuicError>
makePacketNumberCipher(fizz::CipherSuite cipher) const;
[[nodiscard]] std::function<bool(ByteRange, ByteRange)>
getCryptoEqualFunction() const override;

View File

@@ -61,7 +61,9 @@ struct CipherBytes {
TEST_P(LongPacketNumberCipherTest, TestEncryptDecrypt) {
FizzCryptoFactory cryptoFactory;
auto cipher = cryptoFactory.makePacketNumberCipher(GetParam().cipher);
auto cipherResult = cryptoFactory.makePacketNumberCipher(GetParam().cipher);
ASSERT_FALSE(cipherResult.hasError());
auto cipher = std::move(cipherResult.value());
auto key = folly::unhexlify(GetParam().key);
EXPECT_EQ(cipher->keyLength(), key.size());
cipher->setKey(folly::range(key));

View File

@@ -97,8 +97,8 @@ std::unique_ptr<Aead> FizzServerHandshake::buildAead(ByteRange secret) {
kQuicIVLabel));
}
std::unique_ptr<PacketNumberCipher> FizzServerHandshake::buildHeaderCipher(
ByteRange secret) {
folly::Expected<std::unique_ptr<PacketNumberCipher>, QuicError>
FizzServerHandshake::buildHeaderCipher(ByteRange secret) {
return cryptoFactory_->makePacketNumberCipher(secret);
}

View File

@@ -40,8 +40,8 @@ class FizzServerHandshake : public ServerHandshake {
EncryptionLevel getReadRecordLayerEncryptionLevel() override;
void processSocketData(folly::IOBufQueue& queue) override;
std::unique_ptr<Aead> buildAead(ByteRange secret) override;
std::unique_ptr<PacketNumberCipher> buildHeaderCipher(
ByteRange secret) override;
[[nodiscard]] folly::Expected<std::unique_ptr<PacketNumberCipher>, QuicError>
buildHeaderCipher(ByteRange secret) override;
BufPtr getNextTrafficSecret(ByteRange secret) const override;
void processAccept() override;

View File

@@ -11,47 +11,59 @@
namespace quic {
std::unique_ptr<Aead> CryptoFactory::getClientInitialCipher(
folly::Expected<std::unique_ptr<Aead>, QuicError>
CryptoFactory::getClientInitialCipher(
const ConnectionId& clientDestinationConnId,
QuicVersion version) const {
return makeInitialAead(kClientInitialLabel, clientDestinationConnId, version);
}
std::unique_ptr<Aead> CryptoFactory::getServerInitialCipher(
folly::Expected<std::unique_ptr<Aead>, QuicError>
CryptoFactory::getServerInitialCipher(
const ConnectionId& clientDestinationConnId,
QuicVersion version) const {
return makeInitialAead(kServerInitialLabel, clientDestinationConnId, version);
}
BufPtr CryptoFactory::makeServerInitialTrafficSecret(
folly::Expected<BufPtr, QuicError>
CryptoFactory::makeServerInitialTrafficSecret(
const ConnectionId& clientDestinationConnId,
QuicVersion version) const {
return makeInitialTrafficSecret(
kServerInitialLabel, clientDestinationConnId, version);
}
BufPtr CryptoFactory::makeClientInitialTrafficSecret(
folly::Expected<BufPtr, QuicError>
CryptoFactory::makeClientInitialTrafficSecret(
const ConnectionId& clientDestinationConnId,
QuicVersion version) const {
return makeInitialTrafficSecret(
kClientInitialLabel, clientDestinationConnId, version);
}
std::unique_ptr<PacketNumberCipher>
folly::Expected<std::unique_ptr<PacketNumberCipher>, QuicError>
CryptoFactory::makeClientInitialHeaderCipher(
const ConnectionId& initialDestinationConnectionId,
QuicVersion version) const {
auto clientInitialTrafficSecret =
auto clientInitialTrafficSecretResult =
makeClientInitialTrafficSecret(initialDestinationConnectionId, version);
if (clientInitialTrafficSecretResult.hasError()) {
return folly::makeUnexpected(clientInitialTrafficSecretResult.error());
}
auto& clientInitialTrafficSecret = clientInitialTrafficSecretResult.value();
return makePacketNumberCipher(clientInitialTrafficSecret->coalesce());
}
std::unique_ptr<PacketNumberCipher>
folly::Expected<std::unique_ptr<PacketNumberCipher>, QuicError>
CryptoFactory::makeServerInitialHeaderCipher(
const ConnectionId& initialDestinationConnectionId,
QuicVersion version) const {
auto serverInitialTrafficSecret =
auto serverInitialTrafficSecretResult =
makeServerInitialTrafficSecret(initialDestinationConnectionId, version);
if (serverInitialTrafficSecretResult.hasError()) {
return folly::makeUnexpected(serverInitialTrafficSecretResult.error());
}
auto& serverInitialTrafficSecret = serverInitialTrafficSecretResult.value();
return makePacketNumberCipher(serverInitialTrafficSecret->coalesce());
}

View File

@@ -7,7 +7,9 @@
#pragma once
#include <folly/Expected.h>
#include <quic/QuicConstants.h>
#include <quic/QuicException.h>
#include <quic/codec/PacketNumberCipher.h>
#include <quic/codec/QuicConnectionId.h>
#include <quic/codec/Types.h>
@@ -19,50 +21,59 @@ namespace quic {
class CryptoFactory {
public:
std::unique_ptr<Aead> getClientInitialCipher(
[[nodiscard]] folly::Expected<std::unique_ptr<Aead>, QuicError>
getClientInitialCipher(
const ConnectionId& clientDestinationConnId,
QuicVersion version) const;
std::unique_ptr<Aead> getServerInitialCipher(
[[nodiscard]] folly::Expected<std::unique_ptr<Aead>, QuicError>
getServerInitialCipher(
const ConnectionId& clientDestinationConnId,
QuicVersion version) const;
BufPtr makeServerInitialTrafficSecret(
[[nodiscard]] folly::Expected<BufPtr, QuicError>
makeServerInitialTrafficSecret(
const ConnectionId& clientDestinationConnId,
QuicVersion version) const;
BufPtr makeClientInitialTrafficSecret(
[[nodiscard]] folly::Expected<BufPtr, QuicError>
makeClientInitialTrafficSecret(
const ConnectionId& clientDestinationConnId,
QuicVersion version) const;
/**
* Makes the header cipher for writing client initial packets.
*/
std::unique_ptr<PacketNumberCipher> makeClientInitialHeaderCipher(
[[nodiscard]] folly::Expected<std::unique_ptr<PacketNumberCipher>, QuicError>
makeClientInitialHeaderCipher(
const ConnectionId& initialDestinationConnectionId,
QuicVersion version) const;
/**
* Makes the header cipher for writing server initial packets.
*/
std::unique_ptr<PacketNumberCipher> makeServerInitialHeaderCipher(
[[nodiscard]] folly::Expected<std::unique_ptr<PacketNumberCipher>, QuicError>
makeServerInitialHeaderCipher(
const ConnectionId& initialDestinationConnectionId,
QuicVersion version) const;
/**
* Crypto layer specific methods.
*/
virtual BufPtr makeInitialTrafficSecret(
[[nodiscard]] virtual folly::Expected<BufPtr, QuicError>
makeInitialTrafficSecret(
folly::StringPiece label,
const ConnectionId& clientDestinationConnId,
QuicVersion version) const = 0;
virtual std::unique_ptr<Aead> makeInitialAead(
[[nodiscard]] virtual folly::Expected<std::unique_ptr<Aead>, QuicError>
makeInitialAead(
folly::StringPiece label,
const ConnectionId& clientDestinationConnId,
QuicVersion version) const = 0;
virtual std::unique_ptr<PacketNumberCipher> makePacketNumberCipher(
ByteRange baseSecret) const = 0;
[[nodiscard]] virtual folly::
Expected<std::unique_ptr<PacketNumberCipher>, QuicError>
makePacketNumberCipher(ByteRange baseSecret) const = 0;
[[nodiscard]] virtual std::function<bool(ByteRange, ByteRange)>
getCryptoEqualFunction() const = 0;

View File

@@ -82,7 +82,7 @@ class QuicLossFunctionsTest : public TestWithParam<PacketNumberSpace> {
public:
void SetUp() override {
aead = createNoOpAead();
headerCipher = createNoOpHeaderCipher();
headerCipher = createNoOpHeaderCipher().value();
quicStats_ = std::make_unique<MockQuicStats>();
connIdAlgo_ = std::make_unique<DefaultConnectionIdAlgo>();
socket_ = std::make_unique<MockQuicSocket>();
@@ -1287,7 +1287,7 @@ TEST_F(QuicLossFunctionsTest, PTONoLongerMarksPacketsToBeRetransmitted) {
TEST_F(QuicLossFunctionsTest, PTOWithHandshakePackets) {
auto conn = createConn();
conn->handshakeWriteCipher = createNoOpAead();
conn->handshakeWriteHeaderCipher = createNoOpHeaderCipher();
conn->handshakeWriteHeaderCipher = createNoOpHeaderCipher().value();
auto mockQLogger = std::make_shared<MockQLogger>(VantagePoint::Server);
conn->qLogger = mockQLogger;
auto mockCongestionController = std::make_unique<MockCongestionController>();
@@ -1338,11 +1338,11 @@ TEST_F(QuicLossFunctionsTest, PTOWithLostInitialData) {
auto conn = createConn();
conn->initialWriteCipher = createNoOpAead();
conn->initialHeaderCipher = createNoOpHeaderCipher();
conn->initialHeaderCipher = createNoOpHeaderCipher().value();
conn->handshakeWriteCipher = createNoOpAead();
conn->handshakeWriteHeaderCipher = createNoOpHeaderCipher();
conn->handshakeWriteHeaderCipher = createNoOpHeaderCipher().value();
conn->oneRttWriteCipher = createNoOpAead();
conn->oneRttWriteHeaderCipher = createNoOpHeaderCipher();
conn->oneRttWriteHeaderCipher = createNoOpHeaderCipher().value();
auto buf = buildRandomInputData(20);
WriteStreamBuffer initialData(ChainedByteRangeHead(buf), 0);
@@ -1368,11 +1368,11 @@ TEST_F(QuicLossFunctionsTest, PTOWithLostHandshakeData) {
auto conn = createConn();
conn->initialWriteCipher = createNoOpAead();
conn->initialHeaderCipher = createNoOpHeaderCipher();
conn->initialHeaderCipher = createNoOpHeaderCipher().value();
conn->handshakeWriteCipher = createNoOpAead();
conn->handshakeWriteHeaderCipher = createNoOpHeaderCipher();
conn->handshakeWriteHeaderCipher = createNoOpHeaderCipher().value();
conn->oneRttWriteCipher = createNoOpAead();
conn->oneRttWriteHeaderCipher = createNoOpHeaderCipher();
conn->oneRttWriteHeaderCipher = createNoOpHeaderCipher().value();
auto buf = buildRandomInputData(20);
WriteStreamBuffer handshakeData(ChainedByteRangeHead(buf), 0);
@@ -1399,11 +1399,11 @@ TEST_F(QuicLossFunctionsTest, PTOWithLostAppData) {
auto conn = createConn();
conn->initialWriteCipher = createNoOpAead();
conn->initialHeaderCipher = createNoOpHeaderCipher();
conn->initialHeaderCipher = createNoOpHeaderCipher().value();
conn->handshakeWriteCipher = createNoOpAead();
conn->handshakeWriteHeaderCipher = createNoOpHeaderCipher();
conn->handshakeWriteHeaderCipher = createNoOpHeaderCipher().value();
conn->oneRttWriteCipher = createNoOpAead();
conn->oneRttWriteHeaderCipher = createNoOpHeaderCipher();
conn->oneRttWriteHeaderCipher = createNoOpHeaderCipher().value();
auto buf = buildRandomInputData(20);
WriteStreamBuffer appData(ChainedByteRangeHead(buf), 0);
@@ -1428,13 +1428,13 @@ TEST_F(QuicLossFunctionsTest, PTOAvoidPointless) {
auto conn = createConn();
conn->initialWriteCipher = createNoOpAead();
conn->initialHeaderCipher = createNoOpHeaderCipher();
conn->initialHeaderCipher = createNoOpHeaderCipher().value();
conn->handshakeWriteCipher = createNoOpAead();
conn->handshakeWriteHeaderCipher = createNoOpHeaderCipher();
conn->handshakeWriteHeaderCipher = createNoOpHeaderCipher().value();
conn->oneRttWriteCipher = createNoOpAead();
conn->oneRttWriteHeaderCipher = createNoOpHeaderCipher();
conn->oneRttWriteHeaderCipher = createNoOpHeaderCipher().value();
conn->outstandings.packetCount[PacketNumberSpace::Initial] = 1;
conn->outstandings.packetCount[PacketNumberSpace::Handshake] = 1;

View File

@@ -535,7 +535,16 @@ void ServerHandshake::processActions(
void ServerHandshake::computeCiphers(CipherKind kind, ByteRange secret) {
std::unique_ptr<Aead> aead = buildAead(secret);
std::unique_ptr<PacketNumberCipher> headerCipher = buildHeaderCipher(secret);
auto headerCipherResult = buildHeaderCipher(secret);
if (headerCipherResult.hasError()) {
LOG(ERROR) << "Failed to build header cipher";
onError(std::make_pair(
"Failed to build header cipher", TransportErrorCode::INTERNAL_ERROR));
return;
}
std::unique_ptr<PacketNumberCipher> headerCipher =
std::move(headerCipherResult.value());
switch (kind) {
case CipherKind::HandshakeRead:
handshakeReadCipher_ = std::move(aead);

View File

@@ -322,8 +322,9 @@ class ServerHandshake : public Handshake {
virtual EncryptionLevel getReadRecordLayerEncryptionLevel() = 0;
virtual void processSocketData(folly::IOBufQueue& queue) = 0;
virtual std::unique_ptr<Aead> buildAead(ByteRange secret) = 0;
virtual std::unique_ptr<PacketNumberCipher> buildHeaderCipher(
ByteRange secret) = 0;
[[nodiscard]] virtual folly::
Expected<std::unique_ptr<PacketNumberCipher>, QuicError>
buildHeaderCipher(ByteRange secret) = 0;
virtual void processAccept() = 0;
/*

View File

@@ -1037,8 +1037,13 @@ folly::Expected<folly::Unit, QuicError> onServerReadDataFromOpen(
conn.serverHandshakeLayer->getCryptoFactory();
conn.readCodec = std::make_unique<QuicReadCodec>(QuicNodeType::Server);
conn.readCodec->setConnectionStatsCallback(conn.statsCallback);
conn.readCodec->setInitialReadCipher(cryptoFactory.getClientInitialCipher(
initialDestinationConnectionId, version));
auto clientInitialCipherResult = cryptoFactory.getClientInitialCipher(
initialDestinationConnectionId, version);
if (clientInitialCipherResult.hasError()) {
return folly::makeUnexpected(clientInitialCipherResult.error());
}
conn.readCodec->setInitialReadCipher(
std::move(clientInitialCipherResult.value()));
conn.readCodec->setClientConnectionId(clientConnectionId);
conn.readCodec->setServerConnectionId(*conn.serverConnectionId);
if (conn.qLogger) {
@@ -1050,14 +1055,30 @@ folly::Expected<folly::Unit, QuicError> onServerReadDataFromOpen(
version,
conn.transportSettings.maybeAckReceiveTimestampsConfigSentToPeer,
conn.transportSettings.advertisedExtendedAckFeatures));
conn.initialWriteCipher = cryptoFactory.getServerInitialCipher(
auto serverInitialCipherResult = cryptoFactory.getServerInitialCipher(
initialDestinationConnectionId, version);
if (serverInitialCipherResult.hasError()) {
return folly::makeUnexpected(serverInitialCipherResult.error());
}
conn.initialWriteCipher = std::move(serverInitialCipherResult.value());
conn.readCodec->setInitialHeaderCipher(
auto clientInitialHeaderCipherResult =
cryptoFactory.makeClientInitialHeaderCipher(
initialDestinationConnectionId, version));
conn.initialHeaderCipher = cryptoFactory.makeServerInitialHeaderCipher(
initialDestinationConnectionId, version);
if (clientInitialHeaderCipherResult.hasError()) {
return folly::makeUnexpected(clientInitialHeaderCipherResult.error());
}
conn.readCodec->setInitialHeaderCipher(
std::move(clientInitialHeaderCipherResult.value()));
auto serverInitialHeaderCipherResult =
cryptoFactory.makeServerInitialHeaderCipher(
initialDestinationConnectionId, version);
if (serverInitialHeaderCipherResult.hasError()) {
return folly::makeUnexpected(serverInitialHeaderCipherResult.error());
}
conn.initialHeaderCipher =
std::move(serverInitialHeaderCipherResult.value());
conn.peerAddress = conn.originalPeerAddress;
}
BufQueue& udpData = readData.udpPacket.buf;

View File

@@ -346,7 +346,7 @@ void QuicServerWorkerTest::testSendReset(
.WillRepeatedly(
Invoke([&](auto&, auto, auto) { return std::nullopt; }));
codec.setOneRttReadCipher(std::move(aead));
codec.setOneRttHeaderCipher(test::createNoOpHeaderCipher());
codec.setOneRttHeaderCipher(test::createNoOpHeaderCipher().value());
StatelessResetToken token = generateStatelessResetToken();
codec.setStatelessResetToken(token);
FizzCryptoFactory cryptoFactory;
@@ -2966,7 +2966,7 @@ void QuicServerTest::testReset(BufPtr packet) {
EXPECT_CALL(*aead, _tryDecrypt(_, _, _))
.WillRepeatedly(Invoke([&](auto&, auto, auto) { return std::nullopt; }));
codec.setOneRttReadCipher(std::move(aead));
codec.setOneRttHeaderCipher(test::createNoOpHeaderCipher());
codec.setOneRttHeaderCipher(test::createNoOpHeaderCipher().value());
StatelessResetToken token = generateStatelessResetToken();
codec.setStatelessResetToken(token);
FizzCryptoFactory cryptoFactory;

View File

@@ -3491,7 +3491,7 @@ TEST_F(QuicServerTransportTest, InvokeTxCallbacksSingleByteDSR) {
EXPECT_CALL(dsrByteTxCb, onByteEvent(getTxMatcher(stream, 1))).Times(1);
EXPECT_CALL(lastByteTxCb, onByteEvent(getTxMatcher(stream, 1))).Times(1);
server->getConnectionState().oneRttWriteCipher = test::createNoOpAead();
auto temp = test::createNoOpHeaderCipher();
auto temp = test::createNoOpHeaderCipher().value();
temp->setDefaultKey();
server->getConnectionState().oneRttWriteHeaderCipher = std::move(temp);
CHECK(server->getConnectionState().oneRttWriteCipher->getKey().has_value());
@@ -3700,8 +3700,10 @@ TEST_F(QuicUnencryptedServerTransportTest, TestBadPacketProtectionLevel) {
TEST_F(QuicUnencryptedServerTransportTest, TestBadCleartextEncryption) {
FizzCryptoFactory cryptoFactory;
PacketNum nextPacket = clientNextInitialPacketNum++;
auto aead = cryptoFactory.getServerInitialCipher(
*clientConnectionId, QuicVersion::MVFST);
auto aead =
cryptoFactory
.getServerInitialCipher(*clientConnectionId, QuicVersion::MVFST)
.value();
auto chloBuf = IOBuf::copyBuffer("CHLO");
ChainedByteRangeHead chloRch(chloBuf);
auto packetData = packetToBufCleartext(

View File

@@ -304,15 +304,17 @@ class QuicServerTransportTestBase : public virtual testing::Test {
std::unique_ptr<Aead> getInitialCipher(
QuicVersion version = QuicVersion::MVFST) {
FizzCryptoFactory cryptoFactory;
return cryptoFactory.getClientInitialCipher(
*initialDestinationConnectionId, version);
return cryptoFactory
.getClientInitialCipher(*initialDestinationConnectionId, version)
.value();
}
std::unique_ptr<PacketNumberCipher> getInitialHeaderCipher(
QuicVersion version = QuicVersion::MVFST) {
FizzCryptoFactory cryptoFactory;
return cryptoFactory.makeClientInitialHeaderCipher(
*initialDestinationConnectionId, version);
return cryptoFactory
.makeClientInitialHeaderCipher(*initialDestinationConnectionId, version)
.value();
}
BufPtr recvEncryptedStream(
@@ -367,7 +369,7 @@ class QuicServerTransportTestBase : public virtual testing::Test {
QuicVersion version = QuicVersion::MVFST) {
auto finished = folly::IOBuf::copyBuffer("FINISHED");
auto nextPacketNum = clientNextHandshakePacketNum++;
auto headerCipher = test::createNoOpHeaderCipher();
auto headerCipher = test::createNoOpHeaderCipher().value();
uint64_t offset =
getCryptoStream(
*server->getConn().cryptoState, EncryptionLevel::Handshake)
@@ -395,11 +397,16 @@ class QuicServerTransportTestBase : public virtual testing::Test {
FizzCryptoFactory cryptoFactory;
clientReadCodec = std::make_unique<QuicReadCodec>(QuicNodeType::Client);
clientReadCodec->setClientConnectionId(*clientConnectionId);
clientReadCodec->setInitialReadCipher(cryptoFactory.getServerInitialCipher(
*initialDestinationConnectionId, QuicVersion::MVFST));
clientReadCodec->setInitialReadCipher(
cryptoFactory
.getServerInitialCipher(
*initialDestinationConnectionId, QuicVersion::MVFST)
.value());
clientReadCodec->setInitialHeaderCipher(
cryptoFactory.makeServerInitialHeaderCipher(
*initialDestinationConnectionId, QuicVersion::MVFST));
cryptoFactory
.makeServerInitialHeaderCipher(
*initialDestinationConnectionId, QuicVersion::MVFST)
.value());
clientReadCodec->setCodecParameters(
CodecParameters(kDefaultAckDelayExponent, QuicVersion::MVFST));
}
@@ -601,18 +608,23 @@ class QuicServerTransportTestBase : public virtual testing::Test {
FizzCryptoFactory cryptoFactory;
auto readCodec = std::make_unique<QuicReadCodec>(QuicNodeType::Client);
readCodec->setOneRttReadCipher(test::createNoOpAead());
readCodec->setOneRttHeaderCipher(test::createNoOpHeaderCipher());
readCodec->setOneRttHeaderCipher(test::createNoOpHeaderCipher().value());
readCodec->setHandshakeReadCipher(test::createNoOpAead());
readCodec->setHandshakeHeaderCipher(test::createNoOpHeaderCipher());
readCodec->setHandshakeHeaderCipher(test::createNoOpHeaderCipher().value());
readCodec->setClientConnectionId(*clientConnectionId);
readCodec->setCodecParameters(
CodecParameters(kDefaultAckDelayExponent, QuicVersion::MVFST));
if (handshakeCipher) {
readCodec->setInitialReadCipher(cryptoFactory.getServerInitialCipher(
*initialDestinationConnectionId, QuicVersion::MVFST));
readCodec->setInitialReadCipher(
cryptoFactory
.getServerInitialCipher(
*initialDestinationConnectionId, QuicVersion::MVFST)
.value());
readCodec->setInitialHeaderCipher(
cryptoFactory.makeServerInitialHeaderCipher(
*initialDestinationConnectionId, QuicVersion::MVFST));
cryptoFactory
.makeServerInitialHeaderCipher(
*initialDestinationConnectionId, QuicVersion::MVFST)
.value());
}
return readCodec;
}

View File

@@ -4426,7 +4426,7 @@ class AckEventForAppDataTest : public Test {
public:
void SetUp() override {
aead_ = test::createNoOpAead();
headerCipher_ = test::createNoOpHeaderCipher();
headerCipher_ = test::createNoOpHeaderCipher().value();
conn_ = createConn();
}
@@ -4445,7 +4445,7 @@ class AckEventForAppDataTest : public Test {
conn->flowControlState.peerAdvertisedMaxOffset =
kDefaultConnectionFlowControlWindow;
conn->initialWriteCipher = createNoOpAead();
conn->initialHeaderCipher = createNoOpHeaderCipher();
conn->initialHeaderCipher = createNoOpHeaderCipher().value();
CHECK(
!conn->streamManager
->setMaxLocalBidirectionalStreams(kDefaultMaxStreamsBidirectional)