/* * Copyright (c) Facebook, Inc. and its affiliates. * * This source code is licensed under the MIT license found in the * LICENSE file in the root directory of this source tree. * */ #include #include #include #include #include #include #include #include #include #include namespace quic { namespace test { std::function MockClock::mockNow; const RegularQuicWritePacket& writeQuicPacket( QuicServerConnectionState& conn, ConnectionId srcConnId, ConnectionId dstConnId, folly::test::MockAsyncUDPSocket& sock, QuicStreamState& stream, const folly::IOBuf& data, bool eof) { auto version = conn.version.value_or(*conn.originalVersion); auto aead = createNoOpAead(); auto headerCipher = createNoOpHeaderCipher(); writeDataToQuicStream(stream, data.clone(), eof); writeQuicDataToSocket( sock, conn, srcConnId, dstConnId, *aead, *headerCipher, version, conn.transportSettings.writeConnectionDataPacketsLimit); CHECK( conn.outstandingPackets.rend() != getLastOutstandingPacket(conn, PacketNumberSpace::AppData)); return getLastOutstandingPacket(conn, PacketNumberSpace::AppData)->packet; } PacketNum rstStreamAndSendPacket( QuicServerConnectionState& conn, folly::AsyncUDPSocket& sock, QuicStreamState& stream, ApplicationErrorCode errorCode) { auto aead = createNoOpAead(); auto headerCipher = createNoOpHeaderCipher(); auto version = conn.version.value_or(*conn.originalVersion); sendRstSMHandler(stream, errorCode); writeQuicDataToSocket( sock, conn, *conn.clientConnectionId, *conn.serverConnectionId, *aead, *headerCipher, version, conn.transportSettings.writeConnectionDataPacketsLimit); for (const auto& packet : conn.outstandingPackets) { for (const auto& frame : packet.packet.frames) { auto rstFrame = frame.asRstStreamFrame(); if (!rstFrame) { continue; } if (rstFrame->streamId == stream.id) { return packet.packet.header.getPacketSequenceNum(); } } } CHECK(false) << "no packet with reset stream"; // some compilers are weird. return 0; } RegularQuicPacketBuilder::Packet createAckPacket( QuicConnectionStateBase& dstConn, PacketNum pn, WriteAckFrame::AckBlocks& acks, PacketNumberSpace pnSpace, const Aead* aead) { // This function sends ACK to dstConn auto srcConnId = (dstConn.nodeType == QuicNodeType::Client ? *dstConn.serverConnectionId : *dstConn.clientConnectionId); auto dstConnId = (dstConn.nodeType == QuicNodeType::Client ? *dstConn.clientConnectionId : *dstConn.serverConnectionId); folly::Optional header; if (pnSpace == PacketNumberSpace::Initial) { header = LongHeader( LongHeader::Types::Initial, srcConnId, dstConnId, pn, QuicVersion::MVFST); } else if (pnSpace == PacketNumberSpace::Handshake) { header = LongHeader( LongHeader::Types::Handshake, srcConnId, dstConnId, pn, QuicVersion::MVFST); } else { header = ShortHeader(ProtectionType::KeyPhaseZero, dstConnId, pn); } RegularQuicPacketBuilder builder( dstConn.udpSendPacketLen, std::move(*header), getAckState(dstConn, pnSpace).largestAckedByPeer); if (aead) { builder.setCipherOverhead(aead->getCipherOverhead()); } DCHECK(builder.canBuildPacket()); AckFrameMetaData ackData( acks, 0us, dstConn.transportSettings.ackDelayExponent); writeAckFrame(ackData, builder); return std::move(builder).buildPacket(); } std::shared_ptr readCert() { auto certificate = fizz::test::getCert(fizz::test::kP256Certificate); auto privKey = fizz::test::getPrivateKey(fizz::test::kP256Key); std::vector certs; certs.emplace_back(std::move(certificate)); return std::make_shared>( std::move(privKey), std::move(certs)); } std::shared_ptr createServerCtx() { auto cert = readCert(); auto certManager = std::make_unique(); certManager->addCert(std::move(cert), true); auto serverCtx = std::make_shared(); serverCtx->setFactory(std::make_shared()); serverCtx->setCertManager(std::move(certManager)); serverCtx->setOmitEarlyRecordLayer(true); serverCtx->setClock(std::make_shared()); return serverCtx; } class AcceptingTicketCipher : public fizz::server::TicketCipher { public: ~AcceptingTicketCipher() override = default; folly::Future, std::chrono::seconds>>> encrypt(fizz::server::ResumptionState) const override { // Fake handshake, no need todo anything here. return std::make_pair(folly::IOBuf::create(0), 2s); } void setPsk(const QuicCachedPsk& cachedPsk) { cachedPsk_ = cachedPsk; } fizz::server::ResumptionState createResumptionState() const { fizz::server::ResumptionState resState; resState.version = cachedPsk_.cachedPsk.version; resState.cipher = cachedPsk_.cachedPsk.cipher; resState.resumptionSecret = folly::IOBuf::copyBuffer(cachedPsk_.cachedPsk.secret); resState.serverCert = cachedPsk_.cachedPsk.serverCert; resState.alpn = cachedPsk_.cachedPsk.alpn; resState.ticketAgeAdd = 0; resState.ticketIssueTime = std::chrono::system_clock::time_point(); resState.handshakeTime = std::chrono::system_clock::time_point(); auto version = cachedPsk_.transportParams.negotiatedVersion; AppToken appToken; appToken.transportParams = createTicketTransportParameters( kDefaultIdleTimeout.count(), kDefaultUDPReadBufferSize, kDefaultConnectionWindowSize, kDefaultStreamWindowSize, kDefaultStreamWindowSize, kDefaultStreamWindowSize, std::numeric_limits::max(), std::numeric_limits::max()); appToken.version = version; resState.appToken = encodeAppToken(appToken); return resState; } folly::Future< std::pair>> decrypt( std::unique_ptr, const fizz::server::State* /*state*/ = nullptr) const override { return std::make_pair(fizz::PskType::Resumption, createResumptionState()); } private: QuicCachedPsk cachedPsk_; }; void setupZeroRttOnServerCtx( fizz::server::FizzServerContext& serverCtx, const QuicCachedPsk& cachedPsk) { serverCtx.setEarlyDataSettings( true, fizz::server::ClockSkewTolerance{-100000ms, 100000ms}, std::make_shared()); auto ticketCipher = std::make_shared(); ticketCipher->setPsk(cachedPsk); serverCtx.setTicketCipher(ticketCipher); } QuicCachedPsk setupZeroRttOnClientCtx( fizz::client::FizzClientContext& clientCtx, std::string hostname, QuicVersion version) { clientCtx.setSendEarlyData(true); QuicCachedPsk quicCachedPsk; auto& psk = quicCachedPsk.cachedPsk; psk.psk = std::string("psk"); psk.secret = std::string("secret"); psk.type = fizz::PskType::Resumption; psk.version = clientCtx.getSupportedVersions()[0]; psk.cipher = clientCtx.getSupportedCiphers()[0]; psk.group = clientCtx.getSupportedGroups()[0]; auto mockCert = std::make_shared(); ON_CALL(*mockCert, getIdentity()).WillByDefault(Return(hostname)); psk.serverCert = mockCert; psk.alpn = clientCtx.getSupportedAlpns()[0]; psk.ticketAgeAdd = 1; psk.ticketIssueTime = std::chrono::system_clock::time_point(); psk.ticketExpirationTime = std::chrono::system_clock::time_point(std::chrono::minutes(100)); psk.ticketHandshakeTime = std::chrono::system_clock::time_point(); psk.maxEarlyDataSize = 2; quicCachedPsk.transportParams.negotiatedVersion = version; quicCachedPsk.transportParams.idleTimeout = kDefaultIdleTimeout.count(); quicCachedPsk.transportParams.maxRecvPacketSize = kDefaultUDPReadBufferSize; quicCachedPsk.transportParams.initialMaxData = kDefaultConnectionWindowSize; quicCachedPsk.transportParams.initialMaxStreamDataBidiLocal = kDefaultStreamWindowSize; quicCachedPsk.transportParams.initialMaxStreamDataBidiRemote = kDefaultStreamWindowSize; quicCachedPsk.transportParams.initialMaxStreamDataUni = kDefaultStreamWindowSize; quicCachedPsk.transportParams.initialMaxStreamsBidi = std::numeric_limits::max(); quicCachedPsk.transportParams.initialMaxStreamsUni = std::numeric_limits::max(); return quicCachedPsk; } void setupCtxWithTestCert(fizz::server::FizzServerContext& ctx) { auto cert = readCert(); auto certManager = std::make_unique(); certManager->addCert(std::move(cert), true); ctx.setCertManager(std::move(certManager)); } template std::unique_ptr createNoOpAeadImpl() { // Fake that the handshake has already occured auto aead = std::make_unique>(); ON_CALL(*aead, _encrypt(_, _, _)) .WillByDefault(Invoke([&](auto& buf, auto, auto) { if (buf) { return buf->clone(); } else { return folly::IOBuf::create(0); } })); // Fake that the handshake has already occured and fix the keys. ON_CALL(*aead, _decrypt(_, _, _)) .WillByDefault( Invoke([&](auto& buf, auto, auto) { return buf->clone(); })); ON_CALL(*aead, _tryDecrypt(_, _, _)) .WillByDefault( Invoke([&](auto& buf, auto, auto) { return buf->clone(); })); ON_CALL(*aead, getCipherOverhead()).WillByDefault(Return(0)); return aead; } std::unique_ptr createNoOpAead() { return createNoOpAeadImpl(); } std::unique_ptr createNoOpHeaderCipher() { auto headerCipher = std::make_unique>(); ON_CALL(*headerCipher, mask(_)).WillByDefault(Return(HeaderProtectionMask{})); ON_CALL(*headerCipher, keyLength()).WillByDefault(Return(16)); return headerCipher; } RegularQuicPacketBuilder::Packet createStreamPacket( ConnectionId srcConnId, ConnectionId dstConnId, PacketNum packetNum, StreamId streamId, folly::IOBuf& data, uint8_t cipherOverhead, PacketNum largestAcked, folly::Optional> longHeaderOverride, bool eof, folly::Optional shortHeaderOverride, uint64_t offset, uint64_t packetSizeLimit) { std::unique_ptr builder; if (longHeaderOverride) { LongHeader header( longHeaderOverride->first, srcConnId, dstConnId, packetNum, longHeaderOverride->second); builder.reset(new RegularQuicPacketBuilder( packetSizeLimit, std::move(header), largestAcked)); } else { ProtectionType protectionType = ProtectionType::KeyPhaseZero; if (shortHeaderOverride) { protectionType = *shortHeaderOverride; } ShortHeader header(protectionType, dstConnId, packetNum); builder.reset(new RegularQuicPacketBuilder( packetSizeLimit, std::move(header), largestAcked)); } builder->setCipherOverhead(cipherOverhead); writeStreamFrameHeader( *builder, streamId, offset, data.computeChainDataLength(), data.computeChainDataLength(), eof); writeStreamFrameData(*builder, data.clone(), data.computeChainDataLength()); return std::move(*builder).buildPacket(); } RegularQuicPacketBuilder::Packet createInitialCryptoPacket( ConnectionId srcConnId, ConnectionId dstConnId, PacketNum packetNum, QuicVersion version, folly::IOBuf& data, const Aead& aead, PacketNum largestAcked, uint64_t offset) { LongHeader header( LongHeader::Types::Initial, srcConnId, dstConnId, packetNum, version); RegularQuicPacketBuilder builder( kDefaultUDPSendPacketLen, std::move(header), largestAcked); builder.setCipherOverhead(aead.getCipherOverhead()); writeCryptoFrame(offset, data.clone(), builder); return std::move(builder).buildPacket(); } RegularQuicPacketBuilder::Packet createCryptoPacket( ConnectionId srcConnId, ConnectionId dstConnId, PacketNum packetNum, QuicVersion version, ProtectionType protectionType, folly::IOBuf& data, const Aead& aead, PacketNum largestAcked, uint64_t offset, uint64_t packetSizeLimit) { folly::Optional header; switch (protectionType) { case ProtectionType::Initial: header = LongHeader( LongHeader::Types::Initial, srcConnId, dstConnId, packetNum, version); break; case ProtectionType::Handshake: header = LongHeader( LongHeader::Types::Handshake, srcConnId, dstConnId, packetNum, version); break; case ProtectionType::ZeroRtt: header = LongHeader( LongHeader::Types::ZeroRtt, srcConnId, dstConnId, packetNum, version); break; case ProtectionType::KeyPhaseOne: case ProtectionType::KeyPhaseZero: header = ShortHeader(protectionType, dstConnId, packetNum); break; } RegularQuicPacketBuilder builder( packetSizeLimit, std::move(*header), largestAcked); builder.setCipherOverhead(aead.getCipherOverhead()); writeCryptoFrame(offset, data.clone(), builder); return std::move(builder).buildPacket(); } Buf packetToBuf(const RegularQuicPacketBuilder::Packet& packet) { auto packetBuf = packet.header->clone(); if (packet.body) { packetBuf->prependChain(packet.body->clone()); } return packetBuf; } Buf packetToBufCleartext( const RegularQuicPacketBuilder::Packet& packet, const Aead& cleartextCipher, const PacketNumberCipher& headerCipher, PacketNum packetNum) { VLOG(10) << __func__ << " packet header: " << folly::hexlify(packet.header->clone()->moveToFbString()); auto packetBuf = packet.header->clone(); Buf body; if (packet.body) { body = packet.body->clone(); } auto headerForm = packet.packet.header.getHeaderForm(); auto encryptedBody = cleartextCipher.encrypt(std::move(body), packet.header.get(), packetNum); encryptPacketHeader(headerForm, *packet.header, *encryptedBody, headerCipher); packetBuf->prependChain(std::move(encryptedBody)); return packetBuf; } uint64_t computeExpectedDelay( std::chrono::microseconds ackDelay, uint8_t ackDelayExponent) { uint64_t divide = uint64_t(ackDelay.count()) >> ackDelayExponent; return divide << ackDelayExponent; } ConnectionId getTestConnectionId(uint16_t hostId) { ServerConnectionIdParams params(hostId, 0, 0); DefaultConnectionIdAlgo connIdAlgo; auto connId = connIdAlgo.encodeConnectionId(params); connId.data()[3] = 3; connId.data()[4] = 4; connId.data()[5] = 5; connId.data()[6] = 6; connId.data()[7] = 7; return connId; } class TestCertificateVerifier : public fizz::CertificateVerifier { public: ~TestCertificateVerifier() override = default; void verify(const std::vector>&) const override { return; } std::vector getCertificateRequestExtensions() const override { return std::vector(); } }; std::unique_ptr createTestCertificateVerifier() { return std::make_unique(); } ProtectionType encryptionLevelToProtectionType( fizz::EncryptionLevel encryptionLevel) { switch (encryptionLevel) { case fizz::EncryptionLevel::Plaintext: return ProtectionType::Initial; case fizz::EncryptionLevel::Handshake: // TODO: change this in draft-14 return ProtectionType::Initial; case fizz::EncryptionLevel::EarlyData: return ProtectionType::ZeroRtt; case fizz::EncryptionLevel::AppTraffic: return ProtectionType::KeyPhaseZero; } folly::assume_unreachable(); } void updateAckState( QuicConnectionStateBase& conn, PacketNumberSpace pnSpace, PacketNum packetNum, bool pkHasRetransmittableData, bool pkHasCryptoData, TimePoint receivedTime) { bool outOfOrder = updateLargestReceivedPacketNum( getAckState(conn, pnSpace), packetNum, receivedTime); updateAckSendStateOnRecvPacket( conn, getAckState(conn, pnSpace), outOfOrder, pkHasRetransmittableData, pkHasCryptoData); } std::unique_ptr buildRandomInputData(size_t length) { auto buf = folly::IOBuf::create(length); buf->append(length); folly::Random::secureRandom(buf->writableData(), buf->length()); return buf; } void addAckStatesWithCurrentTimestamps( AckState& ackState, PacketNum start, PacketNum end) { ackState.acks.insert(start, end); ackState.largestRecvdPacketTime = Clock::now(); } OutstandingPacket makeTestingWritePacket( PacketNum desiredPacketSeqNum, size_t desiredSize, uint64_t totalBytesSent, TimePoint sentTime) { LongHeader longHeader( LongHeader::Types::ZeroRtt, getTestConnectionId(1), getTestConnectionId(), desiredPacketSeqNum, QuicVersion::MVFST); RegularQuicWritePacket packet(std::move(longHeader)); return OutstandingPacket( packet, sentTime, desiredSize, false, totalBytesSent); } CongestionController::AckEvent makeAck( PacketNum seq, uint64_t ackedSize, TimePoint ackedTime, TimePoint sentTime) { CHECK(sentTime < ackedTime); RegularQuicWritePacket packet( ShortHeader(ProtectionType::KeyPhaseZero, getTestConnectionId(), seq)); CongestionController::AckEvent ack; ack.ackedBytes = ackedSize; ack.ackTime = ackedTime; ack.largestAckedPacket = seq; ack.ackedPackets.emplace_back( CongestionController::AckEvent::AckPacket::Builder() .setSentTime(sentTime) .setEncodedSize(ackedSize) .build()); ack.largestAckedPacketSentTime = sentTime; return ack; } BufQueue bufToQueue(Buf buf) { BufQueue queue; buf->coalesce(); queue.append(std::move(buf)); return queue; } StatelessResetToken generateStatelessResetToken() { StatelessResetSecret secret; folly::Random::secureRandom(secret.data(), secret.size()); folly::SocketAddress address("1.2.3.4", 8080); StatelessResetGenerator generator(secret, address.getFullyQualified()); return generator.generateToken(ConnectionId({0x14, 0x35, 0x22, 0x11})); } std::array getRandSecret() { std::array secret; folly::Random::secureRandom(secret.data(), secret.size()); return secret; } RegularQuicWritePacket createNewPacket( PacketNum packetNum, PacketNumberSpace pnSpace) { switch (pnSpace) { case PacketNumberSpace::Initial: return RegularQuicWritePacket(LongHeader( LongHeader::Types::Initial, getTestConnectionId(1), getTestConnectionId(2), packetNum, QuicVersion::QUIC_DRAFT)); case PacketNumberSpace::Handshake: return RegularQuicWritePacket(LongHeader( LongHeader::Types::Handshake, getTestConnectionId(0), getTestConnectionId(4), packetNum, QuicVersion::QUIC_DRAFT)); case PacketNumberSpace::AppData: return RegularQuicWritePacket(ShortHeader( ProtectionType::KeyPhaseOne, getTestConnectionId(), packetNum)); } folly::assume_unreachable(); } std::vector versionList( std::initializer_list types) { std::vector versions; for (auto type : types) { versions.push_back(static_cast(type)); } return versions; } RegularQuicWritePacket createRegularQuicWritePacket( StreamId streamId, uint64_t offset, uint64_t len, bool fin) { auto regularWritePacket = createNewPacket(10, PacketNumberSpace::Initial); WriteStreamFrame frame(streamId, offset, len, fin); regularWritePacket.frames.emplace_back(frame); return regularWritePacket; } VersionNegotiationPacket createVersionNegotiationPacket() { auto versions = {QuicVersion::VERSION_NEGOTIATION, QuicVersion::MVFST}; auto packet = VersionNegotiationPacketBuilder( getTestConnectionId(0), getTestConnectionId(1), versions) .buildPacket() .first; return packet; } RegularQuicWritePacket createPacketWithAckFrames() { RegularQuicWritePacket packet = createNewPacket(100, PacketNumberSpace::Initial); WriteAckFrame ackFrame; ackFrame.ackDelay = 111us; ackFrame.ackBlocks.insert(900, 1000); ackFrame.ackBlocks.insert(500, 700); packet.frames.emplace_back(std::move(ackFrame)); return packet; } RegularQuicWritePacket createPacketWithPaddingFrames() { RegularQuicWritePacket packet = createNewPacket(100, PacketNumberSpace::Initial); for (int i = 0; i < 20; ++i) { PaddingFrame paddingFrame; packet.frames.emplace_back(paddingFrame); } return packet; } std::vector getQLogEventIndices( QLogEventType type, const std::shared_ptr& q) { std::vector indices; for (uint64_t i = 0; i < q->logs.size(); ++i) { if (q->logs[i]->eventType == type) { indices.push_back(i); } } return indices; } bool matchError( std::pair> errorCode, LocalErrorCode error) { return errorCode.first.type() == QuicErrorCode::Type::LocalErrorCode_E && *errorCode.first.asLocalErrorCode() == error; } bool matchError( std::pair> errorCode, TransportErrorCode error) { return errorCode.first.type() == QuicErrorCode::Type::TransportErrorCode_E && *errorCode.first.asTransportErrorCode() == error; } bool matchError( std::pair> errorCode, ApplicationErrorCode error) { return errorCode.first.type() == QuicErrorCode::Type::ApplicationErrorCode_E && *errorCode.first.asApplicationErrorCode() == error; } bool matchError( std::pair errorCode, ApplicationErrorCode error) { return errorCode.first.type() == QuicErrorCode::Type::ApplicationErrorCode_E && *errorCode.first.asApplicationErrorCode() == error; } bool matchError( std::pair errorCode, TransportErrorCode error) { return errorCode.first.type() == QuicErrorCode::Type::TransportErrorCode_E && *errorCode.first.asTransportErrorCode() == error; } CongestionController::AckEvent::AckPacket makeAckPacketFromOutstandingPacket( OutstandingPacket outstandingPacket) { return CongestionController::AckEvent::AckPacket::Builder() .setSentTime(outstandingPacket.time) .setEncodedSize(outstandingPacket.encodedSize) .setLastAckedPacketInfo(std::move(outstandingPacket.lastAckedPacketInfo)) .setTotalBytesSentThen(outstandingPacket.totalBytesSent) .setAppLimited(outstandingPacket.isAppLimited) .build(); } } // namespace test } // namespace quic