diff --git a/quic/api/QuicPacketScheduler-inl.h b/quic/api/QuicPacketScheduler-inl.h index 13c3d7d8d..48ddcdc22 100644 --- a/quic/api/QuicPacketScheduler-inl.h +++ b/quic/api/QuicPacketScheduler-inl.h @@ -43,10 +43,10 @@ folly::Optional AckScheduler::writeAcksImpl( // Use default ack delay for long headers. Usually long headers are sent // before crypto negotiation, so the peer might not know about the ack delay // exponent yet, so we use the default. - uint8_t ackDelayExponentToUse = folly::variant_match( - builder.getPacketHeader(), - [](const LongHeader&) { return kDefaultAckDelayExponent; }, - [&](const auto&) { return conn_.transportSettings.ackDelayExponent; }); + uint8_t ackDelayExponentToUse = + builder.getPacketHeader().getHeaderForm() == HeaderForm::Long + ? kDefaultAckDelayExponent + : conn_.transportSettings.ackDelayExponent; auto largestAckedPacketNum = *largestAckToSend(ackState_); auto ackingTime = ClockType::now(); DCHECK(ackState_.largestRecvdPacketTime.hasValue()) diff --git a/quic/api/QuicPacketScheduler.cpp b/quic/api/QuicPacketScheduler.cpp index a23bf982d..daac028a6 100644 --- a/quic/api/QuicPacketScheduler.cpp +++ b/quic/api/QuicPacketScheduler.cpp @@ -444,12 +444,9 @@ bool CryptoStreamScheduler::writeCryptoData(PacketBuilderInterface& builder) { } } if (cryptoDataWritten && conn_.nodeType == QuicNodeType::Client) { - bool initialPacket = folly::variant_match( - builder.getPacketHeader(), - [](const LongHeader& header) { - return header.getHeaderType() == LongHeader::Types::Initial; - }, - [](const auto&) { return false; }); + const LongHeader* longHeader = builder.getPacketHeader().asLong(); + bool initialPacket = + longHeader && longHeader->getHeaderType() == LongHeader::Types::Initial; if (initialPacket) { // This is the initial packet, we need to fill er up. while (builder.remainingSpaceInPkt() > 0) { @@ -521,9 +518,7 @@ CloningScheduler::scheduleFramesForPacket( for (auto iter = conn_.outstandingPackets.rbegin(); iter != conn_.outstandingPackets.rend(); ++iter) { - auto opPnSpace = folly::variant_match( - iter->packet.header, - [](const auto& h) { return h.getPacketNumberSpace(); }); + auto opPnSpace = iter->packet.header.getPacketNumberSpace(); if (opPnSpace != PacketNumberSpace::AppData) { continue; } @@ -532,9 +527,7 @@ CloningScheduler::scheduleFramesForPacket( // clone packet. So re-create a RegularQuicPacketBuilder every time. // TODO: We can avoid the copy & rebuild of the header by creating an // independent header builder. - auto builderPnSpace = folly::variant_match( - builder.getPacketHeader(), - [](const auto& h) { return h.getPacketNumberSpace(); }); + auto builderPnSpace = builder.getPacketHeader().getPacketNumberSpace(); CHECK_EQ(builderPnSpace, PacketNumberSpace::AppData); RegularQuicPacketBuilder regularBuilder( conn_.udpSendPacketLen, diff --git a/quic/api/QuicTransportFunctions.cpp b/quic/api/QuicTransportFunctions.cpp index b85fbe69b..3099d7422 100644 --- a/quic/api/QuicTransportFunctions.cpp +++ b/quic/api/QuicTransportFunctions.cpp @@ -195,14 +195,12 @@ void updateConnection( RegularQuicWritePacket packet, TimePoint sentTime, uint32_t encodedSize) { - auto packetNum = folly::variant_match( - packet.header, [](const auto& h) { return h.getPacketSequenceNum(); }); + auto packetNum = packet.header.getPacketSequenceNum(); bool retransmittable = false; // AckFrame and PaddingFrame are not retx-able. bool isHandshake = false; uint32_t connWindowUpdateSent = 0; uint32_t ackFrameCounter = 0; - auto packetNumberSpace = folly::variant_match( - packet.header, [](const auto& h) { return h.getPacketNumberSpace(); }); + auto packetNumberSpace = packet.header.getPacketNumberSpace(); VLOG(10) << nodeToString(conn.nodeType) << " sent packetNum=" << packetNum << " in space=" << packetNumberSpace << " size=" << encodedSize << " " << conn; @@ -233,9 +231,7 @@ void updateConnection( }, [&](const WriteCryptoFrame& writeCryptoFrame) { retransmittable = true; - auto protectionType = folly::variant_match( - packet.header, - [](const auto& h) { return h.getProtectionType(); }); + auto protectionType = packet.header.getProtectionType(); // NewSessionTicket is sent in crypto frame encrypted with 1-rtt key, // however, it is not part of handshake isHandshake = @@ -406,9 +402,7 @@ void updateConnection( conn.outstandingPackets.end(), packetNum, [&](const auto& packetWithTime, const auto& val) { - return folly::variant_match( - packetWithTime.packet.header, - [&val](const auto& h) { return h.getPacketSequenceNum() < val; }); + return packetWithTime.packet.header.getPacketSequenceNum() < val; }); conn.outstandingPackets.insert(packetIt, std::move(pkt)); @@ -453,14 +447,9 @@ HeaderBuilder LongHeaderBuilder(LongHeader::Types packetType) { const ConnectionId& dstConnId, PacketNum packetNum, QuicVersion version, - Buf token) { + const std::string& token) { return LongHeader( - packetType, - srcConnId, - dstConnId, - packetNum, - version, - token ? std::move(token) : nullptr); + packetType, srcConnId, dstConnId, packetNum, version, token); }; } @@ -469,7 +458,7 @@ HeaderBuilder ShortHeaderBuilder() { const ConnectionId& dstConnId, PacketNum packetNum, QuicVersion, - Buf) { + const std::string&) { return ShortHeader(ProtectionType::KeyPhaseZero, dstConnId, packetNum); }; } @@ -557,7 +546,7 @@ uint64_t writeCryptoAndAckDataToSocket( const PacketNumberCipher& headerCipher, QuicVersion version, uint64_t packetLimit, - Buf token) { + const std::string& token) { auto encryptionLevel = protectionTypeToEncryptionLevel( longHeaderTypeToProtectionType(packetType)); FrameScheduler scheduler = @@ -584,7 +573,7 @@ uint64_t writeCryptoAndAckDataToSocket( cleartextCipher, headerCipher, version, - token ? std::move(token) : nullptr); + token); VLOG_IF(10, written > 0) << nodeToString(connection.nodeType) << " written crypto and acks data type=" << packetType << " packets=" << written << " " @@ -718,10 +707,9 @@ void writeCloseCommon( const PacketNumberCipher& headerCipher) { // close is special, we're going to bypass all the packet sent logic for all // packets we send with a connection close frame. - auto pnSpace = folly::variant_match( - header, [](const auto& h) { return h.getPacketNumberSpace(); }); - PacketNum packetNum = folly::variant_match( - header, [](const auto& h) { return h.getPacketSequenceNum(); }); + PacketNumberSpace pnSpace = header.getPacketNumberSpace(); + HeaderForm headerForm = header.getHeaderForm(); + PacketNum packetNum = header.getPacketSequenceNum(); RegularQuicPacketBuilder packetBuilder( connection.udpSendPacketLen, std::move(header), @@ -760,10 +748,6 @@ void writeCloseCommon( auto packet = std::move(packetBuilder).buildPacket(); auto body = aead.encrypt(std::move(packet.body), packet.header.get(), packetNum); - HeaderForm headerForm = folly::variant_match( - header, - [](const ShortHeader&) { return HeaderForm::Short; }, - [](const LongHeader&) { return HeaderForm::Long; }); encryptPacketHeader(headerForm, *packet.header, *body, headerCipher); auto packetBuf = std::move(packet.header); packetBuf->prependChain(std::move(body)); @@ -891,7 +875,7 @@ uint64_t writeConnectionDataToSocket( const Aead& aead, const PacketNumberCipher& headerCipher, QuicVersion version, - Buf token) { + const std::string& token) { VLOG(10) << nodeToString(connection.nodeType) << " writing data using scheduler=" << scheduler.name() << " " << connection; @@ -917,12 +901,7 @@ uint64_t writeConnectionDataToSocket( } while (scheduler.hasData() && ioBufBatch.getPktSent() < packetLimit) { auto packetNum = getNextPacketNum(connection, pnSpace); - auto header = builder( - srcConnId, - dstConnId, - packetNum, - version, - token ? token->clone() : nullptr); + auto header = builder(srcConnId, dstConnId, packetNum, version, token); uint32_t writableBytes = folly::to(std::min( connection.udpSendPacketLen, writableBytesFunc(connection))); uint64_t cipherOverhead = aead.getCipherOverhead(); @@ -954,10 +933,7 @@ uint64_t writeConnectionDataToSocket( auto body = aead.encrypt(std::move(packet->body), packet->header.get(), packetNum); - HeaderForm headerForm = folly::variant_match( - packet->packet.header, - [](const LongHeader&) { return HeaderForm::Long; }, - [](const ShortHeader&) { return HeaderForm::Short; }); + HeaderForm headerForm = packet->packet.header.getHeaderForm(); encryptPacketHeader(headerForm, *packet->header, *body, headerCipher); auto packetBuf = std::move(packet->header); diff --git a/quic/api/QuicTransportFunctions.h b/quic/api/QuicTransportFunctions.h index 5cb02a281..c80c241a0 100644 --- a/quic/api/QuicTransportFunctions.h +++ b/quic/api/QuicTransportFunctions.h @@ -25,7 +25,7 @@ using HeaderBuilder = std::function; + const std::string& token)>; using WritableBytesFunc = std::function; @@ -59,7 +59,7 @@ uint64_t writeCryptoAndAckDataToSocket( const PacketNumberCipher& headerCipher, QuicVersion version, uint64_t packetLimit, - Buf token = nullptr); + const std::string& token = std::string()); /** * Writes out all the data streams without writing out crypto streams. @@ -216,7 +216,7 @@ uint64_t writeConnectionDataToSocket( const Aead& aead, const PacketNumberCipher& headerCipher, QuicVersion version, - Buf token = nullptr); + const std::string& token = std::string()); uint64_t writeProbingDataToSocket( folly::AsyncUDPSocket& sock, diff --git a/quic/api/test/QuicPacketSchedulerTest.cpp b/quic/api/test/QuicPacketSchedulerTest.cpp index 60b466faf..a04e02691 100644 --- a/quic/api/test/QuicPacketSchedulerTest.cpp +++ b/quic/api/test/QuicPacketSchedulerTest.cpp @@ -376,7 +376,7 @@ TEST_F(QuicPacketSchedulerTest, WriteOnlyOutstandingPacketsTest) { EXPECT_EQ(packetNum, *result.first); // written packet (result.second) should not have any frame in the builder auto& writtenPacket = *result.second; - auto shortHeader = boost::get(&writtenPacket.packet.header); + auto shortHeader = writtenPacket.packet.header.asShort(); CHECK(shortHeader); EXPECT_EQ(ProtectionType::KeyPhaseOne, shortHeader->getProtectionType()); EXPECT_EQ( @@ -510,10 +510,10 @@ TEST_F(QuicPacketSchedulerTest, CloneSchedulerUseNormalSchedulerFirst) { EXPECT_CALL(mockScheduler, _scheduleFramesForPacket(_, _)) .Times(1) - .WillOnce( - Invoke([&, headerCopy = header]( - std::unique_ptr&, uint32_t) { - RegularQuicWritePacket packet(headerCopy); + .WillOnce(Invoke( + [&, headerCopy = header]( + std::unique_ptr&, uint32_t) mutable { + RegularQuicWritePacket packet(std::move(headerCopy)); packet.frames.push_back(MaxDataFrame(2832)); RegularQuicPacketBuilder::Packet builtPacket( std::move(packet), @@ -528,17 +528,12 @@ TEST_F(QuicPacketSchedulerTest, CloneSchedulerUseNormalSchedulerFirst) { auto result = cloningScheduler.scheduleFramesForPacket( std::move(builder), kDefaultUDPSendPacketLen); EXPECT_EQ(folly::none, result.first); - folly::variant_match( - result.second->packet.header, - [&](const ShortHeader& shortHeader) { - EXPECT_EQ(ProtectionType::KeyPhaseOne, shortHeader.getProtectionType()); - EXPECT_EQ( - conn.ackStates.appDataAckState.nextPacketNum, - shortHeader.getPacketSequenceNum()); - }, - [&](const LongHeader&) { - ASSERT_FALSE(true); // should not happen - }); + EXPECT_EQ(result.second->packet.header.getHeaderForm(), HeaderForm::Short); + ShortHeader& shortHeader = *result.second->packet.header.asShort(); + EXPECT_EQ(ProtectionType::KeyPhaseOne, shortHeader.getProtectionType()); + EXPECT_EQ( + conn.ackStates.appDataAckState.nextPacketNum, + shortHeader.getPacketSequenceNum()); EXPECT_EQ(1, result.second->packet.frames.size()); folly::variant_match( result.second->packet.frames.front(), diff --git a/quic/api/test/QuicTransportFunctionsTest.cpp b/quic/api/test/QuicTransportFunctionsTest.cpp index fae3beb59..c2fbee6e6 100644 --- a/quic/api/test/QuicTransportFunctionsTest.cpp +++ b/quic/api/test/QuicTransportFunctionsTest.cpp @@ -385,28 +385,20 @@ TEST_F(QuicTransportFunctionsTest, TestUpdateConnectionPacketSorting) { EXPECT_EQ(3, conn->outstandingPackets.size()); auto& firstHeader = conn->outstandingPackets.front().packet.header; - auto firstPacketNum = folly::variant_match( - firstHeader, [](const auto& h) { return h.getPacketSequenceNum(); }); + auto firstPacketNum = firstHeader.getPacketSequenceNum(); EXPECT_EQ(0, firstPacketNum); EXPECT_EQ(1, event1->packetNum); - EXPECT_EQ( - PacketNumberSpace::Initial, - folly::variant_match( - firstHeader, [](const auto& h) { return h.getPacketNumberSpace(); })); + EXPECT_EQ(PacketNumberSpace::Initial, firstHeader.getPacketNumberSpace()); auto& lastHeader = conn->outstandingPackets.back().packet.header; - auto lastPacketNum = folly::variant_match( - lastHeader, [](const auto& h) { return h.getPacketSequenceNum(); }); + auto lastPacketNum = lastHeader.getPacketSequenceNum(); EXPECT_EQ(2, lastPacketNum); EXPECT_EQ(2, event3->packetNum); - EXPECT_EQ( - PacketNumberSpace::AppData, - folly::variant_match( - lastHeader, [](const auto& h) { return h.getPacketNumberSpace(); })); + EXPECT_EQ(PacketNumberSpace::AppData, lastHeader.getPacketNumberSpace()); } TEST_F(QuicTransportFunctionsTest, TestUpdateConnectionFinOnly) { @@ -891,9 +883,7 @@ TEST_F(QuicTransportFunctionsTest, TestUpdateConnectionStreamWindowUpdate) { auto conn = createConn(); conn->qLogger = std::make_shared(); auto packet = buildEmptyPacket(*conn, PacketNumberSpace::Handshake); - auto packetNum = folly::variant_match( - packet.packet.header, - [](const auto& h) { return h.getPacketSequenceNum(); }); + auto packetNum = packet.packet.header.getPacketSequenceNum(); auto stream = conn->streamManager->createNextBidirectionalStream().value(); MaxStreamDataFrame streamWindowUpdate(stream->id, 0); conn->streamManager->queueWindowUpdate(stream->id); @@ -928,9 +918,7 @@ TEST_F(QuicTransportFunctionsTest, TestUpdateConnectionConnWindowUpdate) { auto conn = createConn(); conn->qLogger = std::make_shared(); auto packet = buildEmptyPacket(*conn, PacketNumberSpace::Handshake); - auto packetNum = folly::variant_match( - packet.packet.header, - [](const auto& h) { return h.getPacketSequenceNum(); }); + auto packetNum = packet.packet.header.getPacketSequenceNum(); conn->pendingEvents.connWindowUpdate = true; auto stream = conn->streamManager->createNextBidirectionalStream().value(); MaxDataFrame connWindowUpdate(conn->flowControlState.advertisedMaxOffset); diff --git a/quic/api/test/QuicTransportTest.cpp b/quic/api/test/QuicTransportTest.cpp index 66e4fdd0b..ae5ef6c3a 100644 --- a/quic/api/test/QuicTransportTest.cpp +++ b/quic/api/test/QuicTransportTest.cpp @@ -845,8 +845,7 @@ TEST_F(QuicTransportTest, WritePendingAckIfHavingData) { EXPECT_EQ(conn.ackStates.appDataAckState.largestAckScheduled, end); // Verify ack state after writing - auto pnSpace = folly::variant_match( - packet.header, [](const auto& h) { return h.getPacketNumberSpace(); }); + auto pnSpace = packet.header.getPacketNumberSpace(); auto ackState = getAckState(conn, pnSpace); EXPECT_EQ(ackState.largestAckScheduled, end); EXPECT_FALSE(ackState.needsToSendAckImmediately); diff --git a/quic/client/QuicClientTransport.cpp b/quic/client/QuicClientTransport.cpp index 7b1ed4372..1a1315e99 100644 --- a/quic/client/QuicClientTransport.cpp +++ b/quic/client/QuicClientTransport.cpp @@ -152,15 +152,11 @@ void QuicClientTransport::processPacketData( return; } - bool longHeader = folly::variant_match( - regularOptional->header, - [](const LongHeader&) { return true; }, - [](const ShortHeader&) { return false; }); + LongHeader* longHeader = regularOptional->header.asLong(); + ShortHeader* shortHeader = regularOptional->header.asShort(); - if (longHeader && - boost::get(regularOptional->header).getHeaderType() == - LongHeader::Types::Retry) { - if (clientConn_->retryToken_) { + if (longHeader && longHeader->getHeaderType() == LongHeader::Types::Retry) { + if (!clientConn_->retryToken.empty()) { VLOG(4) << "Server sent more than one retry packet"; return; } @@ -175,14 +171,12 @@ void QuicClientTransport::processPacketData( // better approach, but I don't know if it is a good indicator that we've // received an initial packet from the server. - auto header = boost::get(regularOptional->header); - const ConnectionId* dstConnId = &(*clientConn_->initialDestinationConnectionId); if (conn_->serverConnectionId) { dstConnId = &(*conn_->serverConnectionId); } - if (*header.getOriginalDstConnId() != *dstConnId) { + if (*longHeader->getOriginalDstConnId() != *dstConnId) { VLOG(4) << "Original destination connection id field in the retry " << "packet doesn't match the destination connection id from the " << "client's initial packet"; @@ -191,7 +185,7 @@ void QuicClientTransport::processPacketData( // Set the destination connection ID to be the value from the source // connection id of the retry packet - clientConn_->initialDestinationConnectionId = header.getSourceConnId(); + clientConn_->initialDestinationConnectionId = longHeader->getSourceConnId(); auto released = static_cast(conn_.release()); std::unique_ptr uniqueClient(released); @@ -200,7 +194,7 @@ void QuicClientTransport::processPacketData( clientConn_ = tempConn.get(); conn_ = std::move(tempConn); - clientConn_->retryToken_ = header.getToken()->clone(); + clientConn_->retryToken = longHeader->getToken(); if (conn_->qLogger) { conn_->qLogger->addPacket(*regularOptional, packetSize); @@ -210,18 +204,11 @@ void QuicClientTransport::processPacketData( return; } - auto protectionLevel = folly::variant_match( - regularOptional->header, - [](auto& header) { return header.getProtectionType(); }); - + auto protectionLevel = regularOptional->header.getProtectionType(); auto encryptionLevel = protectionTypeToEncryptionLevel(protectionLevel); - auto packetNum = folly::variant_match( - regularOptional->header, - [](const auto& h) { return h.getPacketSequenceNum(); }); - auto pnSpace = folly::variant_match( - regularOptional->header, - [](auto& header) { return header.getPacketNumberSpace(); }); + auto packetNum = regularOptional->header.getPacketSequenceNum(); + auto pnSpace = regularOptional->header.getPacketNumberSpace(); bool isProtectedPacket = protectionLevel == ProtectionType::KeyPhaseZero || protectionLevel == ProtectionType::KeyPhaseOne; @@ -255,31 +242,21 @@ void QuicClientTransport::processPacketData( } if (!conn_->serverConnectionId && longHeader) { - folly::Optional receivedSrcConnId(folly::variant_match( - regularOptional->header, - [&](const LongHeader& h) -> folly::Optional { - return h.getSourceConnId(); - }, - [](const ShortHeader&) -> folly::Optional { - return folly::none; - })); - // Assign the conn id to the server chosen connid. - if (!receivedSrcConnId) { - throw QuicTransportException( - "Expected long header with connection-id", - TransportErrorCode::PROTOCOL_VIOLATION); - } - conn_->serverConnectionId = std::move(receivedSrcConnId); + conn_->serverConnectionId = longHeader->getSourceConnId(); conn_->readCodec->setServerConnectionId(*conn_->serverConnectionId); } // Error out if the connection id on the packet is not the one that is // expected. - if (folly::variant_match( - regularOptional->header, - [](const LongHeader& h) { return h.getDestinationConnId(); }, - [](const ShortHeader& h) { return h.getConnectionId(); }) != - *conn_->clientConnectionId) { + bool connidMatched = true; + if (longHeader && longHeader->getDestinationConnId() != *conn_->clientConnectionId) { + connidMatched = false; + } else if ( + shortHeader && + shortHeader->getConnectionId() != *conn_->clientConnectionId) { + connidMatched = false; + } + if (!connidMatched) { throw QuicTransportException( "Invalid connection id", TransportErrorCode::PROTOCOL_VIOLATION); } @@ -303,9 +280,8 @@ void QuicClientTransport::processPacketData( [&](const OutstandingPacket& outstandingPacket, const QuicWriteFrame& packetFrame, const ReadAckFrame&) { - auto outstandingProtectionType = folly::variant_match( - outstandingPacket.packet.header, - [](const auto& h) { return h.getProtectionType(); }); + auto outstandingProtectionType = + outstandingPacket.packet.header.getProtectionType(); if (outstandingProtectionType == ProtectionType::KeyPhaseZero) { // If we received an ack for data that we sent in 1-rtt from // the server, we can assume that the server had successfully @@ -734,7 +710,7 @@ void QuicClientTransport::writeData() { *conn_->initialHeaderCipher, version, packetLimit, - clientConn_->retryToken_ ? clientConn_->retryToken_->clone() : nullptr); + clientConn_->retryToken); } if (!packetLimit) { return; diff --git a/quic/client/state/ClientStateMachine.h b/quic/client/state/ClientStateMachine.h index 493d26fce..69cd20c93 100644 --- a/quic/client/state/ClientStateMachine.h +++ b/quic/client/state/ClientStateMachine.h @@ -43,7 +43,7 @@ struct QuicClientConnectionState : public QuicConnectionStateBase { folly::Optional statelessResetToken; // The retry token sent by the server. - Buf retryToken_{nullptr}; + std::string retryToken; // Initial destination connection id. folly::Optional initialDestinationConnectionId; diff --git a/quic/client/test/QuicClientTransportTest.cpp b/quic/client/test/QuicClientTransportTest.cpp index c2106bd6d..d8981fb78 100644 --- a/quic/client/test/QuicClientTransportTest.cpp +++ b/quic/client/test/QuicClientTransportTest.cpp @@ -1514,9 +1514,7 @@ class QuicClientTransportTest : public Test { if (!parsedPacket) { continue; } - PacketNum packetNumSent = folly::variant_match( - parsedPacket->header, - [](auto& h) { return h.getPacketSequenceNum(); }); + PacketNum packetNumSent = parsedPacket->header.getPacketSequenceNum(); sentPackets.insert(packetNumSent); verifyShortHeader(*write); } @@ -1533,7 +1531,7 @@ class QuicClientTransportTest : public Test { if (!parsedPacket) { return false; } - auto longHeader = boost::get(&parsedPacket->header); + auto longHeader = parsedPacket->header.asLong(); return longHeader && longHeader->getHeaderType() == headerType; } @@ -1546,7 +1544,7 @@ class QuicClientTransportTest : public Test { if (!parsedPacket) { return false; } - return boost::get(&parsedPacket->header) != nullptr; + return parsedPacket->header.asShort(); } std::unique_ptr makeHandshakeCodec() { @@ -2716,16 +2714,12 @@ TEST_F(QuicClientTransportAfterStartTest, CloseConnectionWithStreamPending) { ASSERT_FALSE(client->getConn().outstandingPackets.empty()); IntervalSet acks; - auto start = folly::variant_match( - getFirstOutstandingPacket( - client->getNonConstConn(), PacketNumberSpace::AppData) - ->packet.header, - [](auto& h) { return h.getPacketSequenceNum(); }); - auto end = folly::variant_match( - getLastOutstandingPacket( - client->getNonConstConn(), PacketNumberSpace::AppData) - ->packet.header, - [](auto& h) { return h.getPacketSequenceNum(); }); + auto start = getFirstOutstandingPacket( + client->getNonConstConn(), PacketNumberSpace::AppData) + ->packet.header.getPacketSequenceNum(); + auto end = getLastOutstandingPacket( + client->getNonConstConn(), PacketNumberSpace::AppData) + ->packet.header.getPacketSequenceNum(); acks.insert(start, end); auto ackPacket = packetToBuf(createAckPacket( @@ -2795,16 +2789,12 @@ TEST_F(QuicClientTransportAfterStartTest, CloseConnectionWithNoStreamPending) { ASSERT_FALSE(client->getConn().outstandingPackets.empty()); IntervalSet acks; - auto start = folly::variant_match( - getFirstOutstandingPacket( - client->getNonConstConn(), PacketNumberSpace::AppData) - ->packet.header, - [](auto& h) { return h.getPacketSequenceNum(); }); - auto end = folly::variant_match( - getLastOutstandingPacket( - client->getNonConstConn(), PacketNumberSpace::AppData) - ->packet.header, - [](auto& h) { return h.getPacketSequenceNum(); }); + auto start = getFirstOutstandingPacket( + client->getNonConstConn(), PacketNumberSpace::AppData) + ->packet.header.getPacketSequenceNum(); + auto end = getLastOutstandingPacket( + client->getNonConstConn(), PacketNumberSpace::AppData) + ->packet.header.getPacketSequenceNum(); acks.insert(start, end); auto ackPacket = packetToBuf(createAckPacket( @@ -2925,16 +2915,12 @@ TEST_F(QuicClientTransportAfterStartTest, RecvAckOfCryptoStream) { // initial { IntervalSet acks; - auto start = folly::variant_match( - getFirstOutstandingPacket( - client->getNonConstConn(), PacketNumberSpace::Initial) - ->packet.header, - [](auto& h) { return h.getPacketSequenceNum(); }); - auto end = folly::variant_match( - getLastOutstandingPacket( - client->getNonConstConn(), PacketNumberSpace::Initial) - ->packet.header, - [](auto& h) { return h.getPacketSequenceNum(); }); + auto start = getFirstOutstandingPacket( + client->getNonConstConn(), PacketNumberSpace::Initial) + ->packet.header.getPacketSequenceNum(); + auto end = getLastOutstandingPacket( + client->getNonConstConn(), PacketNumberSpace::Initial) + ->packet.header.getPacketSequenceNum(); acks.insert(start, end); auto pn = initialPacketNum++; auto ackPkt = createAckPacket( @@ -2948,16 +2934,12 @@ TEST_F(QuicClientTransportAfterStartTest, RecvAckOfCryptoStream) { // handshake { IntervalSet acks; - auto start = folly::variant_match( - getFirstOutstandingPacket( - client->getNonConstConn(), PacketNumberSpace::Handshake) - ->packet.header, - [](auto& h) { return h.getPacketSequenceNum(); }); - auto end = folly::variant_match( - getLastOutstandingPacket( - client->getNonConstConn(), PacketNumberSpace::Handshake) - ->packet.header, - [](auto& h) { return h.getPacketSequenceNum(); }); + auto start = getFirstOutstandingPacket( + client->getNonConstConn(), PacketNumberSpace::Handshake) + ->packet.header.getPacketSequenceNum(); + auto end = getLastOutstandingPacket( + client->getNonConstConn(), PacketNumberSpace::Handshake) + ->packet.header.getPacketSequenceNum(); acks.insert(start, end); auto pn = handshakePacketNum++; auto ackPkt = createAckPacket( @@ -3209,8 +3191,7 @@ TEST_F(QuicClientTransportAfterStartTest, IdleTimerResetNoOutstandingPackets) { // This will clear out all the outstanding packets IntervalSet sentPackets; for (auto& packet : client->getNonConstConn().outstandingPackets) { - auto packetNum = folly::variant_match( - packet.packet.header, [](auto& h) { return h.getPacketSequenceNum(); }); + auto packetNum = packet.packet.header.getPacketSequenceNum(); sentPackets.insert(packetNum); } auto ackPacket = packetToBuf(createAckPacket( @@ -3659,7 +3640,7 @@ TEST_F(QuicClientTransportVersionAndRetryTest, RetryPacket) { *originalConnId, 321, QuicVersion::MVFST, - IOBuf::copyBuffer("this is a retry token :)"), + std::string("this is a retry token :)"), *client->getConn().initialDestinationConnectionId); RegularQuicPacketBuilder builder( @@ -3683,7 +3664,7 @@ TEST_F(QuicClientTransportVersionAndRetryTest, RetryPacket) { auto quicPacket = boost::get(&codecResult); auto regularQuicPacket = boost::get(quicPacket); - auto header = boost::get(regularQuicPacket->header); + auto& header = *regularQuicPacket->header.asLong(); std::vector indices = getQLogEventIndices(QLogEventType::PacketReceived, qLogger); @@ -3694,10 +3675,7 @@ TEST_F(QuicClientTransportVersionAndRetryTest, RetryPacket) { EXPECT_EQ(header.getHeaderType(), LongHeader::Types::Initial); EXPECT_TRUE(header.hasToken()); - folly::IOBufEqualTo eq; - EXPECT_TRUE( - eq(header.getToken()->clone(), - IOBuf::copyBuffer("this is a retry token :)"))); + EXPECT_EQ(header.getToken(), std::string("this is a retry token :)")); EXPECT_EQ(header.getDestinationConnId(), serverChosenConnId); eventbase_->loopOnce(); @@ -3903,9 +3881,7 @@ TEST_F(QuicClientTransportAfterStartTest, ResetClearsPendingLoss) { RegularQuicWritePacket* forceLossPacket = CHECK_NOTNULL(findPacketWithStream(client->getNonConstConn(), streamId)); - auto packetNum = folly::variant_match( - forceLossPacket->header, - [](const auto& h) { return h.getPacketSequenceNum(); }); + auto packetNum = forceLossPacket->header.getPacketSequenceNum(); markPacketLoss(client->getNonConstConn(), *forceLossPacket, false, packetNum); auto& pendingLossStreams = client->getConn().streamManager->lossStreams(); auto it = @@ -3932,9 +3908,7 @@ TEST_F(QuicClientTransportAfterStartTest, LossAfterResetStream) { RegularQuicWritePacket* forceLossPacket = CHECK_NOTNULL(findPacketWithStream(client->getNonConstConn(), streamId)); - auto packetNum = folly::variant_match( - forceLossPacket->header, - [](const auto& h) { return h.getPacketSequenceNum(); }); + auto packetNum = forceLossPacket->header.getPacketSequenceNum(); markPacketLoss(client->getNonConstConn(), *forceLossPacket, false, packetNum); auto stream = CHECK_NOTNULL( client->getNonConstConn().streamManager->getStream(streamId)); @@ -4427,12 +4401,8 @@ class QuicZeroRttClientTest : public QuicClientTransportAfterStartTest { bool zeroRttPacketsOutstanding() { for (auto& packet : client->getNonConstConn().outstandingPackets) { - bool isZeroRtt = folly::variant_match( - packet.packet.header, - [](const LongHeader& h) { - return h.getProtectionType() == ProtectionType::ZeroRtt; - }, - [](const ShortHeader&) { return false; }); + bool isZeroRtt = + packet.packet.header.getProtectionType() == ProtectionType::ZeroRtt; if (isZeroRtt) { return true; } diff --git a/quic/codec/Decode.cpp b/quic/codec/Decode.cpp index d66585470..edffdac03 100644 --- a/quic/codec/Decode.cpp +++ b/quic/codec/Decode.cpp @@ -104,10 +104,9 @@ ReadAckFrame decodeAckFrame( // and ack delay, the sender has to use something, so they use the default // ack delay. To keep it consistent the protocol specifies using the same // ack delay for all the long header packets. - uint8_t ackDelayExponentToUse = folly::variant_match( - header, - [](const LongHeader&) { return kDefaultAckDelayExponent; }, - [¶ms](auto&) { return params.peerAckDelayExponent; }); + uint8_t ackDelayExponentToUse = (header.getHeaderForm() == HeaderForm::Long) + ? kDefaultAckDelayExponent + : params.peerAckDelayExponent; DCHECK_LT(ackDelayExponentToUse, sizeof(ackDelay->first) * 8); // ackDelayExponentToUse is guaranteed to be less than the size of uint64_t uint64_t delayOverflowMask = 0xFFFFFFFFFFFFFFFF; @@ -986,7 +985,7 @@ folly::Expected parseLongHeaderVariants( LongHeader( type, std::move(parsedLongHeaderInvariant.invariant), - std::move(token), + token ? token->moveToFbString().toStdString() : std::string(), std::move(originalDstConnId)), PacketLength(0, 0)); } @@ -1037,7 +1036,7 @@ folly::Expected parseLongHeaderVariants( LongHeader( type, std::move(parsedLongHeaderInvariant.invariant), - std::move(token)), + token ? token->moveToFbString().toStdString() : std::string()), PacketLength(pktLen->first, pktLen->second)); } diff --git a/quic/codec/QuicHeaderCodec.h b/quic/codec/QuicHeaderCodec.h index d00aa64a9..2b2796d1b 100644 --- a/quic/codec/QuicHeaderCodec.h +++ b/quic/codec/QuicHeaderCodec.h @@ -17,6 +17,7 @@ namespace quic { struct ParsedHeaderResult { bool isVersionNegotiation; folly::Optional parsedHeader; + ParsedHeaderResult( bool isVersionNegotiationIn, folly::Optional parsedHeaderIn); diff --git a/quic/codec/QuicPacketBuilder.cpp b/quic/codec/QuicPacketBuilder.cpp index a17459342..394b52fb5 100644 --- a/quic/codec/QuicPacketBuilder.cpp +++ b/quic/codec/QuicPacketBuilder.cpp @@ -47,9 +47,9 @@ PacketNumEncodingResult encodeLongHeaderHelper( appender.writeBE(initialByte); bool isInitial = longHeader.getHeaderType() == LongHeader::Types::Initial; uint64_t tokenHeaderLength = 0; - auto token = longHeader.getToken(); + const std::string& token = longHeader.getToken(); if (isInitial) { - uint64_t tokenLength = token ? token->coalesce().size() : 0; + uint64_t tokenLength = token.size(); QuicInteger tokenLengthInt(tokenLength); tokenHeaderLength = tokenLengthInt.getSize() + tokenLength; } @@ -99,11 +99,11 @@ PacketNumEncodingResult encodeLongHeaderHelper( } if (isInitial) { - uint64_t tokenLength = token ? token->coalesce().size() : 0; + uint64_t tokenLength = token.size(); QuicInteger tokenLengthInt(tokenLength); tokenLengthInt.encode(appender); if (tokenLength > 0) { - appender.push(token->coalesce()); + appender.push(folly::StringPiece(token.data(), token.size())); } } @@ -113,8 +113,8 @@ PacketNumEncodingResult encodeLongHeaderHelper( appender.push(originalDstConnId->data(), originalDstConnId->size()); // Write the retry token - CHECK(token) << "Retry packet must contain a token"; - appender.insert(*token); + CHECK(!token.empty()) << "Retry packet must contain a token"; + appender.push(folly::StringPiece(token.data(), token.size())); } // defer write of the packet num and length till payload has been computed return encodedPacketNum; @@ -188,10 +188,7 @@ void RegularQuicPacketBuilder::appendFrame(QuicWriteFrame frame) { RegularQuicPacketBuilder::Packet RegularQuicPacketBuilder::buildPacket() && { // at this point everything should been set in the packet_ - bool isLongHeader = folly::variant_match( - packet_.header, - [](const LongHeader&) { return true; }, - [](const ShortHeader&) { return false; }); + LongHeader* longHeader = packet_.header.asLong(); size_t minBodySize = kMaxPacketNumEncodingSize - packetNumberEncoding_->length + sizeof(Sample); while (outputQueue_.chainLength() + cipherOverhead_ < minBodySize && @@ -201,9 +198,7 @@ RegularQuicPacketBuilder::Packet RegularQuicPacketBuilder::buildPacket() && { write(paddingType); } packet_.frames = std::move(quicFrames_); - if (isLongHeader && - boost::get(packet_.header).getHeaderType() != - LongHeader::Types::Retry) { + if (longHeader && longHeader->getHeaderType() != LongHeader::Types::Retry) { QuicInteger pktLen( packetNumberEncoding_->length + outputQueue_.chainLength() + cipherOverhead_); @@ -218,11 +213,11 @@ RegularQuicPacketBuilder::Packet RegularQuicPacketBuilder::buildPacket() && { void RegularQuicPacketBuilder::writeHeaderBytes( PacketNum largestAckedPacketNum) { - if (packet_.header.type() == typeid(LongHeader)) { - LongHeader& longHeader = boost::get(packet_.header); + if (packet_.header.getHeaderForm() == HeaderForm::Long) { + LongHeader& longHeader = *packet_.header.asLong(); encodeLongHeader(longHeader, largestAckedPacketNum); } else { - ShortHeader& shortHeader = boost::get(packet_.header); + ShortHeader& shortHeader = *packet_.header.asShort(); encodeShortHeader(shortHeader, largestAckedPacketNum); } } diff --git a/quic/codec/QuicPacketRebuilder.cpp b/quic/codec/QuicPacketRebuilder.cpp index 20bd7318f..104e8cf11 100644 --- a/quic/codec/QuicPacketRebuilder.cpp +++ b/quic/codec/QuicPacketRebuilder.cpp @@ -30,8 +30,7 @@ PacketEvent PacketRebuilder::cloneOutstandingPacket(OutstandingPacket& packet) { !packet.associatedEvent || conn_.outstandingPacketEvents.count(*packet.associatedEvent)); if (!packet.associatedEvent) { - auto packetNum = folly::variant_match( - packet.packet.header, [](auto& h) { return h.getPacketSequenceNum(); }); + auto packetNum = packet.packet.header.getPacketSequenceNum(); DCHECK(!conn_.outstandingPacketEvents.count(packetNum)); packet.associatedEvent = packetNum; conn_.outstandingPacketEvents.insert(packetNum); @@ -57,12 +56,11 @@ folly::Optional PacketRebuilder::rebuildFromPacket( writeSuccess = folly::variant_match( frame, [&](const WriteAckFrame& ackFrame) { - uint64_t ackDelayExponent = folly::variant_match( - builder_.getPacketHeader(), - [](const LongHeader&) { return kDefaultAckDelayExponent; }, - [&](const auto&) { - return conn_.transportSettings.ackDelayExponent; - }); + auto& packetHeader = builder_.getPacketHeader(); + uint64_t ackDelayExponent = + (packetHeader.getHeaderForm() == HeaderForm::Long) + ? kDefaultAckDelayExponent + : conn_.transportSettings.ackDelayExponent; AckFrameMetaData meta( ackFrame.ackBlocks, ackFrame.ackDelay, ackDelayExponent); auto ackWriteResult = writeAckFrame(meta, builder_); @@ -99,10 +97,10 @@ folly::Optional PacketRebuilder::rebuildFromPacket( // initialStream and handshakeStream can only be in handshake packet, // so they are not clonable CHECK(!packet.isHandshake); - folly::variant_match(packet.packet.header, [](const auto& header) { - // key update not supported - CHECK(header.getProtectionType() == ProtectionType::KeyPhaseZero); - }); + // key update not supported + DCHECK( + packet.packet.header.getProtectionType() == + ProtectionType::KeyPhaseZero); auto& stream = conn_.cryptoState->oneRttStream; auto buf = cloneCryptoRetransmissionBuffer(cryptoFrame, stream); diff --git a/quic/codec/Types.cpp b/quic/codec/Types.cpp index 65b4de27f..7a1fcc549 100644 --- a/quic/codec/Types.cpp +++ b/quic/codec/Types.cpp @@ -24,15 +24,158 @@ HeaderForm getHeaderForm(uint8_t headerValue) { return HeaderForm::Short; } +PacketHeader::PacketHeader(ShortHeader&& shortHeaderIn) + : headerForm_(HeaderForm::Short) { + new (&shortHeader) ShortHeader(std::move(shortHeaderIn)); +} + +PacketHeader::PacketHeader(LongHeader&& longHeaderIn) + : headerForm_(HeaderForm::Long) { + new (&longHeader) LongHeader(std::move(longHeaderIn)); +} + +PacketHeader::PacketHeader(const PacketHeader& other) + : headerForm_(other.headerForm_) { + switch (other.headerForm_) { + case HeaderForm::Long: + new (&longHeader) LongHeader(other.longHeader); + break; + case HeaderForm::Short: + new (&shortHeader) ShortHeader(other.shortHeader); + break; + } +} + +PacketHeader::PacketHeader(PacketHeader&& other) noexcept + : headerForm_(other.headerForm_) { + switch (other.headerForm_) { + case HeaderForm::Long: + new (&longHeader) LongHeader(std::move(other.longHeader)); + break; + case HeaderForm::Short: + new (&shortHeader) ShortHeader(std::move(other.shortHeader)); + break; + } +} + +PacketHeader& PacketHeader::operator=(PacketHeader&& other) noexcept { + destroyHeader(); + switch (other.headerForm_) { + case HeaderForm::Long: + new (&longHeader) LongHeader(std::move(other.longHeader)); + break; + case HeaderForm::Short: + new (&shortHeader) ShortHeader(std::move(other.shortHeader)); + break; + } + headerForm_ = other.headerForm_; + return *this; +} + +PacketHeader& PacketHeader::operator=(const PacketHeader& other) { + destroyHeader(); + switch (other.headerForm_) { + case HeaderForm::Long: + new (&longHeader) LongHeader(other.longHeader); + break; + case HeaderForm::Short: + new (&shortHeader) ShortHeader(other.shortHeader); + break; + } + headerForm_ = other.headerForm_; + return *this; +} + +PacketHeader::~PacketHeader() { + destroyHeader(); +} + +void PacketHeader::destroyHeader() { + switch (headerForm_) { + case HeaderForm::Long: + longHeader.~LongHeader(); + break; + case HeaderForm::Short: + shortHeader.~ShortHeader(); + break; + } +} + +LongHeader* PacketHeader::asLong() { + switch (headerForm_) { + case HeaderForm::Long: + return &longHeader; + case HeaderForm::Short: + return nullptr; + } +} + +ShortHeader* PacketHeader::asShort() { + switch (headerForm_) { + case HeaderForm::Long: + return nullptr; + case HeaderForm::Short: + return &shortHeader; + } +} + +const LongHeader* PacketHeader::asLong() const { + switch (headerForm_) { + case HeaderForm::Long: + return &longHeader; + case HeaderForm::Short: + return nullptr; + } +} + +const ShortHeader* PacketHeader::asShort() const { + switch (headerForm_) { + case HeaderForm::Long: + return nullptr; + case HeaderForm::Short: + return &shortHeader; + } +} + +PacketNum PacketHeader::getPacketSequenceNum() const { + switch (headerForm_) { + case HeaderForm::Long: + return longHeader.getPacketSequenceNum(); + case HeaderForm::Short: + return shortHeader.getPacketSequenceNum(); + } +} + +HeaderForm PacketHeader::getHeaderForm() const { + return headerForm_; +} + +ProtectionType PacketHeader::getProtectionType() const { + switch (headerForm_) { + case HeaderForm::Long: + return longHeader.getProtectionType(); + case HeaderForm::Short: + return shortHeader.getProtectionType(); + } +} + +PacketNumberSpace PacketHeader::getPacketNumberSpace() const { + switch (headerForm_) { + case HeaderForm::Long: + return longHeader.getPacketNumberSpace(); + case HeaderForm::Short: + return shortHeader.getPacketNumberSpace(); + } +} + LongHeader::LongHeader( Types type, LongHeaderInvariant invariant, - Buf token, + const std::string& token, folly::Optional originalDstConnId) - : headerForm_(HeaderForm::Long), - longHeaderType_(type), + : longHeaderType_(type), invariant_(std::move(invariant)), - token_(std::move(token)), + token_(token), originalDstConnId_(originalDstConnId) {} LongHeader::LongHeader( @@ -41,40 +184,13 @@ LongHeader::LongHeader( const ConnectionId& dstConnId, PacketNum packetNum, QuicVersion version, - Buf token, + const std::string& token, folly::Optional originalDstConnId) - : headerForm_(HeaderForm::Long), - longHeaderType_(type), + : longHeaderType_(type), invariant_(LongHeaderInvariant(version, srcConnId, dstConnId)), - packetSequenceNum_(packetNum), - token_(token ? std::move(token) : nullptr), - originalDstConnId_(originalDstConnId) {} - -LongHeader::LongHeader(const LongHeader& other) - : headerForm_(other.headerForm_), - longHeaderType_(other.longHeaderType_), - invariant_(other.invariant_), - packetSequenceNum_(other.packetSequenceNum_), - originalDstConnId_(other.originalDstConnId_) { - if (other.token_) { - token_ = other.token_->clone(); - } -} - -void LongHeader::setPacketNumber(PacketNum packetNum) { - packetSequenceNum_ = packetNum; -} - -LongHeader& LongHeader::operator=(const LongHeader& other) { - headerForm_ = other.headerForm_; - longHeaderType_ = other.longHeaderType_; - invariant_ = other.invariant_; - packetSequenceNum_ = other.packetSequenceNum_; - originalDstConnId_ = other.originalDstConnId_; - if (other.token_) { - token_ = other.token_->clone(); - } - return *this; + token_(token), + originalDstConnId_(originalDstConnId) { + setPacketNumber(packetNum); } LongHeader::Types LongHeader::getHeaderType() const noexcept { @@ -93,27 +209,31 @@ const folly::Optional& LongHeader::getOriginalDstConnId() const { return originalDstConnId_; } -PacketNum LongHeader::getPacketSequenceNum() const { - return *packetSequenceNum_; -} - QuicVersion LongHeader::getVersion() const { return invariant_.version; } bool LongHeader::hasToken() const { - return token_ ? true : false; + return !token_.empty(); } -folly::IOBuf* LongHeader::getToken() const { - return token_.get(); +const std::string& LongHeader::getToken() const { + return token_; +} + +PacketNum LongHeader::getPacketSequenceNum() const { + return packetSequenceNum_; +} + +void LongHeader::setPacketNumber(PacketNum packetNum) { + packetSequenceNum_ = packetNum; } ProtectionType LongHeader::getProtectionType() const { return longHeaderTypeToProtectionType(getHeaderType()); } -PacketNumberSpace LongHeader::getPacketNumberSpace() const noexcept { +PacketNumberSpace LongHeader::getPacketNumberSpace() const { return longHeaderTypeToPacketNumberSpace(getHeaderType()); } @@ -152,21 +272,17 @@ ShortHeader::ShortHeader( ProtectionType protectionType, ConnectionId connId, PacketNum packetNum) - : headerForm_(HeaderForm::Short), - protectionType_(protectionType), - connectionId_(std::move(connId)), - packetSequenceNum_(packetNum) { + : protectionType_(protectionType), connectionId_(std::move(connId)) { if (protectionType_ != ProtectionType::KeyPhaseZero && protectionType_ != ProtectionType::KeyPhaseOne) { throw QuicInternalException( "bad short header protection type", LocalErrorCode::CODEC_ERROR); } + setPacketNumber(packetNum); } ShortHeader::ShortHeader(ProtectionType protectionType, ConnectionId connId) - : headerForm_(HeaderForm::Short), - protectionType_(protectionType), - connectionId_(std::move(connId)) { + : protectionType_(protectionType), connectionId_(std::move(connId)) { if (protectionType_ != ProtectionType::KeyPhaseZero && protectionType_ != ProtectionType::KeyPhaseOne) { throw QuicInternalException( @@ -174,11 +290,11 @@ ShortHeader::ShortHeader(ProtectionType protectionType, ConnectionId connId) } } -ProtectionType ShortHeader::getProtectionType() const noexcept { +ProtectionType ShortHeader::getProtectionType() const { return protectionType_; } -PacketNumberSpace ShortHeader::getPacketNumberSpace() const noexcept { +PacketNumberSpace ShortHeader::getPacketNumberSpace() const { return PacketNumberSpace::AppData; } @@ -187,7 +303,7 @@ const ConnectionId& ShortHeader::getConnectionId() const { } PacketNum ShortHeader::getPacketSequenceNum() const { - return *packetSequenceNum_; + return packetSequenceNum_; } void ShortHeader::setPacketNumber(PacketNum packetNum) { diff --git a/quic/codec/Types.h b/quic/codec/Types.h index 946abda02..21813e638 100644 --- a/quic/codec/Types.h +++ b/quic/codec/Types.h @@ -627,6 +627,8 @@ struct LongHeaderInvariant { // TODO: split this into read and write types. struct LongHeader { public: + virtual ~LongHeader() = default; + static constexpr uint8_t kFixedBitMask = 0x40; static constexpr uint8_t kPacketTypeMask = 0x30; static constexpr uint8_t kReservedBitsMask = 0x0c; @@ -647,40 +649,33 @@ struct LongHeader { const ConnectionId& dstConnId, PacketNum packetNum, QuicVersion version, - Buf token = nullptr, + const std::string& token = std::string(), folly::Optional originalDstConnId = folly::none); LongHeader( Types type, LongHeaderInvariant invariant, - Buf token = nullptr, + const std::string& token = std::string(), folly::Optional originalDstConnId = folly::none); - void setPacketNumber(PacketNum packetNum); - - // Stuff stored in a variant type needs to be copyable. - // TODO: can we make this copyable only by the variant, but not - // by anyone else. - LongHeader(const LongHeader& other); - LongHeader& operator=(const LongHeader& other); - Types getHeaderType() const noexcept; const ConnectionId& getSourceConnId() const; const ConnectionId& getDestinationConnId() const; const folly::Optional& getOriginalDstConnId() const; - PacketNum getPacketSequenceNum() const; QuicVersion getVersion() const; + PacketNumberSpace getPacketNumberSpace() const; ProtectionType getProtectionType() const; - PacketNumberSpace getPacketNumberSpace() const noexcept; bool hasToken() const; - folly::IOBuf* getToken() const; + const std::string& getToken() const; + PacketNum getPacketSequenceNum() const; + + void setPacketNumber(PacketNum packetNum); private: - HeaderForm headerForm_; + PacketNum packetSequenceNum_{0}; Types longHeaderType_; LongHeaderInvariant invariant_; - folly::Optional packetSequenceNum_; // at most 32 bits on wire - Buf token_; + std::string token_; folly::Optional originalDstConnId_; }; @@ -692,6 +687,8 @@ struct ShortHeaderInvariant { struct ShortHeader { public: + virtual ~ShortHeader() = default; + // There is also a spin bit which is 0x20 that we don't currently implement. static constexpr uint8_t kFixedBitMask = 0x40; static constexpr uint8_t kReservedBitsMask = 0x18; @@ -712,11 +709,10 @@ struct ShortHeader { ConnectionId connId, PacketNum packetNum); - ProtectionType getProtectionType() const noexcept; - PacketNumberSpace getPacketNumberSpace() const noexcept; - - const ConnectionId& getConnectionId() const; + ProtectionType getProtectionType() const; + PacketNumberSpace getPacketNumberSpace() const; PacketNum getPacketSequenceNum() const; + const ConnectionId& getConnectionId() const; void setPacketNumber(PacketNum packetNum); @@ -729,16 +725,48 @@ struct ShortHeader { folly::io::Cursor& cursor); private: - HeaderForm headerForm_; + PacketNum packetSequenceNum_{0}; ProtectionType protectionType_; ConnectionId connectionId_; - folly::Optional packetSequenceNum_; // var-size 8/16/24/32 bits +}; + +struct PacketHeader { + ~PacketHeader(); + + /* implicit */ PacketHeader(LongHeader&& longHeader); + /* implicit */ PacketHeader(ShortHeader&& shortHeader); + + PacketHeader(PacketHeader&& other) noexcept; + PacketHeader(const PacketHeader& other); + + PacketHeader& operator=(PacketHeader&& other) noexcept; + PacketHeader& operator=(const PacketHeader& other); + + LongHeader* asLong(); + ShortHeader* asShort(); + + const LongHeader* asLong() const; + const ShortHeader* asShort() const; + + PacketNum getPacketSequenceNum() const; + HeaderForm getHeaderForm() const; + ProtectionType getProtectionType() const; + PacketNumberSpace getPacketNumberSpace() const; + + private: + void destroyHeader(); + + union { + LongHeader longHeader; + ShortHeader shortHeader; + }; + + HeaderForm headerForm_; }; ProtectionType longHeaderTypeToProtectionType(LongHeader::Types type); -PacketNumberSpace longHeaderTypeToPacketNumberSpace(LongHeader::Types type); -using PacketHeader = boost::variant; +PacketNumberSpace longHeaderTypeToPacketNumberSpace(LongHeader::Types type); struct StreamTypeField { public: @@ -791,8 +819,7 @@ struct VersionNegotiationPacket { struct RegularPacket { PacketHeader header; - explicit RegularPacket(PacketHeader&& headerIn) - : header(std::move(headerIn)) {} + explicit RegularPacket(PacketHeader&& headerIn) : header(std::move(headerIn)) {} }; /** @@ -842,17 +869,16 @@ inline std::ostream& operator<<( } inline std::ostream& operator<<(std::ostream& os, const PacketHeader& header) { - folly::variant_match( - header, - [&os](const LongHeader& h) { - os << "header=long" - << " protectionType=" << (int)h.getProtectionType() - << " type=" << std::hex << (int)h.getHeaderType(); - }, - [&os](const ShortHeader& h) { - os << "header=short" - << " protectionType=" << (int)h.getProtectionType(); - }); + auto shortHeader = header.asShort(); + if (shortHeader) { + os << "header=short" + << " protectionType=" << (int)shortHeader->getProtectionType(); + } else { + auto longHeader = header.asLong(); + os << "header=long" + << " protectionType=" << (int)longHeader->getProtectionType() + << " type=" << std::hex << (int)longHeader->getHeaderType(); + } return os; } diff --git a/quic/codec/test/DecodeTest.cpp b/quic/codec/test/DecodeTest.cpp index 5bb30fcbd..74ea52dd4 100644 --- a/quic/codec/test/DecodeTest.cpp +++ b/quic/codec/test/DecodeTest.cpp @@ -361,7 +361,6 @@ TEST_F(DecodeTest, AckFrameMissingFields) { ackBlocks.emplace_back(QuicInteger(10), QuicInteger(10)); ackBlocks.emplace_back(QuicInteger(10), QuicInteger(10)); - auto header = makeHeader(); auto result1 = createAckFrame( largestAcked, folly::none, @@ -373,7 +372,7 @@ TEST_F(DecodeTest, AckFrameMissingFields) { EXPECT_THROW( decodeAckFrame( cursor1, - header, + makeHeader(), CodecParameters(kDefaultAckDelayExponent, QuicVersion::MVFST)), QuicTransportException); @@ -383,7 +382,7 @@ TEST_F(DecodeTest, AckFrameMissingFields) { EXPECT_THROW( decodeAckFrame( cursor2, - header, + makeHeader(), CodecParameters(kDefaultAckDelayExponent, QuicVersion::MVFST)), QuicTransportException); @@ -393,7 +392,7 @@ TEST_F(DecodeTest, AckFrameMissingFields) { EXPECT_THROW( decodeAckFrame( cursor3, - header, + makeHeader(), CodecParameters(kDefaultAckDelayExponent, QuicVersion::MVFST)), QuicTransportException); @@ -403,7 +402,7 @@ TEST_F(DecodeTest, AckFrameMissingFields) { EXPECT_THROW( decodeAckFrame( cursor4, - header, + makeHeader(), CodecParameters(kDefaultAckDelayExponent, QuicVersion::MVFST)), QuicTransportException); @@ -413,7 +412,7 @@ TEST_F(DecodeTest, AckFrameMissingFields) { EXPECT_THROW( decodeAckFrame( cursor5, - header, + makeHeader(), CodecParameters(kDefaultAckDelayExponent, QuicVersion::MVFST)), QuicTransportException); } diff --git a/quic/codec/test/QuicHeaderCodecTest.cpp b/quic/codec/test/QuicHeaderCodecTest.cpp index c20baa13a..bf537ec5d 100644 --- a/quic/codec/test/QuicHeaderCodecTest.cpp +++ b/quic/codec/test/QuicHeaderCodecTest.cpp @@ -56,17 +56,12 @@ TEST_F(QuicHeaderCodecTest, ShortHeaderTest) { auto packet = std::move(builder).buildPacket(); auto result = parseHeader(*packet.header); auto& header = result->parsedHeader; - - EXPECT_EQ( - getTestConnectionId(), - folly::variant_match( - header.value(), - [](const LongHeader& longHeader) { - return longHeader.getDestinationConnId(); - }, - [](const ShortHeader& shortHeader) { - return shortHeader.getConnectionId(); - })); + LongHeader* longHeader = header->asLong(); + if (longHeader) { + EXPECT_EQ(getTestConnectionId(), longHeader->getDestinationConnId()); + } else { + EXPECT_EQ(getTestConnectionId(), header->asShort()->getConnectionId()); + } } } // namespace test } // namespace quic diff --git a/quic/codec/test/QuicPacketBuilderTest.cpp b/quic/codec/test/QuicPacketBuilderTest.cpp index 2bd490448..5f3c7e54f 100644 --- a/quic/codec/test/QuicPacketBuilderTest.cpp +++ b/quic/codec/test/QuicPacketBuilderTest.cpp @@ -121,7 +121,7 @@ TEST_F(QuicPacketBuilderTest, SimpleRetryPacket) { getTestConnectionId(1), 321, QuicVersion::MVFST, - folly::IOBuf::copyBuffer("454358"), + std::string("454358"), getTestConnectionId(2)); RegularQuicPacketBuilder builder( @@ -140,16 +140,15 @@ TEST_F(QuicPacketBuilderTest, SimpleRetryPacket) { EXPECT_NO_THROW(boost::get(decodedPacket)); auto retryPacket = boost::get(decodedPacket); - auto headerOut = boost::get(retryPacket.header); + auto& headerOut = *retryPacket.header.asLong(); EXPECT_EQ(*headerOut.getOriginalDstConnId(), getTestConnectionId(2)); EXPECT_EQ(headerOut.getVersion(), QuicVersion::MVFST); EXPECT_EQ(headerOut.getSourceConnId(), getTestConnectionId(0)); EXPECT_EQ(headerOut.getDestinationConnId(), getTestConnectionId(1)); - folly::IOBufEqualTo eq; - auto expectedBuf = folly::IOBuf::copyBuffer("454358"); - EXPECT_TRUE(eq(*headerOut.getToken(), *expectedBuf)); + auto expected = std::string("454358"); + EXPECT_EQ(headerOut.getToken(), expected); } TEST_F(QuicPacketBuilderTest, TooManyVersions) { @@ -210,8 +209,8 @@ TEST_F(QuicPacketBuilderTest, LongHeaderRegularPacket) { auto resultBuf = packetToBufCleartext( resultRegularPacket, *cleartextAead, *headerCipher, pktNum); auto& resultHeader = resultRegularPacket.packet.header; - EXPECT_NO_THROW(boost::get(resultHeader)); - auto& resultLongHeader = boost::get(resultHeader); + EXPECT_NE(resultHeader.asLong(), nullptr); + auto& resultLongHeader = *resultHeader.asLong(); EXPECT_EQ(LongHeader::Types::Initial, resultLongHeader.getHeaderType()); EXPECT_EQ(serverConnId, resultLongHeader.getSourceConnId()); EXPECT_EQ(pktNum, resultLongHeader.getPacketSequenceNum()); @@ -225,7 +224,7 @@ TEST_F(QuicPacketBuilderTest, LongHeaderRegularPacket) { auto decodedPacket = boost::get(optionalDecodedPacket); EXPECT_NO_THROW(boost::get(decodedPacket)); auto decodedRegularPacket = boost::get(decodedPacket); - auto& decodedHeader = boost::get(decodedRegularPacket.header); + auto& decodedHeader = *decodedRegularPacket.header.asLong(); EXPECT_EQ(LongHeader::Types::Initial, decodedHeader.getHeaderType()); EXPECT_EQ(clientConnId, decodedHeader.getDestinationConnId()); EXPECT_EQ(pktNum, decodedHeader.getPacketSequenceNum()); @@ -240,7 +239,7 @@ TEST_F(QuicPacketBuilderTest, ShortHeaderRegularPacket) { auto encodedPacketNum = encodePacketNumber(pktNum, largestAckedPacketNum); RegularQuicPacketBuilder builder( kDefaultUDPSendPacketLen, - PacketHeader(ShortHeader(ProtectionType::KeyPhaseZero, connId, pktNum)), + ShortHeader(ProtectionType::KeyPhaseZero, connId, pktNum), largestAckedPacketNum); // write out at least one frame @@ -255,7 +254,7 @@ TEST_F(QuicPacketBuilderTest, ShortHeaderRegularPacket) { EXPECT_EQ(builtOut.body->computeChainDataLength(), expectedOutputSize); auto resultBuf = packetToBuf(builtOut); - auto resultShortHeader = boost::get(resultRegularPacket.header); + auto& resultShortHeader = *resultRegularPacket.header.asShort(); EXPECT_EQ( ProtectionType::KeyPhaseZero, resultShortHeader.getProtectionType()); EXPECT_EQ(connId, resultShortHeader.getConnectionId()); @@ -270,7 +269,7 @@ TEST_F(QuicPacketBuilderTest, ShortHeaderRegularPacket) { ->parsePacket(packetQueue, ackStates); auto decodedPacket = boost::get(parsedPacket); auto decodedRegularPacket = boost::get(decodedPacket); - auto decodedHeader = boost::get(decodedRegularPacket.header); + auto& decodedHeader = *decodedRegularPacket.header.asShort(); EXPECT_EQ(ProtectionType::KeyPhaseZero, decodedHeader.getProtectionType()); EXPECT_EQ(connId, decodedHeader.getConnectionId()); EXPECT_EQ(pktNum, decodedHeader.getPacketSequenceNum()); @@ -284,7 +283,7 @@ TEST_F(QuicPacketBuilderTest, ShortHeaderWithNoFrames) { // frames already and will be too small to parse. RegularQuicPacketBuilder builder( kDefaultUDPSendPacketLen, - PacketHeader(ShortHeader(ProtectionType::KeyPhaseZero, connId, pktNum)), + ShortHeader(ProtectionType::KeyPhaseZero, connId, pktNum), 0 /* largestAcked */); EXPECT_TRUE(builder.canBuildPacket()); auto builtOut = std::move(builder).buildPacket(); @@ -312,7 +311,7 @@ TEST_F(QuicPacketBuilderTest, TestPaddingAccountsForCipherOverhead) { size_t cipherOverhead = 2; RegularQuicPacketBuilder builder( kDefaultUDPSendPacketLen, - PacketHeader(ShortHeader(ProtectionType::KeyPhaseZero, connId, pktNum)), + ShortHeader(ProtectionType::KeyPhaseZero, connId, pktNum), largestAckedPacketNum); builder.setCipherOverhead(cipherOverhead); EXPECT_TRUE(builder.canBuildPacket()); @@ -337,7 +336,7 @@ TEST_F(QuicPacketBuilderTest, TestPaddingRespectsRemainingBytes) { size_t totalPacketSize = 20; RegularQuicPacketBuilder builder( totalPacketSize, - PacketHeader(ShortHeader(ProtectionType::KeyPhaseZero, connId, pktNum)), + ShortHeader(ProtectionType::KeyPhaseZero, connId, pktNum), largestAckedPacketNum); EXPECT_TRUE(builder.canBuildPacket()); writeFrame(PaddingFrame(), builder); diff --git a/quic/codec/test/QuicPacketRebuilderTest.cpp b/quic/codec/test/QuicPacketRebuilderTest.cpp index e746c3d9d..0da47504d 100644 --- a/quic/codec/test/QuicPacketRebuilderTest.cpp +++ b/quic/codec/test/QuicPacketRebuilderTest.cpp @@ -40,8 +40,7 @@ class QuicPacketRebuilderTest : public Test {}; TEST_F(QuicPacketRebuilderTest, RebuildEmpty) { RegularQuicPacketBuilder regularBuilder( kDefaultUDPSendPacketLen, - PacketHeader( - ShortHeader(ProtectionType::KeyPhaseZero, getTestConnectionId(), 0)), + ShortHeader(ProtectionType::KeyPhaseZero, getTestConnectionId(), 0), 0 /* largestAcked */); QuicConnectionStateBase conn(QuicNodeType::Client); PacketRebuilder rebuilder(regularBuilder, conn); @@ -52,10 +51,10 @@ TEST_F(QuicPacketRebuilderTest, RebuildEmpty) { } TEST_F(QuicPacketRebuilderTest, RebuildPacket) { - ShortHeader shortHeader( + ShortHeader shortHeader1( ProtectionType::KeyPhaseZero, getTestConnectionId(), 0); RegularQuicPacketBuilder regularBuilder1( - kDefaultUDPSendPacketLen, shortHeader, 0 /* largestAcked */); + kDefaultUDPSendPacketLen, std::move(shortHeader1), 0 /* largestAcked */); // Get a bunch frames ConnectionCloseFrame connCloseFrame( @@ -107,8 +106,10 @@ TEST_F(QuicPacketRebuilderTest, RebuildPacket) { true); // rebuild a packet from the built out packet + ShortHeader shortHeader2( + ProtectionType::KeyPhaseZero, getTestConnectionId(), 0); RegularQuicPacketBuilder regularBuilder2( - kDefaultUDPSendPacketLen, shortHeader, 0 /* largestAcked */); + kDefaultUDPSendPacketLen, std::move(shortHeader2), 0 /* largestAcked */); PacketRebuilder rebuilder(regularBuilder2, conn); auto outstanding = makeDummyOutstandingPacket(packet1.packet, 1000); EXPECT_TRUE(rebuilder.rebuildFromPacket(outstanding).hasValue()); @@ -166,10 +167,10 @@ TEST_F(QuicPacketRebuilderTest, RebuildPacket) { } TEST_F(QuicPacketRebuilderTest, RebuildAfterResetStream) { - ShortHeader shortHeader( + ShortHeader shortHeader1( ProtectionType::KeyPhaseZero, getTestConnectionId(), 0); RegularQuicPacketBuilder regularBuilder1( - kDefaultUDPSendPacketLen, shortHeader, 0 /* largestAcked */); + kDefaultUDPSendPacketLen, std::move(shortHeader1), 0 /* largestAcked */); QuicServerConnectionState conn; conn.streamManager->setMaxLocalBidirectionalStreams(10); auto stream = conn.streamManager->createNextBidirectionalStream().value(); @@ -192,18 +193,20 @@ TEST_F(QuicPacketRebuilderTest, RebuildAfterResetStream) { conn, *stream, StreamEvents::SendReset(GenericApplicationErrorCode::UNKNOWN)); + ShortHeader shortHeader2( + ProtectionType::KeyPhaseZero, getTestConnectionId(), 0); RegularQuicPacketBuilder regularBuilder2( - kDefaultUDPSendPacketLen, shortHeader, 0 /* largestAcked */); + kDefaultUDPSendPacketLen, std::move(shortHeader2), 0 /* largestAcked */); PacketRebuilder rebuilder(regularBuilder2, conn); auto outstanding = makeDummyOutstandingPacket(packet1.packet, 1000); EXPECT_FALSE(rebuilder.rebuildFromPacket(outstanding).hasValue()); } TEST_F(QuicPacketRebuilderTest, FinOnlyStreamRebuild) { - ShortHeader shortHeader( + ShortHeader shortHeader1( ProtectionType::KeyPhaseZero, getTestConnectionId(), 0); RegularQuicPacketBuilder regularBuilder1( - kDefaultUDPSendPacketLen, shortHeader, 0 /* largestAcked */); + kDefaultUDPSendPacketLen, std::move(shortHeader1), 0 /* largestAcked */); QuicServerConnectionState conn; conn.streamManager->setMaxLocalBidirectionalStreams(10); auto stream = conn.streamManager->createNextBidirectionalStream().value(); @@ -216,8 +219,10 @@ TEST_F(QuicPacketRebuilderTest, FinOnlyStreamRebuild) { stream->retransmissionBuffer.begin(), nullptr, 0, true); // rebuild a packet from the built out packet + ShortHeader shortHeader2( + ProtectionType::KeyPhaseZero, getTestConnectionId(), 0); RegularQuicPacketBuilder regularBuilder2( - kDefaultUDPSendPacketLen, shortHeader, 0 /* largestAcked */); + kDefaultUDPSendPacketLen, std::move(shortHeader2), 0 /* largestAcked */); PacketRebuilder rebuilder(regularBuilder2, conn); auto outstanding = makeDummyOutstandingPacket(packet1.packet, 2000); EXPECT_TRUE(rebuilder.rebuildFromPacket(outstanding).hasValue()); @@ -236,10 +241,10 @@ TEST_F(QuicPacketRebuilderTest, FinOnlyStreamRebuild) { } TEST_F(QuicPacketRebuilderTest, RebuildDataStreamAndEmptyCryptoStream) { - ShortHeader shortHeader( + ShortHeader shortHeader1( ProtectionType::KeyPhaseZero, getTestConnectionId(), 0); RegularQuicPacketBuilder regularBuilder1( - kDefaultUDPSendPacketLen, shortHeader, 0 /* largestAcked */); + kDefaultUDPSendPacketLen, std::move(shortHeader1), 0 /* largestAcked */); // Get a bunch frames QuicServerConnectionState conn; @@ -270,8 +275,10 @@ TEST_F(QuicPacketRebuilderTest, RebuildDataStreamAndEmptyCryptoStream) { // imagine it was cleared // rebuild a packet from the built out packet + ShortHeader shortHeader2( + ProtectionType::KeyPhaseZero, getTestConnectionId(), 0); RegularQuicPacketBuilder regularBuilder2( - kDefaultUDPSendPacketLen, shortHeader, 0 /* largestAcked */); + kDefaultUDPSendPacketLen, std::move(shortHeader2), 0 /* largestAcked */); PacketRebuilder rebuilder(regularBuilder2, conn); auto outstanding = makeDummyOutstandingPacket(packet1.packet, 1000); EXPECT_TRUE(rebuilder.rebuildFromPacket(outstanding).hasValue()); @@ -295,10 +302,10 @@ TEST_F(QuicPacketRebuilderTest, RebuildDataStreamAndEmptyCryptoStream) { } TEST_F(QuicPacketRebuilderTest, CannotRebuildEmptyCryptoStream) { - ShortHeader shortHeader( + ShortHeader shortHeader1( ProtectionType::KeyPhaseZero, getTestConnectionId(), 0); RegularQuicPacketBuilder regularBuilder1( - kDefaultUDPSendPacketLen, shortHeader, 0 /* largestAcked */); + kDefaultUDPSendPacketLen, std::move(shortHeader1), 0 /* largestAcked */); // Get a bunch frames QuicServerConnectionState conn; @@ -313,18 +320,20 @@ TEST_F(QuicPacketRebuilderTest, CannotRebuildEmptyCryptoStream) { // imagine it was cleared // rebuild a packet from the built out packet + ShortHeader shortHeader2( + ProtectionType::KeyPhaseZero, getTestConnectionId(), 0); RegularQuicPacketBuilder regularBuilder2( - kDefaultUDPSendPacketLen, shortHeader, 0 /* largestAcked */); + kDefaultUDPSendPacketLen, std::move(shortHeader2), 0 /* largestAcked */); PacketRebuilder rebuilder(regularBuilder2, conn); auto outstanding = makeDummyOutstandingPacket(packet1.packet, 1000); EXPECT_FALSE(rebuilder.rebuildFromPacket(outstanding).hasValue()); } TEST_F(QuicPacketRebuilderTest, CannotRebuild) { - ShortHeader shortHeader( + ShortHeader shortHeader1( ProtectionType::KeyPhaseZero, getTestConnectionId(), 0); RegularQuicPacketBuilder regularBuilder1( - kDefaultUDPSendPacketLen, shortHeader, 0 /* largestAcked */); + kDefaultUDPSendPacketLen, std::move(shortHeader1), 0 /* largestAcked */); // Get a bunch frames ConnectionCloseFrame connCloseFrame( @@ -364,11 +373,13 @@ TEST_F(QuicPacketRebuilderTest, CannotRebuild) { stream->retransmissionBuffer.begin(), buf->clone(), 0, true); // new builder has a much smaller writable bytes limit + ShortHeader shortHeader2( + ProtectionType::KeyPhaseZero, getTestConnectionId(), 0); RegularQuicPacketBuilder regularBuilder2( (packet1.header->computeChainDataLength() + packet1.body->computeChainDataLength()) / 2, - shortHeader, + std::move(shortHeader2), 0 /* largestAcked */); PacketRebuilder rebuilder(regularBuilder2, conn); auto outstanding = makeDummyOutstandingPacket(packet1.packet, 1000); @@ -376,17 +387,19 @@ TEST_F(QuicPacketRebuilderTest, CannotRebuild) { } TEST_F(QuicPacketRebuilderTest, CloneCounter) { - ShortHeader shortHeader( + ShortHeader shortHeader1( ProtectionType::KeyPhaseZero, getTestConnectionId(), 0); RegularQuicPacketBuilder regularBuilder( - kDefaultUDPSendPacketLen, shortHeader, 0 /* largestAcked */); + kDefaultUDPSendPacketLen, std::move(shortHeader1), 0 /* largestAcked */); PingFrame pingFrame; writeFrame(pingFrame, regularBuilder); auto packet = std::move(regularBuilder).buildPacket(); auto outstandingPacket = makeDummyOutstandingPacket(packet.packet, 1000); QuicServerConnectionState conn; + ShortHeader shortHeader2( + ProtectionType::KeyPhaseZero, getTestConnectionId(), 0); RegularQuicPacketBuilder regularBuilder2( - kDefaultUDPSendPacketLen, shortHeader, 0 /* largestAcked */); + kDefaultUDPSendPacketLen, std::move(shortHeader2), 0 /* largestAcked */); PacketRebuilder rebuilder(regularBuilder2, conn); rebuilder.rebuildFromPacket(outstandingPacket); EXPECT_TRUE(outstandingPacket.associatedEvent.hasValue()); diff --git a/quic/codec/test/QuicReadCodecTest.cpp b/quic/codec/test/QuicReadCodecTest.cpp index 684be0448..15a917be5 100644 --- a/quic/codec/test/QuicReadCodecTest.cpp +++ b/quic/codec/test/QuicReadCodecTest.cpp @@ -109,7 +109,7 @@ TEST_F(QuicReadCodecTest, RetryPacketTest) { getTestConnectionId(90), 321, static_cast(0xffff), - folly::IOBuf::copyBuffer("fluffydog"), + std::string("fluffydog"), getTestConnectionId(110)); RegularQuicPacketBuilder builder( @@ -121,16 +121,15 @@ TEST_F(QuicReadCodecTest, RetryPacketTest) { auto retryPacket = boost::get(boost::get( makeUnencryptedCodec()->parsePacket(packetQueue, ackStates))); - auto headerOut = boost::get(retryPacket.header); + auto headerOut = *retryPacket.header.asLong(); EXPECT_EQ(*headerOut.getOriginalDstConnId(), getTestConnectionId(110)); EXPECT_EQ(headerOut.getVersion(), static_cast(0xffff)); EXPECT_EQ(headerOut.getSourceConnId(), getTestConnectionId(70)); EXPECT_EQ(headerOut.getDestinationConnId(), getTestConnectionId(90)); - folly::IOBufEqualTo eq; - auto expectedBuf = folly::IOBuf::copyBuffer("fluffydog"); - EXPECT_TRUE(eq(*headerOut.getToken(), *expectedBuf)); + auto expected = std::string("fluffydog"); + EXPECT_EQ(headerOut.getToken(), expected); } TEST_F(QuicReadCodecTest, EmptyVersionNegotiationPacketTest) { @@ -486,10 +485,10 @@ TEST_F(QuicReadCodecTest, TestInitialPacket) { EXPECT_NO_THROW(boost::get(quicPacket)); auto regularQuicPacket = boost::get(quicPacket); - EXPECT_NO_THROW(boost::get(regularQuicPacket.header)); - auto longPacketHeader = boost::get(regularQuicPacket.header); + EXPECT_NE(regularQuicPacket.header.asLong(), nullptr); + auto longPacketHeader = regularQuicPacket.header.asLong(); - EXPECT_FALSE(longPacketHeader.hasToken()); + EXPECT_FALSE(longPacketHeader->hasToken()); } TEST_F(QuicReadCodecTest, TestHandshakeDone) { diff --git a/quic/codec/test/QuicWriteCodecTest.cpp b/quic/codec/test/QuicWriteCodecTest.cpp index 5eb6057e8..fb3b3a676 100644 --- a/quic/codec/test/QuicWriteCodecTest.cpp +++ b/quic/codec/test/QuicWriteCodecTest.cpp @@ -124,7 +124,7 @@ TEST_F(QuicWriteCodecTest, WriteStreamFrameToEmptyPacket) { EXPECT_EQ( kDefaultUDPSendPacketLen - 3 - 10, pktBuilder.remainingSpaceInPkt()); auto builtOut = std::move(pktBuilder).buildPacket(); - auto regularPacket = builtOut.first; + auto regularPacket = std::move(builtOut.first); EXPECT_EQ(regularPacket.frames.size(), 1); auto resultFrame = boost::get(regularPacket.frames.back()); @@ -171,7 +171,7 @@ TEST_F(QuicWriteCodecTest, WriteStreamFrameToPartialPacket) { pktBuilder.remainingSpaceInPkt()); auto builtOut = std::move(pktBuilder).buildPacket(); - auto regularPacket = builtOut.first; + auto regularPacket = std::move(builtOut.first); EXPECT_EQ(regularPacket.frames.size(), 1); auto resultFrame = boost::get(regularPacket.frames.back()); EXPECT_EQ(resultFrame.streamId, streamId); @@ -237,7 +237,7 @@ TEST_F(QuicWriteCodecTest, WriteTwoStreamFrames) { kDefaultUDPSendPacketLen - consumedSize, pktBuilder.remainingSpaceInPkt()); auto builtOut = std::move(pktBuilder).buildPacket(); - auto regularPacket = builtOut.first; + auto regularPacket = std::move(builtOut.first); EXPECT_EQ(regularPacket.frames.size(), 2); auto resultFrame = boost::get(regularPacket.frames.front()); EXPECT_EQ(resultFrame.streamId, streamId1); diff --git a/quic/codec/test/TypesTest.cpp b/quic/codec/test/TypesTest.cpp index ed01a4fde..489ce60d9 100644 --- a/quic/codec/test/TypesTest.cpp +++ b/quic/codec/test/TypesTest.cpp @@ -61,7 +61,7 @@ folly::Expected makeLongHeader( getTestConnectionId(), 321, QuicVersion::QUIC_DRAFT, - IOBuf::copyBuffer("this is a retry token :)"), + std::string("this is a retry token :)"), getTestConnectionId()); RegularQuicPacketBuilder builder( @@ -261,5 +261,39 @@ TEST_F(TypesTest, LongHeaderPacketNumberSpace) { PacketNumberSpace::AppData, zeroRttLongHeader.getPacketNumberSpace()); } +class PacketHeaderTest : public Test {}; + +TEST_F(PacketHeaderTest, LongHeader) { + PacketNum packetNumber = 202; + LongHeader handshakeLongHeader( + LongHeader::Types::Handshake, + getTestConnectionId(4), + getTestConnectionId(5), + packetNumber, + QuicVersion::QUIC_DRAFT); + PacketHeader readHeader(std::move(handshakeLongHeader)); + EXPECT_NE(readHeader.asLong(), nullptr); + EXPECT_EQ(readHeader.asShort(), nullptr); + EXPECT_EQ(readHeader.getPacketSequenceNum(), packetNumber); + EXPECT_EQ(readHeader.getHeaderForm(), HeaderForm::Long); + EXPECT_EQ(readHeader.getProtectionType(), ProtectionType::Handshake); + EXPECT_EQ(readHeader.getPacketNumberSpace(), PacketNumberSpace::Handshake); + EXPECT_EQ(readHeader.asLong()->getHeaderType(), LongHeader::Types::Handshake); +} + +TEST_F(PacketHeaderTest, ShortHeader) { + PacketNum packetNumber = 202; + ConnectionId connid = getTestConnectionId(4); + ShortHeader shortHeader(ProtectionType::KeyPhaseZero, connid, packetNumber); + PacketHeader readHeader(std::move(shortHeader)); + EXPECT_EQ(readHeader.asLong(), nullptr); + EXPECT_NE(readHeader.asShort(), nullptr); + EXPECT_EQ(readHeader.getPacketSequenceNum(), packetNumber); + EXPECT_EQ(readHeader.getHeaderForm(), HeaderForm::Short); + EXPECT_EQ(readHeader.getProtectionType(), ProtectionType::KeyPhaseZero); + EXPECT_EQ(readHeader.getPacketNumberSpace(), PacketNumberSpace::AppData); + + EXPECT_EQ(readHeader.asShort()->getConnectionId(), connid); +} } // namespace test } // namespace quic diff --git a/quic/common/test/TestUtils.cpp b/quic/common/test/TestUtils.cpp index ca683f196..caf3b2341 100644 --- a/quic/common/test/TestUtils.cpp +++ b/quic/common/test/TestUtils.cpp @@ -25,10 +25,7 @@ getPreviousOutstandingPacket( std::deque::reverse_iterator from) { return std::find_if( from, conn.outstandingPackets.rend(), [=](const auto& op) { - return packetNumberSpace == - folly::variant_match(op.packet.header, [](const auto& h) { - return h.getPacketNumberSpace(); - }); + return packetNumberSpace == op.packet.header.getPacketNumberSpace(); }); } } // namespace @@ -65,11 +62,6 @@ const RegularQuicWritePacket& writeQuicPacket( return getLastOutstandingPacket(conn, PacketNumberSpace::AppData)->packet; } -PacketNum getPacketSequenceNum(const RegularQuicWritePacket& packet) { - return folly::variant_match( - packet.header, [](const auto& h) { return h.getPacketSequenceNum(); }); -} - PacketNum rstStreamAndSendPacket( QuicServerConnectionState& conn, folly::AsyncUDPSocket& sock, @@ -93,9 +85,7 @@ PacketNum rstStreamAndSendPacket( for (const auto& packet : conn.outstandingPackets) { for (const auto& frame : all_frames(packet.packet.frames)) { if (frame.streamId == stream.id) { - return folly::variant_match(packet.packet.header, [](const auto& h) { - return h.getPacketSequenceNum(); - }); + return packet.packet.header.getPacketSequenceNum(); } } } @@ -394,32 +384,24 @@ RegularQuicPacketBuilder::Packet createCryptoPacket( folly::Optional header; switch (protectionType) { case ProtectionType::Initial: - header = PacketHeader(LongHeader( - LongHeader::Types::Initial, - srcConnId, - dstConnId, - packetNum, - version)); + header = LongHeader( + LongHeader::Types::Initial, srcConnId, dstConnId, packetNum, version); break; case ProtectionType::Handshake: - header = PacketHeader(LongHeader( + header = LongHeader( LongHeader::Types::Handshake, srcConnId, dstConnId, packetNum, - version)); + version); break; case ProtectionType::ZeroRtt: - header = PacketHeader(LongHeader( - LongHeader::Types::ZeroRtt, - srcConnId, - dstConnId, - packetNum, - version)); + header = LongHeader( + LongHeader::Types::ZeroRtt, srcConnId, dstConnId, packetNum, version); break; case ProtectionType::KeyPhaseOne: case ProtectionType::KeyPhaseZero: - header = PacketHeader(ShortHeader(protectionType, dstConnId, packetNum)); + header = ShortHeader(protectionType, dstConnId, packetNum); break; } RegularQuicPacketBuilder builder( @@ -449,10 +431,7 @@ Buf packetToBufCleartext( if (packet.body) { body = packet.body->clone(); } - auto headerForm = folly::variant_match( - packet.packet.header, - [](const LongHeader&) { return HeaderForm::Long; }, - [](const ShortHeader&) { return HeaderForm::Short; }); + auto headerForm = packet.packet.header.getHeaderForm(); auto encryptedBody = cleartextCipher.encrypt(std::move(body), packet.header.get(), packetNum); encryptPacketHeader(headerForm, *packet.header, *encryptedBody, headerCipher); diff --git a/quic/common/test/TestUtils.h b/quic/common/test/TestUtils.h index 081da7291..78630d3ab 100644 --- a/quic/common/test/TestUtils.h +++ b/quic/common/test/TestUtils.h @@ -52,8 +52,6 @@ const RegularQuicWritePacket& writeQuicPacket( const folly::IOBuf& data, bool eof = false); -PacketNum getPacketSequenceNum(const RegularQuicWritePacket& packet); - RegularQuicPacketBuilder::Packet createAckPacket( QuicConnectionStateBase& dstConn, PacketNum pn, diff --git a/quic/congestion_control/Copa.cpp b/quic/congestion_control/Copa.cpp index a0abbfd55..798750fdc 100644 --- a/quic/congestion_control/Copa.cpp +++ b/quic/congestion_control/Copa.cpp @@ -50,10 +50,7 @@ void Copa::onPacketSent(const OutstandingPacket& packet) { VLOG(10) << __func__ << " writable=" << getWritableBytes() << " cwnd=" << cwndBytes_ << " inflight=" << bytesInFlight_ << " bytesBufferred=" << conn_.flowControlState.sumCurStreamBufferLen - << " packetNum=" - << folly::variant_match( - packet.packet.header, - [](auto& h) { return h.getPacketSequenceNum(); }) + << " packetNum=" << packet.packet.header.getPacketSequenceNum() << " " << conn_; if (conn_.qLogger) { conn_.qLogger->addCongestionMetricUpdate( diff --git a/quic/congestion_control/NewReno.cpp b/quic/congestion_control/NewReno.cpp index e192e9e0f..f9ad28e44 100644 --- a/quic/congestion_control/NewReno.cpp +++ b/quic/congestion_control/NewReno.cpp @@ -40,10 +40,7 @@ void NewReno::onPacketSent(const OutstandingPacket& packet) { addAndCheckOverflow(bytesInFlight_, packet.encodedSize); VLOG(10) << __func__ << " writable=" << getWritableBytes() << " cwnd=" << cwndBytes_ << " inflight=" << bytesInFlight_ - << " packetNum=" - << folly::variant_match( - packet.packet.header, - [](auto& h) { return h.getPacketSequenceNum(); }) + << " packetNum=" << packet.packet.header.getPacketSequenceNum() << " " << conn_; if (conn_.qLogger) { conn_.qLogger->addCongestionMetricUpdate( diff --git a/quic/logging/QLogger.cpp b/quic/logging/QLogger.cpp index 063ab1485..7e6824e13 100644 --- a/quic/logging/QLogger.cpp +++ b/quic/logging/QLogger.cpp @@ -23,17 +23,16 @@ std::unique_ptr QLogger::createPacketEvent( std::chrono::steady_clock::now() - refTimePoint); event->packetSize = packetSize; event->eventType = QLogEventType::PacketReceived; - event->packetType = folly::variant_match( - regularPacket.header, - [](const LongHeader& header) { return toString(header.getHeaderType()); }, - [](const ShortHeader& /* unused*/) { - return kShortHeaderPacketType.toString(); - }); + const ShortHeader* shortHeader = regularPacket.header.asShort(); + if (shortHeader) { + event->packetType = kShortHeaderPacketType.toString(); + } else { + event->packetType = + toString(regularPacket.header.asLong()->getHeaderType()); + } if (event->packetType != toString(LongHeader::Types::Retry)) { // A Retry packet does not include a packet number. - event->packetNum = folly::variant_match( - regularPacket.header, - [](const auto& h) { return h.getPacketSequenceNum(); }); + event->packetNum = regularPacket.header.getPacketSequenceNum(); } uint64_t numPaddingFrames = 0; @@ -149,17 +148,15 @@ std::unique_ptr QLogger::createPacketEvent( auto event = std::make_unique(); event->refTime = std::chrono::duration_cast( std::chrono::steady_clock::now() - refTimePoint); - event->packetNum = folly::variant_match( - writePacket.header, - [](const auto& h) { return h.getPacketSequenceNum(); }); + event->packetNum = writePacket.header.getPacketSequenceNum(); event->packetSize = packetSize; event->eventType = QLogEventType::PacketSent; - event->packetType = folly::variant_match( - writePacket.header, - [](const LongHeader& header) { return toString(header.getHeaderType()); }, - [](const ShortHeader& /* unused*/) { - return kShortHeaderPacketType.toString(); - }); + const ShortHeader* shortHeader = writePacket.header.asShort(); + if (shortHeader) { + event->packetType = kShortHeaderPacketType.toString(); + } else { + event->packetType = toString(writePacket.header.asLong()->getHeaderType()); + } uint64_t numPaddingFrames = 0; // looping through the packet to store logs created from frames in the packet diff --git a/quic/logging/test/QLoggerTest.cpp b/quic/logging/test/QLoggerTest.cpp index 9224e1535..eed38f513 100644 --- a/quic/logging/test/QLoggerTest.cpp +++ b/quic/logging/test/QLoggerTest.cpp @@ -50,7 +50,7 @@ TEST_F(QLoggerTest, TestRegularWritePacket) { TEST_F(QLoggerTest, TestRegularPacket) { auto headerIn = ShortHeader(ProtectionType::KeyPhaseZero, getTestConnectionId(1), 1); - RegularQuicPacket regularQuicPacket(headerIn); + RegularQuicPacket regularQuicPacket(std::move(headerIn)); ReadStreamFrame frame(streamId, offset, fin); regularQuicPacket.frames.emplace_back(std::move(frame)); @@ -341,7 +341,7 @@ TEST_F(QLoggerTest, QLoggerFollyDynamic) { auto headerIn = ShortHeader(ProtectionType::KeyPhaseZero, getTestConnectionId(1), 1); - RegularQuicPacket regularQuicPacket(headerIn); + RegularQuicPacket regularQuicPacket(std::move(headerIn)); ReadStreamFrame frame(streamId, offset, fin); regularQuicPacket.frames.emplace_back(std::move(frame)); @@ -384,7 +384,7 @@ TEST_F(QLoggerTest, RegularPacketFollyDynamic) { auto headerIn = ShortHeader(ProtectionType::KeyPhaseZero, getTestConnectionId(1), 1); - RegularQuicPacket regularQuicPacket(headerIn); + RegularQuicPacket regularQuicPacket(std::move(headerIn)); ReadStreamFrame frame(streamId, offset, fin); regularQuicPacket.frames.emplace_back(std::move(frame)); diff --git a/quic/loss/QuicLossFunctions.cpp b/quic/loss/QuicLossFunctions.cpp index ca30cd896..d8068e214 100644 --- a/quic/loss/QuicLossFunctions.cpp +++ b/quic/loss/QuicLossFunctions.cpp @@ -126,9 +126,7 @@ void markPacketLoss( if (processed) { return; } - auto protectionType = folly::variant_match( - packet.header, - [](auto& header) { return header.getProtectionType(); }); + auto protectionType = packet.header.getProtectionType(); auto encryptionLevel = protectionTypeToEncryptionLevel(protectionType); auto cryptoStream = diff --git a/quic/loss/QuicLossFunctions.h b/quic/loss/QuicLossFunctions.h index 6c52eda3c..2cf7f0ef4 100644 --- a/quic/loss/QuicLossFunctions.h +++ b/quic/loss/QuicLossFunctions.h @@ -105,9 +105,8 @@ calculateAlarmDuration(const QuicConnectionStateBase& conn) { std::chrono::duration_cast( lastSentPacketTime + alarmDuration - now); } else { - auto lastSentPacketNum = folly::variant_match( - conn.outstandingPackets.back().packet.header, - [](const auto& h) { return h.getPacketSequenceNum(); }); + auto lastSentPacketNum = + conn.outstandingPackets.back().packet.header.getPacketSequenceNum(); VLOG(10) << __func__ << " alarm already due method=" << *alarmMethod << " lastSentPacketNum=" << lastSentPacketNum << " lastSentPacketTime=" @@ -224,15 +223,11 @@ folly::Optional detectLossPackets( bool shouldSetTimer = false; while (iter != conn.outstandingPackets.end()) { auto& pkt = *iter; - auto currentPacketNum = folly::variant_match( - pkt.packet.header, - [](const auto& h) { return h.getPacketSequenceNum(); }); + auto currentPacketNum = pkt.packet.header.getPacketSequenceNum(); if (currentPacketNum >= largestAcked) { break; } - auto currentPacketNumberSpace = folly::variant_match( - pkt.packet.header, - [](const auto& h) { return h.getPacketNumberSpace(); }); + auto currentPacketNumberSpace = pkt.packet.header.getPacketNumberSpace(); if (currentPacketNumberSpace != pnSpace) { iter++; continue; @@ -347,12 +342,8 @@ void onHandshakeAlarm( // the word "handshake" in our code base is unfortunately overloaded. if (iter->isHandshake) { auto& packet = *iter; - auto currentPacketNum = folly::variant_match( - packet.packet.header, - [](const auto& h) { return h.getPacketSequenceNum(); }); - auto currentPacketNumSpace = folly::variant_match( - packet.packet.header, - [](const auto& h) { return h.getPacketNumberSpace(); }); + auto currentPacketNum = packet.packet.header.getPacketSequenceNum(); + auto currentPacketNumSpace = packet.packet.header.getPacketNumberSpace(); VLOG(10) << "HandshakeAlarm, removing packetNum=" << currentPacketNum << " packetNumSpace=" << currentPacketNumSpace << " " << conn; DCHECK(!packet.pureAck); @@ -478,22 +469,15 @@ void markZeroRttPacketsLost( CongestionController::LossEvent lossEvent(ClockType::now()); auto iter = getFirstOutstandingPacket(conn, PacketNumberSpace::AppData); while (iter != conn.outstandingPackets.end()) { - DCHECK( - PacketNumberSpace::AppData == - folly::variant_match(iter->packet.header, [](const auto& h) { - return h.getPacketNumberSpace(); - })); + DCHECK_EQ( + iter->packet.header.getPacketNumberSpace(), PacketNumberSpace::AppData); auto isZeroRttPacket = - folly::variant_match(iter->packet.header, [&](const auto& h) { - return h.getProtectionType() == ProtectionType::ZeroRtt; - }); + iter->packet.header.getProtectionType() == ProtectionType::ZeroRtt; if (isZeroRttPacket) { auto& pkt = *iter; DCHECK(!pkt.pureAck); DCHECK(!pkt.isHandshake); - auto currentPacketNum = folly::variant_match( - pkt.packet.header, - [](const auto& h) { return h.getPacketSequenceNum(); }); + auto currentPacketNum = pkt.packet.header.getPacketSequenceNum(); bool processed = pkt.associatedEvent && !conn.outstandingPacketEvents.count(*pkt.associatedEvent); lossVisitor(conn, pkt.packet, processed, currentPacketNum); diff --git a/quic/loss/test/QuicLossFunctionsTest.cpp b/quic/loss/test/QuicLossFunctionsTest.cpp index 02e44d6c7..8dc8f1af0 100644 --- a/quic/loss/test/QuicLossFunctionsTest.cpp +++ b/quic/loss/test/QuicLossFunctionsTest.cpp @@ -120,9 +120,7 @@ auto testingLossMarkFunc(std::vector& lostPackets) { return [&lostPackets]( auto& /* conn */, auto& packet, bool processed, PacketNum) { if (!processed) { - auto packetNum = folly::variant_match(packet.header, [](const auto& h) { - return h.getPacketSequenceNum(); - }); + auto packetNum = packet.header.getPacketSequenceNum(); lostPackets.push_back(packetNum); } }; @@ -161,9 +159,13 @@ PacketNum QuicLossFunctionsTest::sendPacket( conn.ackStates.appDataAckState.nextPacketNum); break; } - auto packetNumberSpace = folly::variant_match( - *header, [](const auto& h) { return h.getPacketNumberSpace(); }); - + PacketNumberSpace packetNumberSpace; + auto shortHeader = header->asShort(); + if (shortHeader) { + packetNumberSpace = shortHeader->getPacketNumberSpace(); + } else { + packetNumberSpace = header->asLong()->getPacketNumberSpace(); + } RegularQuicPacketBuilder builder( conn.udpSendPacketLen, std::move(*header), @@ -199,9 +201,7 @@ PacketNum QuicLossFunctionsTest::sendPacket( conn.outstandingPackets.begin(), conn.outstandingPackets.end(), [&associatedEvent](const auto& packet) { - auto packetNum = folly::variant_match( - packet.packet.header, - [](const auto& h) { return h.getPacketSequenceNum(); }); + auto packetNum = packet.packet.header.getPacketSequenceNum(); return packetNum == *associatedEvent; }); if (it != conn.outstandingPackets.end()) { @@ -376,8 +376,7 @@ TEST_F(QuicLossFunctionsTest, TestMarkPacketLoss) { EXPECT_EQ(1, conn->outstandingPackets.size()); auto& packet = getFirstOutstandingPacket(*conn, PacketNumberSpace::AppData)->packet; - auto packetNum = folly::variant_match( - packet.header, [](const auto& h) { return h.getPacketSequenceNum(); }); + auto packetNum = packet.header.getPacketSequenceNum(); markPacketLoss(*conn, packet, false, packetNum); EXPECT_EQ(stream1->retransmissionBuffer.size(), 0); EXPECT_EQ(stream2->retransmissionBuffer.size(), 0); @@ -417,8 +416,7 @@ TEST_F(QuicLossFunctionsTest, TestMarkCryptoLostAfterCancelRetransmission) { ASSERT_EQ(conn->outstandingPackets.size(), 1); EXPECT_GT(conn->cryptoState->handshakeStream.retransmissionBuffer.size(), 0); auto& packet = conn->outstandingPackets.front().packet; - auto packetNum = folly::variant_match( - packet.header, [](const auto& h) { return h.getPacketSequenceNum(); }); + auto packetNum = packet.header.getPacketSequenceNum(); cancelHandshakeCryptoStreamRetransmissions(*conn->cryptoState); markPacketLoss(*conn, packet, false, packetNum); EXPECT_EQ(conn->cryptoState->handshakeStream.retransmissionBuffer.size(), 0); @@ -445,13 +443,7 @@ TEST_F(QuicLossFunctionsTest, TestMarkPacketLossAfterStreamReset) { *stream1, StreamEvents::SendReset(GenericApplicationErrorCode::UNKNOWN)); - markPacketLoss( - *conn, - packet, - false, - folly::variant_match(packet.header, [](const auto& h) { - return h.getPacketSequenceNum(); - })); + markPacketLoss(*conn, packet, false, packet.header.getPacketSequenceNum()); EXPECT_TRUE(stream1->lossBuffer.empty()); EXPECT_TRUE(stream1->retransmissionBuffer.empty()); @@ -470,9 +462,7 @@ TEST_F(QuicLossFunctionsTest, TestReorderingThreshold) { auto testingLossMarkFunc = [&lostPacket](auto& /*conn*/, auto& packet, bool, PacketNum) { - auto packetNum = folly::variant_match(packet.header, [](const auto& h) { - return h.getPacketSequenceNum(); - }); + auto packetNum = packet.header.getPacketSequenceNum(); lostPacket.push_back(packetNum); }; for (int i = 0; i < 6; ++i) { @@ -518,9 +508,8 @@ TEST_F(QuicLossFunctionsTest, TestReorderingThreshold) { // Packet 6 should remain in packet as the delta is less than threshold EXPECT_EQ(conn->outstandingPackets.size(), 1); - auto packetNum = folly::variant_match( - conn->outstandingPackets.front().packet.header, - [](const auto& h) { return h.getPacketSequenceNum(); }); + auto packetNum = + conn->outstandingPackets.front().packet.header.getPacketSequenceNum(); EXPECT_EQ(packetNum, 6); } @@ -644,13 +633,7 @@ TEST_F(QuicLossFunctionsTest, TestMarkRstLoss) { EXPECT_TRUE(conn->pendingEvents.resets.empty()); auto& packet = getFirstOutstandingPacket(*conn, PacketNumberSpace::AppData)->packet; - markPacketLoss( - *conn, - packet, - false, - folly::variant_match(packet.header, [](const auto& h) { - return h.getPacketSequenceNum(); - })); + markPacketLoss(*conn, packet, false, packet.header.getPacketSequenceNum()); EXPECT_EQ(1, conn->pendingEvents.resets.size()); EXPECT_EQ(1, conn->pendingEvents.resets.count(stream->id)); @@ -738,8 +721,7 @@ TEST_F(QuicLossFunctionsTest, TestMarkWindowUpdateLoss) { auto& packet = getFirstOutstandingPacket(*conn, PacketNumberSpace::AppData)->packet; - auto packetNum = folly::variant_match( - packet.header, [](const auto& h) { return h.getPacketSequenceNum(); }); + auto packetNum = packet.header.getPacketSequenceNum(); markPacketLoss(*conn, packet, false, packetNum); EXPECT_TRUE(conn->streamManager->pendingWindowUpdate(stream->id)); } @@ -779,10 +761,8 @@ TEST_F(QuicLossFunctionsTest, TestTimeReordering) { // Packet 6, 7 should remain in outstanding packet list EXPECT_EQ(2, conn->outstandingPackets.size()); - auto packetNum = folly::variant_match( - getFirstOutstandingPacket(*conn, PacketNumberSpace::AppData) - ->packet.header, - [](const auto& h) { return h.getPacketSequenceNum(); }); + auto packetNum = getFirstOutstandingPacket(*conn, PacketNumberSpace::AppData) + ->packet.header.getPacketSequenceNum(); EXPECT_EQ(packetNum, 6); EXPECT_TRUE(conn->lossState.appDataLossTime); } @@ -1335,8 +1315,7 @@ TEST_F(QuicLossFunctionsTest, TestZeroRttRejected) { EXPECT_FALSE(lostPacket.second); } for (size_t i = 0; i < conn->outstandingPackets.size(); ++i) { - auto longHeader = - boost::get(&conn->outstandingPackets[i].packet.header); + auto longHeader = conn->outstandingPackets[i].packet.header.asLong(); EXPECT_FALSE( longHeader && longHeader->getProtectionType() == ProtectionType::ZeroRtt); @@ -1388,8 +1367,7 @@ TEST_F(QuicLossFunctionsTest, TestZeroRttRejectedWithClones) { } EXPECT_EQ(numProcessed, 1); for (size_t i = 0; i < conn->outstandingPackets.size(); ++i) { - auto longHeader = - boost::get(&conn->outstandingPackets[i].packet.header); + auto longHeader = conn->outstandingPackets[i].packet.header.asLong(); EXPECT_FALSE( longHeader && longHeader->getProtectionType() == ProtectionType::ZeroRtt); diff --git a/quic/server/state/ServerStateMachine.cpp b/quic/server/state/ServerStateMachine.cpp index 0f8acd1a8..73866c348 100644 --- a/quic/server/state/ServerStateMachine.cpp +++ b/quic/server/state/ServerStateMachine.cpp @@ -645,17 +645,11 @@ void onServerReadDataFromOpen( conn.infoCallback, onPacketDropped, PacketDropReason::INVALID_PACKET); continue; } - auto protectionLevel = folly::variant_match( - regularOptional->header, - [](auto& header) { return header.getProtectionType(); }); + auto protectionLevel = regularOptional->header.getProtectionType(); auto encryptionLevel = protectionTypeToEncryptionLevel(protectionLevel); - auto packetNum = folly::variant_match( - regularOptional->header, - [](const auto& h) { return h.getPacketSequenceNum(); }); - auto packetNumberSpace = folly::variant_match( - regularOptional->header, - [](auto& header) { return header.getPacketNumberSpace(); }); + auto packetNum = regularOptional->header.getPacketSequenceNum(); + auto packetNumberSpace = regularOptional->header.getPacketNumberSpace(); // TODO: enforce constraints on other protection levels. auto& regularPacket = *regularOptional; @@ -698,7 +692,12 @@ void onServerReadDataFromOpen( // We assume that the higher layer takes care of validating that the version // is supported. if (!conn.version) { - conn.version = boost::get(regularPacket.header).getVersion(); + LongHeader* longHeader = regularPacket.header.asLong(); + if (!longHeader) { + throw QuicTransportException( + "Invalid packet type", TransportErrorCode::PROTOCOL_VIOLATION); + } + conn.version = longHeader->getVersion(); } if (conn.peerAddress != readData.peer) { @@ -1096,16 +1095,9 @@ void onServerReadDataFromClosed( } auto& regularPacket = *regularOptional; - auto protectionLevel = folly::variant_match( - regularPacket.header, - [](auto& header) { return header.getProtectionType(); }); - - auto packetNum = folly::variant_match( - regularOptional->header, - [](const auto& h) { return h.getPacketSequenceNum(); }); - auto pnSpace = folly::variant_match( - regularOptional->header, - [](const auto& h) { return h.getPacketNumberSpace(); }); + auto protectionLevel = regularPacket.header.getProtectionType(); + auto packetNum = regularPacket.header.getPacketSequenceNum(); + auto pnSpace = regularPacket.header.getPacketNumberSpace(); if (conn.qLogger) { conn.qLogger->addPacket(regularPacket, packetSize); } diff --git a/quic/server/test/QuicServerTest.cpp b/quic/server/test/QuicServerTest.cpp index 720a39c3f..8d1306723 100644 --- a/quic/server/test/QuicServerTest.cpp +++ b/quic/server/test/QuicServerTest.cpp @@ -448,7 +448,7 @@ TEST_F(QuicServerWorkerTest, ZeroLengthConnectionId) { EXPECT_CALL(*transportInfoCb_, onPacketDropped(_)).Times(0); RegularQuicPacketBuilder builder( - kDefaultUDPSendPacketLen, header, 0 /* largestAcked */); + kDefaultUDPSendPacketLen, std::move(header), 0 /* largestAcked */); auto packet = packetToBuf(std::move(builder).buildPacket()); worker_->handleNetworkData(kClientAddr, std::move(packet), Clock::now()); eventbase_.loop(); @@ -463,7 +463,7 @@ TEST_F(QuicServerWorkerTest, ConnectionIdTooShort) { EXPECT_CALL(*transportInfoCb_, onPacketDropped(_)); RegularQuicPacketBuilder builder( - kDefaultUDPSendPacketLen, header, 0 /* largestAcked */); + kDefaultUDPSendPacketLen, std::move(header), 0 /* largestAcked */); auto packet = packetToBuf(std::move(builder).buildPacket()); worker_->handleNetworkData(kClientAddr, std::move(packet), Clock::now()); eventbase_.loop(); @@ -501,7 +501,7 @@ TEST_F(QuicServerWorkerTest, PacketAfterShutdown) { EXPECT_CALL(*factory_, _make(_, _, _, _)).Times(0); RegularQuicPacketBuilder builder( - kDefaultUDPSendPacketLen, header, 0 /* largestAcked */); + kDefaultUDPSendPacketLen, std::move(header), 0 /* largestAcked */); auto packet = packetToBuf(std::move(builder).buildPacket()); worker_->handleNetworkData(kClientAddr, std::move(packet), Clock::now()); eventbase_.terminateLoopSoon(); @@ -544,7 +544,7 @@ auto createInitialStream( destConnId, packetNum, version, - IOBuf::copyBuffer("this is a retry token :)"), + std::string("this is a retry token :)"), getTestConnectionId()); RegularQuicPacketBuilder builder( kDefaultUDPSendPacketLen, @@ -806,15 +806,12 @@ void QuicServerWorkerTakeoverTest::testPacketForwarding( // parse header and check connId to verify the integrity of the packet auto parsedHeader = parseHeader(*writtenData); auto& header = parsedHeader->parsedHeader; - const auto& connectionId = folly::variant_match( - header.value(), - [](const LongHeader& longHeader) { - return longHeader.getDestinationConnId(); - }, - [](const ShortHeader& shortHeader) { - return shortHeader.getConnectionId(); - }); - EXPECT_EQ(connId, connectionId); + LongHeader* longHeader = header->asLong(); + if (longHeader) { + EXPECT_EQ(connId, longHeader->getDestinationConnId()); + } else { + EXPECT_EQ(connId, header->asShort()->getConnectionId()); + } return data->computeChainDataLength(); })); takeoverWorker_->startPacketForwarding(folly::SocketAddress("0", 0)); diff --git a/quic/server/test/QuicServerTransportTest.cpp b/quic/server/test/QuicServerTransportTest.cpp index 252daff49..cf3696141 100644 --- a/quic/server/test/QuicServerTransportTest.cpp +++ b/quic/server/test/QuicServerTransportTest.cpp @@ -494,16 +494,12 @@ class QuicServerTransportTest : public Test { auto aead = getInitialCipher(); auto headerCipher = getInitialHeaderCipher(); IntervalSet acks; - auto start = folly::variant_match( - getFirstOutstandingPacket( - server->getNonConstConn(), PacketNumberSpace::Initial) - ->packet.header, - [](auto& h) { return h.getPacketSequenceNum(); }); - auto end = folly::variant_match( - getLastOutstandingPacket( - server->getNonConstConn(), PacketNumberSpace::Initial) - ->packet.header, - [](auto& h) { return h.getPacketSequenceNum(); }); + auto start = getFirstOutstandingPacket( + server->getNonConstConn(), PacketNumberSpace::Initial) + ->packet.header.getPacketSequenceNum(); + auto end = getLastOutstandingPacket( + server->getNonConstConn(), PacketNumberSpace::Initial) + ->packet.header.getPacketSequenceNum(); acks.insert(start, end); auto pn = clientNextInitialPacketNum++; auto ackPkt = createAckPacket( @@ -879,16 +875,12 @@ TEST_F(QuicServerTransportTest, TestCloseConnectionWithNoErrorPendingStreams) { loopForWrites(); IntervalSet acks; - auto start = folly::variant_match( - getFirstOutstandingPacket( - server->getNonConstConn(), PacketNumberSpace::AppData) - ->packet.header, - [](auto& h) { return h.getPacketSequenceNum(); }); - auto end = folly::variant_match( - getLastOutstandingPacket( - server->getNonConstConn(), PacketNumberSpace::AppData) - ->packet.header, - [](auto& h) { return h.getPacketSequenceNum(); }); + auto start = getFirstOutstandingPacket( + server->getNonConstConn(), PacketNumberSpace::AppData) + ->packet.header.getPacketSequenceNum(); + auto end = getLastOutstandingPacket( + server->getNonConstConn(), PacketNumberSpace::AppData) + ->packet.header.getPacketSequenceNum(); acks.insert(start, end); deliverData(packetToBuf(createAckPacket( server->getNonConstConn(), @@ -953,7 +945,9 @@ TEST_F(QuicServerTransportTest, ReceiveCloseAfterLocalError) { *server->getConn().serverConnectionId, clientNextAppDataPacketNum++); RegularQuicPacketBuilder builder( - server->getConn().udpSendPacketLen, header, 0 /* largestAcked */); + server->getConn().udpSendPacketLen, + std::move(header), + 0 /* largestAcked */); ASSERT_TRUE(builder.canBuildPacket()); // Deliver a reset to non existent stream to trigger a local conn error @@ -1071,30 +1065,25 @@ TEST_F(QuicServerTransportTest, TestOpenAckStreamFrame) { // We need more than one packet for this test. ASSERT_FALSE(server->getConn().outstandingPackets.empty()); - PacketNum packetNum1 = folly::variant_match( + PacketNum packetNum1 = getFirstOutstandingPacket( server->getNonConstConn(), PacketNumberSpace::AppData) - ->packet.header, - [](auto& h) { return h.getPacketSequenceNum(); }); + ->packet.header.getPacketSequenceNum(); - PacketNum lastPacketNum = folly::variant_match( + PacketNum lastPacketNum = getLastOutstandingPacket( server->getNonConstConn(), PacketNumberSpace::AppData) - ->packet.header, - [](auto& h) { return h.getPacketSequenceNum(); }); + ->packet.header.getPacketSequenceNum(); uint32_t buffersInPacket1 = 0; for (size_t i = 0; i < server->getNonConstConn().outstandingPackets.size(); ++i) { auto& packet = server->getNonConstConn().outstandingPackets[i]; - if (PacketNumberSpace::AppData != - folly::variant_match(packet.packet.header, [](auto& h) { - return h.getPacketNumberSpace(); - })) { + if (packet.packet.header.getPacketNumberSpace() != + PacketNumberSpace::AppData) { continue; } - PacketNum currentPacket = folly::variant_match( - packet.packet.header, [](auto& h) { return h.getPacketSequenceNum(); }); + PacketNum currentPacket = packet.packet.header.getPacketSequenceNum(); ASSERT_FALSE(packet.packet.frames.empty()); for (auto& quicFrame : packet.packet.frames) { auto frame = boost::get(&quicFrame); @@ -1159,11 +1148,10 @@ TEST_F(QuicServerTransportTest, TestOpenAckStreamFrame) { loopForWrites(); ASSERT_FALSE(server->getConn().outstandingPackets.empty()); - PacketNum finPacketNum = folly::variant_match( + PacketNum finPacketNum = getFirstOutstandingPacket( server->getNonConstConn(), PacketNumberSpace::AppData) - ->packet.header, - [](auto& h) { return h.getPacketSequenceNum(); }); + ->packet.header.getPacketSequenceNum(); IntervalSet acks3 = {{lastPacketNum, finPacketNum}}; auto packet4 = createAckPacket( @@ -1670,9 +1658,7 @@ TEST_F(QuicServerTransportTest, TestAckStopSending) { }; auto op = findOutstandingPacket(server->getNonConstConn(), match); ASSERT_TRUE(op != nullptr); - PacketNum packetNum = folly::variant_match( - op->packet.header, - [](const auto& h) { return h.getPacketSequenceNum(); }); + PacketNum packetNum = op->packet.header.getPacketSequenceNum(); IntervalSet acks = {{packetNum, packetNum}}; auto packet1 = createAckPacket( server->getNonConstConn(), @@ -2986,13 +2972,9 @@ TEST_F(QuicUnencryptedServerTransportTest, TestWriteHandshakeAndZeroRtt) { auto parsedPacket = boost::get(&result); CHECK(parsedPacket); auto& regularPacket = boost::get(*parsedPacket); - bool handshakePacket = folly::variant_match( - regularPacket.header, - [](const LongHeader& h) { - return h.getProtectionType() == ProtectionType::Initial || - h.getProtectionType() == ProtectionType::Handshake; - }, - [](const auto&) { return false; }); + ProtectionType protectionType = regularPacket.header.getProtectionType(); + bool handshakePacket = protectionType == ProtectionType::Initial || + protectionType == ProtectionType::Handshake; EXPECT_GE(regularPacket.frames.size(), 1); bool hasCryptoFrame = false; bool hasNonCryptoStream = false; diff --git a/quic/state/AckHandlers.cpp b/quic/state/AckHandlers.cpp index 2a70a9998..00bbd9e49 100644 --- a/quic/state/AckHandlers.cpp +++ b/quic/state/AckHandlers.cpp @@ -39,9 +39,7 @@ void processAckFrame( conn.outstandingPackets.end(), ackBlockIt->startPacket, [&](const auto& packetWithTime, const auto& val) { - return folly::variant_match( - packetWithTime.packet.header, - [&val](const auto& h) { return h.getPacketSequenceNum() < val; }); + return packetWithTime.packet.header.getPacketSequenceNum() < val; }); if (packetIt == conn.outstandingPackets.end()) { // This means that all the packets are less than the start packet. @@ -60,12 +58,9 @@ void processAckFrame( // or equal to crypto protection level. auto packetItEnd = packetIt; while (packetItEnd != conn.outstandingPackets.end()) { - auto currentPacketNum = folly::variant_match( - packetItEnd->packet.header, - [](const auto& h) { return h.getPacketSequenceNum(); }); - auto currentPacketNumberSpace = folly::variant_match( - packetItEnd->packet.header, - [](const auto& h) { return h.getPacketNumberSpace(); }); + auto currentPacketNum = packetItEnd->packet.header.getPacketSequenceNum(); + auto currentPacketNumberSpace = + packetItEnd->packet.header.getPacketNumberSpace(); if (pnSpace != currentPacketNumberSpace) { packetItEnd++; continue; diff --git a/quic/state/QuicStateFunctions.cpp b/quic/state/QuicStateFunctions.cpp index c776fcab4..6a03d8848 100644 --- a/quic/state/QuicStateFunctions.cpp +++ b/quic/state/QuicStateFunctions.cpp @@ -210,10 +210,7 @@ std::deque::iterator getNextOutstandingPacket( PacketNumberSpace packetNumberSpace, std::deque::iterator from) { return std::find_if(from, conn.outstandingPackets.end(), [=](const auto& op) { - return packetNumberSpace == - folly::variant_match(op.packet.header, [](const auto& h) { - return h.getPacketNumberSpace(); - }); + return packetNumberSpace == op.packet.header.getPacketNumberSpace(); }); } diff --git a/quic/state/StateData.h b/quic/state/StateData.h index ddf8f6479..54521b97d 100644 --- a/quic/state/StateData.h +++ b/quic/state/StateData.h @@ -215,9 +215,7 @@ struct CongestionController { "LossEvent: lostBytes overflow", LocalErrorCode::LOST_BYTES_OVERFLOW); } - auto packetNum = folly::variant_match( - packet.packet.header, - [](const auto& header) { return header.getPacketSequenceNum(); }); + PacketNum packetNum = packet.packet.header.getPacketSequenceNum(); largestLostPacketNum = std::max(packetNum, largestLostPacketNum.value_or(packetNum)); lostBytes += packet.encodedSize; diff --git a/quic/state/test/AckHandlersTest.cpp b/quic/state/test/AckHandlersTest.cpp index 4ea012496..c7ce1c9a8 100644 --- a/quic/state/test/AckHandlersTest.cpp +++ b/quic/state/test/AckHandlersTest.cpp @@ -29,8 +29,7 @@ class AckHandlersTest : public TestWithParam {}; auto testLossHandler(std::vector& lostPackets) -> decltype(auto) { return [&lostPackets]( QuicConnectionStateBase&, auto& packet, bool, PacketNum) { - auto packetNum = folly::variant_match( - packet.header, [](const auto& h) { return h.getPacketSequenceNum(); }); + auto packetNum = packet.header.getPacketSequenceNum(); lostPackets.push_back(packetNum); }; } @@ -102,9 +101,7 @@ TEST_P(AckHandlersTest, TestAckMultipleSequentialBlocks) { } PacketNum packetNum = 16; for (auto& packet : conn.outstandingPackets) { - auto currentPacketNum = folly::variant_match( - packet.packet.header, - [](const auto& h) { return h.getPacketSequenceNum(); }); + auto currentPacketNum = packet.packet.header.getPacketSequenceNum(); EXPECT_EQ(currentPacketNum, packetNum); packetNum++; } @@ -195,9 +192,7 @@ TEST_P(AckHandlersTest, TestAckBlocksWithGaps) { std::back_insert_iterator( actualPacketNumbers), [](const auto& packet) { - return folly::variant_match(packet.packet.header, [](const auto& h) { - return h.getPacketSequenceNum(); - }); + return packet.packet.header.getPacketSequenceNum(); }); EXPECT_TRUE(std::equal( @@ -299,9 +294,7 @@ TEST_P(AckHandlersTest, TestNonSequentialPacketNumbers) { std::back_insert_iterator( actualPacketNumbers), [](const auto& packet) { - return folly::variant_match(packet.packet.header, [](const auto& h) { - return h.getPacketSequenceNum(); - }); + return packet.packet.header.getPacketSequenceNum(); }); EXPECT_TRUE(std::equal( @@ -344,9 +337,8 @@ TEST_P(AckHandlersTest, AckVisitorForAckTest) { [&](const auto& outstandingPacket, const auto& packetFrame, const ReadAckFrame&) { - auto ackedPacketNum = folly::variant_match( - outstandingPacket.packet.header, - [](const auto& h) { return h.getPacketSequenceNum(); }); + auto ackedPacketNum = + outstandingPacket.packet.header.getPacketSequenceNum(); EXPECT_EQ(ackedPacketNum, firstReceivedAck.largestAcked); folly::variant_match( packetFrame,