/* * Copyright (c) Facebook, Inc. and its affiliates. * * This source code is licensed under the MIT license found in the * LICENSE file in the root directory of this source tree. * */ #include #include #include #include namespace { quic::ConnectionId zeroConnId() { std::vector zeroData(quic::kDefaultConnectionIdSize, 0); return quic::ConnectionId(zeroData); } } // namespace namespace quic { QuicReadCodec::QuicReadCodec(QuicNodeType nodeType) : nodeType_(nodeType) {} folly::Optional QuicReadCodec::tryParsingVersionNegotiation(BufQueue& queue) { folly::io::Cursor cursor(queue.front()); if (!cursor.canAdvance(sizeof(uint8_t))) { return folly::none; } uint8_t initialByte = cursor.readBE(); auto headerForm = getHeaderForm(initialByte); if (headerForm != HeaderForm::Long) { return folly::none; } auto longHeaderInvariant = parseLongHeaderInvariant(initialByte, cursor); if (!longHeaderInvariant) { // if it is an invalid packet, it's definitely not a VN packet, so ignore // it. return folly::none; } if (longHeaderInvariant->invariant.version != QuicVersion::VERSION_NEGOTIATION) { return folly::none; } return decodeVersionNegotiation(*longHeaderInvariant, cursor); } CodecResult QuicReadCodec::parseLongHeaderPacket( BufQueue& queue, const AckStates& ackStates) { folly::io::Cursor cursor(queue.front()); auto initialByte = cursor.readBE(); auto longHeaderInvariant = parseLongHeaderInvariant(initialByte, cursor); if (!longHeaderInvariant) { VLOG(4) << "Dropping packet, failed to parse invariant " << connIdToHex(); // We've failed to parse the long header, so we have no idea where this // packet ends. Clear the queue since no other data in this packet is // parse-able. queue.move(); return CodecResult(Nothing()); } if (longHeaderInvariant->invariant.version == QuicVersion::VERSION_NEGOTIATION) { // We shouldn't handle VN packets while parsing the long header. // We assume here that they have been handled before calling this // function. // Since VN is not allowed to be coalesced with another packet // type, we clear out the buffer to avoid anyone else parsing it. queue.move(); return CodecResult(Nothing()); } auto type = parseLongHeaderType(initialByte); auto parsedLongHeader = parseLongHeaderVariants( type, std::move(*longHeaderInvariant), cursor, nodeType_); if (!parsedLongHeader) { VLOG(4) << "Dropping due to failed to parse header " << connIdToHex(); // We've failed to parse the long header, so we have no idea where this // packet ends. Clear the queue since no other data in this packet is // parse-able. queue.move(); return CodecResult(Nothing()); } // As soon as we have parsed out the long header we can split off any // coalesced packets. We do this early since the spec mandates that decryption // failure must not stop the processing of subsequent coalesced packets. auto longHeader = std::move(parsedLongHeader->header); if (type == LongHeader::Types::Retry) { return RegularQuicPacket(std::move(longHeader)); } uint64_t packetNumberOffset = cursor.getCurrentPosition(); size_t currentPacketLen = packetNumberOffset + parsedLongHeader->packetLength.packetLength; if (queue.chainLength() < currentPacketLen) { // Packet appears truncated, there's no parse-able data left. queue.move(); return CodecResult(Nothing()); } auto currentPacketData = queue.splitAtMost(currentPacketLen); cursor.reset(currentPacketData.get()); cursor.skip(packetNumberOffset); // Sample starts after the max packet number size. This ensures that we // have enough bytes to skip before we can start reading the sample. if (!cursor.canAdvance(kMaxPacketNumEncodingSize)) { VLOG(4) << "Dropping packet, not enough for packet number " << connIdToHex(); // Packet appears truncated, there's no parse-able data left. queue.move(); return CodecResult(Nothing()); } cursor.skip(kMaxPacketNumEncodingSize); Sample sample; if (!cursor.canAdvance(sample.size())) { VLOG(4) << "Dropping packet, sample too small " << connIdToHex(); // Packet appears truncated, there's no parse-able data left. queue.move(); return CodecResult(Nothing()); } cursor.pull(sample.data(), sample.size()); const PacketNumberCipher* headerCipher{nullptr}; const Aead* cipher{nullptr}; auto protectionType = longHeader.getProtectionType(); switch (protectionType) { case ProtectionType::Initial: if (!initialHeaderCipher_) { VLOG(4) << nodeToString(nodeType_) << " dropping initial packet after initial keys dropped" << connIdToHex(); return CodecResult(Nothing()); } headerCipher = initialHeaderCipher_.get(); cipher = initialReadCipher_.get(); break; case ProtectionType::Handshake: headerCipher = handshakeHeaderCipher_.get(); cipher = handshakeReadCipher_.get(); break; case ProtectionType::ZeroRtt: if (handshakeDoneTime_) { // TODO actually drop the 0-rtt keys in addition to dropping packets. auto timeBetween = Clock::now() - *handshakeDoneTime_; if (timeBetween > kTimeToRetainZeroRttKeys) { VLOG(4) << nodeToString(nodeType_) << " dropping zero rtt packet for exceeding key timeout" << connIdToHex(); return CodecResult(Nothing()); } } headerCipher = zeroRttHeaderCipher_.get(); cipher = zeroRttReadCipher_.get(); break; case ProtectionType::KeyPhaseZero: case ProtectionType::KeyPhaseOne: CHECK(false) << "one rtt protection type in long header"; } if (!headerCipher || !cipher) { // TODO: remove packet number here. return CodecResult( CipherUnavailable(std::move(currentPacketData), 0, protectionType)); } // TODO: decrypt the long header. PacketNum expectedNextPacketNum = 0; folly::Optional largestReceivedPacketNum; switch (longHeaderTypeToProtectionType(type)) { case ProtectionType::Initial: largestReceivedPacketNum = ackStates.initialAckState.largestReceivedPacketNum; break; case ProtectionType::Handshake: largestReceivedPacketNum = ackStates.handshakeAckState.largestReceivedPacketNum; break; case ProtectionType::ZeroRtt: largestReceivedPacketNum = ackStates.appDataAckState.largestReceivedPacketNum; break; default: folly::assume_unreachable(); } if (largestReceivedPacketNum) { expectedNextPacketNum = 1 + *largestReceivedPacketNum; } folly::MutableByteRange initialByteRange( currentPacketData->writableData(), 1); folly::MutableByteRange packetNumberByteRange( currentPacketData->writableData() + packetNumberOffset, kMaxPacketNumEncodingSize); headerCipher->decryptLongHeader( folly::range(sample), initialByteRange, packetNumberByteRange); std::pair packetNum = parsePacketNumber( initialByteRange.data()[0], packetNumberByteRange, expectedNextPacketNum); longHeader.setPacketNumber(packetNum.first); BufQueue decryptQueue; decryptQueue.append(std::move(currentPacketData)); size_t aadLen = packetNumberOffset + packetNum.second; auto headerData = decryptQueue.splitAtMost(aadLen); // parsing verifies that packetLength >= packet number length. auto encryptedData = decryptQueue.splitAtMost( parsedLongHeader->packetLength.packetLength - packetNum.second); if (!encryptedData) { // There should normally be some integrity tag at least in the data, // however allowing the aead to process the data even if the tag is not // present helps with writing tests. encryptedData = folly::IOBuf::create(0); } Buf decrypted; auto decryptAttempt = cipher->tryDecrypt( std::move(encryptedData), headerData.get(), packetNum.first); if (!decryptAttempt) { VLOG(4) << "Unable to decrypt packet=" << packetNum.first << " packetNumLen=" << parsePacketNumberLength(initialByte) << " protectionType=" << toString(protectionType) << " " << connIdToHex(); return CodecResult(Nothing()); } decrypted = std::move(*decryptAttempt); if (!decrypted) { // TODO better way of handling this (tests break without this) decrypted = folly::IOBuf::create(0); } return decodeRegularPacket( std::move(longHeader), params_, std::move(decrypted)); } CodecResult QuicReadCodec::parsePacket( BufQueue& queue, const AckStates& ackStates, size_t dstConnIdSize) { if (queue.empty()) { return CodecResult(Nothing()); } DCHECK(!queue.front()->isChained()); folly::io::Cursor cursor(queue.front()); if (!cursor.canAdvance(sizeof(uint8_t))) { return CodecResult(Nothing()); } uint8_t initialByte = cursor.readBE(); auto headerForm = getHeaderForm(initialByte); if (headerForm == HeaderForm::Long) { return parseLongHeaderPacket(queue, ackStates); } // Short header: // TODO: support key phase one. if (!oneRttReadCipher_ || !oneRttHeaderCipher_) { VLOG(4) << nodeToString(nodeType_) << " cannot read key phase zero packet"; VLOG(20) << "cannot read data=" << folly::hexlify(queue.front()->clone()->moveToFbString()) << " " << connIdToHex(); return CodecResult( CipherUnavailable(queue.move(), 0, ProtectionType::KeyPhaseZero)); } // TODO: allow other connid lengths from the state. size_t packetNumberOffset = 1 + dstConnIdSize; PacketNum expectedNextPacketNum = ackStates.appDataAckState.largestReceivedPacketNum ? (1 + *ackStates.appDataAckState.largestReceivedPacketNum) : 0; size_t sampleOffset = packetNumberOffset + kMaxPacketNumEncodingSize; Sample sample; if (queue.chainLength() < sampleOffset + sample.size()) { VLOG(10) << "Dropping packet, too small for sample " << connIdToHex(); // There's not enough space for the short header packet, clear the queue // to indicate there's no more parse-able data. queue.move(); return CodecResult(Nothing()); } // Take it out of the queue so we can do some writing. auto data = queue.move(); folly::MutableByteRange initialByteRange(data->writableData(), 1); folly::MutableByteRange packetNumberByteRange( data->writableData() + packetNumberOffset, kMaxPacketNumEncodingSize); folly::ByteRange sampleByteRange( data->writableData() + sampleOffset, sample.size()); oneRttHeaderCipher_->decryptShortHeader( sampleByteRange, initialByteRange, packetNumberByteRange); std::pair packetNum = parsePacketNumber( initialByteRange.data()[0], packetNumberByteRange, expectedNextPacketNum); auto shortHeader = parseShortHeader(initialByteRange.data()[0], cursor, dstConnIdSize); if (!shortHeader) { VLOG(10) << "Dropping packet, cannot parse " << connIdToHex(); return CodecResult(Nothing()); } shortHeader->setPacketNumber(packetNum.first); if (shortHeader->getProtectionType() == ProtectionType::KeyPhaseOne) { VLOG(4) << nodeToString(nodeType_) << " cannot read key phase one packet " << connIdToHex(); return CodecResult(Nothing()); } // We know that the iobuf is not chained. This means that we can safely have a // non-owning reference to the header without cloning the buffer. If we don't // clone the buffer, the buffer will not show up as shared and we can decrypt // in-place. size_t aadLen = packetNumberOffset + packetNum.second; folly::IOBuf headerData = folly::IOBuf::wrapBufferAsValue(data->data(), aadLen); data->trimStart(aadLen); Buf decrypted; // TODO: small optimization we can do here: only read the token if // decryption fails folly::Optional token; auto encryptedDataLength = data->length(); if (statelessResetToken_ && encryptedDataLength > sizeof(StatelessResetToken)) { token = StatelessResetToken(); memcpy( token->data(), data->data() + (encryptedDataLength - sizeof(StatelessResetToken)), token->size()); } auto decryptAttempt = oneRttReadCipher_->tryDecrypt( std::move(data), &headerData, packetNum.first); if (!decryptAttempt) { // Can't return the data now, already consumed it to try decrypting it. if (token) { return StatelessReset(*token); } auto protectionType = shortHeader->getProtectionType(); VLOG(10) << "Unable to decrypt packet=" << packetNum.first << " protectionType=" << (int)protectionType << " " << connIdToHex(); return CodecResult(Nothing()); } decrypted = std::move(*decryptAttempt); if (!decrypted) { // TODO better way of handling this (tests break without this) decrypted = folly::IOBuf::create(0); } return decodeRegularPacket( std::move(*shortHeader), params_, std::move(decrypted)); } const Aead* QuicReadCodec::getOneRttReadCipher() const { return oneRttReadCipher_.get(); } const Aead* QuicReadCodec::getZeroRttReadCipher() const { return zeroRttReadCipher_.get(); } const Aead* QuicReadCodec::getHandshakeReadCipher() const { return handshakeReadCipher_.get(); } const folly::Optional& QuicReadCodec::getStatelessResetToken() const { return statelessResetToken_; } CodecParameters QuicReadCodec::getCodecParameters() const { return params_; } void QuicReadCodec::setInitialReadCipher( std::unique_ptr initialReadCipher) { initialReadCipher_ = std::move(initialReadCipher); } void QuicReadCodec::setOneRttReadCipher( std::unique_ptr oneRttReadCipher) { oneRttReadCipher_ = std::move(oneRttReadCipher); } void QuicReadCodec::setZeroRttReadCipher( std::unique_ptr zeroRttReadCipher) { if (nodeType_ == QuicNodeType::Client) { throw QuicTransportException( "Invalid cipher", TransportErrorCode::INTERNAL_ERROR); } zeroRttReadCipher_ = std::move(zeroRttReadCipher); } void QuicReadCodec::setHandshakeReadCipher( std::unique_ptr handshakeReadCipher) { handshakeReadCipher_ = std::move(handshakeReadCipher); } void QuicReadCodec::setInitialHeaderCipher( std::unique_ptr initialHeaderCipher) { initialHeaderCipher_ = std::move(initialHeaderCipher); } void QuicReadCodec::setOneRttHeaderCipher( std::unique_ptr oneRttHeaderCipher) { oneRttHeaderCipher_ = std::move(oneRttHeaderCipher); } void QuicReadCodec::setZeroRttHeaderCipher( std::unique_ptr zeroRttHeaderCipher) { zeroRttHeaderCipher_ = std::move(zeroRttHeaderCipher); } void QuicReadCodec::setHandshakeHeaderCipher( std::unique_ptr handshakeHeaderCipher) { handshakeHeaderCipher_ = std::move(handshakeHeaderCipher); } void QuicReadCodec::setCodecParameters(CodecParameters params) { params_ = std::move(params); } void QuicReadCodec::setClientConnectionId(ConnectionId connId) { clientConnectionId_ = connId; } void QuicReadCodec::setServerConnectionId(ConnectionId connId) { serverConnectionId_ = connId; } void QuicReadCodec::setStatelessResetToken( StatelessResetToken statelessResetToken) { statelessResetToken_ = std::move(statelessResetToken); } const Aead* QuicReadCodec::getInitialCipher() const { return initialReadCipher_.get(); } const PacketNumberCipher* QuicReadCodec::getInitialHeaderCipher() const { return initialHeaderCipher_.get(); } const PacketNumberCipher* QuicReadCodec::getOneRttHeaderCipher() const { return oneRttHeaderCipher_.get(); } const PacketNumberCipher* QuicReadCodec::getHandshakeHeaderCipher() const { return handshakeHeaderCipher_.get(); } const PacketNumberCipher* QuicReadCodec::getZeroRttHeaderCipher() const { return zeroRttHeaderCipher_.get(); } void QuicReadCodec::onHandshakeDone(TimePoint handshakeDoneTime) { if (!handshakeDoneTime_) { handshakeDoneTime_ = handshakeDoneTime; } } folly::Optional QuicReadCodec::getHandshakeDoneTime() { return handshakeDoneTime_; } std::string QuicReadCodec::connIdToHex() { static ConnectionId zeroConn = zeroConnId(); const auto& serverId = serverConnectionId_.value_or(zeroConn); const auto& clientId = clientConnectionId_.value_or(zeroConn); return folly::to( "server=", serverId.hex(), " ", "client=", clientId.hex()); } CodecResult::CodecResult(RegularQuicPacket&& regularPacketIn) : type_(CodecResult::Type::REGULAR_PACKET) { new (&packet) RegularQuicPacket(std::move(regularPacketIn)); } CodecResult::CodecResult(CipherUnavailable&& cipherUnavailableIn) : type_(CodecResult::Type::CIPHER_UNAVAILABLE) { new (&cipher) CipherUnavailable(std::move(cipherUnavailableIn)); } CodecResult::CodecResult(StatelessReset&& statelessResetIn) : type_(CodecResult::Type::STATELESS_RESET) { new (&reset) StatelessReset(std::move(statelessResetIn)); } CodecResult::CodecResult(Nothing&&) : type_(CodecResult::Type::NOTHING) { new (&none) Nothing(); } void CodecResult::destroyCodecResult() { switch (type_) { case CodecResult::Type::REGULAR_PACKET: packet.~RegularQuicPacket(); break; case CodecResult::Type::CIPHER_UNAVAILABLE: cipher.~CipherUnavailable(); break; case CodecResult::Type::STATELESS_RESET: reset.~StatelessReset(); break; case CodecResult::Type::NOTHING: none.~Nothing(); break; } } CodecResult::~CodecResult() { destroyCodecResult(); } CodecResult::CodecResult(CodecResult&& other) noexcept { switch (other.type_) { case CodecResult::Type::REGULAR_PACKET: new (&packet) RegularQuicPacket(std::move(other.packet)); break; case CodecResult::Type::CIPHER_UNAVAILABLE: new (&cipher) CipherUnavailable(std::move(other.cipher)); break; case CodecResult::Type::STATELESS_RESET: new (&reset) StatelessReset(std::move(other.reset)); break; case CodecResult::Type::NOTHING: new (&none) Nothing(std::move(other.none)); break; } type_ = other.type_; } CodecResult& CodecResult::operator=(CodecResult&& other) noexcept { destroyCodecResult(); switch (other.type_) { case CodecResult::Type::REGULAR_PACKET: new (&packet) RegularQuicPacket(std::move(other.packet)); break; case CodecResult::Type::CIPHER_UNAVAILABLE: new (&cipher) CipherUnavailable(std::move(other.cipher)); break; case CodecResult::Type::STATELESS_RESET: new (&reset) StatelessReset(std::move(other.reset)); break; case CodecResult::Type::NOTHING: new (&none) Nothing(std::move(other.none)); break; } type_ = other.type_; return *this; } CodecResult::Type CodecResult::type() { return type_; } RegularQuicPacket* CodecResult::regularPacket() { if (type_ == CodecResult::Type::REGULAR_PACKET) { return &packet; } return nullptr; } CipherUnavailable* CodecResult::cipherUnavailable() { if (type_ == CodecResult::Type::CIPHER_UNAVAILABLE) { return &cipher; } return nullptr; } StatelessReset* CodecResult::statelessReset() { if (type_ == CodecResult::Type::STATELESS_RESET) { return &reset; } return nullptr; } Nothing* CodecResult::nothing() { if (type_ == CodecResult::Type::NOTHING) { return &none; } return nullptr; } } // namespace quic