diff --git a/quic/api/QuicTransportFunctions.cpp b/quic/api/QuicTransportFunctions.cpp index f5c5d96ff..aa0a946bd 100644 --- a/quic/api/QuicTransportFunctions.cpp +++ b/quic/api/QuicTransportFunctions.cpp @@ -126,7 +126,7 @@ WriteQuicDataResult writeQuicDataToSocketImpl( uint64_t packetLimit, bool exceptCryptoStream, TimePoint writeLoopBeginTime) { - auto builder = ShortHeaderBuilder(); + auto builder = ShortHeaderBuilder(connection.oneRttWritePhase); WriteQuicDataResult result; auto& packetsWritten = result.packetsWritten; auto& probesWritten = result.probesWritten; @@ -843,6 +843,7 @@ void updateConnection( conn.lossState.totalPacketsSent++; conn.lossState.totalStreamBytesSent += streamBytesSent; conn.lossState.totalNewStreamBytesSent += newStreamBytesSent; + conn.oneRttWritePacketsSentInCurrentPhase++; if (!retransmittable && !isPing) { DCHECK(!packetEvent); @@ -1009,13 +1010,14 @@ HeaderBuilder LongHeaderBuilder(LongHeader::Types packetType) { }; } -HeaderBuilder ShortHeaderBuilder() { - return [](const ConnectionId& /* srcConnId */, - const ConnectionId& dstConnId, - PacketNum packetNum, - QuicVersion, - const std::string&) { - return ShortHeader(ProtectionType::KeyPhaseZero, dstConnId, packetNum); +HeaderBuilder ShortHeaderBuilder(ProtectionType keyPhase) { + return [keyPhase]( + const ConnectionId& /* srcConnId */, + const ConnectionId& dstConnId, + PacketNum packetNum, + QuicVersion, + const std::string&) { + return ShortHeader(keyPhase, dstConnId, packetNum); }; } @@ -1903,4 +1905,18 @@ bool toWriteAppDataAcks(const quic::QuicConnectionStateBase& conn) { hasAcksToSchedule(conn.ackStates.appDataAckState) && conn.ackStates.appDataAckState.needsToSendAckImmediately); } + +void updateOneRttWriteCipher( + quic::QuicConnectionStateBase& conn, + std::unique_ptr aead, + ProtectionType oneRttPhase) { + CHECK( + oneRttPhase == ProtectionType::KeyPhaseZero || + oneRttPhase == ProtectionType::KeyPhaseOne); + CHECK(oneRttPhase != conn.oneRttWritePhase) + << "Cannot replace cipher for current write phase"; + conn.oneRttWriteCipher = std::move(aead); + conn.oneRttWritePhase = oneRttPhase; + conn.oneRttWritePacketsSentInCurrentPhase = 0; +} } // namespace quic diff --git a/quic/api/QuicTransportFunctions.h b/quic/api/QuicTransportFunctions.h index 7547e66c6..81b42ea78 100644 --- a/quic/api/QuicTransportFunctions.h +++ b/quic/api/QuicTransportFunctions.h @@ -311,7 +311,7 @@ WriteQuicDataResult writeProbingDataToSocket( const std::string& token = std::string()); HeaderBuilder LongHeaderBuilder(LongHeader::Types packetType); -HeaderBuilder ShortHeaderBuilder(); +HeaderBuilder ShortHeaderBuilder(ProtectionType keyPhase); void maybeSendStreamLimitUpdates(QuicConnectionStateBase& conn); @@ -328,4 +328,9 @@ bool writeLoopTimeLimit( bool toWriteInitialAcks(const quic::QuicConnectionStateBase& conn); bool toWriteHandshakeAcks(const quic::QuicConnectionStateBase& conn); bool toWriteAppDataAcks(const quic::QuicConnectionStateBase& conn); + +void updateOneRttWriteCipher( + QuicConnectionStateBase& conn, + std::unique_ptr aead, + ProtectionType oneRttPhase); } // namespace quic diff --git a/quic/api/test/QuicTransportFunctionsTest.cpp b/quic/api/test/QuicTransportFunctionsTest.cpp index 2ee084fad..0118ba24d 100644 --- a/quic/api/test/QuicTransportFunctionsTest.cpp +++ b/quic/api/test/QuicTransportFunctionsTest.cpp @@ -48,7 +48,7 @@ uint64_t writeProbingDataToSocketForTest( conn, *conn.clientConnectionId, *conn.serverConnectionId, - ShortHeaderBuilder(), + ShortHeaderBuilder(conn.oneRttWritePhase), EncryptionLevel::AppData, PacketNumberSpace::AppData, scheduler, diff --git a/quic/codec/QuicReadCodec.cpp b/quic/codec/QuicReadCodec.cpp index ad2e4efff..bcf58706f 100644 --- a/quic/codec/QuicReadCodec.cpp +++ b/quic/codec/QuicReadCodec.cpp @@ -288,23 +288,61 @@ CodecResult QuicReadCodec::tryParseShortHeaderPacket( return CodecResult(Nothing()); } shortHeader->setPacketNumber(packetNum.first); - if (shortHeader->getProtectionType() == ProtectionType::KeyPhaseOne) { - VLOG(4) << nodeToString(nodeType_) << " cannot read key phase one packet " - << connIdToHex(); - return CodecResult(Nothing()); + bool peerKeyUpdateAttempt = false; + auto oneRttReadCipherToUse = [&]() -> Aead* { + if (shortHeader->getProtectionType() == currentOneRttReadPhase_) { + return currentOneRttReadCipher_.get(); + } else { + // This is a packet from a different phase. It may be encrypted using the + // next key (new peer-initiated key update) or the previous key (out of + // order packet or pending locally-initiated key update). + + if (!currentOneRttReadPhaseStartPacketNum_ || + shortHeader->getPacketSequenceNum() < + currentOneRttReadPhaseStartPacketNum_.value()) { + // There is either a pending key update or this an out-of-order packet, + // attempt to use the previous cipher + if (previousOneRttReadCipher_) { + return previousOneRttReadCipher_.get(); + } else { + // There is no previous packet. We can't decrypt this packet + VLOG(4) + << nodeToString(nodeType_) + << " cannot read packet using previous cipher. Cipher is not available"; + return nullptr; + } + } else { + // This is a key update attempt + if (nextOneRttReadCipher_) { + peerKeyUpdateAttempt = true; + return nextOneRttReadCipher_.get(); + } else { + // The next cipher is not yet available. We can't decrypt this packet + VLOG(4) + << nodeToString(nodeType_) + << " unable to process key update. Next cipher is not yet available"; + return nullptr; + } + } + } + }(); + + if (oneRttReadCipherToUse == nullptr) { + return CodecResult( + CipherUnavailable(std::move(data), shortHeader->getProtectionType())); } - // We know that the iobuf is not chained. This means that we can safely have a - // non-owning reference to the header without cloning the buffer. If we don't - // clone the buffer, the buffer will not show up as shared and we can decrypt - // in-place. + // We know that the iobuf is not chained. This means that we can safely have + // a non-owning reference to the header without cloning the buffer. If we + // don't clone the buffer, the buffer will not show up as shared and we can + // decrypt in-place. size_t aadLen = packetNumberOffset + packetNum.second; folly::IOBuf headerData = folly::IOBuf::wrapBufferAsValue(data->data(), aadLen); data->trimStart(aadLen); Buf decrypted; - auto decryptAttempt = oneRttReadCipher_->tryDecrypt( + auto decryptAttempt = oneRttReadCipherToUse->tryDecrypt( std::move(data), &headerData, packetNum.first); if (!decryptAttempt) { auto protectionType = shortHeader->getProtectionType(); @@ -319,6 +357,28 @@ CodecResult QuicReadCodec::tryParseShortHeaderPacket( decrypted = folly::IOBuf::create(0); } + if (peerKeyUpdateAttempt) { + // Peer initiated a key update and we've successfully decrypted a packet + // from the next phase. We should advance our oneRttCipher state. + currentOneRttReadPhase_ = shortHeader->getProtectionType(); + currentOneRttReadPhaseStartPacketNum_.reset(); + previousOneRttReadCipher_ = std::move(currentOneRttReadCipher_); + currentOneRttReadCipher_ = std::move(nextOneRttReadCipher_); + // nextOneRttReadCipher_ will be populated by the transport + } + + if (!currentOneRttReadPhaseStartPacketNum_.has_value() && + oneRttReadCipherToUse == currentOneRttReadCipher_.get()) { + // This is the first packet in the current phase. Record the packet + // number. This applies for both peer-initiated and self-initiated key + // updates. + currentOneRttReadPhaseStartPacketNum_ = shortHeader->getPacketSequenceNum(); + } + + // TODO: Should we discard the previous cipher at some point? Keeping it + // around avoids the timing signals mentioned in the spec, but we could also + // drop it after 3 * PTO. + return decodeRegularPacket( std::move(*shortHeader), params_, std::move(decrypted)); } @@ -340,9 +400,8 @@ CodecResult QuicReadCodec::parsePacket( if (headerForm == HeaderForm::Long) { return parseLongHeaderPacket(queue, ackStates); } - // Missing 1-rtt Cipher is the only case we wouldn't consider reset - // TODO: support key phase one. - if (!oneRttReadCipher_ || !oneRttHeaderCipher_) { + // Missing 1-rtt header cipher is the only case we wouldn't consider reset + if (!currentOneRttReadCipher_ || !oneRttHeaderCipher_) { VLOG(4) << nodeToString(nodeType_) << " cannot read key phase zero packet"; VLOG(20) << "cannot read data=" << folly::hexlify(queue.front()->clone()->moveToFbString()) << " " @@ -379,7 +438,7 @@ CodecResult QuicReadCodec::parsePacket( } const Aead* QuicReadCodec::getOneRttReadCipher() const { - return oneRttReadCipher_.get(); + return currentOneRttReadCipher_.get(); } const Aead* QuicReadCodec::getZeroRttReadCipher() const { @@ -406,7 +465,12 @@ void QuicReadCodec::setInitialReadCipher( void QuicReadCodec::setOneRttReadCipher( std::unique_ptr oneRttReadCipher) { - oneRttReadCipher_ = std::move(oneRttReadCipher); + currentOneRttReadCipher_ = std::move(oneRttReadCipher); +} + +void QuicReadCodec::setNextOneRttReadCipher( + std::unique_ptr oneRttReadCipher) { + nextOneRttReadCipher_ = std::move(oneRttReadCipher); } void QuicReadCodec::setZeroRttReadCipher( @@ -498,6 +562,10 @@ folly::Optional QuicReadCodec::getHandshakeDoneTime() { return handshakeDoneTime_; } +ProtectionType QuicReadCodec::getCurrentOneRttReadPhase() const { + return currentOneRttReadPhase_; +} + std::string QuicReadCodec::connIdToHex() const { static ConnectionId zeroConn = zeroConnId(); const auto& serverId = serverConnectionId_.value_or(zeroConn); diff --git a/quic/codec/QuicReadCodec.h b/quic/codec/QuicReadCodec.h index bcae6932c..608690766 100644 --- a/quic/codec/QuicReadCodec.h +++ b/quic/codec/QuicReadCodec.h @@ -126,10 +126,13 @@ class QuicReadCodec { const folly::Optional& getStatelessResetToken() const; + [[nodiscard]] ProtectionType getCurrentOneRttReadPhase() const; + CodecParameters getCodecParameters() const; void setInitialReadCipher(std::unique_ptr initialReadCipher); void setOneRttReadCipher(std::unique_ptr oneRttReadCipher); + void setNextOneRttReadCipher(std::unique_ptr oneRttReadCipher); void setZeroRttReadCipher(std::unique_ptr zeroRttReadCipher); void setHandshakeReadCipher(std::unique_ptr handshakeReadCipher); @@ -178,10 +181,18 @@ class QuicReadCodec { // Cipher used to decrypt handshake packets. std::unique_ptr initialReadCipher_; - std::unique_ptr oneRttReadCipher_; std::unique_ptr zeroRttReadCipher_; std::unique_ptr handshakeReadCipher_; + std::unique_ptr previousOneRttReadCipher_; + std::unique_ptr currentOneRttReadCipher_; + std::unique_ptr nextOneRttReadCipher_; + ProtectionType currentOneRttReadPhase_{ProtectionType::KeyPhaseZero}; + // The packet number of the first packet in the current 1-RTT phase + // It's not set when a key update is ongoing (i.e. the write key has been + // updated but no packets have been received with the corresponding read key) + folly::Optional currentOneRttReadPhaseStartPacketNum_; + std::unique_ptr initialHeaderCipher_; std::unique_ptr oneRttHeaderCipher_; std::unique_ptr zeroRttHeaderCipher_; diff --git a/quic/codec/test/QuicReadCodecTest.cpp b/quic/codec/test/QuicReadCodecTest.cpp index a83acd531..5e1609143 100644 --- a/quic/codec/test/QuicReadCodecTest.cpp +++ b/quic/codec/test/QuicReadCodecTest.cpp @@ -440,7 +440,8 @@ TEST_F(QuicReadCodecTest, RandomizedShortHeaderLeadsToReset) { uint8_t* packetHeaderBuffer = streamPacket.header.writableData(); uint8_t randomByte; folly::Random::secureRandom(&randomByte, 1); - *packetHeaderBuffer = 0x40 | (randomByte & 0b00111111); + // Do not randomize the HeaderForm bit, Fixed bit and Key Phase bit. + *packetHeaderBuffer = 0x40 | (randomByte & 0b00111011); AckStates ackStates; auto packetQueue = bufToQueue(packetToBuf(streamPacket)); auto packet = codec->parsePacket(packetQueue, ackStates); @@ -843,3 +844,475 @@ TEST_F(QuicReadCodecTest, parseEmptyDatagramFrame) { CodecParameters()), QuicTransportException); } + +TEST_F(QuicReadCodecTest, KeyUpdateIncomingValid) { + /* + * - Receive a packet in phase zero + * - Receive a packet in phase one --> triggers key update + * - Receive an out-of-order packet in phase zero --> uses previous cipher + * - Receive a packet in phase one --> uses current cipher + * - Receive an in-order packet in phase zero --> triggers another key update + * All packets are decrypted successfully. + */ + auto connId = getTestConnectionId(); + auto aead1 = std::make_unique(); + auto rawAead1 = aead1.get(); + + auto codec = makeEncryptedCodec( + connId, + std::move(aead1), + nullptr /* 0-rtt zead */, + nullptr /* stateless reset token*/, + QuicNodeType::Client); + + auto aead2 = std::make_unique(); + auto rawAead2 = aead2.get(); + codec->setNextOneRttReadCipher(std::move(aead2)); + + EXPECT_CALL(*rawAead1, _tryDecrypt(_, _, _)) + .Times(1) + .WillOnce(Invoke( + [](std::unique_ptr& cipherText, const auto&, auto) { + // Successful decryption + return std::move(cipherText); + })); + + // First packet in 1-rtt phase zero. + PacketNum packetNum = 2; + StreamId streamId = 2; + auto data = folly::IOBuf::create(30); + data->append(30); + auto streamPacket = createStreamPacket( + connId, + connId, + packetNum, + streamId, + *data, + 0 /* cipherOverhead */, + 0 /* largestAcked */, + folly::none, + true, + ProtectionType::KeyPhaseZero); + AckStates ackStates; + auto packetQueue = bufToQueue(packetToBuf(streamPacket)); + auto packet = codec->parsePacket(packetQueue, ackStates); + // Packet should be parsed successfully. + EXPECT_TRUE(packet.regularPacket() != nullptr); + + { + // Second packet is in 1-rtt phase one and should be decrypted using aead2 + EXPECT_CALL(*rawAead2, _tryDecrypt(_, _, _)) + .Times(1) + .WillOnce(Invoke( + [](std::unique_ptr& cipherText, const auto&, auto) { + // Successful decryption + return std::move(cipherText); + })); + packetNum = 3; + streamPacket = createStreamPacket( + connId, + connId, + packetNum, + streamId, + *data, + 0 /* cipherOverhead */, + 0 /* largestAcked */, + folly::none, + true, + ProtectionType::KeyPhaseOne); + packetQueue = bufToQueue(packetToBuf(streamPacket)); + packet = codec->parsePacket(packetQueue, ackStates); + + // Packet should be parsed successfully. + EXPECT_TRUE(packet.regularPacket() != nullptr); + // The read codec should advance to phase one. + EXPECT_EQ(codec->getCurrentOneRttReadPhase(), ProtectionType::KeyPhaseOne); + } + + auto aead3 = std::make_unique(); + auto rawAead3 = aead3.get(); + codec->setNextOneRttReadCipher(std::move(aead3)); + + { + // Third packet is in 1-rtt phase zero. This is an out of order packet and + // should be decrypted with the aead1 not aead3. + EXPECT_CALL(*rawAead3, _tryDecrypt(_, _, _)).Times(0); + EXPECT_CALL(*rawAead1, _tryDecrypt(_, _, _)) + .Times(1) + .WillOnce(Invoke( + [](std::unique_ptr& cipherText, const auto&, auto) { + // Successful decryption + return std::move(cipherText); + })); + packetNum = 1; + streamPacket = createStreamPacket( + connId, + connId, + packetNum, + streamId, + *data, + 0 /* cipherOverhead */, + 0 /* largestAcked */, + folly::none, + true, + ProtectionType::KeyPhaseZero); + packetQueue = bufToQueue(packetToBuf(streamPacket)); + packet = codec->parsePacket(packetQueue, ackStates); + // Packet should be parsed successfully. + EXPECT_TRUE(packet.regularPacket() != nullptr); + // The read codec should not advance to phase zero for an old packet. + EXPECT_EQ(codec->getCurrentOneRttReadPhase(), ProtectionType::KeyPhaseOne); + } + + { + // Forth packet is in 1-rtt phase one. This is in the current 1-rtt phase + // and should be handled by aead2. + EXPECT_CALL(*rawAead1, _tryDecrypt(_, _, _)).Times(0); + EXPECT_CALL(*rawAead3, _tryDecrypt(_, _, _)).Times(0); + EXPECT_CALL(*rawAead2, _tryDecrypt(_, _, _)) + .Times(1) + .WillOnce(Invoke( + [](std::unique_ptr& cipherText, const auto&, auto) { + // Successful decryption + return std::move(cipherText); + })); + packetNum = 4; + streamPacket = createStreamPacket( + connId, + connId, + packetNum, + streamId, + *data, + 0 /* cipherOverhead */, + 0 /* largestAcked */, + folly::none, + true, + ProtectionType::KeyPhaseOne); + packetQueue = bufToQueue(packetToBuf(streamPacket)); + packet = codec->parsePacket(packetQueue, ackStates); + // Packet should be parsed successfully. + EXPECT_TRUE(packet.regularPacket() != nullptr); + // The read codec should not advance to phase zero for an old packet. + EXPECT_EQ(codec->getCurrentOneRttReadPhase(), ProtectionType::KeyPhaseOne); + } + + { + // Fifth packet is in 1-rtt phase zero. Since it's in-order, it should + // trigger another key update + EXPECT_CALL(*rawAead1, _tryDecrypt(_, _, _)).Times(0); + EXPECT_CALL(*rawAead2, _tryDecrypt(_, _, _)).Times(0); + EXPECT_CALL(*rawAead3, _tryDecrypt(_, _, _)) + .Times(1) + .WillOnce(Invoke( + [](std::unique_ptr& cipherText, const auto&, auto) { + // Successful decryption + return std::move(cipherText); + })); + packetNum = 5; + streamPacket = createStreamPacket( + connId, + connId, + packetNum, + streamId, + *data, + 0 /* cipherOverhead */, + 0 /* largestAcked */, + folly::none, + true, + ProtectionType::KeyPhaseZero); + packetQueue = bufToQueue(packetToBuf(streamPacket)); + packet = codec->parsePacket(packetQueue, ackStates); + // Packet should be parsed successfully. + EXPECT_TRUE(packet.regularPacket() != nullptr); + // The read codec should advance to phase zero. + EXPECT_EQ(codec->getCurrentOneRttReadPhase(), ProtectionType::KeyPhaseZero); + } +} + +TEST_F(QuicReadCodecTest, KeyUpdateIncomingInvalid) { + /* + * - Receive a packet in phase zero + * - Receive a packet in phase one that cannot be decrypted with next key + * --> no key update + * - Receive a decryptable packet in phase one + * --> triggers key update + * - Receive an out-of-order packet in phase one + * --> does not check previous or next cipher. + * - Receive an in-order packet in phase zero that cannot be decrypted + --> only checks next cipher, no key update + */ + auto connId = getTestConnectionId(); + auto aead1 = std::make_unique(); + auto rawAead1 = aead1.get(); + + auto codec = makeEncryptedCodec( + connId, + std::move(aead1), + nullptr /* 0-rtt zead */, + nullptr /* stateless reset token*/, + QuicNodeType::Client); + + auto aead2 = std::make_unique(); + auto rawAead2 = aead2.get(); + codec->setNextOneRttReadCipher(std::move(aead2)); + + EXPECT_CALL(*rawAead1, _tryDecrypt(_, _, _)) + .Times(1) + .WillOnce(Invoke( + [](std::unique_ptr& cipherText, const auto&, auto) { + // Successful decryption + return std::move(cipherText); + })); + + // First packet in 1-rtt phase zero. + PacketNum packetNum = 2; + StreamId streamId = 2; + auto data = folly::IOBuf::create(30); + data->append(30); + auto streamPacket = createStreamPacket( + connId, + connId, + packetNum, + streamId, + *data, + 0 /* cipherOverhead */, + 0 /* largestAcked */, + folly::none, + true, + ProtectionType::KeyPhaseZero); + AckStates ackStates; + auto packetQueue = bufToQueue(packetToBuf(streamPacket)); + auto packet = codec->parsePacket(packetQueue, ackStates); + // Packet should be parsed successfully. + EXPECT_TRUE(packet.regularPacket() != nullptr); + // We're currently in phase zero. + EXPECT_EQ(codec->getCurrentOneRttReadPhase(), ProtectionType::KeyPhaseZero); + + { + // Second packet is in 1-rtt phase one. Decryption should be attempted with + // aead2 and fail. + EXPECT_CALL(*rawAead1, _tryDecrypt(_, _, _)).Times(0); + EXPECT_CALL(*rawAead2, _tryDecrypt(_, _, _)) + .Times(1) + .WillOnce(Invoke([](std::unique_ptr&, const auto&, auto) { + // Failed decryption + return folly::none; + })); + packetNum = 3; + streamPacket = createStreamPacket( + connId, + connId, + packetNum, + streamId, + *data, + 0 /* cipherOverhead */, + 0 /* largestAcked */, + folly::none, + true, + ProtectionType::KeyPhaseOne); + packetQueue = bufToQueue(packetToBuf(streamPacket)); + packet = codec->parsePacket(packetQueue, ackStates); + // Codec parsing should fail. + EXPECT_TRUE(packet.nothing()); + // The read codec should stay in phase zero + EXPECT_EQ(codec->getCurrentOneRttReadPhase(), ProtectionType::KeyPhaseZero); + } + + { + // Third packet is in 1-rtt phase one. It is successfully decrypted with + // aead2 + EXPECT_CALL(*rawAead1, _tryDecrypt(_, _, _)).Times(0); + EXPECT_CALL(*rawAead2, _tryDecrypt(_, _, _)) + .Times(1) + .WillOnce(Invoke( + [](std::unique_ptr& cipherText, const auto&, auto) { + // Successful decryption + return std::move(cipherText); + })); + packetNum = 4; + streamPacket = createStreamPacket( + connId, + connId, + packetNum, + streamId, + *data, + 0 /* cipherOverhead */, + 0 /* largestAcked */, + folly::none, + true, + ProtectionType::KeyPhaseOne); + packetQueue = bufToQueue(packetToBuf(streamPacket)); + packet = codec->parsePacket(packetQueue, ackStates); + // Codec parsing should succeed + EXPECT_TRUE(packet.regularPacket()); + // The read codec should advance to phase one + EXPECT_EQ(codec->getCurrentOneRttReadPhase(), ProtectionType::KeyPhaseOne); + } + + auto aead3 = std::make_unique(); + auto rawAead3 = aead3.get(); + codec->setNextOneRttReadCipher(std::move(aead3)); + + { + // Forth packet is in current phase (phase one) but it is out of order and + // not decryptable by current cipher. It should not be checked with the + // previous or next cipher. + EXPECT_CALL(*rawAead1, _tryDecrypt(_, _, _)).Times(0); + EXPECT_CALL(*rawAead3, _tryDecrypt(_, _, _)).Times(0); + EXPECT_CALL(*rawAead2, _tryDecrypt(_, _, _)) + .Times(1) + .WillOnce(Invoke([](std::unique_ptr&, const auto&, auto) { + // Failed decryption + return folly::none; + })); + packetNum = 1; + streamPacket = createStreamPacket( + connId, + connId, + packetNum, + streamId, + *data, + 0 /* cipherOverhead */, + 0 /* largestAcked */, + folly::none, + true, + ProtectionType::KeyPhaseOne); + packetQueue = bufToQueue(packetToBuf(streamPacket)); + packet = codec->parsePacket(packetQueue, ackStates); + // Codec parsing should fail. + EXPECT_TRUE(packet.nothing()); + // The read codec should still be in phase one + EXPECT_EQ(codec->getCurrentOneRttReadPhase(), ProtectionType::KeyPhaseOne); + } + + { + // Fifth packet is in next phase (phase zero) and is in order but it is not + // decryptable. It should not be checked with the current or previous + // ciphers. + EXPECT_CALL(*rawAead1, _tryDecrypt(_, _, _)).Times(0); + EXPECT_CALL(*rawAead2, _tryDecrypt(_, _, _)).Times(0); + EXPECT_CALL(*rawAead3, _tryDecrypt(_, _, _)) + .Times(1) + .WillOnce(Invoke([](std::unique_ptr&, const auto&, auto) { + // Failed decryption + return folly::none; + })); + packetNum = 5; + streamPacket = createStreamPacket( + connId, + connId, + packetNum, + streamId, + *data, + 0 /* cipherOverhead */, + 0 /* largestAcked */, + folly::none, + true, + ProtectionType::KeyPhaseZero); + packetQueue = bufToQueue(packetToBuf(streamPacket)); + packet = codec->parsePacket(packetQueue, ackStates); + // Codec parsing should fail. + EXPECT_TRUE(packet.nothing()); + // The read codec should still be in phase one + EXPECT_EQ(codec->getCurrentOneRttReadPhase(), ProtectionType::KeyPhaseOne); + } +} + +TEST_F(QuicReadCodecTest, KeyUpdateCipherUnavailable) { + /* + * - Receive a packet in phase zero + * - Receive an out-of-order packet in phase one without a previous cipher + * available. + * - Receive an in-order packet in phase one without a next cipher available. + */ + auto connId = getTestConnectionId(); + auto aead1 = std::make_unique(); + auto rawAead1 = aead1.get(); + + auto codec = makeEncryptedCodec( + connId, + std::move(aead1), + nullptr /* 0-rtt zead */, + nullptr /* stateless reset token*/, + QuicNodeType::Client); + + EXPECT_CALL(*rawAead1, _tryDecrypt(_, _, _)) + .Times(1) + .WillOnce(Invoke( + [](std::unique_ptr& cipherText, const auto&, auto) { + // Successful decryption + return std::move(cipherText); + })); + + // First packet in 1-rtt phase zero. + PacketNum packetNum = 2; + StreamId streamId = 2; + auto data = folly::IOBuf::create(30); + data->append(30); + auto streamPacket = createStreamPacket( + connId, + connId, + packetNum, + streamId, + *data, + 0 /* cipherOverhead */, + 0 /* largestAcked */, + folly::none, + true, + ProtectionType::KeyPhaseZero); + AckStates ackStates; + auto packetQueue = bufToQueue(packetToBuf(streamPacket)); + auto packet = codec->parsePacket(packetQueue, ackStates); + // Packet should be parsed successfully. + EXPECT_TRUE(packet.regularPacket() != nullptr); + // We're currently in phase zero. + EXPECT_EQ(codec->getCurrentOneRttReadPhase(), ProtectionType::KeyPhaseZero); + + { + // Second packet is in 1-rtt phase one but is out of order and there is no + // previous cipher available. + EXPECT_CALL(*rawAead1, _tryDecrypt(_, _, _)).Times(0); + packetNum = 1; + streamPacket = createStreamPacket( + connId, + connId, + packetNum, + streamId, + *data, + 0 /* cipherOverhead */, + 0 /* largestAcked */, + folly::none, + true, + ProtectionType::KeyPhaseOne); + packetQueue = bufToQueue(packetToBuf(streamPacket)); + packet = codec->parsePacket(packetQueue, ackStates); + // Codec parsing should fail with cipher unavailable. + EXPECT_TRUE(packet.cipherUnavailable()); + // The read codec should stay in phase zero + EXPECT_EQ(codec->getCurrentOneRttReadPhase(), ProtectionType::KeyPhaseZero); + } + + { + // Second packet is in 1-rtt phase one and is in-order but the next cipher + // has not been set yet. + EXPECT_CALL(*rawAead1, _tryDecrypt(_, _, _)).Times(0); + packetNum = 3; + streamPacket = createStreamPacket( + connId, + connId, + packetNum, + streamId, + *data, + 0 /* cipherOverhead */, + 0 /* largestAcked */, + folly::none, + true, + ProtectionType::KeyPhaseOne); + packetQueue = bufToQueue(packetToBuf(streamPacket)); + packet = codec->parsePacket(packetQueue, ackStates); + // Codec parsing should fail with cipher unavailable. + EXPECT_TRUE(packet.cipherUnavailable()); + // The read codec should stay in phase zero + EXPECT_EQ(codec->getCurrentOneRttReadPhase(), ProtectionType::KeyPhaseZero); + } +} diff --git a/quic/common/test/TestUtils.h b/quic/common/test/TestUtils.h index 537a92846..7087b755c 100644 --- a/quic/common/test/TestUtils.h +++ b/quic/common/test/TestUtils.h @@ -535,6 +535,15 @@ class FakeServerHandshake : public FizzServerHandshake { } oneRttReadCipher_ = createNoOpAead(); oneRttReadHeaderCipher_ = createNoOpHeaderCipher(); + readTrafficSecret_ = folly::IOBuf::copyBuffer(getRandSecret()); + } + + std::unique_ptr buildAead(folly::ByteRange /*secret*/) override { + return createNoOpAead(); + } + + Buf getNextTrafficSecret(folly::ByteRange /*secret*/) const override { + return folly::IOBuf::copyBuffer(getRandSecret()); } void setHandshakeKeys() { diff --git a/quic/fizz/server/handshake/FizzServerHandshake.cpp b/quic/fizz/server/handshake/FizzServerHandshake.cpp index cb9a1aa46..c2b7a0d77 100644 --- a/quic/fizz/server/handshake/FizzServerHandshake.cpp +++ b/quic/fizz/server/handshake/FizzServerHandshake.cpp @@ -89,18 +89,26 @@ void FizzServerHandshake::processSocketData(folly::IOBufQueue& queue) { machine_.processSocketData(state_, queue, fizz::Aead::AeadOptions())); } -std::pair, std::unique_ptr> -FizzServerHandshake::buildCiphers(folly::ByteRange secret) { - auto aead = FizzAead::wrap(fizz::Protocol::deriveRecordAeadWithLabel( +std::unique_ptr FizzServerHandshake::buildAead(folly::ByteRange secret) { + return FizzAead::wrap(fizz::Protocol::deriveRecordAeadWithLabel( *state_.context()->getFactory(), *state_.keyScheduler(), *state_.cipher(), secret, kQuicKeyLabel, kQuicIVLabel)); - auto headerCipher = cryptoFactory_->makePacketNumberCipher(secret); +} +std::unique_ptr FizzServerHandshake::buildHeaderCipher( + folly::ByteRange secret) { + return cryptoFactory_->makePacketNumberCipher(secret); +} - return {std::move(aead), std::move(headerCipher)}; +Buf FizzServerHandshake::getNextTrafficSecret(folly::ByteRange secret) const { + auto deriver = + state_.context()->getFactory()->makeKeyDeriver(*state_.cipher()); + auto nextSecret = deriver->expandLabel( + secret, kQuicKULabel, folly::IOBuf::create(0), secret.size()); + return nextSecret; } void FizzServerHandshake::processAccept() { diff --git a/quic/fizz/server/handshake/FizzServerHandshake.h b/quic/fizz/server/handshake/FizzServerHandshake.h index 32ae607f3..f7bcca09d 100644 --- a/quic/fizz/server/handshake/FizzServerHandshake.h +++ b/quic/fizz/server/handshake/FizzServerHandshake.h @@ -38,8 +38,10 @@ class FizzServerHandshake : public ServerHandshake { EncryptionLevel getReadRecordLayerEncryptionLevel() override; void processSocketData(folly::IOBufQueue& queue) override; - std::pair, std::unique_ptr> - buildCiphers(folly::ByteRange secret) override; + std::unique_ptr buildAead(folly::ByteRange secret) override; + std::unique_ptr buildHeaderCipher( + folly::ByteRange secret) override; + Buf getNextTrafficSecret(folly::ByteRange secret) const override; void processAccept() override; bool processPendingCryptoEvent() override; diff --git a/quic/handshake/HandshakeLayer.h b/quic/handshake/HandshakeLayer.h index cc1143c91..92a1eaf0d 100644 --- a/quic/handshake/HandshakeLayer.h +++ b/quic/handshake/HandshakeLayer.h @@ -16,6 +16,7 @@ namespace quic { constexpr folly::StringPiece kQuicKeyLabel = "quic key"; constexpr folly::StringPiece kQuicIVLabel = "quic iv"; constexpr folly::StringPiece kQuicPNLabel = "quic hp"; +constexpr folly::StringPiece kQuicKULabel = "quic ku"; class Handshake { public: diff --git a/quic/server/handshake/ServerHandshake.cpp b/quic/server/handshake/ServerHandshake.cpp index 196932d37..a95deda01 100644 --- a/quic/server/handshake/ServerHandshake.cpp +++ b/quic/server/handshake/ServerHandshake.cpp @@ -80,20 +80,46 @@ std::unique_ptr ServerHandshake::getHandshakeReadCipher() { return std::move(handshakeReadCipher_); } -std::unique_ptr ServerHandshake::getOneRttWriteCipher() { +std::unique_ptr ServerHandshake::getFirstOneRttWriteCipher() { if (error_) { throw QuicTransportException(error_->first, error_->second); } return std::move(oneRttWriteCipher_); } -std::unique_ptr ServerHandshake::getOneRttReadCipher() { +std::unique_ptr ServerHandshake::getNextOneRttWriteCipher() { + if (error_) { + throw QuicTransportException(error_->first, error_->second); + } + CHECK(writeTrafficSecret_); + LOG_IF(WARNING, trafficSecretSync_ > 1 || trafficSecretSync_ < -1) + << "Server read and write secrets are out of sync"; + writeTrafficSecret_ = getNextTrafficSecret(writeTrafficSecret_->coalesce()); + trafficSecretSync_--; + auto cipher = buildAead(writeTrafficSecret_->coalesce()); + return cipher; +} + +std::unique_ptr ServerHandshake::getFirstOneRttReadCipher() { if (error_) { throw QuicTransportException(error_->first, error_->second); } return std::move(oneRttReadCipher_); } +std::unique_ptr ServerHandshake::getNextOneRttReadCipher() { + if (error_) { + throw QuicTransportException(error_->first, error_->second); + } + CHECK(readTrafficSecret_); + LOG_IF(WARNING, trafficSecretSync_ > 1 || trafficSecretSync_ < -1) + << "Server read and write secrets are out of sync"; + readTrafficSecret_ = getNextTrafficSecret(readTrafficSecret_->coalesce()); + trafficSecretSync_++; + auto cipher = buildAead(readTrafficSecret_->coalesce()); + return cipher; +} + std::unique_ptr ServerHandshake::getZeroRttReadCipher() { if (error_) { throw QuicTransportException(error_->first, error_->second); @@ -435,9 +461,8 @@ void ServerHandshake::processActions( } void ServerHandshake::computeCiphers(CipherKind kind, folly::ByteRange secret) { - std::unique_ptr aead; - std::unique_ptr headerCipher; - std::tie(aead, headerCipher) = buildCiphers(secret); + std::unique_ptr aead = buildAead(secret); + std::unique_ptr headerCipher = buildHeaderCipher(secret); switch (kind) { case CipherKind::HandshakeRead: handshakeReadCipher_ = std::move(aead); @@ -448,10 +473,12 @@ void ServerHandshake::computeCiphers(CipherKind kind, folly::ByteRange secret) { conn_->handshakeWriteHeaderCipher = std::move(headerCipher); break; case CipherKind::OneRttRead: + readTrafficSecret_ = folly::IOBuf::copyBuffer(secret); oneRttReadCipher_ = std::move(aead); oneRttReadHeaderCipher_ = std::move(headerCipher); break; case CipherKind::OneRttWrite: + writeTrafficSecret_ = folly::IOBuf::copyBuffer(secret); oneRttWriteCipher_ = std::move(aead); oneRttWriteHeaderCipher_ = std::move(headerCipher); break; diff --git a/quic/server/handshake/ServerHandshake.h b/quic/server/handshake/ServerHandshake.h index 3c3b76e24..6cd9156d6 100644 --- a/quic/server/handshake/ServerHandshake.h +++ b/quic/server/handshake/ServerHandshake.h @@ -113,16 +113,31 @@ class ServerHandshake : public Handshake { std::unique_ptr getHandshakeReadCipher(); /** - * An edge triggered API to get the oneRttWriteCipher. Once you receive the - * write cipher subsequent calls will return null. + * An edge triggered API to get the first oneRttWriteCipher. Once you receive + * the write cipher subsequent calls will return null. */ - std::unique_ptr getOneRttWriteCipher(); + std::unique_ptr getFirstOneRttWriteCipher(); /** - * An edge triggered API to get the oneRttReadCipher. Once you receive the - * read cipher subsequent calls will return null. + * An API to get oneRttWriteCiphers on key rotation. Each call will return a + * one rtt write cipher using the current traffic secret and advance the + * traffic secret. */ - std::unique_ptr getOneRttReadCipher(); + std::unique_ptr getNextOneRttWriteCipher(); + + /** + * An API to get oneRttReadCiphers. Each call will generate a one rtt + * read cipher using the current traffic secret and advance the traffic + * secret. + */ + std::unique_ptr getFirstOneRttReadCipher(); + + /** + * An API to get oneRttReadCiphers on key rotation. Each call will return a + * one rtt read cipher using the current traffic secret and advance the + * traffic secret. + */ + std::unique_ptr getNextOneRttReadCipher(); /** * An edge triggered API to get the zeroRttReadCipher. Once you receive the @@ -182,6 +197,12 @@ class ServerHandshake : public Handshake { */ const folly::Optional& getApplicationProtocol() const override; + /** + * Given secret_n, returns secret_n+1 to be used for generating the next Aead + * on key updates. + */ + virtual Buf getNextTrafficSecret(folly::ByteRange secret) const = 0; + ~ServerHandshake() override = default; void onError(std::pair error); @@ -253,6 +274,15 @@ class ServerHandshake : public Handshake { std::unique_ptr oneRttWriteCipher_; std::unique_ptr zeroRttReadCipher_; + Buf readTrafficSecret_; + Buf writeTrafficSecret_; + + // This variable is incremented every time a read traffic secret is rotated, + // and decremented for the write secret. Its value should be + // between -1 and 1. A value outside of this range indicates that the + // transport's read and write ciphers are likely out of sync. + int trafficSecretSync_{0}; + std::unique_ptr oneRttReadHeaderCipher_; std::unique_ptr oneRttWriteHeaderCipher_; std::unique_ptr handshakeReadHeaderCipher_; @@ -271,8 +301,9 @@ class ServerHandshake : public Handshake { virtual EncryptionLevel getReadRecordLayerEncryptionLevel() = 0; virtual void processSocketData(folly::IOBufQueue& queue) = 0; - virtual std::pair, std::unique_ptr> - buildCiphers(folly::ByteRange secret) = 0; + virtual std::unique_ptr buildAead(folly::ByteRange secret) = 0; + virtual std::unique_ptr buildHeaderCipher( + folly::ByteRange secret) = 0; virtual void processAccept() = 0; /* diff --git a/quic/server/handshake/test/ServerHandshakeTest.cpp b/quic/server/handshake/test/ServerHandshakeTest.cpp index 9616bee83..06c76d25b 100644 --- a/quic/server/handshake/test/ServerHandshakeTest.cpp +++ b/quic/server/handshake/test/ServerHandshakeTest.cpp @@ -219,8 +219,8 @@ class ServerHandshakeTest : public Test { } void setHandshakeState() { - auto oneRttWriteCipherTmp = handshake->getOneRttWriteCipher(); - auto oneRttReadCipherTmp = handshake->getOneRttReadCipher(); + auto oneRttWriteCipherTmp = handshake->getFirstOneRttWriteCipher(); + auto oneRttReadCipherTmp = handshake->getFirstOneRttReadCipher(); auto zeroRttReadCipherTmp = handshake->getZeroRttReadCipher(); auto handshakeWriteCipherTmp = std::move(conn->handshakeWriteCipher); auto handshakeReadCipherTmp = handshake->getHandshakeReadCipher(); @@ -692,7 +692,7 @@ TEST_F(ServerHandshakeAsyncErrorTest, TestAsyncError) { EXPECT_CALL(serverCallback, onCryptoEventAvailable()) .WillRepeatedly(Invoke([&] { try { - handshake->getOneRttReadCipher(); + handshake->getFirstOneRttReadCipher(); } catch (std::exception&) { error = true; } @@ -713,7 +713,7 @@ TEST_F(ServerHandshakeAsyncErrorTest, TestCancelOnAsyncError) { })); promise.setValue(); evb.loop(); - EXPECT_THROW(handshake->getOneRttReadCipher(), std::runtime_error); + EXPECT_THROW(handshake->getFirstOneRttReadCipher(), std::runtime_error); } TEST_F(ServerHandshakeAsyncErrorTest, TestCancelWhileWaitingAsyncError) { @@ -724,7 +724,7 @@ TEST_F(ServerHandshakeAsyncErrorTest, TestCancelWhileWaitingAsyncError) { promise.setValue(); evb.loop(); - EXPECT_THROW(handshake->getOneRttReadCipher(), std::runtime_error); + EXPECT_THROW(handshake->getFirstOneRttReadCipher(), std::runtime_error); } class ServerHandshakeSyncErrorTest : public ServerHandshakePskTest { @@ -743,7 +743,7 @@ TEST_F(ServerHandshakeSyncErrorTest, TestError) { // Make an async ticket decryption operation. clientServerRound(); evb.loop(); - EXPECT_THROW(handshake->getOneRttReadCipher(), std::runtime_error); + EXPECT_THROW(handshake->getFirstOneRttReadCipher(), std::runtime_error); } class ServerHandshakeZeroRttDefaultAppTokenValidatorTest diff --git a/quic/server/state/ServerStateMachine.cpp b/quic/server/state/ServerStateMachine.cpp index 3a401a008..4755d1a05 100644 --- a/quic/server/state/ServerStateMachine.cpp +++ b/quic/server/state/ServerStateMachine.cpp @@ -386,9 +386,9 @@ void updateHandshakeState(QuicServerConnectionState& conn) { // However, the cipher is only exported to QUIC if early data attempt is // accepted. Otherwise, the cipher will be available after cfin is // processed. - auto oneRttWriteCipher = handshakeLayer->getOneRttWriteCipher(); + auto oneRttWriteCipher = handshakeLayer->getFirstOneRttWriteCipher(); // One RTT read cipher is available after cfin is processed. - auto oneRttReadCipher = handshakeLayer->getOneRttReadCipher(); + auto oneRttReadCipher = handshakeLayer->getFirstOneRttReadCipher(); auto oneRttWriteHeaderCipher = handshakeLayer->getOneRttWriteHeaderCipher(); auto oneRttReadHeaderCipher = handshakeLayer->getOneRttReadHeaderCipher(); @@ -419,6 +419,7 @@ void updateHandshakeState(QuicServerConnectionState& conn) { "Duplicate 1-rtt write cipher", TransportErrorCode::CRYPTO_ERROR); } conn.oneRttWriteCipher = std::move(oneRttWriteCipher); + conn.oneRttWritePhase = ProtectionType::KeyPhaseZero; updatePacingOnKeyEstablished(conn); @@ -440,6 +441,8 @@ void updateHandshakeState(QuicServerConnectionState& conn) { conn.isClientAddrVerified = true; conn.writableBytesLimit.reset(); conn.readCodec->setOneRttReadCipher(std::move(oneRttReadCipher)); + conn.readCodec->setNextOneRttReadCipher( + handshakeLayer->getNextOneRttReadCipher()); } auto handshakeReadCipher = handshakeLayer->getHandshakeReadCipher(); auto handshakeReadHeaderCipher = @@ -1015,6 +1018,17 @@ void onServerReadDataFromOpen( } } + if (conn.readCodec->getCurrentOneRttReadPhase() != conn.oneRttWritePhase) { + // Peer has initiated a key update. + updateOneRttWriteCipher( + conn, + conn.serverHandshakeLayer->getNextOneRttWriteCipher(), + conn.readCodec->getCurrentOneRttReadPhase()); + + conn.readCodec->setNextOneRttReadCipher( + conn.serverHandshakeLayer->getNextOneRttReadCipher()); + } + auto& ackState = getAckState(conn, packetNumberSpace); uint64_t distanceFromExpectedPacketNum = addPacketToAckState( conn, ackState, packetNum, readData.udpPacket.timings); @@ -1493,6 +1507,8 @@ void onServerReadDataFromClosed( conn.qLogger->addPacket(regularPacket, packetSize); } + // TODO: Should we honor a key update from the peer on a closed connection? + // Only process the close frames in the packet for (auto& quicFrame : regularPacket.frames) { switch (quicFrame.type()) { diff --git a/quic/state/StateData.h b/quic/state/StateData.h index 9a4ef58cd..9e30e7a28 100644 --- a/quic/state/StateData.h +++ b/quic/state/StateData.h @@ -348,6 +348,8 @@ struct QuicConnectionStateBase : public folly::DelayedDestruction { // Write cipher for 1-RTT data std::unique_ptr oneRttWriteCipher; + ProtectionType oneRttWritePhase{ProtectionType::KeyPhaseZero}; + uint64_t oneRttWritePacketsSentInCurrentPhase = 0; // Write cipher for packets with initial keys. std::unique_ptr initialWriteCipher;