diff --git a/quic/api/test/QuicTransportBaseTest.cpp b/quic/api/test/QuicTransportBaseTest.cpp index 716d78fb0..142f25eed 100644 --- a/quic/api/test/QuicTransportBaseTest.cpp +++ b/quic/api/test/QuicTransportBaseTest.cpp @@ -242,7 +242,7 @@ class TestQuicTransport transportConn = conn.get(); conn_.reset(conn.release()); aead = test::createNoOpAead(); - headerCipher = test::createNoOpHeaderCipher(); + headerCipher = test::createNoOpHeaderCipher().value(); connIdAlgo_ = std::make_unique(); setConnectionSetupCallback(connSetupCb); setConnectionCallbackFromCtor(connCb); diff --git a/quic/api/test/QuicTransportFunctionsTest.cpp b/quic/api/test/QuicTransportFunctionsTest.cpp index 37da62d86..f317cc769 100644 --- a/quic/api/test/QuicTransportFunctionsTest.cpp +++ b/quic/api/test/QuicTransportFunctionsTest.cpp @@ -171,7 +171,7 @@ class QuicTransportFunctionsTest : public Test { public: void SetUp() override { aead = test::createNoOpAead(); - headerCipher = test::createNoOpHeaderCipher(); + headerCipher = test::createNoOpHeaderCipher().value(); quicStats_ = std::make_unique>(); } @@ -191,7 +191,7 @@ class QuicTransportFunctionsTest : public Test { kDefaultConnectionFlowControlWindow * 1000; conn->statsCallback = quicStats_.get(); conn->initialWriteCipher = createNoOpAead(); - conn->initialHeaderCipher = createNoOpHeaderCipher(); + conn->initialHeaderCipher = createNoOpHeaderCipher().value(); CHECK( !conn->streamManager ->setMaxLocalBidirectionalStreams(kDefaultMaxStreamsBidirectional) @@ -3665,7 +3665,7 @@ TEST_F(QuicTransportFunctionsTest, ResetNumProbePackets) { EXPECT_EQ(0, writeRes1->bytesWritten); conn->handshakeWriteCipher = createNoOpAead(); - conn->handshakeWriteHeaderCipher = createNoOpHeaderCipher(); + conn->handshakeWriteHeaderCipher = createNoOpHeaderCipher().value(); conn->pendingEvents.numProbePackets[PacketNumberSpace::Handshake] = 2; auto writeRes2 = writeCryptoAndAckDataToSocket( *rawSocket, @@ -3682,7 +3682,7 @@ TEST_F(QuicTransportFunctionsTest, ResetNumProbePackets) { EXPECT_EQ(0, writeRes2->bytesWritten); conn->oneRttWriteCipher = createNoOpAead(); - conn->oneRttWriteHeaderCipher = createNoOpHeaderCipher(); + conn->oneRttWriteHeaderCipher = createNoOpHeaderCipher().value(); conn->pendingEvents.numProbePackets[PacketNumberSpace::AppData] = 2; auto writeRes3 = writeQuicDataToSocket( *rawSocket, @@ -4474,13 +4474,13 @@ TEST_F(QuicTransportFunctionsTest, HandshakeConfirmedDropCipher) { ASSERT_NE(nullptr, conn->initialWriteCipher); conn->handshakeWriteCipher = createNoOpAead(); conn->readCodec->setInitialReadCipher(createNoOpAead()); - conn->readCodec->setInitialHeaderCipher(createNoOpHeaderCipher()); + conn->readCodec->setInitialHeaderCipher(createNoOpHeaderCipher().value()); conn->readCodec->setHandshakeReadCipher(createNoOpAead()); - conn->readCodec->setHandshakeHeaderCipher(createNoOpHeaderCipher()); + conn->readCodec->setHandshakeHeaderCipher(createNoOpHeaderCipher().value()); conn->oneRttWriteCipher = createNoOpAead(); - conn->oneRttWriteHeaderCipher = createNoOpHeaderCipher(); + conn->oneRttWriteHeaderCipher = createNoOpHeaderCipher().value(); conn->readCodec->setOneRttReadCipher(createNoOpAead()); - conn->readCodec->setOneRttHeaderCipher(createNoOpHeaderCipher()); + conn->readCodec->setOneRttHeaderCipher(createNoOpHeaderCipher().value()); writeCryptoDataProbesToSocketForTest( *socket, *conn, diff --git a/quic/api/test/QuicTransportTest.cpp b/quic/api/test/QuicTransportTest.cpp index a9b2beca5..70eb12ee6 100644 --- a/quic/api/test/QuicTransportTest.cpp +++ b/quic/api/test/QuicTransportTest.cpp @@ -105,7 +105,7 @@ class QuicTransportTest : public Test { .WillRepeatedly( Invoke([&](auto& buf, auto, auto) { return buf->clone(); })); transport_->getConnectionState().oneRttWriteCipher = std::move(aead); - auto tempHeaderCipher = test::createNoOpHeaderCipher(); + auto tempHeaderCipher = test::createNoOpHeaderCipher().value(); tempHeaderCipher->setDefaultKey(); transport_->getConnectionState().oneRttWriteHeaderCipher = std::move(tempHeaderCipher); diff --git a/quic/api/test/TestQuicTransport.h b/quic/api/test/TestQuicTransport.h index 173cc48ab..0b5316102 100644 --- a/quic/api/test/TestQuicTransport.h +++ b/quic/api/test/TestQuicTransport.h @@ -32,7 +32,7 @@ class TestQuicTransport conn_->version = QuicVersion::MVFST; conn_->observerContainer = observerContainer_; aead = test::createNoOpAead(); - headerCipher = test::createNoOpHeaderCipher(); + headerCipher = test::createNoOpHeaderCipher().value(); setConnectionSetupCallback(connSetupCb); setConnectionCallbackFromCtor(connCb); } diff --git a/quic/client/QuicClientTransportLite.cpp b/quic/client/QuicClientTransportLite.cpp index a10701f10..838452e56 100644 --- a/quic/client/QuicClientTransportLite.cpp +++ b/quic/client/QuicClientTransportLite.cpp @@ -1164,15 +1164,35 @@ QuicClientTransportLite::startCryptoHandshake() { auto& cryptoFactory = handshakeLayer->getCryptoFactory(); auto version = conn_->originalVersion.value(); - conn_->initialWriteCipher = cryptoFactory.getClientInitialCipher( + auto initialWriteCipherResult = cryptoFactory.getClientInitialCipher( *clientConn_->initialDestinationConnectionId, version); - conn_->readCodec->setInitialReadCipher(cryptoFactory.getServerInitialCipher( - *clientConn_->initialDestinationConnectionId, version)); + if (initialWriteCipherResult.hasError()) { + return folly::makeUnexpected(initialWriteCipherResult.error()); + } + conn_->initialWriteCipher = std::move(initialWriteCipherResult.value()); + + auto serverInitialCipherResult = cryptoFactory.getServerInitialCipher( + *clientConn_->initialDestinationConnectionId, version); + if (serverInitialCipherResult.hasError()) { + return folly::makeUnexpected(serverInitialCipherResult.error()); + } + conn_->readCodec->setInitialReadCipher( + std::move(serverInitialCipherResult.value())); + + auto serverHeaderCipherResult = cryptoFactory.makeServerInitialHeaderCipher( + *clientConn_->initialDestinationConnectionId, version); + if (serverHeaderCipherResult.hasError()) { + return folly::makeUnexpected(serverHeaderCipherResult.error()); + } conn_->readCodec->setInitialHeaderCipher( - cryptoFactory.makeServerInitialHeaderCipher( - *clientConn_->initialDestinationConnectionId, version)); - conn_->initialHeaderCipher = cryptoFactory.makeClientInitialHeaderCipher( + std::move(serverHeaderCipherResult.value())); + + auto clientHeaderCipherResult = cryptoFactory.makeClientInitialHeaderCipher( *clientConn_->initialDestinationConnectionId, version); + if (clientHeaderCipherResult.hasError()) { + return folly::makeUnexpected(clientHeaderCipherResult.error()); + } + conn_->initialHeaderCipher = std::move(clientHeaderCipherResult.value()); customTransportParameters_ = getSupportedExtTransportParams(*conn_); diff --git a/quic/client/handshake/ClientHandshake.cpp b/quic/client/handshake/ClientHandshake.cpp index 1e909cde5..8de1e1caa 100644 --- a/quic/client/handshake/ClientHandshake.cpp +++ b/quic/client/handshake/ClientHandshake.cpp @@ -159,9 +159,21 @@ bool ClientHandshake::waitingForData() const { } void ClientHandshake::computeCiphers(CipherKind kind, ByteRange secret) { - std::unique_ptr aead = buildAead(kind, secret); + auto aeadResult = buildAead(kind, secret); + if (aeadResult.hasError()) { + error_ = folly::makeUnexpected(std::move(aeadResult.error())); + return; + } + auto packetNumberCipherResult = buildHeaderCipher(secret); + if (packetNumberCipherResult.hasError()) { + error_ = folly::makeUnexpected(std::move(packetNumberCipherResult.error())); + return; + } + + std::unique_ptr aead = std::move(aeadResult.value()); std::unique_ptr packetNumberCipher = - buildHeaderCipher(secret); + std::move(packetNumberCipherResult.value()); + switch (kind) { case CipherKind::HandshakeWrite: conn_->handshakeWriteCipher = std::move(aead); @@ -208,11 +220,15 @@ ClientHandshake::getNextOneRttWriteCipher() { CHECK(writeTrafficSecret_); LOG_IF(WARNING, trafficSecretSync_ > 1 || trafficSecretSync_ < -1) << "Client read and write secrets are out of sync"; - writeTrafficSecret_ = getNextTrafficSecret(writeTrafficSecret_->coalesce()); + + auto nextSecretResult = getNextTrafficSecret(writeTrafficSecret_->coalesce()); + if (nextSecretResult.hasError()) { + return folly::makeUnexpected(std::move(nextSecretResult.error())); + } + writeTrafficSecret_ = std::move(nextSecretResult.value()); trafficSecretSync_--; - auto cipher = - buildAead(CipherKind::OneRttWrite, writeTrafficSecret_->coalesce()); - return cipher; + + return buildAead(CipherKind::OneRttWrite, writeTrafficSecret_->coalesce()); } folly::Expected, QuicError> @@ -224,11 +240,15 @@ ClientHandshake::getNextOneRttReadCipher() { CHECK(readTrafficSecret_); LOG_IF(WARNING, trafficSecretSync_ > 1 || trafficSecretSync_ < -1) << "Client read and write secrets are out of sync"; - readTrafficSecret_ = getNextTrafficSecret(readTrafficSecret_->coalesce()); + + auto nextSecretResult = getNextTrafficSecret(readTrafficSecret_->coalesce()); + if (nextSecretResult.hasError()) { + return folly::makeUnexpected(std::move(nextSecretResult.error())); + } + readTrafficSecret_ = std::move(nextSecretResult.value()); trafficSecretSync_++; - auto cipher = - buildAead(CipherKind::OneRttRead, readTrafficSecret_->coalesce()); - return cipher; + + return buildAead(CipherKind::OneRttRead, readTrafficSecret_->coalesce()); } void ClientHandshake::waitForData() { diff --git a/quic/client/handshake/ClientHandshake.h b/quic/client/handshake/ClientHandshake.h index be70d2d01..dce20b6f9 100644 --- a/quic/client/handshake/ClientHandshake.h +++ b/quic/client/handshake/ClientHandshake.h @@ -102,7 +102,8 @@ class ClientHandshake : public Handshake { * API used to verify that the integrity token present in the retry packet * matches what we would expect */ - virtual bool verifyRetryIntegrityTag( + [[nodiscard]] virtual folly::Expected + verifyRetryIntegrityTag( const ConnectionId& originalDstConnId, const RetryPacket& retryPacket) = 0; @@ -176,7 +177,8 @@ class ClientHandshake : public Handshake { * Given secret_n, returns secret_n+1 to be used for generating the next Aead * on key updates. */ - virtual BufPtr getNextTrafficSecret(ByteRange secret) const = 0; + [[nodiscard]] virtual folly::Expected getNextTrafficSecret( + ByteRange secret) const = 0; BufPtr readTrafficSecret_; BufPtr writeTrafficSecret_; @@ -185,16 +187,17 @@ class ClientHandshake : public Handshake { Optional canResendZeroRtt_; private: - virtual folly::Expected, QuicError> - connectImpl(Optional hostname) = 0; + [[nodiscard]] virtual folly:: + Expected, QuicError> + connectImpl(Optional hostname) = 0; virtual void processSocketData(folly::IOBufQueue& queue) = 0; virtual bool matchEarlyParameters() = 0; - virtual std::unique_ptr buildAead( - CipherKind kind, - ByteRange secret) = 0; - virtual std::unique_ptr buildHeaderCipher( - ByteRange secret) = 0; + [[nodiscard]] virtual folly::Expected, QuicError> + buildAead(CipherKind kind, ByteRange secret) = 0; + [[nodiscard]] virtual folly:: + Expected, QuicError> + buildHeaderCipher(ByteRange secret) = 0; // Represents the packet type that should be used to write the data currently // in the stream. diff --git a/quic/client/test/ClientStateMachineTest.cpp b/quic/client/test/ClientStateMachineTest.cpp index 19e342ee7..48eeed025 100644 --- a/quic/client/test/ClientStateMachineTest.cpp +++ b/quic/client/test/ClientStateMachineTest.cpp @@ -57,7 +57,7 @@ class ClientStateMachineTest : public Test { public: void SetUp() override { mockFactory_ = std::make_shared(); - EXPECT_CALL(*mockFactory_, _makeClientHandshake(_)) + EXPECT_CALL(*mockFactory_, makeClientHandshakeImpl(_)) .WillRepeatedly(Invoke( [&](QuicClientConnectionState* conn) -> std::unique_ptr { diff --git a/quic/client/test/Mocks.h b/quic/client/test/Mocks.h index 7afc72e4e..cc292d078 100644 --- a/quic/client/test/Mocks.h +++ b/quic/client/test/Mocks.h @@ -22,29 +22,29 @@ namespace quic::test { class MockClientHandshakeFactory : public ClientHandshakeFactory { public: - MOCK_METHOD( - std::unique_ptr, - _makeClientHandshake, - (QuicClientConnectionState*)); - std::unique_ptr makeClientHandshake(QuicClientConnectionState* conn) && override { - return _makeClientHandshake(conn); + return std::move(*this).makeClientHandshakeImpl(conn); } + + MOCK_METHOD( + std::unique_ptr, + makeClientHandshakeImpl, + (QuicClientConnectionState*)); }; -class MockClientHandshake : public ClientHandshake { +class MockClientHandshakeBase : public ClientHandshake { public: - MockClientHandshake(QuicClientConnectionState* conn) + MockClientHandshakeBase(QuicClientConnectionState* conn) : ClientHandshake(conn) {} - ~MockClientHandshake() override { + ~MockClientHandshakeBase() override { destroy(); } // Legacy workaround for move-only types folly::Expected doHandshake( - std::unique_ptr data, + BufPtr data, EncryptionLevel encryptionLevel) override { doHandshakeImpl(data.get(), encryptionLevel); return folly::unit; @@ -52,53 +52,133 @@ class MockClientHandshake : public ClientHandshake { MOCK_METHOD(void, doHandshakeImpl, (folly::IOBuf*, EncryptionLevel)); MOCK_METHOD( - bool, + (folly::Expected), verifyRetryIntegrityTag, - (const ConnectionId&, const RetryPacket&)); + (const ConnectionId&, const RetryPacket&), + (override)); MOCK_METHOD(void, removePsk, (const Optional&)); - MOCK_METHOD(const CryptoFactory&, getCryptoFactory, (), (const)); - MOCK_METHOD(bool, isTLSResumed, (), (const)); + MOCK_METHOD(const CryptoFactory&, getCryptoFactory, (), (const, override)); + MOCK_METHOD(bool, isTLSResumed, (), (const, override)); MOCK_METHOD( Optional>, getExportedKeyingMaterial, (const std::string& label, const Optional& context, uint16_t keyLength), - ()); + (override)); MOCK_METHOD(Optional, getZeroRttRejected, ()); + MOCK_METHOD(Optional, getCanResendZeroRtt, (), (const)); MOCK_METHOD( const Optional&, getServerTransportParams, - ()); + (), + (override)); MOCK_METHOD(void, destroy, ()); + MOCK_METHOD( + (folly::Expected, QuicError>), + getNextOneRttWriteCipher, + (), + (override)); + MOCK_METHOD( + (folly::Expected, QuicError>), + getNextOneRttReadCipher, + (), + (override)); + + void handshakeConfirmed() override { + handshakeConfirmedImpl(); + } + + MOCK_METHOD(void, handshakeConfirmedImpl, ()); + + Handshake::TLSSummary getTLSSummary() const override { + return getTLSSummaryImpl(); + } + + MOCK_METHOD(Handshake::TLSSummary, getTLSSummaryImpl, (), (const)); + + // Mock the public connect method + folly::Expected connect( + Optional hostname, + std::shared_ptr transportParams) { + return mockConnect(std::move(hostname), std::move(transportParams)); + } MOCK_METHOD( - (folly::Expected, QuicError>), - connectImpl, - (Optional)); - MOCK_METHOD(EncryptionLevel, getReadRecordLayerEncryptionLevel, ()); + (folly::Expected), + mockConnect, + (Optional, + std::shared_ptr)); + MOCK_METHOD( + EncryptionLevel, + getReadRecordLayerEncryptionLevel, + (), + (override)); MOCK_METHOD(void, processSocketData, (folly::IOBufQueue & queue)); MOCK_METHOD(bool, matchEarlyParameters, ()); MOCK_METHOD( - std::unique_ptr, + (folly::Expected, QuicError>), buildAead, (ClientHandshake::CipherKind kind, ByteRange secret)); MOCK_METHOD( - std::unique_ptr, + (folly::Expected, QuicError>), buildHeaderCipher, (ByteRange secret)); - MOCK_METHOD(BufPtr, getNextTrafficSecret, (ByteRange secret), (const)); + MOCK_METHOD( + (folly::Expected), + getNextTrafficSecret, + (ByteRange secret), + (const)); MOCK_METHOD( const Optional&, getApplicationProtocol, (), - (const)); + (const, override)); MOCK_METHOD( const std::shared_ptr, getPeerCertificate, (), - (const)); - MOCK_METHOD(Handshake::TLSSummary, getTLSSummary, (), (const)); + (const, override)); + MOCK_METHOD(Phase, getPhase, (), (const)); + MOCK_METHOD(bool, waitingForData, (), (const)); +}; + +class MockClientHandshake : public MockClientHandshakeBase { + public: + MockClientHandshake(QuicClientConnectionState* conn) + : MockClientHandshakeBase(conn) {} + + private: + // Implement the private pure virtual methods from ClientHandshake + folly::Expected, QuicError> + connectImpl(Optional /* hostname */) override { + return Optional(std::nullopt); + } + + void processSocketData(folly::IOBufQueue& /* queue */) override {} + + bool matchEarlyParameters() override { + return false; + } + + folly::Expected, QuicError> buildAead( + CipherKind /* kind */, + ByteRange /* secret */) override { + return folly::makeUnexpected( + QuicError(TransportErrorCode::INTERNAL_ERROR, "Not implemented")); + } + + folly::Expected, QuicError> + buildHeaderCipher(ByteRange /* secret */) override { + return folly::makeUnexpected( + QuicError(TransportErrorCode::INTERNAL_ERROR, "Not implemented")); + } + + folly::Expected getNextTrafficSecret( + ByteRange /* secret */) const override { + return folly::makeUnexpected( + QuicError(TransportErrorCode::INTERNAL_ERROR, "Not implemented")); + } }; class MockQuicConnectorCallback : public quic::QuicConnector::Callback { diff --git a/quic/client/test/QuicClientTransportLiteTest.cpp b/quic/client/test/QuicClientTransportLiteTest.cpp index fac35d1ff..09e4d42ac 100644 --- a/quic/client/test/QuicClientTransportLiteTest.cpp +++ b/quic/client/test/QuicClientTransportLiteTest.cpp @@ -51,7 +51,7 @@ class QuicClientTransportLiteTest : public Test { ON_CALL(*socket, setCmsgs(_)).WillByDefault(Return(folly::unit)); ON_CALL(*socket, appendCmsgs(_)).WillByDefault(Return(folly::unit)); auto mockFactory = std::make_shared(); - EXPECT_CALL(*mockFactory, _makeClientHandshake(_)) + EXPECT_CALL(*mockFactory, makeClientHandshakeImpl(_)) .WillRepeatedly(Invoke( [&](QuicClientConnectionState* conn) -> std::unique_ptr { @@ -61,7 +61,7 @@ class QuicClientTransportLiteTest : public Test { qEvb_, std::move(socket), mockFactory); quicClient_->getConn()->oneRttWriteCipher = test::createNoOpAead(); quicClient_->getConn()->oneRttWriteHeaderCipher = - test::createNoOpHeaderCipher(); + test::createNoOpHeaderCipher().value(); ASSERT_FALSE(quicClient_->getState() ->streamManager->setMaxLocalBidirectionalStreams(128) .hasError()); diff --git a/quic/client/test/QuicClientTransportTest.cpp b/quic/client/test/QuicClientTransportTest.cpp index a30d65852..ba990ac8d 100644 --- a/quic/client/test/QuicClientTransportTest.cpp +++ b/quic/client/test/QuicClientTransportTest.cpp @@ -109,7 +109,7 @@ class QuicClientTransportTest : public Test { .WillByDefault(testing::Return(folly::unit)); mockFactory_ = std::make_shared(); - EXPECT_CALL(*mockFactory_, _makeClientHandshake(_)) + EXPECT_CALL(*mockFactory_, makeClientHandshakeImpl(_)) .WillRepeatedly(Invoke( [&](QuicClientConnectionState* conn) -> std::unique_ptr { diff --git a/quic/codec/test/QuicPacketBuilderTest.cpp b/quic/codec/test/QuicPacketBuilderTest.cpp index b530d904b..9ab5d2b79 100644 --- a/quic/codec/test/QuicPacketBuilderTest.cpp +++ b/quic/codec/test/QuicPacketBuilderTest.cpp @@ -64,23 +64,25 @@ std::unique_ptr makeCodec( auto codec = std::make_unique(nodeType); if (nodeType != QuicNodeType::Client) { codec->setZeroRttReadCipher(std::move(zeroRttCipher)); - codec->setZeroRttHeaderCipher(test::createNoOpHeaderCipher()); + codec->setZeroRttHeaderCipher(test::createNoOpHeaderCipher().value()); } codec->setOneRttReadCipher(std::move(oneRttCipher)); - codec->setOneRttHeaderCipher(test::createNoOpHeaderCipher()); + codec->setOneRttHeaderCipher(test::createNoOpHeaderCipher().value()); codec->setHandshakeReadCipher(test::createNoOpAead()); - codec->setHandshakeHeaderCipher(test::createNoOpHeaderCipher()); + codec->setHandshakeHeaderCipher(test::createNoOpHeaderCipher().value()); codec->setClientConnectionId(clientConnId); if (nodeType == QuicNodeType::Client) { codec->setInitialReadCipher( - cryptoFactory.getServerInitialCipher(clientConnId, version)); + cryptoFactory.getServerInitialCipher(clientConnId, version).value()); codec->setInitialHeaderCipher( - cryptoFactory.makeServerInitialHeaderCipher(clientConnId, version)); + cryptoFactory.makeServerInitialHeaderCipher(clientConnId, version) + .value()); } else { codec->setInitialReadCipher( - cryptoFactory.getClientInitialCipher(clientConnId, version)); + cryptoFactory.getClientInitialCipher(clientConnId, version).value()); codec->setInitialHeaderCipher( - cryptoFactory.makeClientInitialHeaderCipher(clientConnId, version)); + cryptoFactory.makeClientInitialHeaderCipher(clientConnId, version) + .value()); } return codec; } @@ -181,9 +183,10 @@ TEST_P(QuicPacketBuilderTest, LongHeaderRegularPacket) { QuicVersion ver = QuicVersion::MVFST; // create a server cleartext write codec. FizzCryptoFactory cryptoFactory; - auto cleartextAead = cryptoFactory.getClientInitialCipher(serverConnId, ver); + auto cleartextAead = + cryptoFactory.getClientInitialCipher(serverConnId, ver).value(); auto headerCipher = - cryptoFactory.makeClientInitialHeaderCipher(serverConnId, ver); + cryptoFactory.makeClientInitialHeaderCipher(serverConnId, ver).value(); std::unique_ptr builderOwner; auto builderProvider = [&](PacketHeader header, PacketNum largestAcked) { diff --git a/quic/codec/test/QuicReadCodecTest.cpp b/quic/codec/test/QuicReadCodecTest.cpp index 64afd0517..946982c3f 100644 --- a/quic/codec/test/QuicReadCodecTest.cpp +++ b/quic/codec/test/QuicReadCodecTest.cpp @@ -44,15 +44,18 @@ std::unique_ptr makeEncryptedCodec( auto codec = std::make_unique(nodeType); codec->setClientConnectionId(clientConnId); codec->setInitialReadCipher( - cryptoFactory.getClientInitialCipher(clientConnId, QuicVersion::MVFST)); - codec->setInitialHeaderCipher(cryptoFactory.makeClientInitialHeaderCipher( - clientConnId, QuicVersion::MVFST)); + cryptoFactory.getClientInitialCipher(clientConnId, QuicVersion::MVFST) + .value()); + codec->setInitialHeaderCipher( + cryptoFactory + .makeClientInitialHeaderCipher(clientConnId, QuicVersion::MVFST) + .value()); if (zeroRttAead) { codec->setZeroRttReadCipher(std::move(zeroRttAead)); } - codec->setZeroRttHeaderCipher(test::createNoOpHeaderCipher()); + codec->setZeroRttHeaderCipher(test::createNoOpHeaderCipher().value()); codec->setOneRttReadCipher(std::move(oneRttAead)); - codec->setOneRttHeaderCipher(test::createNoOpHeaderCipher()); + codec->setOneRttHeaderCipher(test::createNoOpHeaderCipher().value()); if (sourceToken) { codec->setStatelessResetToken(*sourceToken); } @@ -217,7 +220,7 @@ TEST_F(QuicReadCodecTest, LongHeaderPacketLenMismatch) { AckStates ackStates; auto codec = makeUnencryptedCodec(); codec->setInitialReadCipher(createNoOpAead()); - codec->setInitialHeaderCipher(test::createNoOpHeaderCipher()); + codec->setInitialHeaderCipher(test::createNoOpHeaderCipher().value()); auto result = codec->parsePacket(packetQueue, ackStates); auto nothing = result.nothing(); EXPECT_NE(nothing, nullptr); @@ -667,9 +670,11 @@ TEST_F(QuicReadCodecTest, TestInitialPacket) { FizzCryptoFactory cryptoFactory; PacketNum packetNum = 1; uint64_t offset = 0; - auto aead = cryptoFactory.getClientInitialCipher(connId, QuicVersion::MVFST); + auto aead = + cryptoFactory.getClientInitialCipher(connId, QuicVersion::MVFST).value(); auto headerCipher = - cryptoFactory.makeClientInitialHeaderCipher(connId, QuicVersion::MVFST); + cryptoFactory.makeClientInitialHeaderCipher(connId, QuicVersion::MVFST) + .value(); auto initialCryptoPacketBuf = folly::IOBuf::copyBuffer("CHLO"); ChainedByteRangeHead initialCryptoPacketRch(initialCryptoPacketBuf); auto packet = createInitialCryptoPacket( @@ -682,7 +687,8 @@ TEST_F(QuicReadCodecTest, TestInitialPacket) { offset); auto codec = makeEncryptedCodec(connId, std::move(aead), nullptr); - aead = cryptoFactory.getClientInitialCipher(connId, QuicVersion::MVFST); + aead = + cryptoFactory.getClientInitialCipher(connId, QuicVersion::MVFST).value(); AckStates ackStates; auto packetQueue = bufToQueue(packetToBufCleartext(packet, *aead, *headerCipher, packetNum)); @@ -703,9 +709,11 @@ TEST_F(QuicReadCodecTest, TestInitialPacketExtractToken) { FizzCryptoFactory cryptoFactory; PacketNum packetNum = 1; uint64_t offset = 0; - auto aead = cryptoFactory.getClientInitialCipher(connId, QuicVersion::MVFST); + auto aead = + cryptoFactory.getClientInitialCipher(connId, QuicVersion::MVFST).value(); auto headerCipher = - cryptoFactory.makeClientInitialHeaderCipher(connId, QuicVersion::MVFST); + cryptoFactory.makeClientInitialHeaderCipher(connId, QuicVersion::MVFST) + .value(); std::string token = "aswerdfewdewrgetg"; auto initialCryptoPacketBuf = folly::IOBuf::copyBuffer("CHLO"); ChainedByteRangeHead initialCryptoPacketRch(initialCryptoPacketBuf); @@ -721,7 +729,8 @@ TEST_F(QuicReadCodecTest, TestInitialPacketExtractToken) { token); auto codec = makeEncryptedCodec(connId, std::move(aead), nullptr); - aead = cryptoFactory.getClientInitialCipher(connId, QuicVersion::MVFST); + aead = + cryptoFactory.getClientInitialCipher(connId, QuicVersion::MVFST).value(); auto packetQueue = bufToQueue(packetToBufCleartext(packet, *aead, *headerCipher, packetNum)); @@ -740,9 +749,11 @@ TEST_F(QuicReadCodecTest, TestHandshakeDone) { FizzCryptoFactory cryptoFactory; PacketNum packetNum = 1; uint64_t offset = 0; - auto aead = cryptoFactory.getClientInitialCipher(connId, QuicVersion::MVFST); + auto aead = + cryptoFactory.getClientInitialCipher(connId, QuicVersion::MVFST).value(); auto headerCipher = - cryptoFactory.makeClientInitialHeaderCipher(connId, QuicVersion::MVFST); + cryptoFactory.makeClientInitialHeaderCipher(connId, QuicVersion::MVFST) + .value(); auto initialCryptoPacketBuf = folly::IOBuf::copyBuffer("CHLO"); ChainedByteRangeHead initialCryptoPacketRch(initialCryptoPacketBuf); auto packet = createInitialCryptoPacket( @@ -755,7 +766,8 @@ TEST_F(QuicReadCodecTest, TestHandshakeDone) { offset); auto codec = makeEncryptedCodec(connId, std::move(aead), nullptr); - aead = cryptoFactory.getClientInitialCipher(connId, QuicVersion::MVFST); + aead = + cryptoFactory.getClientInitialCipher(connId, QuicVersion::MVFST).value(); AckStates ackStates; auto packetQueue = bufToQueue(packetToBufCleartext(packet, *aead, *headerCipher, packetNum)); diff --git a/quic/common/test/TestUtils.cpp b/quic/common/test/TestUtils.cpp index b690d3118..2c5b6c0e3 100644 --- a/quic/common/test/TestUtils.cpp +++ b/quic/common/test/TestUtils.cpp @@ -39,18 +39,20 @@ const RegularQuicWritePacket& writeQuicPacket( bool eof) { auto version = conn.version.value_or(*conn.originalVersion); auto aead = createNoOpAead(); - auto headerCipher = createNoOpHeaderCipher(); + auto headerCipherResult = createNoOpHeaderCipher(); + CHECK(!headerCipherResult.hasError()) << "Failed to create header cipher"; + auto headerCipher = std::move(headerCipherResult.value()); CHECK(!writeDataToQuicStream(stream, data.clone(), eof).hasError()); - CHECK(!writeQuicDataToSocket( - sock, - conn, - srcConnId, - dstConnId, - *aead, - *headerCipher, - version, - conn.transportSettings.writeConnectionDataPacketsLimit) - .hasError()); + auto result = writeQuicDataToSocket( + sock, + conn, + srcConnId, + dstConnId, + *aead, + *headerCipher, + version, + conn.transportSettings.writeConnectionDataPacketsLimit); + CHECK(!result.hasError()); CHECK( conn.outstandings.packets.rend() != getLastOutstandingPacket(conn, PacketNumberSpace::AppData)); @@ -63,19 +65,22 @@ PacketNum rstStreamAndSendPacket( QuicStreamState& stream, ApplicationErrorCode errorCode) { auto aead = createNoOpAead(); - auto headerCipher = createNoOpHeaderCipher(); + auto headerCipherResult = createNoOpHeaderCipher(); + CHECK(!headerCipherResult.hasError()) << "Failed to create header cipher"; + auto headerCipher = std::move(headerCipherResult.value()); auto version = conn.version.value_or(*conn.originalVersion); CHECK(!sendRstSMHandler(stream, errorCode).hasError()); - CHECK(!writeQuicDataToSocket( - sock, - conn, - *conn.clientConnectionId, - *conn.serverConnectionId, - *aead, - *headerCipher, - version, - conn.transportSettings.writeConnectionDataPacketsLimit) - .hasError()); + auto result = writeQuicDataToSocket( + sock, + conn, + *conn.clientConnectionId, + *conn.serverConnectionId, + *aead, + *headerCipher, + version, + conn.transportSettings.writeConnectionDataPacketsLimit); + CHECK(!result.hasError()); + CHECK(!result.hasError()); for (const auto& packet : conn.outstandings.packets) { for (const auto& frame : packet.packet.frames) { @@ -269,7 +274,8 @@ std::unique_ptr createNoOpAead(uint64_t cipherOverhead) { return createNoOpAeadImpl(cipherOverhead); } -std::unique_ptr createNoOpHeaderCipher() { +folly::Expected, QuicError> +createNoOpHeaderCipher() { auto headerCipher = std::make_unique>(); ON_CALL(*headerCipher, mask(_)).WillByDefault(Return(HeaderProtectionMask{})); ON_CALL(*headerCipher, keyLength()).WillByDefault(Return(16)); @@ -776,12 +782,12 @@ bool writableContains(QuicStreamManager& streamManager, StreamId streamId) { streamManager.controlWriteQueue().count(streamId) > 0; } -std::unique_ptr +folly::Expected, QuicError> FizzCryptoTestFactory::makePacketNumberCipher(fizz::CipherSuite) const { return std::move(packetNumberCipher_); } -std::unique_ptr +folly::Expected, QuicError> FizzCryptoTestFactory::makePacketNumberCipher(ByteRange secret) const { return _makePacketNumberCipher(secret); } @@ -793,9 +799,12 @@ void FizzCryptoTestFactory::setMockPacketNumberCipher( void FizzCryptoTestFactory::setDefault() { ON_CALL(*this, _makePacketNumberCipher(_)) - .WillByDefault(Invoke([&](ByteRange secret) { - return FizzCryptoFactory::makePacketNumberCipher(secret); - })); + .WillByDefault(Invoke( + [&](ByteRange secret) -> folly::Expected< + std::unique_ptr, + QuicError> { + return FizzCryptoFactory::makePacketNumberCipher(secret); + })); } void TestPacketBatchWriter::reset() { @@ -831,15 +840,18 @@ TrafficKey getQuicTestKey() { std::unique_ptr getProtectionKey() { FizzCryptoFactory factory; auto secret = getRandSecret(); - auto pnCipher = + auto pnCipherResult = factory.makePacketNumberCipher(fizz::CipherSuite::TLS_AES_128_GCM_SHA256); + CHECK(!pnCipherResult.hasError()) << "Failed to make packet number cipher"; + auto& pnCipher = pnCipherResult.value(); auto deriver = factory.getFizzFactory()->makeKeyDeriver( fizz::CipherSuite::TLS_AES_128_GCM_SHA256); - return deriver->expandLabel( + auto pnKey = deriver->expandLabel( folly::range(secret), kQuicPNLabel, folly::IOBuf::create(0), - pnCipher->keyLength()); + (*pnCipher).keyLength()); + return pnKey; } size_t getTotalIovecLen(const struct iovec* vec, size_t iovec_len) { diff --git a/quic/common/test/TestUtils.h b/quic/common/test/TestUtils.h index e7b716e51..cc44ca092 100644 --- a/quic/common/test/TestUtils.h +++ b/quic/common/test/TestUtils.h @@ -189,7 +189,15 @@ std::unique_ptr createNoOpAeadImpl(uint64_t cipherOverhead = 0) { std::unique_ptr createNoOpAead(uint64_t cipherOverhead = 0); -std::unique_ptr createNoOpHeaderCipher(); +folly::Expected, QuicError> +createNoOpHeaderCipher(); + +// For backward compatibility with existing code +inline std::unique_ptr createNoOpHeaderCipherNoThrow() { + auto result = createNoOpHeaderCipher(); + CHECK(!result.hasError()) << "Failed to create header cipher"; + return std::move(result.value()); +} uint64_t computeExpectedDelay( std::chrono::microseconds ackDelay, @@ -348,17 +356,17 @@ class FizzCryptoTestFactory : public FizzCryptoFactory { ~FizzCryptoTestFactory() override = default; using FizzCryptoFactory::makePacketNumberCipher; - std::unique_ptr makePacketNumberCipher( - fizz::CipherSuite) const override; + folly::Expected, QuicError> + makePacketNumberCipher(fizz::CipherSuite) const override; MOCK_METHOD( - std::unique_ptr, + (folly::Expected, QuicError>), _makePacketNumberCipher, (ByteRange), (const)); - std::unique_ptr makePacketNumberCipher( - ByteRange secret) const override; + folly::Expected, QuicError> + makePacketNumberCipher(ByteRange secret) const override; void setMockPacketNumberCipher( std::unique_ptr packetNumberCipher); @@ -564,9 +572,16 @@ class FakeServerHandshake : public FizzServerHandshake { void setEarlyKeys() { oneRttWriteCipher_ = createNoOpAead(); - oneRttWriteHeaderCipher_ = createNoOpHeaderCipher(); + auto oneRttWriteHeaderCipherResult = createNoOpHeaderCipher(); + CHECK(!oneRttWriteHeaderCipherResult.hasError()) + << "Failed to create header cipher"; + oneRttWriteHeaderCipher_ = std::move(oneRttWriteHeaderCipherResult.value()); + zeroRttReadCipher_ = createNoOpAead(); - zeroRttReadHeaderCipher_ = createNoOpHeaderCipher(); + auto zeroRttReadHeaderCipherResult = createNoOpHeaderCipher(); + CHECK(!zeroRttReadHeaderCipherResult.hasError()) + << "Failed to create header cipher"; + zeroRttReadHeaderCipher_ = std::move(zeroRttReadHeaderCipherResult.value()); } void setOneRttKeys() { @@ -577,12 +592,19 @@ class FakeServerHandshake : public FizzServerHandshake { ON_CALL(*mockOneRttWriteCipher, getKey()) .WillByDefault(testing::Invoke([]() { return getQuicTestKey(); })); oneRttWriteCipher_ = std::move(mockOneRttWriteCipher); - auto mockOneRttWriteHeaderCipher = createNoOpHeaderCipher(); + auto mockOneRttWriteHeaderCipherResult = createNoOpHeaderCipher(); + CHECK(!mockOneRttWriteHeaderCipherResult.hasError()) + << "Failed to create header cipher"; + auto& mockOneRttWriteHeaderCipher = + mockOneRttWriteHeaderCipherResult.value(); mockOneRttWriteHeaderCipher->setDefaultKey(); oneRttWriteHeaderCipher_ = std::move(mockOneRttWriteHeaderCipher); } oneRttReadCipher_ = createNoOpAead(); - oneRttReadHeaderCipher_ = createNoOpHeaderCipher(); + auto oneRttReadHeaderCipherResult = createNoOpHeaderCipher(); + CHECK(!oneRttReadHeaderCipherResult.hasError()) + << "Failed to create header cipher"; + oneRttReadHeaderCipher_ = std::move(oneRttReadHeaderCipherResult.value()); readTrafficSecret_ = folly::IOBuf::copyBuffer(getRandSecret()); writeTrafficSecret_ = folly::IOBuf::copyBuffer(getRandSecret()); } @@ -597,9 +619,17 @@ class FakeServerHandshake : public FizzServerHandshake { void setHandshakeKeys() { conn_.handshakeWriteCipher = createNoOpAead(); - conn_.handshakeWriteHeaderCipher = createNoOpHeaderCipher(); + auto handshakeWriteHeaderCipherResult = createNoOpHeaderCipher(); + CHECK(!handshakeWriteHeaderCipherResult.hasError()) + << "Failed to create header cipher"; + conn_.handshakeWriteHeaderCipher = + std::move(handshakeWriteHeaderCipherResult.value()); handshakeReadCipher_ = createNoOpAead(); - handshakeReadHeaderCipher_ = createNoOpHeaderCipher(); + auto handshakeReadHeaderCipherResult = createNoOpHeaderCipher(); + CHECK(!handshakeReadHeaderCipherResult.hasError()) + << "Failed to create header cipher"; + handshakeReadHeaderCipher_ = + std::move(handshakeReadHeaderCipherResult.value()); } void setHandshakeDone(bool done) { diff --git a/quic/dsr/backend/DSRPacketizer.h b/quic/dsr/backend/DSRPacketizer.h index 84f44c9c7..701061fbd 100644 --- a/quic/dsr/backend/DSRPacketizer.h +++ b/quic/dsr/backend/DSRPacketizer.h @@ -62,8 +62,12 @@ class CipherBuilder { *quicFizzCryptoFactory_.getFizzFactory(), std::move(trafficKey), cipherSuite)); - auto headerCipher = quicFizzCryptoFactory_.makePacketNumberCipher( + auto headerCipherResult = quicFizzCryptoFactory_.makePacketNumberCipher( fizz::CipherSuite::TLS_AES_128_GCM_SHA256); + if (headerCipherResult.hasError()) { + throw std::runtime_error("Failed to create header cipher"); + } + auto headerCipher = std::move(headerCipherResult.value()); headerCipher->setKey(packetProtectionKey->coalesce()); return {std::move(aead), std::move(headerCipher)}; diff --git a/quic/dsr/backend/test/DSRPacketizerTest.cpp b/quic/dsr/backend/test/DSRPacketizerTest.cpp index e1344eec5..368d15c2c 100644 --- a/quic/dsr/backend/test/DSRPacketizerTest.cpp +++ b/quic/dsr/backend/test/DSRPacketizerTest.cpp @@ -45,7 +45,7 @@ class DSRPacketizerSingleWriteTest : public Test { protected: void SetUp() override { aead = test::createNoOpAead(); - headerCipher = test::createNoOpHeaderCipher(); + headerCipher = test::createNoOpHeaderCipher().value(); qEvb_ = std::make_shared(&evb); } diff --git a/quic/fizz/client/handshake/FizzClientHandshake.cpp b/quic/fizz/client/handshake/FizzClientHandshake.cpp index c1d8eb1cc..5a6f0b979 100644 --- a/quic/fizz/client/handshake/FizzClientHandshake.cpp +++ b/quic/fizz/client/handshake/FizzClientHandshake.cpp @@ -149,26 +149,32 @@ const Optional& FizzClientHandshake::getApplicationProtocol() } } -bool FizzClientHandshake::verifyRetryIntegrityTag( +folly::Expected FizzClientHandshake::verifyRetryIntegrityTag( const ConnectionId& originalDstConnId, const RetryPacket& retryPacket) { - PseudoRetryPacketBuilder pseudoRetryPacketBuilder( - retryPacket.initialByte, - retryPacket.header.getSourceConnId(), - retryPacket.header.getDestinationConnId(), - originalDstConnId, - retryPacket.header.getVersion(), - BufHelpers::copyBuffer(retryPacket.header.getToken())); + try { + PseudoRetryPacketBuilder pseudoRetryPacketBuilder( + retryPacket.initialByte, + retryPacket.header.getSourceConnId(), + retryPacket.header.getDestinationConnId(), + originalDstConnId, + retryPacket.header.getVersion(), + BufHelpers::copyBuffer(retryPacket.header.getToken())); - BufPtr pseudoRetryPacket = std::move(pseudoRetryPacketBuilder).buildPacket(); + BufPtr pseudoRetryPacket = + std::move(pseudoRetryPacketBuilder).buildPacket(); - FizzRetryIntegrityTagGenerator retryIntegrityTagGenerator; - auto expectedIntegrityTag = retryIntegrityTagGenerator.getRetryIntegrityTag( - retryPacket.header.getVersion(), pseudoRetryPacket.get()); + FizzRetryIntegrityTagGenerator retryIntegrityTagGenerator; + auto expectedIntegrityTag = retryIntegrityTagGenerator.getRetryIntegrityTag( + retryPacket.header.getVersion(), pseudoRetryPacket.get()); - folly::IOBuf integrityTagWrapper = BufHelpers::wrapBufferAsValue( - retryPacket.integrityTag.data(), retryPacket.integrityTag.size()); - return BufEq()(*expectedIntegrityTag, integrityTagWrapper); + folly::IOBuf integrityTagWrapper = BufHelpers::wrapBufferAsValue( + retryPacket.integrityTag.data(), retryPacket.integrityTag.size()); + return BufEq()(*expectedIntegrityTag, integrityTagWrapper); + } catch (const std::exception& ex) { + return folly::makeUnexpected( + QuicError(TransportErrorCode::INTERNAL_ERROR, ex.what())); + } } bool FizzClientHandshake::isTLSResumed() const { @@ -212,40 +218,50 @@ bool FizzClientHandshake::matchEarlyParameters() { return fizz::client::earlyParametersMatch(state_); } -std::unique_ptr FizzClientHandshake::buildAead( - CipherKind kind, - ByteRange secret) { - bool isEarlyTraffic = kind == CipherKind::ZeroRttWrite; - fizz::CipherSuite cipher = - isEarlyTraffic ? state_.earlyDataParams()->cipher : *state_.cipher(); - std::unique_ptr keySchedulerPtr = isEarlyTraffic - ? state_.context()->getFactory()->makeKeyScheduler(cipher) - : nullptr; - fizz::KeyScheduler& keyScheduler = - isEarlyTraffic ? *keySchedulerPtr : *state_.keyScheduler(); +folly::Expected, QuicError> +FizzClientHandshake::buildAead(CipherKind kind, ByteRange secret) { + try { + bool isEarlyTraffic = kind == CipherKind::ZeroRttWrite; + fizz::CipherSuite cipher = + isEarlyTraffic ? state_.earlyDataParams()->cipher : *state_.cipher(); + std::unique_ptr keySchedulerPtr = isEarlyTraffic + ? state_.context()->getFactory()->makeKeyScheduler(cipher) + : nullptr; + fizz::KeyScheduler& keyScheduler = + isEarlyTraffic ? *keySchedulerPtr : *state_.keyScheduler(); - auto aead = FizzAead::wrap(fizz::Protocol::deriveRecordAeadWithLabel( - *state_.context()->getFactory(), - keyScheduler, - cipher, - secret, - kQuicKeyLabel, - kQuicIVLabel)); + auto aead = FizzAead::wrap(fizz::Protocol::deriveRecordAeadWithLabel( + *state_.context()->getFactory(), + keyScheduler, + cipher, + secret, + kQuicKeyLabel, + kQuicIVLabel)); - return aead; + return aead; + } catch (const std::exception& ex) { + return folly::makeUnexpected( + QuicError(TransportErrorCode::INTERNAL_ERROR, ex.what())); + } } -std::unique_ptr FizzClientHandshake::buildHeaderCipher( - ByteRange secret) { +folly::Expected, QuicError> +FizzClientHandshake::buildHeaderCipher(ByteRange secret) { return cryptoFactory_->makePacketNumberCipher(secret); } -BufPtr FizzClientHandshake::getNextTrafficSecret(ByteRange secret) const { - auto deriver = - state_.context()->getFactory()->makeKeyDeriver(*state_.cipher()); - auto nextSecret = deriver->expandLabel( - secret, kQuicKULabel, BufHelpers::create(0), secret.size()); - return nextSecret; +folly::Expected FizzClientHandshake::getNextTrafficSecret( + ByteRange secret) const { + try { + auto deriver = + state_.context()->getFactory()->makeKeyDeriver(*state_.cipher()); + auto nextSecret = deriver->expandLabel( + secret, kQuicKULabel, BufHelpers::create(0), secret.size()); + return nextSecret; + } catch (const std::exception& ex) { + return folly::makeUnexpected( + QuicError(TransportErrorCode::INTERNAL_ERROR, ex.what())); + } } void FizzClientHandshake::onNewCachedPsk( diff --git a/quic/fizz/client/handshake/FizzClientHandshake.h b/quic/fizz/client/handshake/FizzClientHandshake.h index 64aee0418..3efb77ae6 100644 --- a/quic/fizz/client/handshake/FizzClientHandshake.h +++ b/quic/fizz/client/handshake/FizzClientHandshake.h @@ -33,7 +33,7 @@ class FizzClientHandshake : public ClientHandshake { const Optional& getApplicationProtocol() const override; - bool verifyRetryIntegrityTag( + [[nodiscard]] folly::Expected verifyRetryIntegrityTag( const ConnectionId& originalDstConnId, const RetryPacket& retryPacket) override; @@ -70,16 +70,20 @@ class FizzClientHandshake : public ClientHandshake { void echRetryAvailable(fizz::client::ECHRetryAvailable& retry); private: - folly::Expected, QuicError> - connectImpl(Optional hostname) override; + [[nodiscard]] folly:: + Expected, QuicError> + connectImpl(Optional hostname) override; EncryptionLevel getReadRecordLayerEncryptionLevel() override; void processSocketData(folly::IOBufQueue& queue) override; bool matchEarlyParameters() override; - std::unique_ptr buildAead(CipherKind kind, ByteRange secret) override; - std::unique_ptr buildHeaderCipher( + [[nodiscard]] folly::Expected, QuicError> buildAead( + CipherKind kind, ByteRange secret) override; - BufPtr getNextTrafficSecret(ByteRange secret) const override; + [[nodiscard]] folly::Expected, QuicError> + buildHeaderCipher(ByteRange secret) override; + [[nodiscard]] folly::Expected getNextTrafficSecret( + ByteRange secret) const override; class ActionMoveVisitor; void processActions(fizz::client::Actions actions); diff --git a/quic/fizz/client/test/QuicClientTransportTest.cpp b/quic/fizz/client/test/QuicClientTransportTest.cpp index 20770316b..dbeb7330b 100644 --- a/quic/fizz/client/test/QuicClientTransportTest.cpp +++ b/quic/fizz/client/test/QuicClientTransportTest.cpp @@ -2912,8 +2912,10 @@ TEST_P(QuicClientTransportAfterStartTest, ReadStreamCoalesced) { FizzCryptoFactory cryptoFactory; auto garbage = IOBuf::copyBuffer("garbage"); - auto initialCipher = cryptoFactory.getServerInitialCipher( + auto initialCipherResult = cryptoFactory.getServerInitialCipher( *serverChosenConnId, QuicVersion::MVFST); + ASSERT_FALSE(initialCipherResult.hasError()); + auto& initialCipher = initialCipherResult.value(); auto firstPacketNum = appDataPacketNum++; auto packet1 = packetToBufCleartext( createStreamPacket( @@ -2963,8 +2965,11 @@ TEST_F(QuicClientTransportAfterStartTest, ReadStreamCoalescedMany) { BufQueue packets; for (int i = 0; i < kMaxNumCoalescedPackets; i++) { auto garbage = IOBuf::copyBuffer("garbage"); - auto initialCipher = cryptoFactory.getServerInitialCipher( + auto initialCipherResult = cryptoFactory.getServerInitialCipher( *serverChosenConnId, QuicVersion::MVFST); + ASSERT_FALSE(initialCipherResult.hasError()); + auto& initialCipher = initialCipherResult.value(); + auto packetNum = appDataPacketNum++; auto packet1 = packetToBufCleartext( createStreamPacket( @@ -2973,7 +2978,7 @@ TEST_F(QuicClientTransportAfterStartTest, ReadStreamCoalescedMany) { packetNum, streamId, *garbage, - initialCipher->getCipherOverhead(), + initialCipher.get()->getCipherOverhead(), 0 /* largestAcked */, std::make_pair(LongHeader::Types::Initial, QuicVersion::MVFST)), *initialCipher, @@ -3651,8 +3656,10 @@ TEST_F(QuicClientTransportAfterStartTest, WrongCleartextCipher) { // throws on getting unencrypted stream data. PacketNum nextPacketNum = appDataPacketNum++; - auto initialCipher = cryptoFactory.getServerInitialCipher( + auto initialCipherResult = cryptoFactory.getServerInitialCipher( *serverChosenConnId, QuicVersion::MVFST); + ASSERT_FALSE(initialCipherResult.hasError()); + auto& initialCipher = initialCipherResult.value(); auto packet = packetToBufCleartext( createStreamPacket( *serverChosenConnId /* src */, @@ -3660,7 +3667,7 @@ TEST_F(QuicClientTransportAfterStartTest, WrongCleartextCipher) { nextPacketNum, streamId, *expected, - initialCipher->getCipherOverhead(), + initialCipher.get()->getCipherOverhead(), 0 /* largestAcked */, std::make_pair(LongHeader::Types::Initial, QuicVersion::MVFST)), *initialCipher, @@ -5148,15 +5155,15 @@ class QuicZeroRttClientTest : public QuicClientTransportAfterStartTestBase { mockClientHandshake->setHandshakeWriteCipher(std::move(handshakeWriteAead)); mockClientHandshake->setHandshakeReadHeaderCipher( - test::createNoOpHeaderCipher()); + test::createNoOpHeaderCipherNoThrow()); mockClientHandshake->setHandshakeWriteHeaderCipher( - test::createNoOpHeaderCipher()); + test::createNoOpHeaderCipherNoThrow()); mockClientHandshake->setOneRttWriteHeaderCipher( - test::createNoOpHeaderCipher()); + test::createNoOpHeaderCipherNoThrow()); mockClientHandshake->setOneRttReadHeaderCipher( - test::createNoOpHeaderCipher()); + test::createNoOpHeaderCipherNoThrow()); mockClientHandshake->setZeroRttWriteHeaderCipher( - test::createNoOpHeaderCipher()); + test::createNoOpHeaderCipherNoThrow()); } std::shared_ptr getPskCache() override { diff --git a/quic/fizz/client/test/QuicClientTransportTestUtil.h b/quic/fizz/client/test/QuicClientTransportTestUtil.h index 62cd9777d..dd453b9fd 100644 --- a/quic/fizz/client/test/QuicClientTransportTestUtil.h +++ b/quic/fizz/client/test/QuicClientTransportTestUtil.h @@ -336,9 +336,10 @@ class FakeOneRttHandshakeLayer : public FizzClientHandshake { } if (getPhase() == Phase::Initial) { conn->handshakeWriteCipher = test::createNoOpAead(); - conn->handshakeWriteHeaderCipher = test::createNoOpHeaderCipher(); + conn->handshakeWriteHeaderCipher = test::createNoOpHeaderCipherNoThrow(); conn->readCodec->setHandshakeReadCipher(test::createNoOpAead()); - conn->readCodec->setHandshakeHeaderCipher(test::createNoOpHeaderCipher()); + conn->readCodec->setHandshakeHeaderCipher( + test::createNoOpHeaderCipherNoThrow()); writeDataToQuicStream( conn->cryptoState->handshakeStream, folly::IOBuf::copyBuffer("ClientFinished")); @@ -357,7 +358,8 @@ class FakeOneRttHandshakeLayer : public FizzClientHandshake { return params_; } - BufPtr getNextTrafficSecret(ByteRange /*secret*/) const override { + folly::Expected getNextTrafficSecret( + ByteRange /*secret*/) const override { return folly::IOBuf::copyBuffer(getRandSecret()); } @@ -410,12 +412,15 @@ class FakeOneRttHandshakeLayer : public FizzClientHandshake { throw std::runtime_error("matchEarlyParameters not implemented"); } - std::unique_ptr buildAead(CipherKind, ByteRange) override { + folly::Expected, QuicError> buildAead( + CipherKind, + ByteRange) override { return createNoOpAead(); } - std::unique_ptr buildHeaderCipher(ByteRange) override { - throw std::runtime_error("buildHeaderCipher not implemented"); + folly::Expected, QuicError> + buildHeaderCipher(ByteRange) override { + return createNoOpHeaderCipher(); } }; @@ -650,13 +655,13 @@ class QuicClientTransportTestBase : public virtual testing::Test { mockClientHandshake->setOneRttWriteCipher(std::move(writeAead)); mockClientHandshake->setHandshakeReadHeaderCipher( - test::createNoOpHeaderCipher()); + test::createNoOpHeaderCipherNoThrow()); mockClientHandshake->setHandshakeWriteHeaderCipher( - test::createNoOpHeaderCipher()); + test::createNoOpHeaderCipherNoThrow()); mockClientHandshake->setOneRttWriteHeaderCipher( - test::createNoOpHeaderCipher()); + test::createNoOpHeaderCipherNoThrow()); mockClientHandshake->setOneRttReadHeaderCipher( - test::createNoOpHeaderCipher()); + test::createNoOpHeaderCipherNoThrow()); } virtual void setUpSocketExpectations() { @@ -1002,12 +1007,20 @@ class QuicClientTransportTestBase : public virtual testing::Test { FizzCryptoFactory cryptoFactory; auto codec = std::make_unique(QuicNodeType::Server); codec->setClientConnectionId(*originalConnId); - codec->setInitialReadCipher(cryptoFactory.getClientInitialCipher( - *client->getConn().initialDestinationConnectionId, QuicVersion::MVFST)); - codec->setInitialHeaderCipher(cryptoFactory.makeClientInitialHeaderCipher( - *client->getConn().initialDestinationConnectionId, QuicVersion::MVFST)); + codec->setInitialReadCipher( + cryptoFactory + .getClientInitialCipher( + *client->getConn().initialDestinationConnectionId, + QuicVersion::MVFST) + .value()); + codec->setInitialHeaderCipher( + cryptoFactory + .makeClientInitialHeaderCipher( + *client->getConn().initialDestinationConnectionId, + QuicVersion::MVFST) + .value()); codec->setHandshakeReadCipher(test::createNoOpAead()); - codec->setHandshakeHeaderCipher(test::createNoOpHeaderCipher()); + codec->setHandshakeHeaderCipher(test::createNoOpHeaderCipher().value()); return codec; } @@ -1018,18 +1031,24 @@ class QuicClientTransportTestBase : public virtual testing::Test { std::unique_ptr handshakeReadCipher; codec->setClientConnectionId(*originalConnId); codec->setOneRttReadCipher(test::createNoOpAead()); - codec->setOneRttHeaderCipher(test::createNoOpHeaderCipher()); + codec->setOneRttHeaderCipher(test::createNoOpHeaderCipher().value()); codec->setZeroRttReadCipher(test::createNoOpAead()); - codec->setZeroRttHeaderCipher(test::createNoOpHeaderCipher()); + codec->setZeroRttHeaderCipher(test::createNoOpHeaderCipher().value()); if (handshakeCipher) { - codec->setInitialReadCipher(cryptoFactory.getClientInitialCipher( + auto initialReadCipher = cryptoFactory.getClientInitialCipher( *client->getConn().initialDestinationConnectionId, - QuicVersion::MVFST)); - codec->setInitialHeaderCipher(cryptoFactory.makeClientInitialHeaderCipher( + QuicVersion::MVFST); + CHECK(initialReadCipher.hasValue()); + codec->setInitialReadCipher(std::move(initialReadCipher.value())); + + auto initialHeaderCipher = cryptoFactory.makeClientInitialHeaderCipher( *client->getConn().initialDestinationConnectionId, - QuicVersion::MVFST)); + QuicVersion::MVFST); + CHECK(initialHeaderCipher.hasValue()); + codec->setInitialHeaderCipher(std::move(initialHeaderCipher.value())); + codec->setHandshakeReadCipher(test::createNoOpAead()); - codec->setHandshakeHeaderCipher(test::createNoOpHeaderCipher()); + codec->setHandshakeHeaderCipher(test::createNoOpHeaderCipher().value()); } return codec; } diff --git a/quic/fizz/handshake/FizzCryptoFactory.cpp b/quic/fizz/handshake/FizzCryptoFactory.cpp index 8a4088e41..3ef53fd47 100644 --- a/quic/fizz/handshake/FizzCryptoFactory.cpp +++ b/quic/fizz/handshake/FizzCryptoFactory.cpp @@ -14,7 +14,7 @@ namespace quic { -BufPtr FizzCryptoFactory::makeInitialTrafficSecret( +folly::Expected FizzCryptoFactory::makeInitialTrafficSecret( folly::StringPiece label, const ConnectionId& clientDestinationConnId, QuicVersion version) const { @@ -31,12 +31,18 @@ BufPtr FizzCryptoFactory::makeInitialTrafficSecret( return trafficSecret; } -std::unique_ptr FizzCryptoFactory::makeInitialAead( +folly::Expected, QuicError> +FizzCryptoFactory::makeInitialAead( folly::StringPiece label, const ConnectionId& clientDestinationConnId, QuicVersion version) const { - auto trafficSecret = + auto trafficSecretResult = makeInitialTrafficSecret(label, clientDestinationConnId, version); + if (trafficSecretResult.hasError()) { + return folly::makeUnexpected(trafficSecretResult.error()); + } + auto& trafficSecret = trafficSecretResult.value(); + auto deriver = fizzFactory_->makeKeyDeriver(fizz::CipherSuite::TLS_AES_128_GCM_SHA256); auto aead = fizzFactory_->makeAead(fizz::CipherSuite::TLS_AES_128_GCM_SHA256); @@ -51,15 +57,22 @@ std::unique_ptr FizzCryptoFactory::makeInitialAead( BufHelpers::create(0), aead->ivLength()); - fizz::TrafficKey trafficKey = {std::move(key), std::move(iv)}; + fizz::TrafficKey trafficKey; + trafficKey.key = std::move(key); + trafficKey.iv = std::move(iv); aead->setKey(std::move(trafficKey)); return FizzAead::wrap(std::move(aead)); } -std::unique_ptr FizzCryptoFactory::makePacketNumberCipher( - ByteRange baseSecret) const { - auto pnCipher = +folly::Expected, QuicError> +FizzCryptoFactory::makePacketNumberCipher(ByteRange baseSecret) const { + auto pnCipherResult = makePacketNumberCipher(fizz::CipherSuite::TLS_AES_128_GCM_SHA256); + if (pnCipherResult.hasError()) { + return folly::makeUnexpected(pnCipherResult.error()); + } + auto pnCipher = std::move(pnCipherResult.value()); + auto deriver = fizzFactory_->makeKeyDeriver(fizz::CipherSuite::TLS_AES_128_GCM_SHA256); auto pnKey = deriver->expandLabel( @@ -68,15 +81,17 @@ std::unique_ptr FizzCryptoFactory::makePacketNumberCipher( return pnCipher; } -std::unique_ptr FizzCryptoFactory::makePacketNumberCipher( - fizz::CipherSuite cipher) const { +folly::Expected, QuicError> +FizzCryptoFactory::makePacketNumberCipher(fizz::CipherSuite cipher) const { switch (cipher) { case fizz::CipherSuite::TLS_AES_128_GCM_SHA256: return std::make_unique(); case fizz::CipherSuite::TLS_AES_256_GCM_SHA384: return std::make_unique(); default: - throw std::runtime_error("Packet number cipher not implemented"); + return folly::makeUnexpected(QuicError( + TransportErrorCode::INTERNAL_ERROR, + "Packet number cipher not implemented")); } } diff --git a/quic/fizz/handshake/FizzCryptoFactory.h b/quic/fizz/handshake/FizzCryptoFactory.h index 3a10c49aa..c022af1f6 100644 --- a/quic/fizz/handshake/FizzCryptoFactory.h +++ b/quic/fizz/handshake/FizzCryptoFactory.h @@ -16,21 +16,23 @@ class FizzCryptoFactory : public CryptoFactory { public: FizzCryptoFactory() : fizzFactory_{std::make_shared()} {} - BufPtr makeInitialTrafficSecret( + [[nodiscard]] folly::Expected makeInitialTrafficSecret( folly::StringPiece label, const ConnectionId& clientDestinationConnId, QuicVersion version) const override; - std::unique_ptr makeInitialAead( + [[nodiscard]] folly::Expected, QuicError> + makeInitialAead( folly::StringPiece label, const ConnectionId& clientDestinationConnId, QuicVersion version) const override; - std::unique_ptr makePacketNumberCipher( - ByteRange baseSecret) const override; + [[nodiscard]] folly::Expected, QuicError> + makePacketNumberCipher(ByteRange baseSecret) const override; - virtual std::unique_ptr makePacketNumberCipher( - fizz::CipherSuite cipher) const; + [[nodiscard]] virtual folly:: + Expected, QuicError> + makePacketNumberCipher(fizz::CipherSuite cipher) const; [[nodiscard]] std::function getCryptoEqualFunction() const override; diff --git a/quic/fizz/handshake/test/FizzPacketNumberCipherTest.cpp b/quic/fizz/handshake/test/FizzPacketNumberCipherTest.cpp index 7972e259a..fd66a8e1c 100644 --- a/quic/fizz/handshake/test/FizzPacketNumberCipherTest.cpp +++ b/quic/fizz/handshake/test/FizzPacketNumberCipherTest.cpp @@ -61,7 +61,9 @@ struct CipherBytes { TEST_P(LongPacketNumberCipherTest, TestEncryptDecrypt) { FizzCryptoFactory cryptoFactory; - auto cipher = cryptoFactory.makePacketNumberCipher(GetParam().cipher); + auto cipherResult = cryptoFactory.makePacketNumberCipher(GetParam().cipher); + ASSERT_FALSE(cipherResult.hasError()); + auto cipher = std::move(cipherResult.value()); auto key = folly::unhexlify(GetParam().key); EXPECT_EQ(cipher->keyLength(), key.size()); cipher->setKey(folly::range(key)); diff --git a/quic/fizz/server/handshake/FizzServerHandshake.cpp b/quic/fizz/server/handshake/FizzServerHandshake.cpp index eb27720b2..b1ceae184 100644 --- a/quic/fizz/server/handshake/FizzServerHandshake.cpp +++ b/quic/fizz/server/handshake/FizzServerHandshake.cpp @@ -97,8 +97,8 @@ std::unique_ptr FizzServerHandshake::buildAead(ByteRange secret) { kQuicIVLabel)); } -std::unique_ptr FizzServerHandshake::buildHeaderCipher( - ByteRange secret) { +folly::Expected, QuicError> +FizzServerHandshake::buildHeaderCipher(ByteRange secret) { return cryptoFactory_->makePacketNumberCipher(secret); } diff --git a/quic/fizz/server/handshake/FizzServerHandshake.h b/quic/fizz/server/handshake/FizzServerHandshake.h index 0a7cacef1..2638849ef 100644 --- a/quic/fizz/server/handshake/FizzServerHandshake.h +++ b/quic/fizz/server/handshake/FizzServerHandshake.h @@ -40,8 +40,8 @@ class FizzServerHandshake : public ServerHandshake { EncryptionLevel getReadRecordLayerEncryptionLevel() override; void processSocketData(folly::IOBufQueue& queue) override; std::unique_ptr buildAead(ByteRange secret) override; - std::unique_ptr buildHeaderCipher( - ByteRange secret) override; + [[nodiscard]] folly::Expected, QuicError> + buildHeaderCipher(ByteRange secret) override; BufPtr getNextTrafficSecret(ByteRange secret) const override; void processAccept() override; diff --git a/quic/handshake/CryptoFactory.cpp b/quic/handshake/CryptoFactory.cpp index 765690275..d1b018f82 100644 --- a/quic/handshake/CryptoFactory.cpp +++ b/quic/handshake/CryptoFactory.cpp @@ -11,47 +11,59 @@ namespace quic { -std::unique_ptr CryptoFactory::getClientInitialCipher( +folly::Expected, QuicError> +CryptoFactory::getClientInitialCipher( const ConnectionId& clientDestinationConnId, QuicVersion version) const { return makeInitialAead(kClientInitialLabel, clientDestinationConnId, version); } -std::unique_ptr CryptoFactory::getServerInitialCipher( +folly::Expected, QuicError> +CryptoFactory::getServerInitialCipher( const ConnectionId& clientDestinationConnId, QuicVersion version) const { return makeInitialAead(kServerInitialLabel, clientDestinationConnId, version); } -BufPtr CryptoFactory::makeServerInitialTrafficSecret( +folly::Expected +CryptoFactory::makeServerInitialTrafficSecret( const ConnectionId& clientDestinationConnId, QuicVersion version) const { return makeInitialTrafficSecret( kServerInitialLabel, clientDestinationConnId, version); } -BufPtr CryptoFactory::makeClientInitialTrafficSecret( +folly::Expected +CryptoFactory::makeClientInitialTrafficSecret( const ConnectionId& clientDestinationConnId, QuicVersion version) const { return makeInitialTrafficSecret( kClientInitialLabel, clientDestinationConnId, version); } -std::unique_ptr +folly::Expected, QuicError> CryptoFactory::makeClientInitialHeaderCipher( const ConnectionId& initialDestinationConnectionId, QuicVersion version) const { - auto clientInitialTrafficSecret = + auto clientInitialTrafficSecretResult = makeClientInitialTrafficSecret(initialDestinationConnectionId, version); + if (clientInitialTrafficSecretResult.hasError()) { + return folly::makeUnexpected(clientInitialTrafficSecretResult.error()); + } + auto& clientInitialTrafficSecret = clientInitialTrafficSecretResult.value(); return makePacketNumberCipher(clientInitialTrafficSecret->coalesce()); } -std::unique_ptr +folly::Expected, QuicError> CryptoFactory::makeServerInitialHeaderCipher( const ConnectionId& initialDestinationConnectionId, QuicVersion version) const { - auto serverInitialTrafficSecret = + auto serverInitialTrafficSecretResult = makeServerInitialTrafficSecret(initialDestinationConnectionId, version); + if (serverInitialTrafficSecretResult.hasError()) { + return folly::makeUnexpected(serverInitialTrafficSecretResult.error()); + } + auto& serverInitialTrafficSecret = serverInitialTrafficSecretResult.value(); return makePacketNumberCipher(serverInitialTrafficSecret->coalesce()); } diff --git a/quic/handshake/CryptoFactory.h b/quic/handshake/CryptoFactory.h index dc8f2d004..6317188c2 100644 --- a/quic/handshake/CryptoFactory.h +++ b/quic/handshake/CryptoFactory.h @@ -7,7 +7,9 @@ #pragma once +#include #include +#include #include #include #include @@ -19,50 +21,59 @@ namespace quic { class CryptoFactory { public: - std::unique_ptr getClientInitialCipher( + [[nodiscard]] folly::Expected, QuicError> + getClientInitialCipher( const ConnectionId& clientDestinationConnId, QuicVersion version) const; - std::unique_ptr getServerInitialCipher( + [[nodiscard]] folly::Expected, QuicError> + getServerInitialCipher( const ConnectionId& clientDestinationConnId, QuicVersion version) const; - BufPtr makeServerInitialTrafficSecret( + [[nodiscard]] folly::Expected + makeServerInitialTrafficSecret( const ConnectionId& clientDestinationConnId, QuicVersion version) const; - BufPtr makeClientInitialTrafficSecret( + [[nodiscard]] folly::Expected + makeClientInitialTrafficSecret( const ConnectionId& clientDestinationConnId, QuicVersion version) const; /** * Makes the header cipher for writing client initial packets. */ - std::unique_ptr makeClientInitialHeaderCipher( + [[nodiscard]] folly::Expected, QuicError> + makeClientInitialHeaderCipher( const ConnectionId& initialDestinationConnectionId, QuicVersion version) const; /** * Makes the header cipher for writing server initial packets. */ - std::unique_ptr makeServerInitialHeaderCipher( + [[nodiscard]] folly::Expected, QuicError> + makeServerInitialHeaderCipher( const ConnectionId& initialDestinationConnectionId, QuicVersion version) const; /** * Crypto layer specific methods. */ - virtual BufPtr makeInitialTrafficSecret( + [[nodiscard]] virtual folly::Expected + makeInitialTrafficSecret( folly::StringPiece label, const ConnectionId& clientDestinationConnId, QuicVersion version) const = 0; - virtual std::unique_ptr makeInitialAead( + [[nodiscard]] virtual folly::Expected, QuicError> + makeInitialAead( folly::StringPiece label, const ConnectionId& clientDestinationConnId, QuicVersion version) const = 0; - virtual std::unique_ptr makePacketNumberCipher( - ByteRange baseSecret) const = 0; + [[nodiscard]] virtual folly:: + Expected, QuicError> + makePacketNumberCipher(ByteRange baseSecret) const = 0; [[nodiscard]] virtual std::function getCryptoEqualFunction() const = 0; diff --git a/quic/loss/test/QuicLossFunctionsTest.cpp b/quic/loss/test/QuicLossFunctionsTest.cpp index 89b5e479a..63526e444 100644 --- a/quic/loss/test/QuicLossFunctionsTest.cpp +++ b/quic/loss/test/QuicLossFunctionsTest.cpp @@ -82,7 +82,7 @@ class QuicLossFunctionsTest : public TestWithParam { public: void SetUp() override { aead = createNoOpAead(); - headerCipher = createNoOpHeaderCipher(); + headerCipher = createNoOpHeaderCipher().value(); quicStats_ = std::make_unique(); connIdAlgo_ = std::make_unique(); socket_ = std::make_unique(); @@ -1287,7 +1287,7 @@ TEST_F(QuicLossFunctionsTest, PTONoLongerMarksPacketsToBeRetransmitted) { TEST_F(QuicLossFunctionsTest, PTOWithHandshakePackets) { auto conn = createConn(); conn->handshakeWriteCipher = createNoOpAead(); - conn->handshakeWriteHeaderCipher = createNoOpHeaderCipher(); + conn->handshakeWriteHeaderCipher = createNoOpHeaderCipher().value(); auto mockQLogger = std::make_shared(VantagePoint::Server); conn->qLogger = mockQLogger; auto mockCongestionController = std::make_unique(); @@ -1338,11 +1338,11 @@ TEST_F(QuicLossFunctionsTest, PTOWithLostInitialData) { auto conn = createConn(); conn->initialWriteCipher = createNoOpAead(); - conn->initialHeaderCipher = createNoOpHeaderCipher(); + conn->initialHeaderCipher = createNoOpHeaderCipher().value(); conn->handshakeWriteCipher = createNoOpAead(); - conn->handshakeWriteHeaderCipher = createNoOpHeaderCipher(); + conn->handshakeWriteHeaderCipher = createNoOpHeaderCipher().value(); conn->oneRttWriteCipher = createNoOpAead(); - conn->oneRttWriteHeaderCipher = createNoOpHeaderCipher(); + conn->oneRttWriteHeaderCipher = createNoOpHeaderCipher().value(); auto buf = buildRandomInputData(20); WriteStreamBuffer initialData(ChainedByteRangeHead(buf), 0); @@ -1368,11 +1368,11 @@ TEST_F(QuicLossFunctionsTest, PTOWithLostHandshakeData) { auto conn = createConn(); conn->initialWriteCipher = createNoOpAead(); - conn->initialHeaderCipher = createNoOpHeaderCipher(); + conn->initialHeaderCipher = createNoOpHeaderCipher().value(); conn->handshakeWriteCipher = createNoOpAead(); - conn->handshakeWriteHeaderCipher = createNoOpHeaderCipher(); + conn->handshakeWriteHeaderCipher = createNoOpHeaderCipher().value(); conn->oneRttWriteCipher = createNoOpAead(); - conn->oneRttWriteHeaderCipher = createNoOpHeaderCipher(); + conn->oneRttWriteHeaderCipher = createNoOpHeaderCipher().value(); auto buf = buildRandomInputData(20); WriteStreamBuffer handshakeData(ChainedByteRangeHead(buf), 0); @@ -1399,11 +1399,11 @@ TEST_F(QuicLossFunctionsTest, PTOWithLostAppData) { auto conn = createConn(); conn->initialWriteCipher = createNoOpAead(); - conn->initialHeaderCipher = createNoOpHeaderCipher(); + conn->initialHeaderCipher = createNoOpHeaderCipher().value(); conn->handshakeWriteCipher = createNoOpAead(); - conn->handshakeWriteHeaderCipher = createNoOpHeaderCipher(); + conn->handshakeWriteHeaderCipher = createNoOpHeaderCipher().value(); conn->oneRttWriteCipher = createNoOpAead(); - conn->oneRttWriteHeaderCipher = createNoOpHeaderCipher(); + conn->oneRttWriteHeaderCipher = createNoOpHeaderCipher().value(); auto buf = buildRandomInputData(20); WriteStreamBuffer appData(ChainedByteRangeHead(buf), 0); @@ -1428,13 +1428,13 @@ TEST_F(QuicLossFunctionsTest, PTOAvoidPointless) { auto conn = createConn(); conn->initialWriteCipher = createNoOpAead(); - conn->initialHeaderCipher = createNoOpHeaderCipher(); + conn->initialHeaderCipher = createNoOpHeaderCipher().value(); conn->handshakeWriteCipher = createNoOpAead(); - conn->handshakeWriteHeaderCipher = createNoOpHeaderCipher(); + conn->handshakeWriteHeaderCipher = createNoOpHeaderCipher().value(); conn->oneRttWriteCipher = createNoOpAead(); - conn->oneRttWriteHeaderCipher = createNoOpHeaderCipher(); + conn->oneRttWriteHeaderCipher = createNoOpHeaderCipher().value(); conn->outstandings.packetCount[PacketNumberSpace::Initial] = 1; conn->outstandings.packetCount[PacketNumberSpace::Handshake] = 1; diff --git a/quic/server/handshake/ServerHandshake.cpp b/quic/server/handshake/ServerHandshake.cpp index 8c38d4651..de15fa15d 100644 --- a/quic/server/handshake/ServerHandshake.cpp +++ b/quic/server/handshake/ServerHandshake.cpp @@ -535,7 +535,16 @@ void ServerHandshake::processActions( void ServerHandshake::computeCiphers(CipherKind kind, ByteRange secret) { std::unique_ptr aead = buildAead(secret); - std::unique_ptr headerCipher = buildHeaderCipher(secret); + auto headerCipherResult = buildHeaderCipher(secret); + if (headerCipherResult.hasError()) { + LOG(ERROR) << "Failed to build header cipher"; + onError(std::make_pair( + "Failed to build header cipher", TransportErrorCode::INTERNAL_ERROR)); + return; + } + std::unique_ptr headerCipher = + std::move(headerCipherResult.value()); + switch (kind) { case CipherKind::HandshakeRead: handshakeReadCipher_ = std::move(aead); diff --git a/quic/server/handshake/ServerHandshake.h b/quic/server/handshake/ServerHandshake.h index 2b942051f..649cc6916 100644 --- a/quic/server/handshake/ServerHandshake.h +++ b/quic/server/handshake/ServerHandshake.h @@ -322,8 +322,9 @@ class ServerHandshake : public Handshake { virtual EncryptionLevel getReadRecordLayerEncryptionLevel() = 0; virtual void processSocketData(folly::IOBufQueue& queue) = 0; virtual std::unique_ptr buildAead(ByteRange secret) = 0; - virtual std::unique_ptr buildHeaderCipher( - ByteRange secret) = 0; + [[nodiscard]] virtual folly:: + Expected, QuicError> + buildHeaderCipher(ByteRange secret) = 0; virtual void processAccept() = 0; /* diff --git a/quic/server/state/ServerStateMachine.cpp b/quic/server/state/ServerStateMachine.cpp index da046acd1..977dc862e 100644 --- a/quic/server/state/ServerStateMachine.cpp +++ b/quic/server/state/ServerStateMachine.cpp @@ -1037,8 +1037,13 @@ folly::Expected onServerReadDataFromOpen( conn.serverHandshakeLayer->getCryptoFactory(); conn.readCodec = std::make_unique(QuicNodeType::Server); conn.readCodec->setConnectionStatsCallback(conn.statsCallback); - conn.readCodec->setInitialReadCipher(cryptoFactory.getClientInitialCipher( - initialDestinationConnectionId, version)); + auto clientInitialCipherResult = cryptoFactory.getClientInitialCipher( + initialDestinationConnectionId, version); + if (clientInitialCipherResult.hasError()) { + return folly::makeUnexpected(clientInitialCipherResult.error()); + } + conn.readCodec->setInitialReadCipher( + std::move(clientInitialCipherResult.value())); conn.readCodec->setClientConnectionId(clientConnectionId); conn.readCodec->setServerConnectionId(*conn.serverConnectionId); if (conn.qLogger) { @@ -1050,14 +1055,30 @@ folly::Expected onServerReadDataFromOpen( version, conn.transportSettings.maybeAckReceiveTimestampsConfigSentToPeer, conn.transportSettings.advertisedExtendedAckFeatures)); - conn.initialWriteCipher = cryptoFactory.getServerInitialCipher( + auto serverInitialCipherResult = cryptoFactory.getServerInitialCipher( initialDestinationConnectionId, version); + if (serverInitialCipherResult.hasError()) { + return folly::makeUnexpected(serverInitialCipherResult.error()); + } + conn.initialWriteCipher = std::move(serverInitialCipherResult.value()); - conn.readCodec->setInitialHeaderCipher( + auto clientInitialHeaderCipherResult = cryptoFactory.makeClientInitialHeaderCipher( - initialDestinationConnectionId, version)); - conn.initialHeaderCipher = cryptoFactory.makeServerInitialHeaderCipher( - initialDestinationConnectionId, version); + initialDestinationConnectionId, version); + if (clientInitialHeaderCipherResult.hasError()) { + return folly::makeUnexpected(clientInitialHeaderCipherResult.error()); + } + conn.readCodec->setInitialHeaderCipher( + std::move(clientInitialHeaderCipherResult.value())); + + auto serverInitialHeaderCipherResult = + cryptoFactory.makeServerInitialHeaderCipher( + initialDestinationConnectionId, version); + if (serverInitialHeaderCipherResult.hasError()) { + return folly::makeUnexpected(serverInitialHeaderCipherResult.error()); + } + conn.initialHeaderCipher = + std::move(serverInitialHeaderCipherResult.value()); conn.peerAddress = conn.originalPeerAddress; } BufQueue& udpData = readData.udpPacket.buf; diff --git a/quic/server/test/QuicServerTest.cpp b/quic/server/test/QuicServerTest.cpp index 013079bc4..b51bddbe8 100644 --- a/quic/server/test/QuicServerTest.cpp +++ b/quic/server/test/QuicServerTest.cpp @@ -346,7 +346,7 @@ void QuicServerWorkerTest::testSendReset( .WillRepeatedly( Invoke([&](auto&, auto, auto) { return std::nullopt; })); codec.setOneRttReadCipher(std::move(aead)); - codec.setOneRttHeaderCipher(test::createNoOpHeaderCipher()); + codec.setOneRttHeaderCipher(test::createNoOpHeaderCipher().value()); StatelessResetToken token = generateStatelessResetToken(); codec.setStatelessResetToken(token); FizzCryptoFactory cryptoFactory; @@ -2966,7 +2966,7 @@ void QuicServerTest::testReset(BufPtr packet) { EXPECT_CALL(*aead, _tryDecrypt(_, _, _)) .WillRepeatedly(Invoke([&](auto&, auto, auto) { return std::nullopt; })); codec.setOneRttReadCipher(std::move(aead)); - codec.setOneRttHeaderCipher(test::createNoOpHeaderCipher()); + codec.setOneRttHeaderCipher(test::createNoOpHeaderCipher().value()); StatelessResetToken token = generateStatelessResetToken(); codec.setStatelessResetToken(token); FizzCryptoFactory cryptoFactory; diff --git a/quic/server/test/QuicServerTransportTest.cpp b/quic/server/test/QuicServerTransportTest.cpp index 8f4323bc4..e35b4eb0d 100644 --- a/quic/server/test/QuicServerTransportTest.cpp +++ b/quic/server/test/QuicServerTransportTest.cpp @@ -3491,7 +3491,7 @@ TEST_F(QuicServerTransportTest, InvokeTxCallbacksSingleByteDSR) { EXPECT_CALL(dsrByteTxCb, onByteEvent(getTxMatcher(stream, 1))).Times(1); EXPECT_CALL(lastByteTxCb, onByteEvent(getTxMatcher(stream, 1))).Times(1); server->getConnectionState().oneRttWriteCipher = test::createNoOpAead(); - auto temp = test::createNoOpHeaderCipher(); + auto temp = test::createNoOpHeaderCipher().value(); temp->setDefaultKey(); server->getConnectionState().oneRttWriteHeaderCipher = std::move(temp); CHECK(server->getConnectionState().oneRttWriteCipher->getKey().has_value()); @@ -3700,8 +3700,10 @@ TEST_F(QuicUnencryptedServerTransportTest, TestBadPacketProtectionLevel) { TEST_F(QuicUnencryptedServerTransportTest, TestBadCleartextEncryption) { FizzCryptoFactory cryptoFactory; PacketNum nextPacket = clientNextInitialPacketNum++; - auto aead = cryptoFactory.getServerInitialCipher( - *clientConnectionId, QuicVersion::MVFST); + auto aead = + cryptoFactory + .getServerInitialCipher(*clientConnectionId, QuicVersion::MVFST) + .value(); auto chloBuf = IOBuf::copyBuffer("CHLO"); ChainedByteRangeHead chloRch(chloBuf); auto packetData = packetToBufCleartext( diff --git a/quic/server/test/QuicServerTransportTestUtil.h b/quic/server/test/QuicServerTransportTestUtil.h index d5227341a..c2d8059db 100644 --- a/quic/server/test/QuicServerTransportTestUtil.h +++ b/quic/server/test/QuicServerTransportTestUtil.h @@ -304,15 +304,17 @@ class QuicServerTransportTestBase : public virtual testing::Test { std::unique_ptr getInitialCipher( QuicVersion version = QuicVersion::MVFST) { FizzCryptoFactory cryptoFactory; - return cryptoFactory.getClientInitialCipher( - *initialDestinationConnectionId, version); + return cryptoFactory + .getClientInitialCipher(*initialDestinationConnectionId, version) + .value(); } std::unique_ptr getInitialHeaderCipher( QuicVersion version = QuicVersion::MVFST) { FizzCryptoFactory cryptoFactory; - return cryptoFactory.makeClientInitialHeaderCipher( - *initialDestinationConnectionId, version); + return cryptoFactory + .makeClientInitialHeaderCipher(*initialDestinationConnectionId, version) + .value(); } BufPtr recvEncryptedStream( @@ -367,7 +369,7 @@ class QuicServerTransportTestBase : public virtual testing::Test { QuicVersion version = QuicVersion::MVFST) { auto finished = folly::IOBuf::copyBuffer("FINISHED"); auto nextPacketNum = clientNextHandshakePacketNum++; - auto headerCipher = test::createNoOpHeaderCipher(); + auto headerCipher = test::createNoOpHeaderCipher().value(); uint64_t offset = getCryptoStream( *server->getConn().cryptoState, EncryptionLevel::Handshake) @@ -395,11 +397,16 @@ class QuicServerTransportTestBase : public virtual testing::Test { FizzCryptoFactory cryptoFactory; clientReadCodec = std::make_unique(QuicNodeType::Client); clientReadCodec->setClientConnectionId(*clientConnectionId); - clientReadCodec->setInitialReadCipher(cryptoFactory.getServerInitialCipher( - *initialDestinationConnectionId, QuicVersion::MVFST)); + clientReadCodec->setInitialReadCipher( + cryptoFactory + .getServerInitialCipher( + *initialDestinationConnectionId, QuicVersion::MVFST) + .value()); clientReadCodec->setInitialHeaderCipher( - cryptoFactory.makeServerInitialHeaderCipher( - *initialDestinationConnectionId, QuicVersion::MVFST)); + cryptoFactory + .makeServerInitialHeaderCipher( + *initialDestinationConnectionId, QuicVersion::MVFST) + .value()); clientReadCodec->setCodecParameters( CodecParameters(kDefaultAckDelayExponent, QuicVersion::MVFST)); } @@ -601,18 +608,23 @@ class QuicServerTransportTestBase : public virtual testing::Test { FizzCryptoFactory cryptoFactory; auto readCodec = std::make_unique(QuicNodeType::Client); readCodec->setOneRttReadCipher(test::createNoOpAead()); - readCodec->setOneRttHeaderCipher(test::createNoOpHeaderCipher()); + readCodec->setOneRttHeaderCipher(test::createNoOpHeaderCipher().value()); readCodec->setHandshakeReadCipher(test::createNoOpAead()); - readCodec->setHandshakeHeaderCipher(test::createNoOpHeaderCipher()); + readCodec->setHandshakeHeaderCipher(test::createNoOpHeaderCipher().value()); readCodec->setClientConnectionId(*clientConnectionId); readCodec->setCodecParameters( CodecParameters(kDefaultAckDelayExponent, QuicVersion::MVFST)); if (handshakeCipher) { - readCodec->setInitialReadCipher(cryptoFactory.getServerInitialCipher( - *initialDestinationConnectionId, QuicVersion::MVFST)); + readCodec->setInitialReadCipher( + cryptoFactory + .getServerInitialCipher( + *initialDestinationConnectionId, QuicVersion::MVFST) + .value()); readCodec->setInitialHeaderCipher( - cryptoFactory.makeServerInitialHeaderCipher( - *initialDestinationConnectionId, QuicVersion::MVFST)); + cryptoFactory + .makeServerInitialHeaderCipher( + *initialDestinationConnectionId, QuicVersion::MVFST) + .value()); } return readCodec; } diff --git a/quic/state/test/AckHandlersTest.cpp b/quic/state/test/AckHandlersTest.cpp index 9e42ddf92..f2812e12f 100644 --- a/quic/state/test/AckHandlersTest.cpp +++ b/quic/state/test/AckHandlersTest.cpp @@ -4426,7 +4426,7 @@ class AckEventForAppDataTest : public Test { public: void SetUp() override { aead_ = test::createNoOpAead(); - headerCipher_ = test::createNoOpHeaderCipher(); + headerCipher_ = test::createNoOpHeaderCipher().value(); conn_ = createConn(); } @@ -4445,7 +4445,7 @@ class AckEventForAppDataTest : public Test { conn->flowControlState.peerAdvertisedMaxOffset = kDefaultConnectionFlowControlWindow; conn->initialWriteCipher = createNoOpAead(); - conn->initialHeaderCipher = createNoOpHeaderCipher(); + conn->initialHeaderCipher = createNoOpHeaderCipher().value(); CHECK( !conn->streamManager ->setMaxLocalBidirectionalStreams(kDefaultMaxStreamsBidirectional)