diff --git a/quic/QuicConstants.h b/quic/QuicConstants.h index 02ef8fae6..d143fdea9 100644 --- a/quic/QuicConstants.h +++ b/quic/QuicConstants.h @@ -583,6 +583,7 @@ enum class EncryptionLevel : uint8_t { Handshake, EarlyData, AppData, + MAX, }; /** diff --git a/quic/client/QuicClientTransport.cpp b/quic/client/QuicClientTransport.cpp index 967c290c8..f6064a404 100644 --- a/quic/client/QuicClientTransport.cpp +++ b/quic/client/QuicClientTransport.cpp @@ -125,6 +125,34 @@ void QuicClientTransport::processUDPData( << "Leaving " << udpData.chainLength() << " bytes unprocessed after attempting to process " << kMaxNumCoalescedPackets << " packets."; + + // Process any pending 1RTT and handshake packets if we have keys. + if (conn_->readCodec->getOneRttReadCipher() && + !clientConn_->pendingOneRttData.empty()) { + BufQueue pendingPacket; + for (auto& pendingData : clientConn_->pendingOneRttData) { + pendingPacket.append(std::move(pendingData.networkData.data)); + processPacketData( + pendingData.peer, + pendingData.networkData.receiveTimePoint, + pendingPacket); + pendingPacket.move(); + } + clientConn_->pendingOneRttData.clear(); + } + if (conn_->readCodec->getHandshakeReadCipher() && + !clientConn_->pendingHandshakeData.empty()) { + BufQueue pendingPacket; + for (auto& pendingData : clientConn_->pendingHandshakeData) { + pendingPacket.append(std::move(pendingData.networkData.data)); + processPacketData( + pendingData.peer, + pendingData.networkData.receiveTimePoint, + pendingPacket); + pendingPacket.move(); + } + clientConn_->pendingHandshakeData.clear(); + } } void QuicClientTransport::processPacketData( @@ -197,6 +225,31 @@ void QuicClientTransport::processPacketData( return; } + auto cipherUnavailable = parsedPacket.cipherUnavailable(); + if (cipherUnavailable && cipherUnavailable->packet && + !cipherUnavailable->packet->empty() && + (cipherUnavailable->protectionType == ProtectionType::KeyPhaseZero || + cipherUnavailable->protectionType == ProtectionType::Handshake) && + clientConn_->pendingOneRttData.size() + + clientConn_->pendingHandshakeData.size() < + clientConn_->transportSettings.maxPacketsToBuffer) { + auto& pendingData = + cipherUnavailable->protectionType == ProtectionType::KeyPhaseZero + ? clientConn_->pendingOneRttData + : clientConn_->pendingHandshakeData; + pendingData.emplace_back( + NetworkDataSingle( + std::move(cipherUnavailable->packet), receiveTimePoint), + peer); + if (conn_->qLogger) { + conn_->qLogger->addPacketBuffered( + cipherUnavailable->packetNum, + cipherUnavailable->protectionType, + packetSize); + } + return; + } + RegularQuicPacket* regularOptional = parsedPacket.regularPacket(); if (!regularOptional) { QUIC_STATS(statsCallback_, onPacketDropped, PacketDropReason::PARSE_ERROR); @@ -1442,6 +1495,8 @@ void QuicClientTransport::start(ConnectionCallback* cb) { } QUIC_TRACE(fst_trace, *conn_, "start"); setConnectionCallback(cb); + clientConn_->pendingOneRttData.reserve( + conn_->transportSettings.maxPacketsToBuffer); try { happyEyeballsSetUpSocket( *socket_, diff --git a/quic/client/handshake/ClientHandshake.cpp b/quic/client/handshake/ClientHandshake.cpp index b58c1ada7..063ccd4b1 100644 --- a/quic/client/handshake/ClientHandshake.cpp +++ b/quic/client/handshake/ClientHandshake.cpp @@ -78,6 +78,8 @@ void ClientHandshake::doHandshake( case EncryptionLevel::AppData: appDataReadBuf_.append(std::move(data)); break; + default: + LOG(FATAL) << "Unhandled EncryptionLevel"; } // Get the current buffer type the transport is accepting. waitForData_ = false; @@ -93,6 +95,8 @@ void ClientHandshake::doHandshake( case EncryptionLevel::AppData: processSocketData(appDataReadBuf_); break; + default: + LOG(FATAL) << "Unhandled EncryptionLevel"; } throwOnError(); } diff --git a/quic/client/state/ClientStateMachine.cpp b/quic/client/state/ClientStateMachine.cpp index b9b3fb472..b106b4b2f 100644 --- a/quic/client/state/ClientStateMachine.cpp +++ b/quic/client/state/ClientStateMachine.cpp @@ -57,6 +57,8 @@ std::unique_ptr undoAllClientStateForRetry( std::move(conn->earlyDataAppParamsValidator); newConn->earlyDataAppParamsGetter = std::move(conn->earlyDataAppParamsGetter); newConn->happyEyeballsState = std::move(conn->happyEyeballsState); + newConn->pendingOneRttData.reserve( + newConn->transportSettings.maxPacketsToBuffer); if (conn->congestionControllerFactory) { newConn->congestionControllerFactory = conn->congestionControllerFactory; if (conn->congestionController) { diff --git a/quic/client/state/ClientStateMachine.h b/quic/client/state/ClientStateMachine.h index 85063aac6..cfb0f9129 100644 --- a/quic/client/state/ClientStateMachine.h +++ b/quic/client/state/ClientStateMachine.h @@ -21,6 +21,16 @@ namespace quic { struct CachedServerTransportParameters; +struct PendingClientData { + NetworkDataSingle networkData; + folly::SocketAddress peer; + + PendingClientData( + NetworkDataSingle networkDataIn, + folly::SocketAddress peerIn) + : networkData(std::move(networkDataIn)), peer(std::move(peerIn)) {} +}; + struct QuicClientConnectionState : public QuicConnectionStateBase { ~QuicClientConnectionState() override = default; @@ -56,10 +66,10 @@ struct QuicClientConnectionState : public QuicConnectionStateBase { uint64_t peerAdvertisedInitialMaxStreamsBidi{0}; uint64_t peerAdvertisedInitialMaxStreamsUni{0}; - // Packet number in which client initial was sent. Receipt of data on the - // crypto stream from the server can implicitly ack the client initial packet. - // TODO: use this to get rid of the data in the crypto stream. - // folly::Optional clientInitialPacketNum; + // Short header packets we received but couldn't yet decrypt. + std::vector pendingOneRttData; + // Handshake packets we received but couldn't yet decrypt. + std::vector pendingHandshakeData; explicit QuicClientConnectionState( std::shared_ptr handshakeFactoryIn) diff --git a/quic/fizz/client/test/QuicClientTransportTest.cpp b/quic/fizz/client/test/QuicClientTransportTest.cpp index 44c3fef45..992ce76f9 100644 --- a/quic/fizz/client/test/QuicClientTransportTest.cpp +++ b/quic/fizz/client/test/QuicClientTransportTest.cpp @@ -1078,7 +1078,8 @@ class FakeOneRttHandshakeLayer : public FizzClientHandshake { } } - void doHandshake(std::unique_ptr, EncryptionLevel) override { + void doHandshake(std::unique_ptr buf, EncryptionLevel level) + override { EXPECT_EQ(writeBuf.get(), nullptr); QuicClientConnectionState* conn = getClientConn(); if (!conn->oneRttWriteCipher) { @@ -1095,6 +1096,7 @@ class FakeOneRttHandshakeLayer : public FizzClientHandshake { IOBuf::copyBuffer("ClientFinished")); handshakeInitiated(); } + readBuffers[level].append(std::move(buf)); } bool connectInvoked() { @@ -1121,6 +1123,7 @@ class FakeOneRttHandshakeLayer : public FizzClientHandshake { uint64_t maxInitialStreamsBidi{std::numeric_limits::max()}; uint64_t maxInitialStreamsUni{std::numeric_limits::max()}; folly::Optional params_; + EnumArray readBuffers{}; std::unique_ptr oneRttWriteCipher_; std::unique_ptr oneRttWriteHeaderCipher_; @@ -5683,6 +5686,213 @@ TEST_F(QuicProcessDataTest, ProcessDataWithGarbageAtEnd) { EXPECT_EQ(event->dropReason, kParse); } +TEST_F(QuicProcessDataTest, ProcessPendingData) { + auto params = mockClientHandshake->getServerTransportParams(); + params->parameters.push_back(encodeConnIdParameter( + TransportParameterId::initial_source_connection_id, *serverChosenConnId)); + params->parameters.push_back(encodeConnIdParameter( + TransportParameterId::original_destination_connection_id, + *client->getConn().initialDestinationConnectionId)); + mockClientHandshake->setServerTransportParams(std::move(*params)); + auto serverHello = IOBuf::copyBuffer("Fake SHLO"); + PacketNum nextPacketNum = initialPacketNum++; + auto& aead = getInitialCipher(); + auto packet = createCryptoPacket( + *serverChosenConnId, + *originalConnId, + nextPacketNum, + QuicVersion::QUIC_DRAFT, + ProtectionType::Initial, + *serverHello, + aead, + 0 /* largestAcked */); + auto packetData = packetToBufCleartext( + packet, aead, getInitialHeaderCipher(), nextPacketNum); + deliverData(serverAddr, packetData->coalesce()); + verifyTransportParameters( + kDefaultConnectionWindowSize, + kDefaultStreamWindowSize, + kDefaultIdleTimeout, + kDefaultAckDelayExponent, + mockClientHandshake->maxRecvPacketSize); + + mockClientHandshake->setOneRttReadCipher(nullptr); + mockClientHandshake->setHandshakeReadCipher(nullptr); + ASSERT_TRUE(client->getConn().pendingOneRttData.empty()); + auto streamId1 = client->createBidirectionalStream().value(); + + auto data = folly::IOBuf::copyBuffer("1RTT data!"); + auto streamPacket1 = packetToBuf(createStreamPacket( + *serverChosenConnId /* src */, + *originalConnId /* dest */, + appDataPacketNum++, + streamId1, + *data, + 0 /* cipherOverhead */, + 0 /* largestAcked */)); + deliverData(streamPacket1->coalesce()); + EXPECT_EQ(client->getConn().pendingOneRttData.size(), 1); + + auto cryptoData = folly::IOBuf::copyBuffer("Crypto data!"); + auto cryptoPacket1 = packetToBuf(createCryptoPacket( + *serverChosenConnId, + *originalConnId, + handshakePacketNum++, + QuicVersion::QUIC_DRAFT, + ProtectionType::Handshake, + *cryptoData, + *createNoOpAead(), + 0 /* largestAcked */)); + deliverData(cryptoPacket1->coalesce()); + EXPECT_EQ(client->getConn().pendingOneRttData.size(), 1); + EXPECT_EQ(client->getConn().pendingHandshakeData.size(), 1); + + mockClientHandshake->setOneRttReadCipher(createNoOpAead()); + auto streamId2 = client->createBidirectionalStream().value(); + auto streamPacket2 = packetToBuf(createStreamPacket( + *serverChosenConnId /* src */, + *originalConnId /* dest */, + appDataPacketNum++, + streamId2, + *data, + 0 /* cipherOverhead */, + 0 /* largestAcked */)); + deliverData(streamPacket2->coalesce()); + EXPECT_TRUE(client->getConn().pendingOneRttData.empty()); + EXPECT_EQ(client->getConn().pendingHandshakeData.size(), 1); + + // Set the oneRtt one back to nullptr to make sure we trigger it on handshake + // only. + // mockClientHandshake->setOneRttReadCipher(nullptr); + mockClientHandshake->setHandshakeReadCipher(createNoOpAead()); + auto cryptoPacket2 = packetToBuf(createCryptoPacket( + *serverChosenConnId, + *originalConnId, + handshakePacketNum++, + QuicVersion::QUIC_DRAFT, + ProtectionType::Handshake, + *cryptoData, + *createNoOpAead(), + 0, + cryptoData->length())); + deliverData(cryptoPacket2->coalesce()); + EXPECT_TRUE(client->getConn().pendingHandshakeData.empty()); + EXPECT_TRUE(client->getConn().pendingOneRttData.empty()); + + // Both stream data and crypto data should be there. + auto d1 = client->read(streamId1, 1000); + ASSERT_FALSE(d1.hasError()); + auto d2 = client->read(streamId2, 1000); + ASSERT_FALSE(d2.hasError()); + EXPECT_TRUE(folly::IOBufEqualTo()(*d1.value().first, *data)); + EXPECT_TRUE(folly::IOBufEqualTo()(*d2.value().first, *data)); + + ASSERT_FALSE( + mockClientHandshake->readBuffers[EncryptionLevel::Handshake].empty()); + auto handshakeReadData = + mockClientHandshake->readBuffers[EncryptionLevel::Handshake].move(); + cryptoData->prependChain(cryptoData->clone()); + EXPECT_TRUE(folly::IOBufEqualTo()(*cryptoData, *handshakeReadData)); +} + +TEST_F(QuicProcessDataTest, ProcessPendingDataBufferLimit) { + auto params = mockClientHandshake->getServerTransportParams(); + params->parameters.push_back(encodeConnIdParameter( + TransportParameterId::initial_source_connection_id, *serverChosenConnId)); + params->parameters.push_back(encodeConnIdParameter( + TransportParameterId::original_destination_connection_id, + *client->getConn().initialDestinationConnectionId)); + mockClientHandshake->setServerTransportParams(std::move(*params)); + auto serverHello = IOBuf::copyBuffer("Fake SHLO"); + PacketNum nextPacketNum = initialPacketNum++; + auto& aead = getInitialCipher(); + auto packet = createCryptoPacket( + *serverChosenConnId, + *originalConnId, + nextPacketNum, + QuicVersion::QUIC_DRAFT, + ProtectionType::Initial, + *serverHello, + aead, + 0 /* largestAcked */); + auto packetData = packetToBufCleartext( + packet, aead, getInitialHeaderCipher(), nextPacketNum); + deliverData(serverAddr, packetData->coalesce()); + verifyTransportParameters( + kDefaultConnectionWindowSize, + kDefaultStreamWindowSize, + kDefaultIdleTimeout, + kDefaultAckDelayExponent, + mockClientHandshake->maxRecvPacketSize); + + client->getNonConstConn().transportSettings.maxPacketsToBuffer = 2; + auto data = folly::IOBuf::copyBuffer("1RTT data!"); + mockClientHandshake->setOneRttReadCipher(nullptr); + ASSERT_TRUE(client->getConn().pendingOneRttData.empty()); + auto streamId1 = client->createBidirectionalStream().value(); + auto streamPacket1 = packetToBuf(createStreamPacket( + *serverChosenConnId /* src */, + *originalConnId /* dest */, + appDataPacketNum++, + streamId1, + *data, + 0 /* cipherOverhead */, + 0 /* largestAcked */)); + deliverData(streamPacket1->coalesce()); + EXPECT_EQ(client->getConn().pendingOneRttData.size(), 1); + + auto streamId2 = client->createBidirectionalStream().value(); + auto streamPacket2 = packetToBuf(createStreamPacket( + *serverChosenConnId /* src */, + *originalConnId /* dest */, + appDataPacketNum++, + streamId2, + *data, + 0 /* cipherOverhead */, + 0 /* largestAcked */)); + deliverData(streamPacket2->coalesce()); + EXPECT_EQ(client->getConn().pendingOneRttData.size(), 2); + + auto streamId3 = client->createBidirectionalStream().value(); + auto streamPacket3 = packetToBuf(createStreamPacket( + *serverChosenConnId /* src */, + *originalConnId /* dest */, + appDataPacketNum++, + streamId3, + *data, + 0 /* cipherOverhead */, + 0 /* largestAcked */)); + deliverData(streamPacket3->coalesce()); + EXPECT_EQ(client->getConn().pendingOneRttData.size(), 2); + + mockClientHandshake->setOneRttReadCipher(createNoOpAead()); + auto streamId4 = client->createBidirectionalStream().value(); + auto streamPacket4 = packetToBuf(createStreamPacket( + *serverChosenConnId /* src */, + *originalConnId /* dest */, + appDataPacketNum++, + streamId4, + *data, + 0 /* cipherOverhead */, + 0 /* largestAcked */)); + deliverData(streamPacket4->coalesce()); + EXPECT_TRUE(client->getConn().pendingOneRttData.empty()); + + // First, second, and fourht stream data should be there. + auto d1 = client->read(streamId1, 1000); + ASSERT_FALSE(d1.hasError()); + auto d2 = client->read(streamId2, 1000); + ASSERT_FALSE(d2.hasError()); + auto d3 = client->read(streamId3, 1000); + ASSERT_FALSE(d3.hasError()); + EXPECT_EQ(d3.value().first, nullptr); + auto d4 = client->read(streamId4, 1000); + ASSERT_FALSE(d4.hasError()); + EXPECT_TRUE(folly::IOBufEqualTo()(*d1.value().first, *data)); + EXPECT_TRUE(folly::IOBufEqualTo()(*d2.value().first, *data)); + EXPECT_TRUE(folly::IOBufEqualTo()(*d4.value().first, *data)); +} + TEST_P(QuicProcessDataTest, ProcessDataHeaderOnly) { uint8_t connIdSize = GetParam(); client->getNonConstConn().clientConnectionId = diff --git a/quic/server/handshake/ServerHandshake.cpp b/quic/server/handshake/ServerHandshake.cpp index bb0c39829..797513b0d 100644 --- a/quic/server/handshake/ServerHandshake.cpp +++ b/quic/server/handshake/ServerHandshake.cpp @@ -53,6 +53,8 @@ void ServerHandshake::doHandshake( case EncryptionLevel::AppData: appDataReadBuf_.append(std::move(data)); break; + default: + LOG(FATAL) << "Unhandled EncryptionLevel"; } processPendingEvents(); if (error_) { @@ -240,6 +242,8 @@ void ServerHandshake::processPendingEvents() { // any more. processSocketData(appDataReadBuf_); break; + default: + LOG(FATAL) << "Unhandled EncryptionLevel"; } } else if (!processPendingCryptoEvent()) { actionGuard_ = folly::DelayedDestruction::DestructorGuard(nullptr); diff --git a/quic/state/QuicStreamFunctions.cpp b/quic/state/QuicStreamFunctions.cpp index 994c2aa71..42c3383c4 100644 --- a/quic/state/QuicStreamFunctions.cpp +++ b/quic/state/QuicStreamFunctions.cpp @@ -464,6 +464,8 @@ QuicCryptoStream* getCryptoStream( return &cryptoState.handshakeStream; case EncryptionLevel::AppData: return &cryptoState.oneRttStream; + default: + LOG(FATAL) << "Unhandled EncryptionLevel"; } folly::assume_unreachable(); }