diff --git a/quic/api/QuicTransportFunctions.cpp b/quic/api/QuicTransportFunctions.cpp index bf8f2cdb8..eddacd5c7 100644 --- a/quic/api/QuicTransportFunctions.cpp +++ b/quic/api/QuicTransportFunctions.cpp @@ -302,8 +302,12 @@ continuousMemoryBuildScheduleEncrypt( connection.bufAccessor->trimStart(prevSize + headerLen); // buf and packetBuf is actually the same. auto buf = connection.bufAccessor->obtain(); - auto packetBuf = + auto encryptResult = aead.inplaceEncrypt(std::move(buf), &packet->header, packetNum); + if (encryptResult.hasError()) { + return folly::makeUnexpected(encryptResult.error()); + } + auto packetBuf = std::move(encryptResult.value()); CHECK(packetBuf->headroom() == headerLen + prevSize); // Include header back. packetBuf->prepend(headerLen); @@ -410,8 +414,12 @@ iobufChainBasedBuildScheduleEncrypt( bodyCursor.pull(unencrypted->writableData() + headerLen, bodyLen); unencrypted->advance(headerLen); unencrypted->append(bodyLen); - auto packetBuf = + auto encryptResult = aead.inplaceEncrypt(std::move(unencrypted), &packet->header, packetNum); + if (encryptResult.hasError()) { + return folly::makeUnexpected(encryptResult.error()); + } + auto packetBuf = std::move(encryptResult.value()); DCHECK(packetBuf->headroom() == headerLen); packetBuf->clear(); auto headerCursor = Cursor(&packet->header); @@ -1447,8 +1455,13 @@ void writeCloseCommon( auto packet = std::move(packetBuilder).buildPacket(); CHECK_GE(packet.body.tailroom(), aead.getCipherOverhead()); auto bufUniquePtr = packet.body.clone(); - bufUniquePtr = + auto encryptResult = aead.inplaceEncrypt(std::move(bufUniquePtr), &packet.header, packetNum); + if (encryptResult.hasError()) { + LOG(ERROR) << "Error encrypting packet: " << encryptResult.error().message; + return; + } + bufUniquePtr = std::move(encryptResult.value()); bufUniquePtr->coalesce(); encryptPacketHeader( headerForm, diff --git a/quic/codec/test/QuicPacketBuilderTest.cpp b/quic/codec/test/QuicPacketBuilderTest.cpp index 9ab5d2b79..6dde89a8f 100644 --- a/quic/codec/test/QuicPacketBuilderTest.cpp +++ b/quic/codec/test/QuicPacketBuilderTest.cpp @@ -39,7 +39,9 @@ BufPtr packetToBuf( } if (aead && !packet.header.empty()) { auto bodySize = body->computeChainDataLength(); - body = aead->inplaceEncrypt(std::move(body), &packet.header, num); + auto result = aead->inplaceEncrypt(std::move(body), &packet.header, num); + CHECK(!result.hasError()); + body = std::move(result.value()); EXPECT_GT(body->computeChainDataLength(), bodySize); } if (body) { diff --git a/quic/common/test/TestUtils.cpp b/quic/common/test/TestUtils.cpp index dcabd7bbe..f688468a9 100644 --- a/quic/common/test/TestUtils.cpp +++ b/quic/common/test/TestUtils.cpp @@ -448,8 +448,13 @@ BufPtr packetToBufCleartext( body->appendToChain(folly::IOBuf::create(tagLen)); } body->coalesce(); - auto encryptedBody = cleartextCipher.inplaceEncrypt( + auto encryptResult = cleartextCipher.inplaceEncrypt( std::move(body), &packet.header, packetNum); + if (encryptResult.hasError()) { + throw std::runtime_error( + "Failed to encrypt packet: " + encryptResult.error().message); + } + auto encryptedBody = std::move(encryptResult.value()); encryptedBody->coalesce(); encryptPacketHeader( headerForm, diff --git a/quic/dsr/backend/DSRPacketizer.cpp b/quic/dsr/backend/DSRPacketizer.cpp index dbaac5a0d..a9198d182 100644 --- a/quic/dsr/backend/DSRPacketizer.cpp +++ b/quic/dsr/backend/DSRPacketizer.cpp @@ -88,8 +88,14 @@ bool PacketGroupWriter::writeSingleQuicPacket( // buildBuf's data starts from the body part of buildBuf. buildBuf->trimStart(prevSize_ + headerLen); // buildBuf and packetbuildBuf is actually the same. - auto packetbuildBuf = + auto encryptResult = aead.inplaceEncrypt(std::move(buildBuf), &packet.header, packetNum); + if (encryptResult.hasError()) { + throw QuicInternalException( + "DSR Send failed: Encryption error: " + encryptResult.error().message, + LocalErrorCode::INTERNAL_ERROR); + } + auto packetbuildBuf = std::move(encryptResult.value()); CHECK_EQ(packetbuildBuf->headroom(), headerLen + prevSize_); // Include header back. packetbuildBuf->prepend(headerLen); diff --git a/quic/fizz/handshake/FizzBridge.h b/quic/fizz/handshake/FizzBridge.h index aba6ffea1..f132f9242 100644 --- a/quic/fizz/handshake/FizzBridge.h +++ b/quic/fizz/handshake/FizzBridge.h @@ -31,14 +31,20 @@ class FizzAead final : public Aead { Optional getKey() const override; /** - * Simply forward all calls to fizz::Aead. + * Forward calls to fizz::Aead, catching any exceptions and converting them to + * folly::Expected. */ - std::unique_ptr inplaceEncrypt( + folly::Expected, QuicError> inplaceEncrypt( std::unique_ptr&& plaintext, const folly::IOBuf* associatedData, uint64_t seqNum) const override { - return fizzAead->inplaceEncrypt( - std::move(plaintext), associatedData, seqNum); + try { + return fizzAead->inplaceEncrypt( + std::move(plaintext), associatedData, seqNum); + } catch (const std::exception& ex) { + return folly::makeUnexpected( + QuicError(TransportErrorCode::INTERNAL_ERROR, ex.what())); + } } std::unique_ptr decrypt( diff --git a/quic/handshake/Aead.h b/quic/handshake/Aead.h index ebf056b3d..4477db478 100644 --- a/quic/handshake/Aead.h +++ b/quic/handshake/Aead.h @@ -7,7 +7,9 @@ #pragma once +#include #include +#include #include namespace quic { @@ -27,12 +29,15 @@ class Aead { virtual Optional getKey() const = 0; /** - * Encrypts plaintext inplace. Will throw on error. + * Encrypts plaintext inplace. Returns folly::Expected with the encrypted + * buffer or an error. */ - virtual std::unique_ptr inplaceEncrypt( - std::unique_ptr&& plaintext, - const folly::IOBuf* associatedData, - uint64_t seqNum) const = 0; + [[nodiscard]] virtual folly:: + Expected, QuicError> + inplaceEncrypt( + std::unique_ptr&& plaintext, + const folly::IOBuf* associatedData, + uint64_t seqNum) const = 0; /** * Decrypt ciphertext. Will throw if the ciphertext does not decrypt diff --git a/quic/handshake/BUCK b/quic/handshake/BUCK index fbd2a1c08..7291850b6 100644 --- a/quic/handshake/BUCK +++ b/quic/handshake/BUCK @@ -8,7 +8,9 @@ mvfst_cpp_library( "Aead.h", ], exported_deps = [ + "//folly:expected", "//folly/io:iobuf", + "//quic:exception", "//quic/common:optional", ], ) diff --git a/quic/handshake/test/Mocks.h b/quic/handshake/test/Mocks.h index 9eb1b8a17..afbab79e9 100644 --- a/quic/handshake/test/Mocks.h +++ b/quic/handshake/test/Mocks.h @@ -54,14 +54,14 @@ class MockAead : public Aead { MOCK_METHOD(Optional, getKey, (), (const)); MOCK_METHOD( - std::unique_ptr, + (folly::Expected, QuicError>), _inplaceEncrypt, (std::unique_ptr & plaintext, const folly::IOBuf* associatedData, uint64_t seqNum), (const)); - std::unique_ptr inplaceEncrypt( + folly::Expected, QuicError> inplaceEncrypt( std::unique_ptr&& plaintext, const folly::IOBuf* associatedData, uint64_t seqNum) const override { @@ -102,7 +102,9 @@ class MockAead : public Aead { using namespace testing; ON_CALL(*this, _inplaceEncrypt(_, _, _)) .WillByDefault(InvokeWithoutArgs( - []() { return folly::IOBuf::copyBuffer("ciphertext"); })); + []() -> folly::Expected, QuicError> { + return folly::IOBuf::copyBuffer("ciphertext"); + })); ON_CALL(*this, _decrypt(_, _, _)).WillByDefault(InvokeWithoutArgs([]() { return folly::IOBuf::copyBuffer("plaintext"); }));