diff --git a/quic/QuicConstants.h b/quic/QuicConstants.h index e1b69698d..4e6283dc8 100644 --- a/quic/QuicConstants.h +++ b/quic/QuicConstants.h @@ -51,6 +51,9 @@ constexpr uint16_t kMaxNumCoalescedPackets = 5; // have ids with first byte being 0xff. constexpr uint16_t kCustomTransportParameterThreshold = 0xff00; +// The length of the integrity tag present in a retry packet. +constexpr uint32_t kRetryIntegrityTagLen = 16; + // If the amount of data in the buffer of a QuicSocket equals or exceeds this // threshold, then the callback registered through // notifyPendingWriteOnConnection() will not be called diff --git a/quic/client/QuicClientTransport.cpp b/quic/client/QuicClientTransport.cpp index 20be978dd..f3fcae84c 100644 --- a/quic/client/QuicClientTransport.cpp +++ b/quic/client/QuicClientTransport.cpp @@ -156,6 +156,45 @@ void QuicClientTransport::processPacketData( VLOG(4) << "Drop StatelessReset for bad connId or token " << *this; } + RetryPacket* retryPacket = parsedPacket.retryPacket(); + if (retryPacket) { + if (!clientConn_->retryToken.empty()) { + VLOG(4) << "Server sent more than one retry packet"; + return; + } + + const ConnectionId* originalDstConnId = + &(*clientConn_->initialDestinationConnectionId); + + if (!clientConn_->clientHandshakeLayer->verifyRetryIntegrityTag( + *originalDstConnId, *retryPacket)) { + VLOG(4) << "The integrity tag in the retry packet was invalid. " + << "Dropping bad retry packet."; + return; + } + + // Set the destination connection ID to be the value from the source + // connection id of the retry packet + clientConn_->initialDestinationConnectionId = + retryPacket->header.getSourceConnId(); + + auto released = static_cast(conn_.release()); + std::unique_ptr uniqueClient(released); + auto tempConn = undoAllClientStateForRetry(std::move(uniqueClient)); + + clientConn_ = tempConn.get(); + conn_.reset(tempConn.release()); + + clientConn_->retryToken = retryPacket->header.getToken(); + + // TODO (amsharma): add a "RetryPacket" QLog event, and log it here. + // TODO (amsharma): verify the "original_connection_id" parameter + // upon receiving a subsequent initial from the server. + + startCryptoHandshake(); + return; + } + RegularQuicPacket* regularOptional = parsedPacket.regularPacket(); if (!regularOptional) { if (conn_->qLogger) { @@ -172,55 +211,6 @@ void QuicClientTransport::processPacketData( LongHeader* longHeader = regularOptional->header.asLong(); ShortHeader* shortHeader = regularOptional->header.asShort(); - if (longHeader && longHeader->getHeaderType() == LongHeader::Types::Retry) { - if (!clientConn_->retryToken.empty()) { - VLOG(4) << "Server sent more than one retry packet"; - return; - } - - // TODO (amsharma): Check if we have already received an initial packet - // from the server. If so, discard it. Here are some ways in which I - // could do this: - // 1. Have a boolean flag initialPacketReceived_ that we set to true when - // we get an initial packet from the server. This seems a bit messy. - // 2. Check for the presence of the oneRttWriteCipher and/or the - // oneRttReadCipher in the handshake layer. I think this might be a - // better approach, but I don't know if it is a good indicator that we've - // received an initial packet from the server. - - const ConnectionId* dstConnId = - &(*clientConn_->initialDestinationConnectionId); - if (conn_->serverConnectionId) { - dstConnId = &(*conn_->serverConnectionId); - } - if (*longHeader->getOriginalDstConnId() != *dstConnId) { - VLOG(4) << "Original destination connection id field in the retry " - << "packet doesn't match the destination connection id from the " - << "client's initial packet"; - return; - } - - // Set the destination connection ID to be the value from the source - // connection id of the retry packet - clientConn_->initialDestinationConnectionId = longHeader->getSourceConnId(); - - auto released = static_cast(conn_.release()); - std::unique_ptr uniqueClient(released); - auto tempConn = undoAllClientStateForRetry(std::move(uniqueClient)); - - clientConn_ = tempConn.get(); - conn_.reset(tempConn.release()); - - clientConn_->retryToken = longHeader->getToken(); - - if (conn_->qLogger) { - conn_->qLogger->addPacket(*regularOptional, packetSize); - } - - startCryptoHandshake(); - return; - } - auto protectionLevel = regularOptional->header.getProtectionType(); auto encryptionLevel = protectionTypeToEncryptionLevel(protectionLevel); diff --git a/quic/codec/Decode.cpp b/quic/codec/Decode.cpp index 801bacdf1..c2226bbf1 100644 --- a/quic/codec/Decode.cpp +++ b/quic/codec/Decode.cpp @@ -930,36 +930,22 @@ folly::Expected parseLongHeaderVariants( folly::io::Cursor& cursor, QuicNodeType nodeType) { if (type == LongHeader::Types::Retry) { - if (!cursor.canAdvance(sizeof(uint8_t))) { - VLOG(5) << "Not enough bytes for ODCID length"; - return folly::makeUnexpected(TransportErrorCode::FRAME_ENCODING_ERROR); - } - uint8_t originalDstConnIdLen = cursor.readBE(); - if (originalDstConnIdLen > kMaxConnectionIdSize) { - VLOG(5) << "originalDstConnIdLen > kMaxConnectionIdSize: " - << originalDstConnIdLen; - return folly::makeUnexpected(TransportErrorCode::PROTOCOL_VIOLATION); - } - if (!cursor.canAdvance(originalDstConnIdLen)) { - VLOG(5) << "Not enough bytes for ODCID"; - return folly::makeUnexpected(TransportErrorCode::FRAME_ENCODING_ERROR); - } - ConnectionId originalDstConnId(cursor, originalDstConnIdLen); - - if (cursor.totalLength() == 0) { + // The integrity tag is kRetryIntegrityTagLen bytes in length, and the + // token must be at least one byte, so the remaining length must + // be > kRetryIntegrityTagLen. + if (cursor.totalLength() <= kRetryIntegrityTagLen) { VLOG(5) << "Not enough bytes for retry token"; return folly::makeUnexpected(TransportErrorCode::FRAME_ENCODING_ERROR); } Buf token; - cursor.clone(token, cursor.totalLength()); + cursor.clone(token, cursor.totalLength() - kRetryIntegrityTagLen); return ParsedLongHeader( LongHeader( type, std::move(parsedLongHeaderInvariant.invariant), - token ? token->moveToFbString().toStdString() : std::string(), - std::move(originalDstConnId)), + token ? token->moveToFbString().toStdString() : std::string()), PacketLength(0, 0)); } diff --git a/quic/codec/QuicReadCodec.cpp b/quic/codec/QuicReadCodec.cpp index 1565eaf9a..400a9a7ad 100644 --- a/quic/codec/QuicReadCodec.cpp +++ b/quic/codec/QuicReadCodec.cpp @@ -88,7 +88,11 @@ CodecResult QuicReadCodec::parseLongHeaderPacket( auto longHeader = std::move(parsedLongHeader->header); if (type == LongHeader::Types::Retry) { - return RegularQuicPacket(std::move(longHeader)); + Buf integrityTag; + cursor.clone(integrityTag, kRetryIntegrityTagLen); + queue.move(); + return RetryPacket( + std::move(longHeader), std::move(integrityTag), initialByte); } uint64_t packetNumberOffset = cursor.getCurrentPosition(); diff --git a/quic/codec/test/QuicPacketBuilderTest.cpp b/quic/codec/test/QuicPacketBuilderTest.cpp index 90afbd9ea..5d1b8fd9a 100644 --- a/quic/codec/test/QuicPacketBuilderTest.cpp +++ b/quic/codec/test/QuicPacketBuilderTest.cpp @@ -149,44 +149,6 @@ TEST_F(QuicPacketBuilderTest, SimpleVersionNegotiationPacket) { EXPECT_EQ(decodedVersionNegotiationPacket->versions, versions); } -TEST_P(QuicPacketBuilderTest, SimpleRetryPacket) { - LongHeader headerIn( - LongHeader::Types::Retry, - getTestConnectionId(0), - getTestConnectionId(1), - 321, - QuicVersion::MVFST, - std::string("454358"), - getTestConnectionId(2)); - - auto builderAndBuf = testBuilderProvider( - GetParam(), - kDefaultUDPSendPacketLen, - std::move(headerIn), - 0 /* largestAcked */, - 2000); - auto packet = packetToBuf(std::move(*(builderAndBuf.builder)).buildPacket()); - auto packetQueue = bufToQueue(std::move(packet)); - - // Verify the returned buf from packet builder can be decoded by read codec: - AckStates ackStates; - auto optionalDecodedPacket = - makeCodec(getTestConnectionId(1), QuicNodeType::Client) - ->parsePacket(packetQueue, ackStates); - ASSERT_NE(optionalDecodedPacket.regularPacket(), nullptr); - auto& retryPacket = *optionalDecodedPacket.regularPacket(); - - auto& headerOut = *retryPacket.header.asLong(); - - EXPECT_EQ(*headerOut.getOriginalDstConnId(), getTestConnectionId(2)); - EXPECT_EQ(headerOut.getVersion(), QuicVersion::MVFST); - EXPECT_EQ(headerOut.getSourceConnId(), getTestConnectionId(0)); - EXPECT_EQ(headerOut.getDestinationConnId(), getTestConnectionId(1)); - - auto expected = std::string("454358"); - EXPECT_EQ(headerOut.getToken(), expected); -} - TEST_F(QuicPacketBuilderTest, TooManyVersions) { std::vector versions; for (size_t i = 0; i < 1000; i++) { diff --git a/quic/codec/test/QuicReadCodecTest.cpp b/quic/codec/test/QuicReadCodecTest.cpp index 9e2dbd805..586c00399 100644 --- a/quic/codec/test/QuicReadCodecTest.cpp +++ b/quic/codec/test/QuicReadCodecTest.cpp @@ -95,33 +95,40 @@ TEST_F(QuicReadCodecTest, VersionNegotiationPacketTest) { } TEST_F(QuicReadCodecTest, RetryPacketTest) { - LongHeader headerIn( - LongHeader::Types::Retry, - getTestConnectionId(70), - getTestConnectionId(90), - 321, - static_cast(0xffff), - std::string("fluffydog"), - getTestConnectionId(110)); + uint8_t initialByte = 0xFF; + ConnectionId srcConnId = getTestConnectionId(70); + ConnectionId dstConnId = getTestConnectionId(90); + auto quicVersion = static_cast(0xffff); + std::string token = "fluffydog"; + std::string integrityTag = "MustBe16CharLong"; - RegularQuicPacketBuilder builder( - kDefaultUDPSendPacketLen, std::move(headerIn), 0 /* largestAcked */); - auto packet = packetToBuf(std::move(builder).buildPacket()); - auto packetQueue = bufToQueue(std::move(packet)); + Buf retryPacketEncoded = std::make_unique(); + BufAppender appender(retryPacketEncoded.get(), 100); + + appender.writeBE(initialByte); + appender.writeBE(static_cast(quicVersion)); + + appender.writeBE(dstConnId.size()); + appender.push(dstConnId.data(), dstConnId.size()); + appender.writeBE(srcConnId.size()); + appender.push(srcConnId.data(), srcConnId.size()); + + appender.push((const uint8_t*)token.data(), token.size()); + appender.push((const uint8_t*)integrityTag.data(), integrityTag.size()); + + auto packetQueue = bufToQueue(std::move(retryPacketEncoded)); AckStates ackStates; auto result = makeUnencryptedCodec()->parsePacket(packetQueue, ackStates); - auto& retryPacket = *result.regularPacket(); + auto retryPacket = result.retryPacket(); + EXPECT_TRUE(retryPacket); - auto headerOut = *retryPacket.header.asLong(); + auto headerOut = retryPacket->header; - EXPECT_EQ(*headerOut.getOriginalDstConnId(), getTestConnectionId(110)); EXPECT_EQ(headerOut.getVersion(), static_cast(0xffff)); - EXPECT_EQ(headerOut.getSourceConnId(), getTestConnectionId(70)); - EXPECT_EQ(headerOut.getDestinationConnId(), getTestConnectionId(90)); - - auto expected = std::string("fluffydog"); - EXPECT_EQ(headerOut.getToken(), expected); + EXPECT_EQ(headerOut.getSourceConnId(), srcConnId); + EXPECT_EQ(headerOut.getDestinationConnId(), dstConnId); + EXPECT_EQ(headerOut.getToken(), token); } TEST_F(QuicReadCodecTest, LongHeaderPacketLenMismatch) { diff --git a/quic/fizz/client/test/QuicClientTransportTest.cpp b/quic/fizz/client/test/QuicClientTransportTest.cpp index 7913d7bd2..381c83679 100644 --- a/quic/fizz/client/test/QuicClientTransportTest.cpp +++ b/quic/fizz/client/test/QuicClientTransportTest.cpp @@ -1257,7 +1257,8 @@ class FakeOneRttHandshakeLayer : public FizzClientHandshake { throw std::runtime_error("getApplicationProtocol not implemented"); } std::unique_ptr getRetryPacketCipher() override { - throw std::runtime_error("getRetryPacketCipher not implemented"); + FizzClientHandshake fizzClientHandshake(nullptr, nullptr); + return fizzClientHandshake.getRetryPacketCipher(); } void processSocketData(folly::IOBufQueue&) override { throw std::runtime_error("processSocketData not implemented"); @@ -4389,9 +4390,18 @@ TEST_F(QuicClientTransportAfterStartTest, BadStatelessResetWontCloseTransport) { } TEST_F(QuicClientTransportVersionAndRetryTest, RetryPacket) { + std::vector clientConnIdVec = {}; + ConnectionId clientConnId(clientConnIdVec); + + std::vector initialDstConnIdVec = { + 0x83, 0x94, 0xc8, 0xf0, 0x3e, 0x51, 0x57, 0x08}; + ConnectionId initialDstConnId(initialDstConnIdVec); + // Create a stream and attempt to send some data to the server auto qLogger = std::make_shared(VantagePoint::Client); client->getNonConstConn().qLogger = qLogger; + client->getNonConstConn().readCodec->setClientConnectionId(clientConnId); + client->getNonConstConn().initialDestinationConnectionId = initialDstConnId; StreamId streamId = *client->createBidirectionalStream(); auto write = IOBuf::copyBuffer("ice cream"); @@ -4409,22 +4419,24 @@ TEST_F(QuicClientTransportVersionAndRetryTest, RetryPacket) { // Make the server send a retry packet to the client. The server chooses a // connection id that the client must use in all future initial packets. - auto serverChosenConnId = getTestConnectionId(); + std::vector serverConnIdVec = { + 0xf0, 0x67, 0xa5, 0x50, 0x2a, 0x42, 0x62, 0xb5}; + ConnectionId serverChosenConnId(serverConnIdVec); - LongHeader headerIn( - LongHeader::Types::Retry, - serverChosenConnId, - *originalConnId, - 321, - QuicVersion::MVFST, - std::string("this is a retry token :)"), - *client->getConn().initialDestinationConnectionId); + std::string retryToken = "token"; + std::string integrityTag = + "\x1e\x5e\xc5\xb0\x14\xcb\xb1\xf0\xfd\x93\xdf\x40\x48\xc4\x46\xa6"; - RegularQuicPacketBuilder builder( - kDefaultUDPSendPacketLen, std::move(headerIn), 0 /* largestAcked */); - auto packet = packetToBuf(std::move(builder).buildPacket()); - - deliverData(packet->coalesce()); + folly::IOBuf retryPacketBuf; + BufAppender appender(&retryPacketBuf, 100); + appender.writeBE(0xFF); + appender.writeBE(static_cast(0xFF000019)); + appender.writeBE(clientConnId.size()); + appender.writeBE(serverConnIdVec.size()); + appender.push(serverConnIdVec.data(), serverConnIdVec.size()); + appender.push((const uint8_t*)retryToken.data(), retryToken.size()); + appender.push((const uint8_t*)integrityTag.data(), integrityTag.size()); + deliverData(retryPacketBuf.coalesce()); ASSERT_TRUE(bytesWrittenToNetwork); @@ -4442,16 +4454,9 @@ TEST_F(QuicClientTransportVersionAndRetryTest, RetryPacket) { auto& regularQuicPacket = *codecResult.regularPacket(); auto& header = *regularQuicPacket.header.asLong(); - std::vector indices = - getQLogEventIndices(QLogEventType::PacketReceived, qLogger); - EXPECT_EQ(indices.size(), 1); - auto tmp = std::move(qLogger->logs[indices[0]]); - auto event = dynamic_cast(tmp.get()); - EXPECT_EQ(event->packetType, toString(LongHeader::Types::Retry)); - EXPECT_EQ(header.getHeaderType(), LongHeader::Types::Initial); EXPECT_TRUE(header.hasToken()); - EXPECT_EQ(header.getToken(), std::string("this is a retry token :)")); + EXPECT_EQ(header.getToken(), std::string("token")); EXPECT_EQ(header.getDestinationConnId(), serverChosenConnId); eventbase_->loopOnce();