From 089bf581a72f82cde8689d5a63f0442c086f1523 Mon Sep 17 00:00:00 2001 From: Matt Joras Date: Sat, 19 Apr 2025 15:20:15 -0700 Subject: [PATCH] Remove throws from socket layer Summary: More in the theme of returning Expected instead of throwing. For the folly case, we keep the try/catches in there and translate to Expected. For Libev, we convert directly to Expected. Reviewed By: kvtsoy Differential Revision: D73217128 fbshipit-source-id: d00a978f24e3b29a77a8ac99a19765ae49f64df8 --- quic/api/IoBufQuicBatch.cpp | 27 +- quic/api/IoBufQuicBatch.h | 8 +- quic/api/QuicTransportBase.cpp | 11 +- quic/api/QuicTransportBaseLite.cpp | 43 +- quic/api/QuicTransportBaseLite.h | 2 +- quic/api/QuicTransportFunctions.cpp | 69 ++- quic/api/test/QuicBatchWriterTest.cpp | 36 +- quic/api/test/QuicTransportBaseTest.cpp | 21 +- quic/api/test/QuicTransportFunctionsTest.cpp | 26 + quic/api/test/QuicTransportTest.cpp | 17 + quic/api/test/TestQuicTransport.h | 2 +- quic/client/QuicClientTransport.cpp | 70 ++- quic/client/QuicClientTransport.h | 8 +- quic/client/QuicClientTransportLite.cpp | 354 ++++++++---- quic/client/QuicClientTransportLite.h | 23 +- quic/client/test/QuicClientTransportTest.cpp | 49 +- quic/common/BUCK | 2 + quic/common/SocketUtil.h | 6 +- quic/common/test/SocketUtilTest.cpp | 13 +- quic/common/testutil/MockAsyncUDPSocket.h | 82 ++- quic/common/udpsocket/BUCK | 21 + .../udpsocket/FollyQuicAsyncUDPSocket.cpp | 537 +++++++++++++++--- .../udpsocket/FollyQuicAsyncUDPSocket.h | 64 ++- .../udpsocket/LibevQuicAsyncUDPSocket.cpp | 418 ++++++++++---- .../udpsocket/LibevQuicAsyncUDPSocket.h | 80 ++- quic/common/udpsocket/QuicAsyncUDPSocket.h | 102 ++-- .../udpsocket/QuicAsyncUDPSocketImpl.cpp | 48 +- .../common/udpsocket/QuicAsyncUDPSocketImpl.h | 3 +- .../udpsocket/test/QuicAsyncUDPSocketMock.h | 86 ++- .../test/QuicAsyncUDPSocketTestBase.h | 15 +- quic/dsr/backend/DSRPacketizer.cpp | 6 +- quic/dsr/backend/test/DSRPacketizerTest.cpp | 8 +- .../client/test/QuicClientTransportTest.cpp | 260 ++++++--- .../client/test/QuicClientTransportTestUtil.h | 45 +- .../QuicHappyEyeballsFunctions.cpp | 105 +++- .../QuicHappyEyeballsFunctions.h | 2 +- quic/loss/test/QuicLossFunctionsTest.cpp | 12 + quic/server/QuicServerTransport.cpp | 6 +- quic/server/QuicServerWorker.cpp | 14 +- .../server/test/QuicServerTransportTestUtil.h | 44 +- .../stream/test/StreamStateMachineTest.cpp | 4 + 41 files changed, 2057 insertions(+), 692 deletions(-) diff --git a/quic/api/IoBufQuicBatch.cpp b/quic/api/IoBufQuicBatch.cpp index 29c192c35..800964e9f 100644 --- a/quic/api/IoBufQuicBatch.cpp +++ b/quic/api/IoBufQuicBatch.cpp @@ -24,14 +24,19 @@ IOBufQuicBatch::IOBufQuicBatch( statsCallback_(statsCallback), happyEyeballsState_(happyEyeballsState) {} -bool IOBufQuicBatch::write(Buf&& buf, size_t encodedSize) { +folly::Expected IOBufQuicBatch::write( + Buf&& buf, + size_t encodedSize) { result_.packetsSent++; result_.bytesSent += encodedSize; // see if we need to flush the prev buffer(s) if (batchWriter_->needsFlush(encodedSize)) { // continue even if we get an error here - flush(); + auto result = flush(); + if (result.hasError()) { + return result; + } } // try to append the new buffers @@ -43,8 +48,8 @@ bool IOBufQuicBatch::write(Buf&& buf, size_t encodedSize) { return true; } -bool IOBufQuicBatch::flush() { - bool ret = flushInternal(); +folly::Expected IOBufQuicBatch::flush() { + auto ret = flushInternal(); reset(); return ret; @@ -58,7 +63,7 @@ bool IOBufQuicBatch::isRetriableError(int err) { return err == EAGAIN || err == EWOULDBLOCK || err == ENOBUFS; } -bool IOBufQuicBatch::flushInternal() { +folly::Expected IOBufQuicBatch::flushInternal() { if (batchWriter_->empty()) { return true; } @@ -148,13 +153,13 @@ bool IOBufQuicBatch::flushInternal() { // We can get write error for any reason, close the conn only if network // is unreachable, for all others, we throw a transport exception if (isNetworkUnreachable(errno)) { - throw QuicInternalException( - folly::to("Error on socket write ", errorMsg), - LocalErrorCode::CONNECTION_ABANDONED); + return folly::makeUnexpected(QuicError( + LocalErrorCode::CONNECTION_ABANDONED, + folly::to("Error on socket write ", errorMsg))); } else { - throw QuicTransportException( - folly::to("Error on socket write ", errorMsg), - TransportErrorCode::INTERNAL_ERROR); + return folly::makeUnexpected(QuicError( + TransportErrorCode::INTERNAL_ERROR, + folly::to("Error on socket write ", errorMsg))); } } diff --git a/quic/api/IoBufQuicBatch.h b/quic/api/IoBufQuicBatch.h index 97e184058..a9d7c9aa5 100644 --- a/quic/api/IoBufQuicBatch.h +++ b/quic/api/IoBufQuicBatch.h @@ -30,9 +30,11 @@ class IOBufQuicBatch { ~IOBufQuicBatch() = default; // returns true if it succeeds and false if the loop should end - bool write(Buf&& buf, size_t encodedSize); + [[nodiscard]] folly::Expected write( + Buf&& buf, + size_t encodedSize); - bool flush(); + [[nodiscard]] folly::Expected flush(); FOLLY_ALWAYS_INLINE uint64_t getPktSent() const { return result_.packetsSent; @@ -50,7 +52,7 @@ class IOBufQuicBatch { void reset(); // flushes the internal buffers - bool flushInternal(); + [[nodiscard]] folly::Expected flushInternal(); /** * Returns whether or not the errno can be retried later. diff --git a/quic/api/QuicTransportBase.cpp b/quic/api/QuicTransportBase.cpp index 70f9b8bbc..29dc415cf 100644 --- a/quic/api/QuicTransportBase.cpp +++ b/quic/api/QuicTransportBase.cpp @@ -46,7 +46,10 @@ QuicTransportBase::QuicTransportBase( folly::Function()> func = [&]() { return getAdditionalCmsgsForAsyncUDPSocket(); }; - socket_->setAdditionalCmsgsFunc(std::move(func)); + // TODO we probably should have a better way to return error from + // creating a connection. + // Can't really do anything with this at this point. + (void)socket_->setAdditionalCmsgsFunc(std::move(func)); } } @@ -791,11 +794,13 @@ QuicTransportBase::maybeResetStreamFromReadError( } void QuicTransportBase::setCmsgs(const folly::SocketCmsgMap& options) { - socket_->setCmsgs(options); + // TODO figure out what we want to do here in the unlikely error case. + (void)socket_->setCmsgs(options); } void QuicTransportBase::appendCmsgs(const folly::SocketCmsgMap& options) { - socket_->appendCmsgs(options); + // TODO figure out what we want to do here in the unlikely error case. + (void)socket_->appendCmsgs(options); } bool QuicTransportBase::checkCustomRetransmissionProfilesEnabled() const { diff --git a/quic/api/QuicTransportBaseLite.cpp b/quic/api/QuicTransportBaseLite.cpp index e351ca435..0e744b5f2 100644 --- a/quic/api/QuicTransportBaseLite.cpp +++ b/quic/api/QuicTransportBaseLite.cpp @@ -170,7 +170,12 @@ void QuicTransportBaseLite::onNetworkData( // If ECN is enabled, make sure that the packet marking is happening as // expected - validateECNState(); + auto ecnResult = validateECNState(); + if (ecnResult.hasError()) { + VLOG(4) << __func__ << " " << ecnResult.error().message << " " << *this; + exceptionCloseWhat_ = ecnResult.error().message; + closeImpl(ecnResult.error()); + } } else { // In the closed state, we would want to write a close if possible // however the write looper will not be set. @@ -1121,7 +1126,15 @@ void QuicTransportBaseLite::maybeStopWriteLooperAndArmSocketWritableEvent() { if (haveBufferToRetry || (haveNewDataToWrite && connHasWriteWindow)) { // Re-arm the write event and stop the write // looper. - socket_->resumeWrite(this); + auto resumeResult = socket_->resumeWrite(this); + if (resumeResult.hasError()) { + exceptionCloseWhat_ = resumeResult.error().message; + closeImpl(QuicError( + resumeResult.error().code, + std::string( + "maybeStopWriteLooperAndArmSocketWritableEvent() error"))); + return; + } writeLooper_->stop(); } } @@ -2177,7 +2190,9 @@ void QuicTransportBaseLite::closeUdpSocket() { auto sock = std::move(socket_); socket_ = nullptr; sock->pauseRead(); - sock->close(); + auto closeResult = sock->close(); + LOG_IF(ERROR, closeResult.hasError()) + << "close hit an error: " << closeResult.error().message; } folly::Expected @@ -3114,22 +3129,27 @@ void QuicTransportBaseLite::updateSocketTosSettings(uint8_t dscpValue) { if (socket_ && socket_->isBound() && conn_->socketTos.value != initialTosValue) { - socket_->setTosOrTrafficClass(conn_->socketTos.value); + auto tosResult = socket_->setTosOrTrafficClass(conn_->socketTos.value); + if (tosResult.hasError()) { + exceptionCloseWhat_ = tosResult.error().message; + return closeImpl(tosResult.error()); + } } } -void QuicTransportBaseLite::validateECNState() { +folly::Expected +QuicTransportBaseLite::validateECNState() { if (conn_->ecnState == ECNState::NotAttempted || conn_->ecnState == ECNState::FailedValidation) { // Verification not needed - return; + return folly::unit; } const auto& minExpectedMarkedPacketsCount = conn_->ackStates.appDataAckState.minimumExpectedEcnMarksEchoed; if (minExpectedMarkedPacketsCount < 10) { // We wait for 10 ack-eliciting app data packets to be marked before trying // to validate ECN. - return; + return folly::unit; } const auto& maxExpectedMarkedPacketsCount = conn_->lossState.totalPacketsSent; @@ -3189,7 +3209,11 @@ void QuicTransportBaseLite::validateECNState() { if (conn_->ecnState == ECNState::FailedValidation) { conn_->socketTos.fields.ecn = 0; CHECK(socket_ && socket_->isBound()); - socket_->setTosOrTrafficClass(conn_->socketTos.value); + auto result = socket_->setTosOrTrafficClass(conn_->socketTos.value); + if (result.hasError()) { + return result; + } + VLOG(4) << "ECN validation failed. Disabling ECN"; if (conn_->ecnL4sTracker) { conn_->packetProcessors.erase( @@ -3201,6 +3225,7 @@ void QuicTransportBaseLite::validateECNState() { conn_->ecnL4sTracker.reset(); } } + return folly::unit; } void QuicTransportBaseLite::scheduleAckTimeout() { @@ -3335,7 +3360,7 @@ QuicSocketLite::TransportInfo QuicTransportBaseLite::getTransportInfo() const { } const folly::SocketAddress& QuicTransportBaseLite::getLocalAddress() const { - return socket_ && socket_->isBound() ? socket_->address() + return socket_ && socket_->isBound() ? socket_->addressRef() : localFallbackAddress; } diff --git a/quic/api/QuicTransportBaseLite.h b/quic/api/QuicTransportBaseLite.h index c13de0310..ad56a9ba8 100644 --- a/quic/api/QuicTransportBaseLite.h +++ b/quic/api/QuicTransportBaseLite.h @@ -777,7 +777,7 @@ class QuicTransportBaseLite : virtual public QuicSocketLite, * is not enabled or has already failed validation, this function does * nothing. */ - void validateECNState(); + [[nodiscard]] folly::Expected validateECNState(); std::shared_ptr evb_; std::unique_ptr socket_; diff --git a/quic/api/QuicTransportFunctions.cpp b/quic/api/QuicTransportFunctions.cpp index 7d4e25222..59eeb3f68 100644 --- a/quic/api/QuicTransportFunctions.cpp +++ b/quic/api/QuicTransportFunctions.cpp @@ -272,7 +272,10 @@ continuousMemoryBuildScheduleEncrypt( auto& packet = result->packet; if (!packet || packet->packet.frames.empty()) { rollbackBuf(); - ioBufBatch.flush(); + auto flushResult = ioBufBatch.flush(); + if (flushResult.hasError()) { + return folly::makeUnexpected(flushResult.error()); + } updateErrnoCount(connection, ioBufBatch); if (connection.loopDetectorCallback) { connection.writeDebugState.noWriteReason = NoWriteReason::NO_FRAME; @@ -282,7 +285,10 @@ continuousMemoryBuildScheduleEncrypt( if (packet->body.empty()) { // No more space remaining. rollbackBuf(); - ioBufBatch.flush(); + auto flushResult = ioBufBatch.flush(); + if (flushResult.hasError()) { + return folly::makeUnexpected(flushResult.error()); + } updateErrnoCount(connection, ioBufBatch); if (connection.loopDetectorCallback) { connection.writeDebugState.noWriteReason = NoWriteReason::NO_BODY; @@ -327,10 +333,17 @@ continuousMemoryBuildScheduleEncrypt( << encodedSize; } // TODO: I think we should add an API that doesn't need a buffer. - bool ret = ioBufBatch.write(nullptr /* no need to pass buf */, encodedSize); + auto writeResult = + ioBufBatch.write(nullptr /* no need to pass buf */, encodedSize); + if (writeResult.hasError()) { + return folly::makeUnexpected(writeResult.error()); + } updateErrnoCount(connection, ioBufBatch); return DataPathResult::makeWriteResult( - ret, std::move(result.value()), encodedSize, encodedBodySize); + writeResult.value(), + std::move(result.value()), + encodedSize, + encodedBodySize); } [[nodiscard]] folly::Expected @@ -358,7 +371,10 @@ iobufChainBasedBuildScheduleEncrypt( } auto& packet = result->packet; if (!packet || packet->packet.frames.empty()) { - ioBufBatch.flush(); + auto flushResult = ioBufBatch.flush(); + if (flushResult.hasError()) { + return folly::makeUnexpected(flushResult.error()); + } updateErrnoCount(connection, ioBufBatch); if (connection.loopDetectorCallback) { connection.writeDebugState.noWriteReason = NoWriteReason::NO_FRAME; @@ -367,7 +383,10 @@ iobufChainBasedBuildScheduleEncrypt( } if (packet->body.empty()) { // No more space remaining. - ioBufBatch.flush(); + auto flushResult = ioBufBatch.flush(); + if (flushResult.hasError()) { + return folly::makeUnexpected(flushResult.error()); + } updateErrnoCount(connection, ioBufBatch); if (connection.loopDetectorCallback) { connection.writeDebugState.noWriteReason = NoWriteReason::NO_BODY; @@ -405,10 +424,16 @@ iobufChainBasedBuildScheduleEncrypt( VLOG(3) << "Quic sending pkt larger than limit, encodedSize=" << encodedSize << " encodedBodySize=" << encodedBodySize; } - bool ret = ioBufBatch.write(std::move(packetBuf), encodedSize); + auto writeResult = ioBufBatch.write(std::move(packetBuf), encodedSize); + if (writeResult.hasError()) { + return folly::makeUnexpected(writeResult.error()); + } updateErrnoCount(connection, ioBufBatch); return DataPathResult::makeWriteResult( - ret, std::move(result.value()), encodedSize, encodedBodySize); + writeResult.value(), + std::move(result.value()), + encodedSize, + encodedBodySize); } } // namespace @@ -1610,7 +1635,12 @@ folly::Expected writeConnectionDataToSocket( << connection; if (!connection.gsoSupported.has_value()) { - connection.gsoSupported = sock.getGSO() >= 0; + auto gsoResult = sock.getGSO(); + if (gsoResult.hasError()) { + LOG(ERROR) << "Failed to get GSO: " << gsoResult.error().message; + return folly::makeUnexpected(gsoResult.error()); + } + connection.gsoSupported = sock.getGSO().value() >= 0; if (!*connection.gsoSupported) { if (!useSinglePacketInplaceBatchWriter( connection.transportSettings.maxBatchSize, @@ -1648,7 +1678,11 @@ folly::Expected writeConnectionDataToSocket( // If we have a pending write to retry. Flush that first and make sure it // succeeds before scheduling any new data. if (pendingBufferedWrite) { - bool flushSuccess = ioBufBatch.flush(); + auto flushResult = ioBufBatch.flush(); + if (flushResult.hasError()) { + return folly::makeUnexpected(flushResult.error()); + } + auto flushSuccess = flushResult.value(); updateErrnoCount(connection, ioBufBatch); if (!flushSuccess) { // Could not flush retried data. Return empty write result and wait for @@ -1725,7 +1759,10 @@ folly::Expected writeConnectionDataToSocket( if (!ret->buildSuccess) { // If we're returning because we couldn't schedule more packets, // make sure we flush the buffer in this function. - ioBufBatch.flush(); + auto flushResult = ioBufBatch.flush(); + if (flushResult.hasError()) { + return folly::makeUnexpected(flushResult.error()); + } updateErrnoCount(connection, ioBufBatch); return WriteQuicDataResult{ioBufBatch.getPktSent(), 0, bytesWritten}; } @@ -1777,13 +1814,19 @@ folly::Expected writeConnectionDataToSocket( connection.transportSettings.dataPathType)) { // With SinglePacketInplaceBatchWriter we always write one packet, and so // ioBufBatch needs a flush. - ioBufBatch.flush(); + auto flushResult = ioBufBatch.flush(); + if (flushResult.hasError()) { + return folly::makeUnexpected(flushResult.error()); + } updateErrnoCount(connection, ioBufBatch); } } // Ensure that the buffer is flushed before returning - ioBufBatch.flush(); + auto flushResult = ioBufBatch.flush(); + if (flushResult.hasError()) { + return folly::makeUnexpected(flushResult.error()); + } updateErrnoCount(connection, ioBufBatch); if (connection.transportSettings.dataPathType == diff --git a/quic/api/test/QuicBatchWriterTest.cpp b/quic/api/test/QuicBatchWriterTest.cpp index 9cbf75f63..fe5fee800 100644 --- a/quic/api/test/QuicBatchWriterTest.cpp +++ b/quic/api/test/QuicBatchWriterTest.cpp @@ -65,8 +65,10 @@ TEST_F(QuicBatchWriterTest, TestBatchingGSOBase) { std::make_shared(&evb); FollyQuicAsyncUDPSocket sock(qEvb); sock.setReuseAddr(false); - sock.bind(folly::SocketAddress("127.0.0.1", 0)); - gsoSupported_ = sock.getGSO() >= 0; + ASSERT_FALSE(sock.bind(folly::SocketAddress("127.0.0.1", 0)).hasError()); + auto gsoResult = sock.getGSO(); + ASSERT_FALSE(gsoResult.hasError()); + gsoSupported_ = gsoResult.value(); auto batchWriter = quic::BatchWriterFactory::makeBatchWriter( quic::QuicBatchingMode::BATCHING_MODE_GSO, @@ -95,8 +97,10 @@ TEST_F(QuicBatchWriterTest, TestBatchingGSOLastSmallPacket) { std::make_shared(&evb); FollyQuicAsyncUDPSocket sock(qEvb); sock.setReuseAddr(false); - sock.bind(folly::SocketAddress("127.0.0.1", 0)); - gsoSupported_ = sock.getGSO() >= 0; + ASSERT_FALSE(sock.bind(folly::SocketAddress("127.0.0.1", 0)).hasError()); + auto gsoResult = sock.getGSO(); + ASSERT_FALSE(gsoResult.hasError()); + gsoSupported_ = gsoResult.value(); auto batchWriter = quic::BatchWriterFactory::makeBatchWriter( quic::QuicBatchingMode::BATCHING_MODE_GSO, @@ -137,8 +141,10 @@ TEST_F(QuicBatchWriterTest, TestBatchingGSOLastBigPacket) { std::make_shared(&evb); FollyQuicAsyncUDPSocket sock(qEvb); sock.setReuseAddr(false); - sock.bind(folly::SocketAddress("127.0.0.1", 0)); - gsoSupported_ = sock.getGSO() >= 0; + ASSERT_FALSE(sock.bind(folly::SocketAddress("127.0.0.1", 0)).hasError()); + auto gsoResult = sock.getGSO(); + ASSERT_FALSE(gsoResult.hasError()); + gsoSupported_ = gsoResult.value(); auto batchWriter = quic::BatchWriterFactory::makeBatchWriter( quic::QuicBatchingMode::BATCHING_MODE_GSO, @@ -174,8 +180,10 @@ TEST_F(QuicBatchWriterTest, TestBatchingGSOBatchNum) { std::make_shared(&evb); FollyQuicAsyncUDPSocket sock(qEvb); sock.setReuseAddr(false); - sock.bind(folly::SocketAddress("127.0.0.1", 0)); - gsoSupported_ = sock.getGSO() >= 0; + ASSERT_FALSE(sock.bind(folly::SocketAddress("127.0.0.1", 0)).hasError()); + auto gsoResult = sock.getGSO(); + ASSERT_FALSE(gsoResult.hasError()); + gsoSupported_ = gsoResult.value(); auto batchWriter = quic::BatchWriterFactory::makeBatchWriter( quic::QuicBatchingMode::BATCHING_MODE_GSO, @@ -475,8 +483,10 @@ TEST_F(QuicBatchWriterTest, TestBatchingSendmmsgGSOBatchNum) { std::make_shared(&evb); FollyQuicAsyncUDPSocket sock(qEvb); sock.setReuseAddr(false); - sock.bind(folly::SocketAddress("127.0.0.1", 0)); - gsoSupported_ = sock.getGSO() >= 0; + ASSERT_FALSE(sock.bind(folly::SocketAddress("127.0.0.1", 0)).hasError()); + auto gsoResult = sock.getGSO(); + ASSERT_FALSE(gsoResult.hasError()); + gsoSupported_ = gsoResult.value(); auto batchWriter = quic::BatchWriterFactory::makeBatchWriter( quic::QuicBatchingMode::BATCHING_MODE_SENDMMSG_GSO, @@ -521,8 +531,10 @@ TEST_F(QuicBatchWriterTest, TestBatchingSendmmsgGSOBatcBigSmallPacket) { std::make_shared(&evb); FollyQuicAsyncUDPSocket sock(qEvb); sock.setReuseAddr(false); - sock.bind(folly::SocketAddress("127.0.0.1", 0)); - gsoSupported_ = sock.getGSO() >= 0; + ASSERT_FALSE(sock.bind(folly::SocketAddress("127.0.0.1", 0)).hasError()); + auto gsoResult = sock.getGSO(); + ASSERT_FALSE(gsoResult.hasError()); + gsoSupported_ = gsoResult.value(); auto batchWriter = quic::BatchWriterFactory::makeBatchWriter( quic::QuicBatchingMode::BATCHING_MODE_SENDMMSG_GSO, diff --git a/quic/api/test/QuicTransportBaseTest.cpp b/quic/api/test/QuicTransportBaseTest.cpp index f54b496bf..a9040b4f0 100644 --- a/quic/api/test/QuicTransportBaseTest.cpp +++ b/quic/api/test/QuicTransportBaseTest.cpp @@ -651,6 +651,19 @@ class QuicTransportImplTest : public Test { qEvb = std::make_shared(fEvb.get()); auto socket = std::make_unique>(qEvb); + ON_CALL(*socket, setAdditionalCmsgsFunc(_)) + .WillByDefault(Return(folly::unit)); + ON_CALL(*socket, close()).WillByDefault(Return(folly::unit)); + ON_CALL(*socket, resumeWrite(_)).WillByDefault(Return(folly::unit)); + ON_CALL(*socket, bind(_)).WillByDefault(Return(folly::unit)); + ON_CALL(*socket, connect(_)).WillByDefault(Return(folly::unit)); + ON_CALL(*socket, setReuseAddr(_)).WillByDefault(Return(folly::unit)); + ON_CALL(*socket, setReusePort(_)).WillByDefault(Return(folly::unit)); + ON_CALL(*socket, setRecvTos(_)).WillByDefault(Return(folly::unit)); + ON_CALL(*socket, getRecvTos()).WillByDefault(Return(false)); + ON_CALL(*socket, getGSO()).WillByDefault(Return(0)); + ON_CALL(*socket, setCmsgs(_)).WillByDefault(Return(folly::unit)); + ON_CALL(*socket, appendCmsgs(_)).WillByDefault(Return(folly::unit)); socketPtr = socket.get(); transport = std::make_shared( qEvb, std::move(socket), &connSetupCallback, &connCallback); @@ -3378,7 +3391,7 @@ TEST_P(QuicTransportImplTestBase, UncleanShutdownEventBase) { TEST_P(QuicTransportImplTestBase, GetLocalAddressBoundSocket) { SocketAddress addr("127.0.0.1", 443); EXPECT_CALL(*socketPtr, isBound()).WillOnce(Return(true)); - EXPECT_CALL(*socketPtr, address()).WillRepeatedly(ReturnRef(addr)); + EXPECT_CALL(*socketPtr, addressRef()).WillRepeatedly(ReturnRef(addr)); SocketAddress localAddr = transport->getLocalAddress(); EXPECT_TRUE(localAddr == addr); } @@ -4859,9 +4872,7 @@ TEST_P( // Write event is not armed. EXPECT_CALL(*socketPtr, isWritableCallbackSet()).WillOnce(Return(false)); - EXPECT_CALL(*socketPtr, resumeWrite(_)) - .WillOnce(Return( - folly::makeExpected(folly::Unit()))); + EXPECT_CALL(*socketPtr, resumeWrite(_)).WillOnce(Return(folly::unit)); transport->maybeStopWriteLooperAndArmSocketWritableEvent(); // Write looper is stopped. EXPECT_FALSE(transport->writeLooper()->isRunning()); @@ -5072,7 +5083,7 @@ TEST_P( EXPECT_CALL(*socketPtr, resumeWrite(_)) .WillOnce(Invoke([&](QuicAsyncUDPSocket::WriteCallback*) { writeCallbackArmed = true; - return folly::makeExpected(folly::Unit()); + return folly::unit; })); // Fail the first write loop. diff --git a/quic/api/test/QuicTransportFunctionsTest.cpp b/quic/api/test/QuicTransportFunctionsTest.cpp index 59cce3d2c..bd3cdc673 100644 --- a/quic/api/test/QuicTransportFunctionsTest.cpp +++ b/quic/api/test/QuicTransportFunctionsTest.cpp @@ -2690,6 +2690,7 @@ TEST_F(QuicTransportFunctionsTest, WriteQuicDataToSocketWithCC) { auto socket = std::make_unique>(qEvb); auto rawSocket = socket.get(); + ON_CALL(*rawSocket, getGSO).WillByDefault(testing::Return(0)); auto stream1 = conn->streamManager->createNextBidirectionalStream().value(); auto buf = @@ -2734,6 +2735,7 @@ TEST_F(QuicTransportFunctionsTest, WriteQuicdataToSocketWithPacer) { auto socket = std::make_unique>(qEvb); auto rawSocket = socket.get(); + ON_CALL(*rawSocket, getGSO).WillByDefault(testing::Return(0)); auto stream1 = conn->streamManager->createNextBidirectionalStream().value(); auto buf = @@ -2768,6 +2770,7 @@ TEST_F(QuicTransportFunctionsTest, WriteQuicDataToSocketLimitTest) { auto socket = std::make_unique>(qEvb); auto rawSocket = socket.get(); + ON_CALL(*rawSocket, getGSO).WillByDefault(testing::Return(0)); auto stream1 = conn->streamManager->createNextBidirectionalStream().value(); // ~50 bytes auto buf = @@ -2879,6 +2882,7 @@ TEST_F( auto socket = std::make_unique>(qEvb); auto rawSocket = socket.get(); + ON_CALL(*rawSocket, getGSO).WillByDefault(testing::Return(0)); auto stream1 = conn->streamManager->createNextBidirectionalStream().value(); auto buf = @@ -2927,6 +2931,7 @@ TEST_F(QuicTransportFunctionsTest, WriteQuicDataToSocketWithNoBytesForHeader) { auto socket = std::make_unique>(qEvb); auto rawSocket = socket.get(); + ON_CALL(*rawSocket, getGSO).WillByDefault(testing::Return(0)); auto stream1 = conn->streamManager->createNextBidirectionalStream().value(); auto buf = IOBuf::copyBuffer("0123456789012"); @@ -2954,6 +2959,7 @@ TEST_F(QuicTransportFunctionsTest, WriteQuicDataToSocketRetxBufferSorted) { std::shared_ptr qEvb = std::make_shared(&evb); quic::test::MockAsyncUDPSocket socket(qEvb); + ON_CALL(socket, getGSO).WillByDefault(testing::Return(0)); auto conn = createConn(); auto stream = conn->streamManager->createNextBidirectionalStream().value(); auto buf1 = IOBuf::copyBuffer("Whatsapp"); @@ -3000,6 +3006,7 @@ TEST_F(QuicTransportFunctionsTest, NothingWritten) { auto socket = std::make_unique>(qEvb); auto rawSocket = socket.get(); + ON_CALL(*rawSocket, getGSO).WillByDefault(testing::Return(0)); // 18 isn't enough to write 3 ack blocks, but is enough to write a pure // header packet, which we shouldn't write @@ -3047,6 +3054,7 @@ TEST_F(QuicTransportFunctionsTest, WriteBlockedFrameWhenBlocked) { auto socket = std::make_unique>(qEvb); auto rawSocket = socket.get(); + ON_CALL(*rawSocket, getGSO).WillByDefault(testing::Return(0)); auto stream1 = conn->streamManager->createNextBidirectionalStream().value(); auto buf = buildRandomInputData(200); ASSERT_FALSE(writeDataToQuicStream(*stream1, buf->clone(), true).hasError()); @@ -3120,6 +3128,7 @@ TEST_F(QuicTransportFunctionsTest, WriteProbingNewData) { auto socket = std::make_unique>(qEvb); auto rawSocket = socket.get(); + ON_CALL(*rawSocket, getGSO).WillByDefault(testing::Return(0)); auto stream1 = conn->streamManager->createNextBidirectionalStream().value(); auto buf = buildRandomInputData(conn->udpSendPacketLen * 2); ASSERT_FALSE( @@ -3155,6 +3164,7 @@ TEST_F(QuicTransportFunctionsTest, WriteProbingOldData) { auto socket = std::make_unique>(qEvb); auto rawSocket = socket.get(); + ON_CALL(*rawSocket, getGSO).WillByDefault(testing::Return(0)); EXPECT_CALL(*rawSocket, write(_, _, _)).WillRepeatedly(Return(100)); auto capturingAead = std::make_unique(); auto stream = conn->streamManager->createNextBidirectionalStream().value(); @@ -3200,6 +3210,7 @@ TEST_F(QuicTransportFunctionsTest, WriteProbingOldDataAckFreq) { auto socket = std::make_unique>(qEvb); auto rawSocket = socket.get(); + ON_CALL(*rawSocket, getGSO).WillByDefault(testing::Return(0)); EXPECT_CALL(*rawSocket, write(_, _, _)).WillRepeatedly(Return(100)); auto capturingAead = std::make_unique(); auto stream = conn->streamManager->createNextBidirectionalStream().value(); @@ -3265,6 +3276,7 @@ TEST_F(QuicTransportFunctionsTest, WriteProbingCryptoData) { auto socket = std::make_unique>(qEvb); auto rawSocket = socket.get(); + ON_CALL(*rawSocket, getGSO).WillByDefault(testing::Return(0)); auto cryptoStream = &conn.cryptoState->initialStream; auto buf = buildRandomInputData(conn.udpSendPacketLen * 2); writeDataToQuicStream(*cryptoStream, buf->clone()); @@ -3310,6 +3322,7 @@ TEST_F(QuicTransportFunctionsTest, WriteableBytesLimitedProbingCryptoData) { auto socket = std::make_unique>(qEvb); auto rawSocket = socket.get(); + ON_CALL(*rawSocket, getGSO).WillByDefault(testing::Return(0)); auto cryptoStream = &conn.cryptoState->initialStream; uint8_t probesToSend = 4; auto buf = buildRandomInputData(conn.udpSendPacketLen * probesToSend); @@ -3349,6 +3362,7 @@ TEST_F(QuicTransportFunctionsTest, ProbingNotFallbackToPingWhenNoQuota) { auto socket = std::make_unique>(qEvb); auto rawSocket = socket.get(); + ON_CALL(*rawSocket, getGSO).WillByDefault(testing::Return(0)); EXPECT_CALL(*rawCongestionController, onPacketSent(_)).Times(0); EXPECT_CALL(*rawSocket, write(_, _, _)).Times(0); uint8_t probesToSend = 0; @@ -3371,6 +3385,7 @@ TEST_F(QuicTransportFunctionsTest, ProbingFallbackToPing) { auto socket = std::make_unique>(qEvb); auto rawSocket = socket.get(); + ON_CALL(*rawSocket, getGSO).WillByDefault(testing::Return(0)); EXPECT_CALL(*rawSocket, write(_, _, _)) .Times(1) .WillOnce(Invoke( @@ -3400,6 +3415,7 @@ TEST_F(QuicTransportFunctionsTest, ProbingFallbackToImmediateAck) { auto socket = std::make_unique>(qEvb); auto rawSocket = socket.get(); + ON_CALL(*rawSocket, getGSO).WillByDefault(testing::Return(0)); EXPECT_CALL(*rawSocket, write(_, _, _)) .Times(1) .WillOnce(Invoke( @@ -3432,6 +3448,7 @@ TEST_F(QuicTransportFunctionsTest, NoCryptoProbeWriteIfNoProbeCredit) { auto socket = std::make_unique>(qEvb); auto rawSocket = socket.get(); + ON_CALL(*rawSocket, getGSO).WillByDefault(testing::Return(0)); auto res = writeCryptoAndAckDataToSocket( *rawSocket, *conn, @@ -3483,6 +3500,7 @@ TEST_F(QuicTransportFunctionsTest, ImmediatelyRetransmitInitialPackets) { auto socket = std::make_unique>(qEvb); auto rawSocket = socket.get(); + ON_CALL(*rawSocket, getGSO).WillByDefault(testing::Return(0)); auto res = writeCryptoAndAckDataToSocket( *rawSocket, *conn, @@ -3512,6 +3530,7 @@ TEST_F(QuicTransportFunctionsTest, ResetNumProbePackets) { auto socket = std::make_unique>(qEvb); auto rawSocket = socket.get(); + ON_CALL(*rawSocket, getGSO).WillByDefault(testing::Return(0)); conn->pendingEvents.numProbePackets[PacketNumberSpace::Initial] = 2; auto writeRes1 = writeCryptoAndAckDataToSocket( @@ -3574,6 +3593,7 @@ TEST_F(QuicTransportFunctionsTest, WritePureAckWhenNoWritableBytes) { auto socket = std::make_unique>(qEvb); auto rawSocket = socket.get(); + ON_CALL(*rawSocket, getGSO).WillByDefault(testing::Return(0)); auto stream1 = conn->streamManager->createNextBidirectionalStream().value(); auto buf = IOBuf::copyBuffer("0123456789012"); @@ -3626,6 +3646,7 @@ TEST_F(QuicTransportFunctionsTest, ShouldWriteDataTest) { auto socket = std::make_unique>(qEvb); auto rawSocket = socket.get(); + ON_CALL(*rawSocket, getGSO).WillByDefault(testing::Return(0)); // Pure acks without an oneRttCipher CHECK(!conn->oneRttWriteCipher); @@ -4187,6 +4208,7 @@ TEST_F(QuicTransportFunctionsTest, WriteLimitBytRttFraction) { auto socket = std::make_unique>(qEvb); auto rawSocket = socket.get(); + ON_CALL(*rawSocket, getGSO).WillByDefault(testing::Return(0)); auto stream1 = conn->streamManager->createNextBidirectionalStream().value(); auto buf = buildRandomInputData(2048 * 2048); @@ -4248,6 +4270,7 @@ TEST_F(QuicTransportFunctionsTest, WriteLimitBytRttFractionNoLimit) { auto socket = std::make_unique>(qEvb); auto rawSocket = socket.get(); + ON_CALL(*rawSocket, getGSO).WillByDefault(testing::Return(0)); auto stream1 = conn->streamManager->createNextBidirectionalStream().value(); auto buf = buildRandomInputData(2048 * 2048); @@ -4321,6 +4344,7 @@ TEST_F(QuicTransportFunctionsTest, HandshakeConfirmedDropCipher) { std::make_shared(&evb); auto socket = std::make_unique>(qEvb); + ON_CALL(*socket, getGSO).WillByDefault(testing::Return(0)); auto initialStream = getCryptoStream(*conn->cryptoState, EncryptionLevel::Initial); auto handshakeStream = @@ -4391,6 +4415,7 @@ TEST_F(QuicTransportFunctionsTest, ProbeWriteNewFunctionalFrames) { std::make_shared(&evb); auto sock = std::make_unique>(qEvb); auto rawSocket = sock.get(); + ON_CALL(*rawSocket, getGSO).WillByDefault(testing::Return(0)); EXPECT_CALL(*rawSocket, write(_, _, _)) .WillRepeatedly(Invoke( @@ -4444,6 +4469,7 @@ TEST_F(QuicTransportFunctionsTest, ProbeWriteNewFunctionalFramesAckFreq) { std::make_shared(&evb); auto sock = std::make_unique>(qEvb); auto rawSocket = sock.get(); + ON_CALL(*rawSocket, getGSO).WillByDefault(testing::Return(0)); EXPECT_CALL(*rawSocket, write(_, _, _)) .WillRepeatedly(Invoke( diff --git a/quic/api/test/QuicTransportTest.cpp b/quic/api/test/QuicTransportTest.cpp index db0379784..5753824bc 100644 --- a/quic/api/test/QuicTransportTest.cpp +++ b/quic/api/test/QuicTransportTest.cpp @@ -75,6 +75,23 @@ class QuicTransportTest : public Test { std::unique_ptr sock = std::make_unique>(qEvb_); socket_ = sock.get(); + ON_CALL(*socket_, setAdditionalCmsgsFunc(_)) + .WillByDefault(Return(folly::unit)); + ON_CALL(*socket_, setTosOrTrafficClass(_)) + .WillByDefault(Return(folly::unit)); + ON_CALL(*socket_, close()).WillByDefault(Return(folly::unit)); + ON_CALL(*socket_, getGSO()).WillByDefault(Return(0)); + ON_CALL(*socket_, init(_)).WillByDefault(Return(folly::unit)); + ON_CALL(*socket_, bind(_)).WillByDefault(Return(folly::unit)); + ON_CALL(*socket_, connect(_)).WillByDefault(Return(folly::unit)); + ON_CALL(*socket_, setRecvTos(_)).WillByDefault(Return(folly::unit)); + ON_CALL(*socket_, setReuseAddr(_)).WillByDefault(Return(folly::unit)); + ON_CALL(*socket_, setReusePort(_)).WillByDefault(Return(folly::unit)); + ON_CALL(*socket_, setRcvBuf(_)).WillByDefault(Return(folly::unit)); + ON_CALL(*socket_, setSndBuf(_)).WillByDefault(Return(folly::unit)); + ON_CALL(*socket_, setErrMessageCallback(_)) + .WillByDefault(Return(folly::unit)); + ON_CALL(*socket_, applyOptions(_, _)).WillByDefault(Return(folly::unit)); transport_.reset(new TestQuicTransport( qEvb_, std::move(sock), &connSetupCallback_, &connCallback_)); // Set the write handshake state to tell the client that the handshake has diff --git a/quic/api/test/TestQuicTransport.h b/quic/api/test/TestQuicTransport.h index 4ba199e8a..25d92fa21 100644 --- a/quic/api/test/TestQuicTransport.h +++ b/quic/api/test/TestQuicTransport.h @@ -206,7 +206,7 @@ class TestQuicTransport } void validateECN() { - QuicTransportBase::validateECNState(); + CHECK(!QuicTransportBase::validateECNState().hasError()); } std::unique_ptr aead; diff --git a/quic/client/QuicClientTransport.cpp b/quic/client/QuicClientTransport.cpp index aaf891b4a..496d907ca 100644 --- a/quic/client/QuicClientTransport.cpp +++ b/quic/client/QuicClientTransport.cpp @@ -26,7 +26,7 @@ QuicClientTransport::~QuicClientTransport() { if (clientConn_->happyEyeballsState.secondSocket) { auto sock = std::move(clientConn_->happyEyeballsState.secondSocket); sock->pauseRead(); - sock->close(); + (void)sock->close(); } } @@ -43,20 +43,26 @@ void QuicClientTransport::onNotifyDataAvailable( ? conn_->transportSettings.readCoalescingSize : readBufferSize; - if (conn_->transportSettings.networkDataPerSocketRead) { - readWithRecvmsgSinglePacketLoop(sock, readAllocSize); - } else if (conn_->transportSettings.shouldUseWrapperRecvmmsgForBatchRecv) { - readWithRecvmmsgWrapper(sock, readAllocSize, numPackets); - } else if (conn_->transportSettings.shouldUseRecvmmsgForBatchRecv) { - readWithRecvmmsg(sock, readAllocSize, numPackets); - } else if (conn_->transportSettings.shouldUseRecvfromForBatchRecv) { - readWithRecvfrom(sock, readAllocSize, numPackets); - } else { - readWithRecvmsg(sock, readAllocSize, numPackets); + auto result = [&]() -> folly::Expected { + if (conn_->transportSettings.networkDataPerSocketRead) { + return readWithRecvmsgSinglePacketLoop(sock, readAllocSize); + } else if (conn_->transportSettings.shouldUseWrapperRecvmmsgForBatchRecv) { + return readWithRecvmmsgWrapper(sock, readAllocSize, numPackets); + } else if (conn_->transportSettings.shouldUseRecvmmsgForBatchRecv) { + return readWithRecvmmsg(sock, readAllocSize, numPackets); + } else if (conn_->transportSettings.shouldUseRecvfromForBatchRecv) { + return readWithRecvfrom(sock, readAllocSize, numPackets); + } else { + return readWithRecvmsg(sock, readAllocSize, numPackets); + } + }(); + if (result.hasError()) { + asyncClose(result.error()); } } -void QuicClientTransport::readWithRecvmmsgWrapper( +folly::Expected +QuicClientTransport::readWithRecvmmsgWrapper( QuicAsyncUDPSocket& sock, uint64_t readBufferSize, uint16_t numPackets) { @@ -68,6 +74,10 @@ void QuicClientTransport::readWithRecvmmsgWrapper( const auto result = sock.recvmmsgNetworkData( readBufferSize, numPackets, networkData, server, totalData); + if (result.hasError()) { + return folly::makeUnexpected(result.error()); + } + // track the received packets for (const auto& packet : networkData.getPackets()) { if (packet.buf.empty()) { @@ -82,8 +92,8 @@ void QuicClientTransport::readWithRecvmmsgWrapper( // Propagate errors // TODO(bschlinker): Investigate generalization of loopDetectorCallback // TODO(bschlinker): Consider merging this into ReadCallback - if (result.maybeNoReadReason) { - const auto& noReadReason = result.maybeNoReadReason.value(); + if (result->maybeNoReadReason) { + const auto& noReadReason = result->maybeNoReadReason.value(); switch (noReadReason) { case NoReadReason::RETRIABLE_ERROR: if (conn_->loopDetectorCallback) { @@ -109,10 +119,10 @@ void QuicClientTransport::readWithRecvmmsgWrapper( break; } } - processPackets(std::move(networkData), server); + return processPackets(std::move(networkData), server); } -void QuicClientTransport::readWithRecvmmsg( +folly::Expected QuicClientTransport::readWithRecvmmsg( QuicAsyncUDPSocket& sock, uint64_t readBufferSize, uint16_t numPackets) { @@ -123,12 +133,16 @@ void QuicClientTransport::readWithRecvmmsg( // TODO(bschlinker): Deprecate in favor of Wrapper::recvmmsg recvmmsgStorage_.resize(numPackets); - recvMmsg(sock, readBufferSize, numPackets, networkData, server, totalData); + auto recvResult = recvMmsg( + sock, readBufferSize, numPackets, networkData, server, totalData); + if (recvResult.hasError()) { + return recvResult; + } - processPackets(std::move(networkData), server); + return processPackets(std::move(networkData), server); } -void QuicClientTransport::readWithRecvfrom( +folly::Expected QuicClientTransport::readWithRecvfrom( QuicAsyncUDPSocket& sock, uint64_t readBufferSize, uint16_t numPackets) { @@ -136,11 +150,15 @@ void QuicClientTransport::readWithRecvfrom( networkData.reserve(numPackets); size_t totalData = 0; Optional server; - recvFrom(sock, readBufferSize, numPackets, networkData, server, totalData); - processPackets(std::move(networkData), server); + auto recvResult = recvFrom( + sock, readBufferSize, numPackets, networkData, server, totalData); + if (recvResult.hasError()) { + return recvResult; + } + return processPackets(std::move(networkData), server); } -void QuicClientTransport::readWithRecvmsg( +folly::Expected QuicClientTransport::readWithRecvmsg( QuicAsyncUDPSocket& sock, uint64_t readBufferSize, uint16_t numPackets) { @@ -150,9 +168,13 @@ void QuicClientTransport::readWithRecvmsg( Optional server; // TODO(bschlinker): Deprecate in favor of Wrapper::recvmmsg - recvMsg(sock, readBufferSize, numPackets, networkData, server, totalData); + auto recvResult = + recvMsg(sock, readBufferSize, numPackets, networkData, server, totalData); + if (recvResult.hasError()) { + return recvResult; + } - processPackets(std::move(networkData), server); + return processPackets(std::move(networkData), server); } } // namespace quic diff --git a/quic/client/QuicClientTransport.h b/quic/client/QuicClientTransport.h index 3ccd927d3..788d93e57 100644 --- a/quic/client/QuicClientTransport.h +++ b/quic/client/QuicClientTransport.h @@ -101,22 +101,22 @@ class QuicClientTransport : public QuicTransportBase, return wrappedObserverContainer_.getPtr(); } - void readWithRecvmmsgWrapper( + [[nodiscard]] folly::Expected readWithRecvmmsgWrapper( QuicAsyncUDPSocket& sock, uint64_t readBufferSize, uint16_t numPackets); - void readWithRecvmmsg( + [[nodiscard]] folly::Expected readWithRecvmmsg( QuicAsyncUDPSocket& sock, uint64_t readBufferSize, uint16_t numPackets); - void readWithRecvfrom( + [[nodiscard]] folly::Expected readWithRecvfrom( QuicAsyncUDPSocket& sock, uint64_t readBufferSize, uint16_t numPackets); - void readWithRecvmsg( + [[nodiscard]] folly::Expected readWithRecvmsg( QuicAsyncUDPSocket& sock, uint64_t readBufferSize, uint16_t numPackets); diff --git a/quic/client/QuicClientTransportLite.cpp b/quic/client/QuicClientTransportLite.cpp index 14d876e7e..c832b6034 100644 --- a/quic/client/QuicClientTransportLite.cpp +++ b/quic/client/QuicClientTransportLite.cpp @@ -114,7 +114,7 @@ QuicClientTransportLite::~QuicClientTransportLite() { if (clientConn_->happyEyeballsState.secondSocket) { auto sock = std::move(clientConn_->happyEyeballsState.secondSocket); sock->pauseRead(); - sock->close(); + (void)sock->close(); } } @@ -970,11 +970,18 @@ folly::Expected QuicClientTransportLite::onReadData( replaySafeNotified_ = true; // We don't need this any more. Also unset it so that we don't allow random // middleboxes to shutdown our connection once we have crypto keys. - socket_->setErrMessageCallback(nullptr); + auto result = socket_->setErrMessageCallback(nullptr); + if (result.hasError()) { + return folly::makeUnexpected(result.error()); + } connSetupCallback_->onReplaySafe(); } - maybeSendTransportKnobs(); + auto result = maybeSendTransportKnobs(); + if (result.hasError()) { + return result; + } + return folly::unit; } @@ -1204,6 +1211,11 @@ std::shared_ptr QuicClientTransportLite::sharedGuard() { return shared_from_this(); } +std::shared_ptr +QuicClientTransportLite::sharedGuardClient() { + return shared_from_this(); +} + bool QuicClientTransportLite::isTLSResumed() const { return clientConn_->clientHandshakeLayer->isTLSResumed(); } @@ -1241,13 +1253,8 @@ void QuicClientTransportLite::errMessage( auto errStr = folly::errnoStr(serr->ee_errno); if (!happyEyeballsState.shouldWriteToFirstSocket && !happyEyeballsState.shouldWriteToSecondSocket) { - runOnEvbAsync([errString = std::move(errStr)](auto self) mutable { - auto quicError = QuicError( - QuicErrorCode(LocalErrorCode::CONNECT_FAILED), - std::move(errString)); - auto clientPtr = dynamic_cast(self.get()); - clientPtr->closeImpl(std::move(quicError), false, false); - }); + asyncClose(QuicError( + QuicErrorCode(LocalErrorCode::CONNECT_FAILED), std::move(errStr))); } } #endif @@ -1259,12 +1266,8 @@ void QuicClientTransportLite::onReadError( // closeNow will skip draining the socket. onReadError doesn't gets // triggered by retriable errors. If we are here, there is no point of // draining the socket. - runOnEvbAsync([ex](auto self) { - auto clientPtr = dynamic_cast(self.get()); - clientPtr->closeNow(QuicError( - QuicErrorCode(LocalErrorCode::CONNECTION_ABANDONED), - std::string(ex.what()))); - }); + asyncClose(QuicError( + QuicErrorCode(LocalErrorCode::CONNECTION_ABANDONED), ex.what())); } } @@ -1286,7 +1289,7 @@ bool QuicClientTransportLite::shouldOnlyNotify() { return true; } -void QuicClientTransportLite::recvMsg( +folly::Expected QuicClientTransportLite::recvMsg( QuicAsyncUDPSocket& sock, uint64_t readBufferSize, int numPackets, @@ -1295,8 +1298,7 @@ void QuicClientTransportLite::recvMsg( size_t& totalData) { for (int packetNum = 0; packetNum < numPackets; ++packetNum) { // We create 1 buffer per packet so that it is not shared, this enables - // us to decrypt in place. If the fizz decrypt api could decrypt in-place - // even if shared, then we could allocate one giant IOBuf here. + // us to decrypt in place. Buf readBuffer = BufHelpers::createCombined(readBufferSize); struct iovec vec; vec.iov_base = readBuffer->writableData(); @@ -1308,7 +1310,15 @@ void QuicClientTransportLite::recvMsg( if (!server) { rawAddr = reinterpret_cast(&addrStorage); - rawAddr->sa_family = sock.getLocalAddressFamily(); + auto familyResult = sock.getLocalAddressFamily(); + if (familyResult.hasError()) { + return folly::makeUnexpected(QuicError( + QuicErrorCode(TransportErrorCode::INTERNAL_ERROR), + folly::to( + "Failed to get address family: ", + familyResult.error().message))); + } + rawAddr->sa_family = familyResult.value(); } int flags = 0; @@ -1321,9 +1331,35 @@ void QuicClientTransportLite::recvMsg( msg.msg_iov = &vec; msg.msg_iovlen = 1; #ifdef FOLLY_HAVE_MSG_ERRQUEUE - bool useGRO = sock.getGRO() > 0; - bool useTs = sock.getTimestamping() > 0; - bool recvTos = sock.getRecvTos(); + bool useGRO = false; + bool useTs = false; + bool recvTos = false; + + auto groResult = sock.getGRO(); + if (groResult.hasError()) { + // Non-fatal, just log and continue + LOG(WARNING) << "Failed to get GRO status: " << groResult.error().message; + } else { + useGRO = groResult.value() > 0; + } + + auto tsResult = sock.getTimestamping(); + if (tsResult.hasError()) { + // Non-fatal, just log and continue + LOG(WARNING) << "Failed to get timestamping status: " + << tsResult.error().message; + } else { + useTs = tsResult.value() > 0; + } + + auto tosResult = sock.getRecvTos(); + if (tosResult.hasError()) { + // Non-fatal, just log and continue + LOG(WARNING) << "Failed to get TOS status: " << tosResult.error().message; + } else { + recvTos = tosResult.value(); + } + bool checkCmsgs = useGRO || useTs || recvTos; char control [QuicAsyncUDPSocket::ReadCallback::OnDataAvailableParams::kCmsgSpace] = @@ -1353,10 +1389,10 @@ void QuicClientTransportLite::recvMsg( if (conn_->loopDetectorCallback) { conn_->readDebugState.noReadReason = NoReadReason::NONRETRIABLE_ERROR; } - return onReadError(folly::AsyncSocketException( - folly::AsyncSocketException::INTERNAL_ERROR, - "::recvmsg() failed", - errno)); + return folly::makeUnexpected(QuicError( + QuicErrorCode(LocalErrorCode::CONNECTION_ABANDONED), + folly::to( + "recvmsg() failed, errno=", errno, " ", folly::errnoStr(errno)))); } else if (ret == 0) { break; } @@ -1426,9 +1462,11 @@ void QuicClientTransportLite::recvMsg( } trackDatagramsReceived( networkData.getPackets().size(), networkData.getTotalData()); + + return folly::unit; } -void QuicClientTransportLite::recvFrom( +folly::Expected QuicClientTransportLite::recvFrom( QuicAsyncUDPSocket& sock, uint64_t readBufferSize, int numPackets, @@ -1447,7 +1485,15 @@ void QuicClientTransportLite::recvFrom( if (!server) { rawAddr = reinterpret_cast(&addrStorage); - rawAddr->sa_family = sock.getLocalAddressFamily(); + auto familyResult = sock.getLocalAddressFamily(); + if (familyResult.hasError()) { + return folly::makeUnexpected(QuicError( + QuicErrorCode(TransportErrorCode::INTERNAL_ERROR), + folly::to( + "Failed to get address family: ", + familyResult.error().message))); + } + rawAddr->sa_family = familyResult.value(); } ssize_t ret = @@ -1466,10 +1512,13 @@ void QuicClientTransportLite::recvFrom( if (conn_->loopDetectorCallback) { conn_->readDebugState.noReadReason = NoReadReason::NONRETRIABLE_ERROR; } - return onReadError(folly::AsyncSocketException( - folly::AsyncSocketException::INTERNAL_ERROR, - "::recvmsg() failed", - errno)); + return folly::makeUnexpected(QuicError( + QuicErrorCode(TransportErrorCode::INTERNAL_ERROR), + folly::to( + "recvfrom() failed, errno=", + errno, + " ", + folly::errnoStr(errno)))); } else if (ret == 0) { break; } @@ -1487,9 +1536,11 @@ void QuicClientTransportLite::recvFrom( } trackDatagramsReceived( networkData.getPackets().size(), networkData.getTotalData()); + + return folly::unit; } -void QuicClientTransportLite::recvMmsg( +folly::Expected QuicClientTransportLite::recvMmsg( QuicAsyncUDPSocket& sock, uint64_t readBufferSize, uint16_t numPackets, @@ -1499,9 +1550,33 @@ void QuicClientTransportLite::recvMmsg( auto& msgs = recvmmsgStorage_.msgs; int flags = 0; #ifdef FOLLY_HAVE_MSG_ERRQUEUE - bool useGRO = sock.getGRO() > 0; - bool useTs = sock.getTimestamping() > 0; - bool recvTos = sock.getRecvTos(); + auto groResult = sock.getGRO(); + if (groResult.hasError()) { + return folly::makeUnexpected(QuicError( + QuicErrorCode(TransportErrorCode::INTERNAL_ERROR), + folly::to( + "Failed to get GRO status: ", groResult.error().message))); + } + bool useGRO = groResult.value() > 0; + + auto tsResult = sock.getTimestamping(); + if (tsResult.hasError()) { + return folly::makeUnexpected(QuicError( + QuicErrorCode(TransportErrorCode::INTERNAL_ERROR), + folly::to( + "Failed to get timestamping status: ", tsResult.error().message))); + } + bool useTs = tsResult.value() > 0; + + auto tosResult = sock.getRecvTos(); + if (tosResult.hasError()) { + return folly::makeUnexpected(QuicError( + QuicErrorCode(TransportErrorCode::INTERNAL_ERROR), + folly::to( + "Failed to get TOS status: ", tosResult.error().message))); + } + bool recvTos = tosResult.value(); + bool checkCmsgs = useGRO || useTs || recvTos; std::vector(&addr); - rawAddr->sa_family = sock.address().getFamily(); + auto addrResult = sock.address(); + if (addrResult.hasError()) { + return folly::makeUnexpected(QuicError( + QuicErrorCode(TransportErrorCode::INTERNAL_ERROR), + folly::to( + "Failed to get socket address: ", addrResult.error().message))); + } + rawAddr->sa_family = addrResult.value().getFamily(); msg->msg_name = rawAddr; msg->msg_namelen = kAddrLen; #ifdef FOLLY_HAVE_MSG_ERRQUEUE @@ -1548,7 +1630,7 @@ void QuicClientTransportLite::recvMmsg( if (conn_->loopDetectorCallback) { conn_->readDebugState.noReadReason = NoReadReason::RETRIABLE_ERROR; } - return; + return folly::unit; } // If we got a non-retriable error, we might have received // a packet that we could process, however let's just quit early. @@ -1556,10 +1638,10 @@ void QuicClientTransportLite::recvMmsg( if (conn_->loopDetectorCallback) { conn_->readDebugState.noReadReason = NoReadReason::NONRETRIABLE_ERROR; } - return onReadError(folly::AsyncSocketException( - folly::AsyncSocketException::INTERNAL_ERROR, - "::recvmmsg() failed", - errno)); + return folly::makeUnexpected(QuicError( + QuicErrorCode(TransportErrorCode::INTERNAL_ERROR), + folly::to( + "recvmmsg() failed, errno=", errno, " ", folly::errnoStr(errno)))); } CHECK_LE(numMsgsRecvd, numPackets); @@ -1644,9 +1726,11 @@ void QuicClientTransportLite::recvMmsg( } trackDatagramsReceived( networkData.getPackets().size(), networkData.getTotalData()); + + return folly::unit; } -void QuicClientTransportLite::processPackets( +folly::Expected QuicClientTransportLite::processPackets( NetworkData&& networkData, const Optional& server) { if (networkData.getPackets().empty()) { @@ -1661,17 +1745,20 @@ void QuicClientTransportLite::processPackets( conn_->readDebugState.noReadReason); } } - return; + return folly::unit; } DCHECK(server.has_value()); // TODO: we can get better receive time accuracy than this, with // SO_TIMESTAMP or SIOCGSTAMP. auto packetReceiveTime = Clock::now(); networkData.setReceiveTimePoint(packetReceiveTime); + onNetworkData(*server, std::move(networkData)); + return folly::unit; } -void QuicClientTransportLite::readWithRecvmsgSinglePacketLoop( +folly::Expected +QuicClientTransportLite::readWithRecvmsgSinglePacketLoop( QuicAsyncUDPSocket& sock, uint64_t readBufferSize) { size_t totalData = 0; @@ -1679,32 +1766,48 @@ void QuicClientTransportLite::readWithRecvmsgSinglePacketLoop( for (size_t i = 0; i < conn_->transportSettings.maxRecvBatchSize; i++) { auto networkDataSinglePacket = NetworkData(); networkDataSinglePacket.reserve(1); - recvMsg( + + auto recvResult = recvMsg( sock, readBufferSize, 1 /* numPackets */, networkDataSinglePacket, server, totalData); + + if (recvResult.hasError()) { + return recvResult; + } + if (!socket_) { // Socket has been closed. - return; + return folly::unit; } + if (networkDataSinglePacket.getPackets().size() == 0) { break; } - processPackets(std::move(networkDataSinglePacket), server); + + auto processResult = + processPackets(std::move(networkDataSinglePacket), server); + if (processResult.hasError()) { + return processResult; + } + if (!socket_) { // Socket has been closed. - return; + return folly::unit; } } + // Call callbacks/updates manually because processPackets()/onNetworkData() // will not schedule it when transportSettings.networkDataPerSocketRead is on. processCallbacksAfterNetworkData(); checkForClosedStream(); updateReadLooper(); updateWriteLooper(true); + + return folly::unit; } void QuicClientTransportLite::onNotifyDataAvailable( @@ -1719,7 +1822,10 @@ void QuicClientTransportLite::onNotifyDataAvailable( ? conn_->transportSettings.readCoalescingSize : readBufferSize; - readWithRecvmsgSinglePacketLoop(sock, readAllocSize); + auto result = readWithRecvmsgSinglePacketLoop(sock, readAllocSize); + if (result.hasError()) { + asyncClose(result.error()); + } } void QuicClientTransportLite:: @@ -1764,45 +1870,32 @@ void QuicClientTransportLite::start( clientConn_->pendingOneRttData.reserve( conn_->transportSettings.maxPacketsToBuffer); - try { - happyEyeballsSetUpSocket( - *socket_, - conn_->localAddress, - conn_->peerAddress, - conn_->transportSettings, - conn_->socketTos.value, - this, - this, - socketOptions_); - // adjust the GRO buffers - adjustGROBuffers(); - auto handshakeResult = startCryptoHandshake(); - if (handshakeResult.hasError()) { - runOnEvbAsync([error = handshakeResult.error()](auto self) { - auto clientPtr = dynamic_cast(self.get()); - clientPtr->closeImpl(error); - }); - } - } catch (const QuicTransportException& ex) { - runOnEvbAsync([ex](auto self) { - auto clientPtr = dynamic_cast(self.get()); - clientPtr->closeImpl( - QuicError(QuicErrorCode(ex.errorCode()), std::string(ex.what()))); - }); - } catch (const QuicInternalException& ex) { - runOnEvbAsync([ex](auto self) { - auto clientPtr = dynamic_cast(self.get()); - clientPtr->closeImpl( - QuicError(QuicErrorCode(ex.errorCode()), std::string(ex.what()))); - }); - } catch (const std::exception& ex) { - LOG(ERROR) << "Connect failed " << ex.what(); - runOnEvbAsync([ex](auto self) { - auto clientPtr = dynamic_cast(self.get()); - clientPtr->closeImpl(QuicError( - QuicErrorCode(TransportErrorCode::INTERNAL_ERROR), - std::string(ex.what()))); - }); + + auto socketResult = happyEyeballsSetUpSocket( + *socket_, + conn_->localAddress, + conn_->peerAddress, + conn_->transportSettings, + conn_->socketTos.value, + this, + this, + socketOptions_); + + if (socketResult.hasError()) { + asyncClose(socketResult.error()); + return; + } + + auto adjustResult = adjustGROBuffers(); + if (adjustResult.hasError()) { + asyncClose(adjustResult.error()); + return; + } + + auto handshakeResult = startCryptoHandshake(); + if (handshakeResult.hasError()) { + asyncClose(handshakeResult.error()); + return; } } @@ -1855,13 +1948,26 @@ void QuicClientTransportLite::setSelfOwning() { selfOwning_ = shared_from_this(); } -void QuicClientTransportLite::adjustGROBuffers() { +folly::Expected +QuicClientTransportLite::adjustGROBuffers() { if (socket_ && conn_) { if (conn_->transportSettings.numGROBuffers_ > kDefaultNumGROBuffers) { - socket_->setGRO(true); - auto ret = socket_->getGRO(); + auto setResult = socket_->setGRO(true); + if (setResult.hasError()) { + // Not a fatal error, just log and continue with default buffers + LOG(WARNING) << "Failed to enable GRO: " << setResult.error().message; + return folly::unit; + } - if (ret > 0) { + auto groResult = socket_->getGRO(); + if (groResult.hasError()) { + // Not a fatal error, just log and continue with default buffers + LOG(WARNING) << "Failed to get GRO status: " + << groResult.error().message; + return folly::unit; + } + + if (groResult.value() > 0) { numGROBuffers_ = (conn_->transportSettings.numGROBuffers_ < kMaxNumGROBuffers) ? conn_->transportSettings.numGROBuffers_ @@ -1869,6 +1975,7 @@ void QuicClientTransportLite::adjustGROBuffers() { } } } + return folly::unit; } void QuicClientTransportLite::closeTransport() { @@ -1888,6 +1995,28 @@ void QuicClientTransportLite::setSupportedVersions( conn_->readCodec->setCodecParameters(params); } +void QuicClientTransportLite::runOnEvbAsync( + folly::Function)> func) { + auto evb = getEventBase(); + evb->runInLoop( + [self = sharedGuardClient(), func = std::move(func), evb]() mutable { + if (self->getEventBase() != evb) { + // The eventbase changed between scheduling the loop and invoking + // the callback, ignore this + return; + } + func(std::move(self)); + }, + true); +} + +void QuicClientTransportLite::asyncClose(QuicError error) { + runOnEvbAsync([error = std::move(error)](auto self) { + auto clientPtr = static_cast(self.get()); + clientPtr->closeImpl(std::move(error), false, false); + }); +} + void QuicClientTransportLite::onNetworkSwitch( std::unique_ptr newSock) { if (!conn_->oneRttWriteCipher) { @@ -1896,14 +2025,24 @@ void QuicClientTransportLite::onNetworkSwitch( if (socket_ && newSock) { auto sock = std::move(socket_); socket_ = nullptr; - sock->setErrMessageCallback(nullptr); + if (auto err = sock->setErrMessageCallback(nullptr); err.hasError()) { + asyncClose(err.error()); + return; + } sock->pauseRead(); - sock->close(); + if (auto err = sock->close(); err.hasError()) { + asyncClose(err.error()); + return; + } socket_ = std::move(newSock); - socket_->setAdditionalCmsgsFunc( - [&]() { return getAdditionalCmsgsForAsyncUDPSocket(); }); - happyEyeballsSetUpSocket( + if (auto err = socket_->setAdditionalCmsgsFunc( + [&]() { return getAdditionalCmsgsForAsyncUDPSocket(); }); + err.hasError()) { + asyncClose(err.error()); + return; + } + auto setupResult = happyEyeballsSetUpSocket( *socket_, conn_->localAddress, conn_->peerAddress, @@ -1912,12 +2051,18 @@ void QuicClientTransportLite::onNetworkSwitch( this, this, socketOptions_); + if (setupResult.hasError()) { + asyncClose(setupResult.error()); + return; + } if (conn_->qLogger) { conn_->qLogger->addConnectionMigrationUpdate(true); } - // adjust the GRO buffers - adjustGROBuffers(); + auto adjustResult = adjustGROBuffers(); + if (adjustResult.hasError()) { + asyncClose(adjustResult.error()); + } } } @@ -1946,7 +2091,8 @@ void QuicClientTransportLite::trackDatagramsReceived( QUIC_STATS(statsCallback_, onRead, totalPacketLen); } -void QuicClientTransportLite::maybeSendTransportKnobs() { +folly::Expected +QuicClientTransportLite::maybeSendTransportKnobs() { if (!transportKnobsSent_ && hasWriteCipher()) { for (const auto& knob : conn_->transportSettings.knobs) { auto res = @@ -1954,6 +2100,9 @@ void QuicClientTransportLite::maybeSendTransportKnobs() { if (res.hasError()) { if (res.error() != LocalErrorCode::KNOB_FRAME_UNSUPPORTED) { LOG(ERROR) << "Unexpected error while sending knob frames"; + return folly::makeUnexpected(QuicError( + QuicErrorCode(res.error()), + "Unexpected error while sending knob frames")); } // No point in keep trying if transport does not support knob frame break; @@ -1961,6 +2110,7 @@ void QuicClientTransportLite::maybeSendTransportKnobs() { } transportKnobsSent_ = true; } + return folly::unit; } Optional> diff --git a/quic/client/QuicClientTransportLite.h b/quic/client/QuicClientTransportLite.h index c0bdf3137..62e22469b 100644 --- a/quic/client/QuicClientTransportLite.h +++ b/quic/client/QuicClientTransportLite.h @@ -165,6 +165,7 @@ class QuicClientTransportLite void unbindConnection() override; bool hasWriteCipher() const override; std::shared_ptr sharedGuard() override; + std::shared_ptr sharedGuardClient(); // QuicAsyncUDPSocket::ReadCallback void onReadClosed() noexcept override {} @@ -283,8 +284,7 @@ class QuicClientTransportLite OnDataAvailableParams params) noexcept override; bool shouldOnlyNotify() override; void onNotifyDataAvailable(QuicAsyncUDPSocket& sock) noexcept override; - - void recvFrom( + [[nodiscard]] folly::Expected recvFrom( QuicAsyncUDPSocket& sock, uint64_t readBufferSize, int numPackets, @@ -292,14 +292,15 @@ class QuicClientTransportLite Optional& server, size_t& totalData); - void recvMsg( + [[nodiscard]] folly::Expected recvMsg( QuicAsyncUDPSocket& sock, uint64_t readBufferSize, int numPackets, NetworkData& networkData, Optional& server, size_t& totalData); - void recvMmsg( + + [[nodiscard]] folly::Expected recvMmsg( QuicAsyncUDPSocket& sock, uint64_t readBufferSize, uint16_t numPackets, @@ -353,11 +354,12 @@ class QuicClientTransportLite const QuicWriteFrame& packetFrame, const ReadAckFrame&); - virtual void processPackets( + [[nodiscard]] virtual folly::Expected processPackets( NetworkData&& networkData, const Optional& server); - void readWithRecvmsgSinglePacketLoop( + [[nodiscard]] folly::Expected + readWithRecvmsgSinglePacketLoop( QuicAsyncUDPSocket& sock, uint64_t readBufferSize); @@ -371,6 +373,8 @@ class QuicClientTransportLite void trackDatagramsReceived(uint32_t totalPackets, uint32_t totalPacketLen); + void asyncClose(QuicError error); + // Same value as conn_->transportSettings.numGROBuffers_ if the kernel // supports GRO. otherwise kDefaultNumGROBuffers uint32_t numGROBuffers_{kDefaultNumGROBuffers}; @@ -394,13 +398,16 @@ class QuicClientTransportLite RecvmmsgStorage recvmmsgStorage_; private: - void adjustGROBuffers(); + [[nodiscard]] folly::Expected adjustGROBuffers(); + + void runOnEvbAsync( + folly::Function)> func); /** * Send quic transport knobs defined by transportSettings.knobs to peer. This * calls setKnobs() internally. */ - void maybeSendTransportKnobs(); + folly::Expected maybeSendTransportKnobs(); bool replaySafeNotified_{false}; // Set it QuicClientTransportLite is in a self owning mode. This will be diff --git a/quic/client/test/QuicClientTransportTest.cpp b/quic/client/test/QuicClientTransportTest.cpp index 13c462223..3d67f0e59 100644 --- a/quic/client/test/QuicClientTransportTest.cpp +++ b/quic/client/test/QuicClientTransportTest.cpp @@ -34,20 +34,25 @@ class QuicClientTransportMock : public QuicClientTransport { QuicAsyncUDPSocket& sock, uint64_t readBufferSize, uint16_t numPackets) { - QuicClientTransport::readWithRecvmsg(sock, readBufferSize, numPackets); + CHECK( + !QuicClientTransport::readWithRecvmsg(sock, readBufferSize, numPackets) + .hasError()); } void readWithRecvmsgSinglePacketLoop( QuicAsyncUDPSocket& sock, uint64_t readBufferSize) { - QuicClientTransport::readWithRecvmsgSinglePacketLoop(sock, readBufferSize); + CHECK(!QuicClientTransport::readWithRecvmsgSinglePacketLoop( + sock, readBufferSize) + .hasError()); } - void processPackets( + folly::Expected processPackets( NetworkData&& networkData, const Optional& server) override { networkDataVec_.push_back(std::move(networkData)); server_ = server; + return folly::unit; } QuicClientConnectionState* getClientConn() { @@ -64,6 +69,44 @@ class QuicClientTransportTest : public Test { evb_ = std::make_shared(); auto sock = std::make_unique(); sockPtr_ = sock.get(); + ON_CALL(*sock, setReuseAddr(_)).WillByDefault(testing::Return(folly::unit)); + ON_CALL(*sock, setAdditionalCmsgsFunc(_)) + .WillByDefault(testing::Return(folly::unit)); + ON_CALL(*sock, setDFAndTurnOffPMTU()) + .WillByDefault(testing::Return(folly::unit)); + ON_CALL(*sock, setErrMessageCallback(testing::_)) + .WillByDefault(testing::Return(folly::unit)); + ON_CALL(*sock, setTosOrTrafficClass(_)) + .WillByDefault(testing::Return(folly::unit)); + ON_CALL(*sock, init(testing::_)) + .WillByDefault(testing::Return(folly::unit)); + ON_CALL(*sock, applyOptions(testing::_, testing::_)) + .WillByDefault(testing::Return(folly::unit)); + ON_CALL(*sock, bind(testing::_)) + .WillByDefault(testing::Return(folly::unit)); + ON_CALL(*sock, connect(testing::_)) + .WillByDefault(testing::Return(folly::unit)); + ON_CALL(*sock, close()).WillByDefault(testing::Return(folly::unit)); + ON_CALL(*sock, setGRO(testing::_)) + .WillByDefault(testing::Return(folly::unit)); + ON_CALL(*sock, getGRO()).WillByDefault(testing::Return(0)); + ON_CALL(*sock, getGSO()).WillByDefault(testing::Return(0)); + ON_CALL(*sock, setRecvTos(testing::_)) + .WillByDefault(testing::Return(folly::unit)); + ON_CALL(*sock, getRecvTos()).WillByDefault(testing::Return(false)); + ON_CALL(*sock, setCmsgs(testing::_)) + .WillByDefault(testing::Return(folly::unit)); + ON_CALL(*sock, appendCmsgs(testing::_)) + .WillByDefault(testing::Return(folly::unit)); + ON_CALL(*sock, getTimestamping()).WillByDefault(testing::Return(0)); + ON_CALL(*sock, setReusePort(testing::_)) + .WillByDefault(testing::Return(folly::unit)); + ON_CALL(*sock, setRcvBuf(testing::_)) + .WillByDefault(testing::Return(folly::unit)); + ON_CALL(*sock, setSndBuf(testing::_)) + .WillByDefault(testing::Return(folly::unit)); + ON_CALL(*sock, setFD(testing::_, testing::_)) + .WillByDefault(testing::Return(folly::unit)); mockFactory_ = std::make_shared(); EXPECT_CALL(*mockFactory_, _makeClientHandshake(_)) diff --git a/quic/common/BUCK b/quic/common/BUCK index a4b8fa98e..cc4f9b375 100644 --- a/quic/common/BUCK +++ b/quic/common/BUCK @@ -97,8 +97,10 @@ mvfst_cpp_library( "SocketUtil.h", ], exported_deps = [ + "//folly:expected", "//folly/io:socket_option_map", "//folly/net:net_ops", + "//quic:exception", ], ) diff --git a/quic/common/SocketUtil.h b/quic/common/SocketUtil.h index dac7ee965..a65a9dac6 100644 --- a/quic/common/SocketUtil.h +++ b/quic/common/SocketUtil.h @@ -7,8 +7,10 @@ #pragma once +#include #include #include +#include namespace quic { @@ -21,7 +23,7 @@ inline bool isNetworkUnreachable(int err) { // const folly::SocketOptionMap& /* options */, // folly::SocketOptionKey::ApplyPos /* pos */) template -void applySocketOptions( +folly::Expected applySocketOptions( T& sock, const folly::SocketOptionMap& options, sa_family_t family, @@ -45,7 +47,7 @@ void applySocketOptions( validOptions.insert(option); } } - sock.applyOptions(validOptions, pos); + return sock.applyOptions(validOptions, pos); } } // namespace quic diff --git a/quic/common/test/SocketUtilTest.cpp b/quic/common/test/SocketUtilTest.cpp index 5e7343c60..7c4d0346f 100644 --- a/quic/common/test/SocketUtilTest.cpp +++ b/quic/common/test/SocketUtilTest.cpp @@ -18,7 +18,9 @@ class MockQuicAsyncUDPSocket : public quic::FollyQuicAsyncUDPSocket { MOCK_METHOD2( applyOptions, - void(const folly::SocketOptionMap&, folly::SocketOptionKey::ApplyPos)); + folly::Expected( + const folly::SocketOptionMap&, + folly::SocketOptionKey::ApplyPos)); }; TEST(SocketUtilTest, applySocketOptions) { @@ -75,12 +77,11 @@ TEST(SocketUtilTest, applySocketOptions) { {{IPPROTO_UDP, TCP_MAXSEG, folly::SocketOptionKey::ApplyPos::POST_BIND}, 576}, }; - EXPECT_CALL( sock, applyOptions( expected_v4_prebind_opts, folly::SocketOptionKey::ApplyPos::PRE_BIND)) - .Times(1); + .WillOnce(testing::Return(folly::unit)); applySocketOptions( sock, opts, AF_INET, folly::SocketOptionKey::ApplyPos::PRE_BIND); EXPECT_CALL( @@ -88,14 +89,14 @@ TEST(SocketUtilTest, applySocketOptions) { applyOptions( expected_v4_postbind_opts, folly::SocketOptionKey::ApplyPos::POST_BIND)) - .Times(1); + .WillOnce(testing::Return(folly::unit)); applySocketOptions( sock, opts, AF_INET, folly::SocketOptionKey::ApplyPos::POST_BIND); EXPECT_CALL( sock, applyOptions( expected_v6_prebind_opts, folly::SocketOptionKey::ApplyPos::PRE_BIND)) - .Times(1); + .WillOnce(testing::Return(folly::unit)); applySocketOptions( sock, opts, AF_INET6, folly::SocketOptionKey::ApplyPos::PRE_BIND); EXPECT_CALL( @@ -103,7 +104,7 @@ TEST(SocketUtilTest, applySocketOptions) { applyOptions( expected_v6_postbind_opts, folly::SocketOptionKey::ApplyPos::POST_BIND)) - .Times(1); + .WillOnce(testing::Return(folly::unit)); applySocketOptions( sock, opts, AF_INET6, folly::SocketOptionKey::ApplyPos::POST_BIND); } diff --git a/quic/common/testutil/MockAsyncUDPSocket.h b/quic/common/testutil/MockAsyncUDPSocket.h index 9de8dca77..9d213e531 100644 --- a/quic/common/testutil/MockAsyncUDPSocket.h +++ b/quic/common/testutil/MockAsyncUDPSocket.h @@ -20,10 +20,21 @@ struct MockAsyncUDPSocket : public FollyQuicAsyncUDPSocket { ~MockAsyncUDPSocket() override {} - MOCK_METHOD(void, init, (sa_family_t)); - MOCK_METHOD(const folly::SocketAddress&, address, (), (const)); - MOCK_METHOD(void, bind, (const folly::SocketAddress&)); - MOCK_METHOD(void, setFD, (int, QuicAsyncUDPSocket::FDOwnership)); + MOCK_METHOD((folly::Expected), init, (sa_family_t)); + MOCK_METHOD( + (folly::Expected), + address, + (), + (const)); + MOCK_METHOD((const folly::SocketAddress&), addressRef, (), (const)); + MOCK_METHOD( + (folly::Expected), + bind, + (const folly::SocketAddress&)); + MOCK_METHOD( + (folly::Expected), + setFD, + (int, QuicAsyncUDPSocket::FDOwnership)); MOCK_METHOD( ssize_t, write, @@ -59,63 +70,78 @@ struct MockAsyncUDPSocket : public FollyQuicAsyncUDPSocket { size_t count, const WriteOptions* options)); MOCK_METHOD( - RecvResult, + (folly::Expected), recvmmsgNetworkData, (uint64_t readBufferSize, uint16_t numPackets, NetworkData& networkData, Optional& peerAddress, size_t& totalData)); - MOCK_METHOD(int, getGRO, ()); - MOCK_METHOD(bool, setGRO, (bool)); + MOCK_METHOD((folly::Expected), getGRO, ()); + MOCK_METHOD((folly::Expected), setGRO, (bool)); MOCK_METHOD( - void, + (folly::Expected), setAdditionalCmsgsFunc, (folly::Function()>&&)); - MOCK_METHOD(void, setRcvBuf, (int)); - MOCK_METHOD(void, setSndBuf, (int)); - MOCK_METHOD(int, getTimestamping, ()); + MOCK_METHOD((folly::Expected), setRcvBuf, (int)); + MOCK_METHOD((folly::Expected), setSndBuf, (int)); + MOCK_METHOD((folly::Expected), getTimestamping, ()); MOCK_METHOD(void, resumeRead, (QuicAsyncUDPSocket::ReadCallback*)); MOCK_METHOD(void, pauseRead, ()); - MOCK_METHOD(void, close, ()); - MOCK_METHOD(void, setDFAndTurnOffPMTU, ()); + MOCK_METHOD((folly::Expected), close, ()); + MOCK_METHOD( + (folly::Expected), + setDFAndTurnOffPMTU, + ()); MOCK_METHOD(int, getFD, ()); - MOCK_METHOD(void, setReusePort, (bool)); - MOCK_METHOD(void, setReuseAddr, (bool)); + MOCK_METHOD((folly::Expected), setReusePort, (bool)); + MOCK_METHOD((folly::Expected), setReuseAddr, (bool)); MOCK_METHOD(void, dontFragment, (bool)); MOCK_METHOD( - void, + (folly::Expected), setErrMessageCallback, (QuicAsyncUDPSocket::ErrMessageCallback*)); - MOCK_METHOD(void, connect, (const folly::SocketAddress&)); + MOCK_METHOD( + (folly::Expected), + connect, + (const folly::SocketAddress&)); MOCK_METHOD(bool, isBound, (), (const)); - MOCK_METHOD(int, getGSO, ()); - MOCK_METHOD(bool, setGSO, (int)); + MOCK_METHOD((folly::Expected), getGSO, ()); + MOCK_METHOD((folly::Expected), setGSO, (int)); MOCK_METHOD(ssize_t, recvmsg, (struct msghdr*, int)); MOCK_METHOD( int, recvmmsg, (struct mmsghdr*, unsigned int, unsigned int, struct timespec*)); - MOCK_METHOD(void, setCmsgs, (const folly::SocketCmsgMap&)); MOCK_METHOD( - void, + (folly::Expected), + setCmsgs, + (const folly::SocketCmsgMap&)); + MOCK_METHOD( + (folly::Expected), setNontrivialCmsgs, (const folly::SocketNontrivialCmsgMap&)); - MOCK_METHOD(void, appendCmsgs, (const folly::SocketCmsgMap&)); MOCK_METHOD( - void, + (folly::Expected), + appendCmsgs, + (const folly::SocketCmsgMap&)); + MOCK_METHOD( + (folly::Expected), appendNontrivialCmsgs, (const folly::SocketNontrivialCmsgMap&)); MOCK_METHOD( - void, + (folly::Expected), applyOptions, (const folly::SocketOptionMap&, folly::SocketOptionKey::ApplyPos)); - MOCK_METHOD((void), setRecvTos, (bool)); - MOCK_METHOD((bool), getRecvTos, ()); - MOCK_METHOD((void), setTosOrTrafficClass, (uint8_t)); + MOCK_METHOD((folly::Expected), setRecvTos, (bool)); + MOCK_METHOD((folly::Expected), getRecvTos, ()); + MOCK_METHOD( + (folly::Expected), + setTosOrTrafficClass, + (uint8_t)); MOCK_METHOD((bool), isWritableCallbackSet, (), (const)); MOCK_METHOD( - (folly::Expected), + (folly::Expected), resumeWrite, (WriteCallback*)); MOCK_METHOD((void), pauseWrite, ()); diff --git a/quic/common/udpsocket/BUCK b/quic/common/udpsocket/BUCK index 8df77c365..9cef2485f 100644 --- a/quic/common/udpsocket/BUCK +++ b/quic/common/udpsocket/BUCK @@ -12,12 +12,15 @@ mvfst_cpp_library( "QuicAsyncUDPSocket.h", ], exported_deps = [ + "//folly:expected", "//folly:network_address", "//folly:range", "//folly/io:iobuf", "//folly/io:socket_option_map", "//folly/io/async:async_socket_exception", + "//folly/lang:exception", "//folly/portability:sockets", + "//quic:exception", "//quic/common:network_data", "//quic/common:optional", "//quic/common/events:eventbase", @@ -32,6 +35,12 @@ mvfst_cpp_library( headers = [ "QuicAsyncUDPSocketImpl.h", ], + deps = [ + "//folly:likely", + "//folly:string", + "//folly/lang:exception", + "//quic:exception", + ], exported_deps = [ ":quic_async_udp_socket", ], @@ -45,8 +54,16 @@ mvfst_cpp_library( headers = [ "FollyQuicAsyncUDPSocket.h", ], + deps = [ + "//folly:string", + "//folly:unit", + "//folly/io/async:async_socket_exception", + "//folly/lang:exception", + "//quic:exception", + ], exported_deps = [ ":quic_async_udp_socket_impl", + "//folly:expected", "//folly/io/async:async_udp_socket", "//folly/net:network_socket", "//quic/common:network_data", @@ -64,10 +81,14 @@ mvfst_cpp_library( ], labels = ci.labels(ci.remove(ci.windows())), deps = [ + "//folly:string", + "//folly/lang:exception", "//quic/common:optional", ], exported_deps = [ ":quic_async_udp_socket_impl", + "//folly:expected", + "//quic:exception", "//quic/common:network_data", "//quic/common/events:libev_eventbase", ], diff --git a/quic/common/udpsocket/FollyQuicAsyncUDPSocket.cpp b/quic/common/udpsocket/FollyQuicAsyncUDPSocket.cpp index 4154207b1..f3efe5691 100644 --- a/quic/common/udpsocket/FollyQuicAsyncUDPSocket.cpp +++ b/quic/common/udpsocket/FollyQuicAsyncUDPSocket.cpp @@ -5,37 +5,99 @@ * LICENSE file in the root directory of this source tree. */ +#include // For errno +#include +#include +#include +#include +#include // For folly::errnoStr +#include // For QuicError, QuicErrorCode, TransportErrorCode #include #include namespace quic { - -void FollyQuicAsyncUDPSocket::init(sa_family_t family) { - follySocket_.init(family); +folly::Expected FollyQuicAsyncUDPSocket::init( + sa_family_t family) { + try { + follySocket_.init(family); + return folly::unit; + } catch (const folly::AsyncSocketException& ex) { + std::string errorMsg = "Folly init failed: " + std::string(ex.what()); + if (ex.getErrno() != 0) { + errorMsg += ": " + folly::errnoStr(ex.getErrno()); + } + return folly::makeUnexpected(QuicError( + QuicErrorCode(TransportErrorCode::INTERNAL_ERROR), + std::move(errorMsg))); + } } -void FollyQuicAsyncUDPSocket::bind(const folly::SocketAddress& address) { - follySocket_.bind(address); +folly::Expected FollyQuicAsyncUDPSocket::bind( + const folly::SocketAddress& address) { + try { + follySocket_.bind(address); + return folly::unit; + } catch (const folly::AsyncSocketException& ex) { + std::string errorMsg = "Folly bind failed: " + std::string(ex.what()); + if (ex.getErrno() != 0) { + errorMsg += ": " + folly::errnoStr(ex.getErrno()); + } + return folly::makeUnexpected(QuicError( + QuicErrorCode(TransportErrorCode::INTERNAL_ERROR), + std::move(errorMsg))); + } } [[nodiscard]] bool FollyQuicAsyncUDPSocket::isBound() const { return follySocket_.isBound(); } -void FollyQuicAsyncUDPSocket::connect(const folly::SocketAddress& address) { - follySocket_.connect(address); +folly::Expected FollyQuicAsyncUDPSocket::connect( + const folly::SocketAddress& address) { + try { + follySocket_.connect(address); + return folly::unit; + } catch (const folly::AsyncSocketException& ex) { + std::string errorMsg = "Folly connect failed: " + std::string(ex.what()); + if (ex.getErrno() != 0) { + errorMsg += ": " + folly::errnoStr(ex.getErrno()); + } + return folly::makeUnexpected(QuicError( + QuicErrorCode(TransportErrorCode::INTERNAL_ERROR), + std::move(errorMsg))); + } } -void FollyQuicAsyncUDPSocket::close() { - follySocket_.close(); +folly::Expected FollyQuicAsyncUDPSocket::close() { + try { + follySocket_.close(); + readCallbackWrapper_.reset(); // Ensure wrapper is cleared on close + errCallbackWrapper_.reset(); + return folly::unit; + } catch (const folly::AsyncSocketException& ex) { + std::string errorMsg = "Folly close failed: " + std::string(ex.what()); + if (ex.getErrno() != 0) { + errorMsg += ": " + folly::errnoStr(ex.getErrno()); + } + return folly::makeUnexpected(QuicError( + QuicErrorCode(TransportErrorCode::INTERNAL_ERROR), + std::move(errorMsg))); + } } void FollyQuicAsyncUDPSocket::resumeRead(ReadCallback* callback) { - // TODO: We could skip this check and rely on the one in AsyncUDPSocket - CHECK(!readCallbackWrapper_) << "Already registered a read callback"; - readCallbackWrapper_ = - std::make_unique(callback, this); - follySocket_.resumeRead(readCallbackWrapper_.get()); + try { + // TODO: We could skip this check and rely on the one in AsyncUDPSocket + CHECK(!readCallbackWrapper_) << "Already registered a read callback"; + readCallbackWrapper_ = + std::make_unique(callback, this); + follySocket_.resumeRead(readCallbackWrapper_.get()); + // TODO: This should return Expected + } catch (const folly::AsyncSocketException& ex) { + // TODO: Convert to QuicError and return folly::makeUnexpected + LOG(ERROR) << "FollyQuicAsyncUDPSocket::resumeRead failed: " << ex.what(); + throw; // Re-throw for now until signature is updated + } } void FollyQuicAsyncUDPSocket::pauseRead() { @@ -43,16 +105,28 @@ void FollyQuicAsyncUDPSocket::pauseRead() { readCallbackWrapper_.reset(); } -void FollyQuicAsyncUDPSocket::setErrMessageCallback( - ErrMessageCallback* callback) { - if (errCallbackWrapper_) { - errCallbackWrapper_.reset(); - } - if (callback) { - errCallbackWrapper_ = std::make_unique(callback); - follySocket_.setErrMessageCallback(errCallbackWrapper_.get()); - } else { - follySocket_.setErrMessageCallback(nullptr); +folly::Expected +FollyQuicAsyncUDPSocket::setErrMessageCallback(ErrMessageCallback* callback) { + try { + if (errCallbackWrapper_) { + errCallbackWrapper_.reset(); + } + if (callback) { + errCallbackWrapper_ = std::make_unique(callback); + follySocket_.setErrMessageCallback(errCallbackWrapper_.get()); + } else { + follySocket_.setErrMessageCallback(nullptr); + } + return folly::unit; + } catch (const folly::AsyncSocketException& ex) { + std::string errorMsg = + "Folly setErrMessageCallback failed: " + std::string(ex.what()); + if (ex.getErrno() != 0) { + errorMsg += ": " + folly::errnoStr(ex.getErrno()); + } + return folly::makeUnexpected(QuicError( + QuicErrorCode(TransportErrorCode::INTERNAL_ERROR), + std::move(errorMsg))); } } @@ -60,9 +134,16 @@ ssize_t FollyQuicAsyncUDPSocket::write( const folly::SocketAddress& address, const struct iovec* vec, size_t iovec_len) { - folly::AsyncUDPSocket::WriteOptions writeOptions( - 0 /*gsoVal*/, false /* zerocopyVal*/); - return follySocket_.writev(address, vec, iovec_len, writeOptions); + try { + folly::AsyncUDPSocket::WriteOptions writeOptions( + 0 /*gsoVal*/, false /* zerocopyVal*/); + return follySocket_.writev(address, vec, iovec_len, writeOptions); + } catch (const folly::AsyncSocketException& ex) { + // Log the error, set errno, return -1 for syscall-like behavior + errno = ex.getErrno(); + LOG(ERROR) << "FollyQuicAsyncUDPSocket::write failed: " << ex.what(); + return -1; + } } int FollyQuicAsyncUDPSocket::writem( @@ -70,7 +151,14 @@ int FollyQuicAsyncUDPSocket::writem( iovec* iov, size_t* numIovecsInBuffer, size_t count) { - return follySocket_.writemv(addrs, iov, numIovecsInBuffer, count); + try { + return follySocket_.writemv(addrs, iov, numIovecsInBuffer, count); + } catch (const folly::AsyncSocketException& ex) { + // Log the error, set errno, return -1 for syscall-like behavior + errno = ex.getErrno(); + LOG(ERROR) << "FollyQuicAsyncUDPSocket::writem failed: " << ex.what(); + return -1; + } } ssize_t FollyQuicAsyncUDPSocket::writeGSO( @@ -78,10 +166,17 @@ ssize_t FollyQuicAsyncUDPSocket::writeGSO( const struct iovec* vec, size_t iovec_len, WriteOptions options) { - folly::AsyncUDPSocket::WriteOptions follyOptions( - options.gso, options.zerocopy); - follyOptions.txTime = options.txTime; - return follySocket_.writev(address, vec, iovec_len, follyOptions); + try { + folly::AsyncUDPSocket::WriteOptions follyOptions( + options.gso, options.zerocopy); + follyOptions.txTime = options.txTime; + return follySocket_.writev(address, vec, iovec_len, follyOptions); + } catch (const folly::AsyncSocketException& ex) { + // Log the error, set errno, return -1 for syscall-like behavior + errno = ex.getErrno(); + LOG(ERROR) << "FollyQuicAsyncUDPSocket::writeGSO failed: " << ex.what(); + return -1; + } } int FollyQuicAsyncUDPSocket::writemGSO( @@ -89,13 +184,21 @@ int FollyQuicAsyncUDPSocket::writemGSO( const Buf* bufs, size_t count, const WriteOptions* options) { - std::vector follyOptions(count); - for (size_t i = 0; i < count; ++i) { - follyOptions[i].gso = options[i].gso; - follyOptions[i].zerocopy = options[i].zerocopy; - follyOptions[i].txTime = options[i].txTime; + try { + std::vector follyOptions(count); + for (size_t i = 0; i < count; ++i) { + follyOptions[i].gso = options[i].gso; + follyOptions[i].zerocopy = options[i].zerocopy; + follyOptions[i].txTime = options[i].txTime; + } + return follySocket_.writemGSO(addrs, bufs, count, follyOptions.data()); + } catch (const folly::AsyncSocketException& ex) { + // Log the error, set errno, return -1 for syscall-like behavior + errno = ex.getErrno(); + LOG(ERROR) << "FollyQuicAsyncUDPSocket::writemGSO(IOBuf) failed: " + << ex.what(); + return -1; } - return follySocket_.writemGSO(addrs, bufs, count, follyOptions.data()); } int FollyQuicAsyncUDPSocket::writemGSO( @@ -104,18 +207,33 @@ int FollyQuicAsyncUDPSocket::writemGSO( size_t* numIovecsInBuffer, size_t count, const WriteOptions* options) { - std::vector follyOptions(count); - for (size_t i = 0; i < count; ++i) { - follyOptions[i].gso = options[i].gso; - follyOptions[i].zerocopy = options[i].zerocopy; - follyOptions[i].txTime = options[i].txTime; + try { + std::vector follyOptions(count); + for (size_t i = 0; i < count; ++i) { + follyOptions[i].gso = options[i].gso; + follyOptions[i].zerocopy = options[i].zerocopy; + follyOptions[i].txTime = options[i].txTime; + } + return follySocket_.writemGSOv( + addrs, iov, numIovecsInBuffer, count, follyOptions.data()); + } catch (const folly::AsyncSocketException& ex) { + // Log the error, set errno, return -1 for syscall-like behavior + errno = ex.getErrno(); + LOG(ERROR) << "FollyQuicAsyncUDPSocket::writemGSO(iovec) failed: " + << ex.what(); + return -1; } - return follySocket_.writemGSOv( - addrs, iov, numIovecsInBuffer, count, follyOptions.data()); } ssize_t FollyQuicAsyncUDPSocket::recvmsg(struct msghdr* msg, int flags) { - return follySocket_.recvmsg(msg, flags); + try { + return follySocket_.recvmsg(msg, flags); + } catch (const folly::AsyncSocketException& ex) { + // Log the error, set errno, return -1 for syscall-like behavior + errno = ex.getErrno(); + LOG(ERROR) << "FollyQuicAsyncUDPSocket::recvmsg failed: " << ex.what(); + return -1; + } } int FollyQuicAsyncUDPSocket::recvmmsg( @@ -123,39 +241,151 @@ int FollyQuicAsyncUDPSocket::recvmmsg( unsigned int vlen, unsigned int flags, struct timespec* timeout) { - return follySocket_.recvmmsg(msgvec, vlen, flags, timeout); + try { + return follySocket_.recvmmsg(msgvec, vlen, flags, timeout); + } catch (const folly::AsyncSocketException& ex) { + // Log the error, set errno, return -1 for syscall-like behavior + errno = ex.getErrno(); + LOG(ERROR) << "FollyQuicAsyncUDPSocket::recvmmsg failed: " << ex.what(); + return -1; + } } -int FollyQuicAsyncUDPSocket::getGSO() { - return follySocket_.getGSO(); +folly::Expected FollyQuicAsyncUDPSocket::getGSO() { + try { + return follySocket_.getGSO(); + } catch (const folly::AsyncSocketException& ex) { + std::string errorMsg = "Folly getGSO failed: " + std::string(ex.what()); + if (ex.getErrno() != 0) { + errorMsg += ": " + folly::errnoStr(ex.getErrno()); + } + LOG(ERROR) << "getGSO failed: " << errorMsg; + return folly::makeUnexpected(QuicError( + QuicErrorCode(TransportErrorCode::INTERNAL_ERROR), + std::move(errorMsg))); + } } -int FollyQuicAsyncUDPSocket::getGRO() { - return follySocket_.getGRO(); +folly::Expected FollyQuicAsyncUDPSocket::getGRO() { + try { + return follySocket_.getGRO(); + } catch (const folly::AsyncSocketException& ex) { + std::string errorMsg = "Folly getGRO failed: " + std::string(ex.what()); + if (ex.getErrno() != 0) { + errorMsg += ": " + folly::errnoStr(ex.getErrno()); + } + return folly::makeUnexpected(QuicError( + QuicErrorCode(TransportErrorCode::INTERNAL_ERROR), + std::move(errorMsg))); + } } -bool FollyQuicAsyncUDPSocket::setGRO(bool bVal) { - return follySocket_.setGRO(bVal); +folly::Expected FollyQuicAsyncUDPSocket::setGRO( + bool bVal) { + try { + if (follySocket_.setGRO(bVal)) { + return folly::unit; + } else { + // Folly's setGRO returns bool, not throwing. Assume failure means error. + int errnoCopy = errno; // Capture errno immediately after failure + std::string errorMsg = "Folly setGRO failed"; + if (errnoCopy != 0) { + errorMsg += ": " + folly::errnoStr(errnoCopy); + } + return folly::makeUnexpected(QuicError( + QuicErrorCode(TransportErrorCode::INTERNAL_ERROR), + std::move(errorMsg))); + } + } catch (const folly::AsyncSocketException& ex) { + // Catch just in case future folly versions throw + std::string errorMsg = "Folly setGRO exception: " + std::string(ex.what()); + if (ex.getErrno() != 0) { + errorMsg += ": " + folly::errnoStr(ex.getErrno()); + } + return folly::makeUnexpected(QuicError( + QuicErrorCode(TransportErrorCode::INTERNAL_ERROR), + std::move(errorMsg))); + } } -void FollyQuicAsyncUDPSocket::setRecvTos(bool recvTos) { - follySocket_.setRecvTos(recvTos); +folly::Expected FollyQuicAsyncUDPSocket::setRecvTos( + bool recvTos) { + try { + follySocket_.setRecvTos(recvTos); + return folly::unit; + } catch (const folly::AsyncSocketException& ex) { + std::string errorMsg = "Folly setRecvTos failed: " + std::string(ex.what()); + if (ex.getErrno() != 0) { + errorMsg += ": " + folly::errnoStr(ex.getErrno()); + } + return folly::makeUnexpected(QuicError( + QuicErrorCode(TransportErrorCode::INTERNAL_ERROR), + std::move(errorMsg))); + } } -bool FollyQuicAsyncUDPSocket::getRecvTos() { - return follySocket_.getRecvTos(); +folly::Expected FollyQuicAsyncUDPSocket::getRecvTos() { + try { + return follySocket_.getRecvTos(); + } catch (const folly::AsyncSocketException& ex) { + std::string errorMsg = "Folly getRecvTos failed: " + std::string(ex.what()); + if (ex.getErrno() != 0) { + errorMsg += ": " + folly::errnoStr(ex.getErrno()); + } + return folly::makeUnexpected(QuicError( + QuicErrorCode(TransportErrorCode::INTERNAL_ERROR), + std::move(errorMsg))); + } } -void FollyQuicAsyncUDPSocket::setTosOrTrafficClass(uint8_t tos) { - follySocket_.setTosOrTrafficClass(tos); +folly::Expected +FollyQuicAsyncUDPSocket::setTosOrTrafficClass(uint8_t tos) { + try { + follySocket_.setTosOrTrafficClass(tos); + return folly::unit; + } catch (const folly::AsyncSocketException& ex) { + std::string errorMsg = + "Folly setTosOrTrafficClass failed: " + std::string(ex.what()); + if (ex.getErrno() != 0) { + errorMsg += ": " + folly::errnoStr(ex.getErrno()); + } + return folly::makeUnexpected(QuicError( + QuicErrorCode(TransportErrorCode::INTERNAL_ERROR), + std::move(errorMsg))); + } } -[[nodiscard]] const folly::SocketAddress& FollyQuicAsyncUDPSocket::address() +[[nodiscard]] folly::Expected +FollyQuicAsyncUDPSocket::address() const { + try { + return follySocket_.address(); + } catch (const folly::AsyncSocketException& ex) { + std::string errorMsg = "Folly address() failed: " + std::string(ex.what()); + if (ex.getErrno() != 0) { + errorMsg += ": " + folly::errnoStr(ex.getErrno()); + } + return folly::makeUnexpected(QuicError( + QuicErrorCode(TransportErrorCode::INTERNAL_ERROR), + std::move(errorMsg))); + } +} + +[[nodiscard]] const folly::SocketAddress& FollyQuicAsyncUDPSocket::addressRef() const { - return follySocket_.address(); + try { + return follySocket_.address(); + } catch (const folly::AsyncSocketException& ex) { + std::string errorMsg = "Folly address() failed: " + std::string(ex.what()); + if (ex.getErrno() != 0) { + errorMsg += ": " + folly::errnoStr(ex.getErrno()); + } + LOG(FATAL) << errorMsg; + } } -void FollyQuicAsyncUDPSocket::attachEventBase( +void FollyQuicAsyncUDPSocket::attachEventBase( // Keep void, attach/detach + // usually don't throw socket + // errors std::shared_ptr evb) { CHECK(evb != nullptr); std::shared_ptr follyEvb = @@ -165,7 +395,7 @@ void FollyQuicAsyncUDPSocket::attachEventBase( follySocket_.attachEventBase(follyEvb->getBackingEventBase()); } -void FollyQuicAsyncUDPSocket::detachEventBase() { +void FollyQuicAsyncUDPSocket::detachEventBase() { // Keep void follySocket_.detachEventBase(); } @@ -177,50 +407,176 @@ FollyQuicAsyncUDPSocket::getEventBase() const { return evb_; } -void FollyQuicAsyncUDPSocket::setCmsgs(const folly::SocketCmsgMap& cmsgs) { - follySocket_.setCmsgs(cmsgs); +folly::Expected FollyQuicAsyncUDPSocket::setCmsgs( + const folly::SocketCmsgMap& cmsgs) { + try { + follySocket_.setCmsgs(cmsgs); + return folly::unit; + } catch (const folly::AsyncSocketException& ex) { + std::string errorMsg = "Folly setCmsgs failed: " + std::string(ex.what()); + if (ex.getErrno() != 0) { + errorMsg += ": " + folly::errnoStr(ex.getErrno()); + } + return folly::makeUnexpected(QuicError( + QuicErrorCode(TransportErrorCode::INTERNAL_ERROR), + std::move(errorMsg))); + } } -void FollyQuicAsyncUDPSocket::appendCmsgs(const folly::SocketCmsgMap& cmsgs) { - follySocket_.appendCmsgs(cmsgs); +folly::Expected FollyQuicAsyncUDPSocket::appendCmsgs( + const folly::SocketCmsgMap& cmsgs) { + try { + follySocket_.appendCmsgs(cmsgs); + return folly::unit; + } catch (const folly::AsyncSocketException& ex) { + std::string errorMsg = + "Folly appendCmsgs failed: " + std::string(ex.what()); + if (ex.getErrno() != 0) { + errorMsg += ": " + folly::errnoStr(ex.getErrno()); + } + return folly::makeUnexpected(QuicError( + QuicErrorCode(TransportErrorCode::INTERNAL_ERROR), + std::move(errorMsg))); + } } -void FollyQuicAsyncUDPSocket::setAdditionalCmsgsFunc( +folly::Expected +FollyQuicAsyncUDPSocket::setAdditionalCmsgsFunc( folly::Function()>&& additionalCmsgsFunc) { - follySocket_.setAdditionalCmsgsFunc(std::move(additionalCmsgsFunc)); + try { + follySocket_.setAdditionalCmsgsFunc(std::move(additionalCmsgsFunc)); + return folly::unit; + } catch (const folly::AsyncSocketException& ex) { + std::string errorMsg = + "Folly setAdditionalCmsgsFunc failed: " + std::string(ex.what()); + if (ex.getErrno() != 0) { + errorMsg += ": " + folly::errnoStr(ex.getErrno()); + } + return folly::makeUnexpected(QuicError( + QuicErrorCode(TransportErrorCode::INTERNAL_ERROR), + std::move(errorMsg))); + } } -int FollyQuicAsyncUDPSocket::getTimestamping() { - return follySocket_.getTimestamping(); +folly::Expected FollyQuicAsyncUDPSocket::getTimestamping() { + try { + return follySocket_.getTimestamping(); + } catch (const folly::AsyncSocketException& ex) { + std::string errorMsg = + "Folly getTimestamping failed: " + std::string(ex.what()); + if (ex.getErrno() != 0) { + errorMsg += ": " + folly::errnoStr(ex.getErrno()); + } + return folly::makeUnexpected(QuicError( + QuicErrorCode(TransportErrorCode::INTERNAL_ERROR), + std::move(errorMsg))); + } } -void FollyQuicAsyncUDPSocket::setReuseAddr(bool reuseAddr) { - follySocket_.setReuseAddr(reuseAddr); +folly::Expected FollyQuicAsyncUDPSocket::setReuseAddr( + bool reuseAddr) { + try { + follySocket_.setReuseAddr(reuseAddr); + return folly::unit; + } catch (const folly::AsyncSocketException& ex) { + std::string errorMsg = + "Folly setReuseAddr failed: " + std::string(ex.what()); + if (ex.getErrno() != 0) { + errorMsg += ": " + folly::errnoStr(ex.getErrno()); + } + return folly::makeUnexpected(QuicError( + QuicErrorCode(TransportErrorCode::INTERNAL_ERROR), + std::move(errorMsg))); + } } -void FollyQuicAsyncUDPSocket::setDFAndTurnOffPMTU() { - follySocket_.setDFAndTurnOffPMTU(); +folly::Expected +FollyQuicAsyncUDPSocket::setDFAndTurnOffPMTU() { + try { + follySocket_.setDFAndTurnOffPMTU(); + return folly::unit; + } catch (const folly::AsyncSocketException& ex) { + std::string errorMsg = + "Folly setDFAndTurnOffPMTU failed: " + std::string(ex.what()); + if (ex.getErrno() != 0) { + errorMsg += ": " + folly::errnoStr(ex.getErrno()); + } + return folly::makeUnexpected(QuicError( + QuicErrorCode(TransportErrorCode::INTERNAL_ERROR), + std::move(errorMsg))); + } } -void FollyQuicAsyncUDPSocket::applyOptions( +folly::Expected FollyQuicAsyncUDPSocket::applyOptions( const folly::SocketOptionMap& options, folly::SocketOptionKey::ApplyPos pos) { - follySocket_.applyOptions(options, pos); + try { + follySocket_.applyOptions(options, pos); + return folly::unit; + } catch (const folly::AsyncSocketException& ex) { + std::string errorMsg = + "Folly applyOptions failed: " + std::string(ex.what()); + if (ex.getErrno() != 0) { + errorMsg += ": " + folly::errnoStr(ex.getErrno()); + } + return folly::makeUnexpected(QuicError( + QuicErrorCode(TransportErrorCode::INTERNAL_ERROR), + std::move(errorMsg))); + } } -void FollyQuicAsyncUDPSocket::setReusePort(bool reusePort) { - follySocket_.setReusePort(reusePort); +folly::Expected FollyQuicAsyncUDPSocket::setReusePort( + bool reusePort) { + try { + follySocket_.setReusePort(reusePort); + return folly::unit; + } catch (const folly::AsyncSocketException& ex) { + std::string errorMsg = + "Folly setReusePort failed: " + std::string(ex.what()); + if (ex.getErrno() != 0) { + errorMsg += ": " + folly::errnoStr(ex.getErrno()); + } + return folly::makeUnexpected(QuicError( + QuicErrorCode(TransportErrorCode::INTERNAL_ERROR), + std::move(errorMsg))); + } } -void FollyQuicAsyncUDPSocket::setRcvBuf(int rcvBuf) { - follySocket_.setRcvBuf(rcvBuf); +folly::Expected FollyQuicAsyncUDPSocket::setRcvBuf( + int rcvBuf) { + try { + follySocket_.setRcvBuf(rcvBuf); + return folly::unit; + } catch (const folly::AsyncSocketException& ex) { + std::string errorMsg = "Folly setRcvBuf failed: " + std::string(ex.what()); + if (ex.getErrno() != 0) { + errorMsg += ": " + folly::errnoStr(ex.getErrno()); + } + return folly::makeUnexpected(QuicError( + QuicErrorCode(TransportErrorCode::INTERNAL_ERROR), + std::move(errorMsg))); + } } -void FollyQuicAsyncUDPSocket::setSndBuf(int sndBuf) { - follySocket_.setSndBuf(sndBuf); +folly::Expected FollyQuicAsyncUDPSocket::setSndBuf( + int sndBuf) { + try { + follySocket_.setSndBuf(sndBuf); + return folly::unit; + } catch (const folly::AsyncSocketException& ex) { + std::string errorMsg = "Folly setSndBuf failed: " + std::string(ex.what()); + if (ex.getErrno() != 0) { + errorMsg += ": " + folly::errnoStr(ex.getErrno()); + } + return folly::makeUnexpected(QuicError( + QuicErrorCode(TransportErrorCode::INTERNAL_ERROR), + std::move(errorMsg))); + } } -void FollyQuicAsyncUDPSocket::setFD(int fd, FDOwnership ownership) { +folly::Expected FollyQuicAsyncUDPSocket::setFD( + int fd, + FDOwnership ownership) { folly::AsyncUDPSocket::FDOwnership follyOwnership; switch (ownership) { case FDOwnership::OWNS: @@ -230,7 +586,18 @@ void FollyQuicAsyncUDPSocket::setFD(int fd, FDOwnership ownership) { follyOwnership = folly::AsyncUDPSocket::FDOwnership::SHARED; break; } - follySocket_.setFD(folly::NetworkSocket::fromFd(fd), follyOwnership); + try { + follySocket_.setFD(folly::NetworkSocket::fromFd(fd), follyOwnership); + return folly::unit; + } catch (const folly::AsyncSocketException& ex) { + std::string errorMsg = "Folly setFD failed: " + std::string(ex.what()); + if (ex.getErrno() != 0) { + errorMsg += ": " + folly::errnoStr(ex.getErrno()); + } + return folly::makeUnexpected(QuicError( + QuicErrorCode(TransportErrorCode::INTERNAL_ERROR), + std::move(errorMsg))); + } } int FollyQuicAsyncUDPSocket::getFD() { diff --git a/quic/common/udpsocket/FollyQuicAsyncUDPSocket.h b/quic/common/udpsocket/FollyQuicAsyncUDPSocket.h index 8c57bc44a..dfa8e58c9 100644 --- a/quic/common/udpsocket/FollyQuicAsyncUDPSocket.h +++ b/quic/common/udpsocket/FollyQuicAsyncUDPSocket.h @@ -11,6 +11,7 @@ #include #include +#include #include #include #include @@ -54,17 +55,22 @@ class FollyQuicAsyncUDPSocket : public QuicAsyncUDPSocketImpl { } } - void init(sa_family_t family) override; + [[nodiscard]] folly::Expected init( + sa_family_t family) override; - void bind(const folly::SocketAddress& address) override; + [[nodiscard]] folly::Expected bind( + const folly::SocketAddress& address) override; + // TODO: bind should return Expected [[nodiscard]] bool isBound() const override; - void connect(const folly::SocketAddress& address) override; + folly::Expected connect( + const folly::SocketAddress& address) override; - void close() override; + folly::Expected close() override; void resumeRead(ReadCallback* callback) override; + // TODO: resumeRead should return Expected void pauseRead() override; @@ -117,25 +123,32 @@ class FollyQuicAsyncUDPSocket : public QuicAsyncUDPSocketImpl { // generic segmentation offload get/set // negative return value means GSO is not available - int getGSO() override; + folly::Expected getGSO() override; // generic receive offload get/set // negative return value means GRO is not available - int getGRO() override; - bool setGRO(bool bVal) override; + folly::Expected getGRO() override; + folly::Expected setGRO(bool bVal) override; // receive tos cmsgs // if true, the IPv6 Traffic Class/IPv4 Type of Service field should be // populated in OnDataAvailableParams. - void setRecvTos(bool recvTos) override; - bool getRecvTos() override; + folly::Expected setRecvTos(bool recvTos) override; + folly::Expected getRecvTos() override; - void setTosOrTrafficClass(uint8_t tos) override; + folly::Expected setTosOrTrafficClass( + uint8_t tos) override; /** - * Returns the socket server is bound to + * Returns the socket address this socket is bound to and error otherwise. */ - [[nodiscard]] const folly::SocketAddress& address() const override; + [[nodiscard]] folly::Expected address() + const override; + + /** + * Returns the socket address this socket is bound to and crashes otherwise. + */ + [[nodiscard]] virtual const folly::SocketAddress& addressRef() const override; /** * Manage the eventbase driving this socket @@ -147,21 +160,23 @@ class FollyQuicAsyncUDPSocket : public QuicAsyncUDPSocketImpl { /** * Set extra control messages to send */ - void setCmsgs(const folly::SocketCmsgMap& cmsgs) override; - void appendCmsgs(const folly::SocketCmsgMap& cmsgs) override; - void setAdditionalCmsgsFunc( + folly::Expected setCmsgs( + const folly::SocketCmsgMap& cmsgs) override; + folly::Expected appendCmsgs( + const folly::SocketCmsgMap& cmsgs) override; + folly::Expected setAdditionalCmsgsFunc( folly::Function()>&& additionalCmsgsFunc) override; /* * Packet timestamping is currentl not supported. */ - int getTimestamping() override; + folly::Expected getTimestamping() override; /** * Set SO_REUSEADDR flag on the socket. Default is OFF. */ - void setReuseAddr(bool reuseAddr) override; + folly::Expected setReuseAddr(bool reuseAddr) override; /** * Set Dont-Fragment (DF) but ignore Path MTU. @@ -171,32 +186,32 @@ class FollyQuicAsyncUDPSocket : public QuicAsyncUDPSocketImpl { * This may be desirable for apps that has its own PMTU Discovery mechanism. * See http://man7.org/linux/man-pages/man7/ip.7.html for more info. */ - void setDFAndTurnOffPMTU() override; + folly::Expected setDFAndTurnOffPMTU() override; /** * Callback for receiving errors on the UDP sockets */ - void setErrMessageCallback( + folly::Expected setErrMessageCallback( ErrMessageCallback* /* errMessageCallback */) override; - void applyOptions( + folly::Expected applyOptions( const folly::SocketOptionMap& options, folly::SocketOptionKey::ApplyPos pos) override; /** * Set reuse port mode to call bind() on the same address multiple times */ - void setReusePort(bool reusePort) override; + folly::Expected setReusePort(bool reusePort) override; /** * Set SO_RCVBUF option on the socket, if not zero. Default is zero. */ - void setRcvBuf(int rcvBuf) override; + folly::Expected setRcvBuf(int rcvBuf) override; /** * Set SO_SNDBUF option on the socket, if not zero. Default is zero. */ - void setSndBuf(int sndBuf) override; + folly::Expected setSndBuf(int sndBuf) override; /** * Use an already bound file descriptor. You can either transfer ownership @@ -204,7 +219,8 @@ class FollyQuicAsyncUDPSocket : public QuicAsyncUDPSocketImpl { * FDOwnership::SHARED. In case FD is shared, it will not be `close`d in * destructor. */ - void setFD(int fd, FDOwnership ownership) override; + folly::Expected setFD(int fd, FDOwnership ownership) + override; int getFD() override; diff --git a/quic/common/udpsocket/LibevQuicAsyncUDPSocket.cpp b/quic/common/udpsocket/LibevQuicAsyncUDPSocket.cpp index a920d79b5..5dbd0c79d 100644 --- a/quic/common/udpsocket/LibevQuicAsyncUDPSocket.cpp +++ b/quic/common/udpsocket/LibevQuicAsyncUDPSocket.cpp @@ -5,15 +5,18 @@ * LICENSE file in the root directory of this source tree. */ +#include +#include +#include // For folly::errnoStr +#include // For QuicError, QuicErrorCode, TransportErrorCode #include #include #include -#include - #include #include +#include #include #include @@ -34,7 +37,12 @@ LibevQuicAsyncUDPSocket::LibevQuicAsyncUDPSocket( LibevQuicAsyncUDPSocket::~LibevQuicAsyncUDPSocket() { if (fd_ != -1) { - LibevQuicAsyncUDPSocket::close(); + // Use folly::Expected result even in destructor? Best effort close. + auto closeResult = LibevQuicAsyncUDPSocket::close(); + if (closeResult.hasError()) { + LOG(ERROR) << "Error closing socket in destructor: " + << closeResult.error().message; + } } if (evb_) { ev_io_stop(evb_->getLibevLoop(), &readWatcher_); @@ -58,10 +66,11 @@ void LibevQuicAsyncUDPSocket::resumeRead(ReadCallback* cb) { CHECK(cb) << "A non-null callback is required to resume read"; readCallback_ = cb; addEvent(EV_READ); + // TODO: This should return Expected } -folly::Expected -LibevQuicAsyncUDPSocket::resumeWrite(WriteCallback* cob) { +folly::Expected LibevQuicAsyncUDPSocket::resumeWrite( + WriteCallback* cob) { CHECK(!writeCallback_) << "A write callback is already installed"; CHECK_NE(fd_, -1) << "Socket must be initialized before a write callback is attached"; @@ -81,8 +90,11 @@ ssize_t LibevQuicAsyncUDPSocket::write( const struct iovec* vec, size_t iovec_len) { if (fd_ == -1) { - throw folly::AsyncSocketException( - folly::AsyncSocketException::NOT_OPEN, "socket is not initialized"); + // Return error consistent with syscall failure on bad FD + errno = EBADF; + LOG(ERROR) + << "LibevQuicAsyncUDPSocket::write failed: socket not initialized"; + return -1; } sockaddr_storage addrStorage; address.getAddress(&addrStorage); @@ -94,9 +106,12 @@ ssize_t LibevQuicAsyncUDPSocket::write( msg.msg_namelen = address.getActualSize(); } else { if (connectedAddress_ != address) { - throw folly::AsyncSocketException( - folly::AsyncSocketException::BAD_ARGS, - "wrong destination address for connected socket"); + // Return error consistent with syscall failure for wrong address on + // connected socket + errno = EINVAL; // Or maybe EISCONN? EINVAL seems appropriate. + LOG(ERROR) + << "LibevQuicAsyncUDPSocket::write failed: wrong destination for connected socket"; + return -1; } msg.msg_name = nullptr; msg.msg_namelen = 0; @@ -111,39 +126,50 @@ ssize_t LibevQuicAsyncUDPSocket::write( return ::sendmsg(fd_, &msg, msg_flags); } -int LibevQuicAsyncUDPSocket::getGSO() { +folly::Expected LibevQuicAsyncUDPSocket::getGSO() { // TODO: Implement GSO return -1; } int LibevQuicAsyncUDPSocket::writem( - folly::Range, - iovec*, - size_t*, - size_t) { + folly::Range /*addrs*/, + iovec* /*iov*/, + size_t* /*numIovecsInBuffer*/, + size_t /*count*/) { LOG(FATAL) << __func__ << "is not implemented in LibevQuicAsyncUDPSocket"; return -1; } -void LibevQuicAsyncUDPSocket::setAdditionalCmsgsFunc( +folly::Expected +LibevQuicAsyncUDPSocket::setAdditionalCmsgsFunc( folly::Function()>&& /* additionalCmsgsFunc */) { LOG(WARNING) << "Setting an additional cmsgs function is not implemented for LibevQuicAsyncUDPSocket"; + // Return success despite warning, or error if strictness needed + return folly::unit; } bool LibevQuicAsyncUDPSocket::isBound() const { return bound_; } -const folly::SocketAddress& LibevQuicAsyncUDPSocket::address() const { +folly::Expected +LibevQuicAsyncUDPSocket::address() const { if (!bound_) { - throw folly::AsyncSocketException( - folly::AsyncSocketException::NOT_OPEN, "socket is not bound"); + // Return error if not bound + return folly::makeUnexpected(QuicError( + QuicErrorCode(TransportErrorCode::INTERNAL_ERROR), + "socket is not bound")); } return localAddress_; } +const folly::SocketAddress& LibevQuicAsyncUDPSocket::addressRef() const { + LOG_IF(FATAL, !bound_) << "socket is not bound"; + return localAddress_; +} + void LibevQuicAsyncUDPSocket::attachEventBase( std::shared_ptr /* evb */) { LOG(FATAL) << __func__ << "is not implemented in LibevQuicAsyncUDPSocket"; @@ -154,7 +180,7 @@ LibevQuicAsyncUDPSocket::getEventBase() const { return evb_; } -void LibevQuicAsyncUDPSocket::close() { +folly::Expected LibevQuicAsyncUDPSocket::close() { CHECK(evb_->isInEventBaseThread()); if (readCallback_) { @@ -167,88 +193,117 @@ void LibevQuicAsyncUDPSocket::close() { removeEvent(EV_READ | EV_WRITE); if (fd_ != -1 && ownership_ == FDOwnership::OWNS) { - ::close(fd_); + if (::close(fd_) != 0) { + int errnoCopy = errno; + fd_ = -1; // Mark as closed even if error occurred + std::string errorMsg = + "Failed to close socket: " + folly::errnoStr(errnoCopy); + return folly::makeUnexpected(QuicError( + QuicErrorCode(TransportErrorCode::INTERNAL_ERROR), + std::move(errorMsg))); + } } fd_ = -1; + bound_ = false; // Reset state + connected_ = false; + return folly::unit; } void LibevQuicAsyncUDPSocket::detachEventBase() { LOG(FATAL) << __func__ << "is not implemented in LibevQuicAsyncUDPSocket"; } -void LibevQuicAsyncUDPSocket::setCmsgs( +folly::Expected LibevQuicAsyncUDPSocket::setCmsgs( const folly::SocketCmsgMap& /* cmsgs */) { - throw std::runtime_error("setCmsgs is not implemented."); + return folly::makeUnexpected(QuicError( + QuicErrorCode(TransportErrorCode::INTERNAL_ERROR), + "setCmsgs is not implemented.")); } -void LibevQuicAsyncUDPSocket::appendCmsgs( +folly::Expected LibevQuicAsyncUDPSocket::appendCmsgs( const folly::SocketCmsgMap& /* cmsgs */) { - throw std::runtime_error("appendCmsgs is not implemented."); + return folly::makeUnexpected(QuicError( + QuicErrorCode(TransportErrorCode::INTERNAL_ERROR), + "appendCmsgs is not implemented.")); } -void LibevQuicAsyncUDPSocket::init(sa_family_t family) { +folly::Expected LibevQuicAsyncUDPSocket::init( + sa_family_t family) { if (fd_ != -1) { // Socket already initialized. - return; + return folly::unit; } if (family != AF_INET && family != AF_INET6) { - throw folly::AsyncSocketException( - folly::AsyncSocketException::NOT_SUPPORTED, - "address family not supported"); + return folly::makeUnexpected(QuicError( + QuicErrorCode(TransportErrorCode::INTERNAL_ERROR), + "address family not supported")); } int fd = ::socket(family, SOCK_DGRAM, IPPROTO_UDP); if (fd == -1) { - throw folly::AsyncSocketException( - folly::AsyncSocketException::NOT_OPEN, "error creating socket", errno); + int errnoCopy = errno; + std::string errorMsg = + "error creating socket: " + folly::errnoStr(errnoCopy); + return folly::makeUnexpected(QuicError( + QuicErrorCode(TransportErrorCode::INTERNAL_ERROR), + std::move(errorMsg))); } - SCOPE_FAIL { - ::close(fd); - }; + // Use RAII to ensure socket is closed on error + auto fdGuard = folly::makeGuard([fd] { ::close(fd); }); int flags = fcntl(fd, F_GETFL, 0); if (flags == -1) { - throw folly::AsyncSocketException( - folly::AsyncSocketException::INTERNAL_ERROR, - "error getting socket flags", - errno); + int errnoCopy = errno; + std::string errorMsg = + "error getting socket flags: " + folly::errnoStr(errnoCopy); + return folly::makeUnexpected(QuicError( + QuicErrorCode(TransportErrorCode::INTERNAL_ERROR), + std::move(errorMsg))); } if (fcntl(fd, F_SETFL, flags | O_NONBLOCK) != 0) { - throw folly::AsyncSocketException( - folly::AsyncSocketException::INTERNAL_ERROR, - "error setting socket nonblocking flag", - errno); + int errnoCopy = errno; + std::string errorMsg = + "error setting socket nonblocking flag: " + folly::errnoStr(errnoCopy); + return folly::makeUnexpected(QuicError( + QuicErrorCode(TransportErrorCode::INTERNAL_ERROR), + std::move(errorMsg))); } int sockOptVal = 1; if (reuseAddr_ && ::setsockopt( fd, SOL_SOCKET, SO_REUSEADDR, &sockOptVal, sizeof(sockOptVal)) != 0) { - throw folly::AsyncSocketException( - folly::AsyncSocketException::INTERNAL_ERROR, - "error setting reuse address on socket", - errno); + int errnoCopy = errno; + std::string errorMsg = + "error setting reuse address on socket: " + folly::errnoStr(errnoCopy); + return folly::makeUnexpected(QuicError( + QuicErrorCode(TransportErrorCode::INTERNAL_ERROR), + std::move(errorMsg))); } if (reusePort_ && ::setsockopt( fd, SOL_SOCKET, SO_REUSEPORT, &sockOptVal, sizeof(sockOptVal)) != 0) { - throw folly::AsyncSocketException( - folly::AsyncSocketException::INTERNAL_ERROR, - "error setting reuse port on socket", - errno); + int errnoCopy = errno; + std::string errorMsg = + "error setting reuse port on socket: " + folly::errnoStr(errnoCopy); + return folly::makeUnexpected(QuicError( + QuicErrorCode(TransportErrorCode::INTERNAL_ERROR), + std::move(errorMsg))); } if (rcvBuf_ > 0) { // Set the size of the buffer for the received messages in rx_queues. int value = rcvBuf_; if (::setsockopt(fd, SOL_SOCKET, SO_RCVBUF, &value, sizeof(value)) != 0) { - throw folly::AsyncSocketException( - folly::AsyncSocketException::NOT_OPEN, - "failed to set SO_RCVBUF on the socket", - errno); + int errnoCopy = errno; + std::string errorMsg = "failed to set SO_RCVBUF on the socket: " + + folly::errnoStr(errnoCopy); + return folly::makeUnexpected(QuicError( + QuicErrorCode(TransportErrorCode::INTERNAL_ERROR), + std::move(errorMsg))); } } @@ -256,28 +311,38 @@ void LibevQuicAsyncUDPSocket::init(sa_family_t family) { // Set the size of the buffer for the sent messages in tx_queues. int value = sndBuf_; if (::setsockopt(fd, SOL_SOCKET, SO_SNDBUF, &value, sizeof(value)) != 0) { - throw folly::AsyncSocketException( - folly::AsyncSocketException::NOT_OPEN, - "failed to set SO_SNDBUF on the socket", - errno); + int errnoCopy = errno; + std::string errorMsg = "failed to set SO_SNDBUF on the socket: " + + folly::errnoStr(errnoCopy); + return folly::makeUnexpected(QuicError( + QuicErrorCode(TransportErrorCode::INTERNAL_ERROR), + std::move(errorMsg))); } } fd_ = fd; ownership_ = FDOwnership::OWNS; + fdGuard.dismiss(); // Don't close the fd now that we've stored it // Update the watchers removeEvent(EV_READ | EV_WRITE); ev_io_set(&readWatcher_, fd_, EV_READ); ev_io_set(&writeWatcher_, fd_, EV_WRITE); + + return folly::unit; } -void LibevQuicAsyncUDPSocket::bind(const folly::SocketAddress& address) { +folly::Expected LibevQuicAsyncUDPSocket::bind( + const folly::SocketAddress& address) { // TODO: remove dependency on folly::SocketAdress since this pulls in // folly::portability and other headers which should be avoidable. if (fd_ == -1) { - init(address.getFamily()); + auto initResult = init(address.getFamily()); + if (initResult.hasError()) { + return initResult; + } } + // bind to the address sockaddr_storage addrStorage; address.getAddress(&addrStorage); @@ -287,38 +352,50 @@ void LibevQuicAsyncUDPSocket::bind(const folly::SocketAddress& address) { (struct sockaddr*)&saddr, saddr.sa_family == AF_INET6 ? sizeof(sockaddr_in6) : sizeof(sockaddr_in)) != 0) { - throw folly::AsyncSocketException( - folly::AsyncSocketException::INTERNAL_ERROR, - "error binding socket to " + address.describe(), - errno); + int errnoCopy = errno; + std::string errorMsg = "error binding socket to " + address.describe() + + ": " + folly::errnoStr(errnoCopy); + return folly::makeUnexpected(QuicError( + QuicErrorCode(TransportErrorCode::INTERNAL_ERROR), + std::move(errorMsg))); } memset(&saddr, 0, sizeof(saddr)); socklen_t len = sizeof(saddr); if (::getsockname(fd_, &saddr, &len) != 0) { - throw folly::AsyncSocketException( - folly::AsyncSocketException::INTERNAL_ERROR, - "error retrieving local address", - errno); + int errnoCopy = errno; + std::string errorMsg = + "error retrieving local address: " + folly::errnoStr(errnoCopy); + return folly::makeUnexpected(QuicError( + QuicErrorCode(TransportErrorCode::INTERNAL_ERROR), + std::move(errorMsg))); } localAddress_.setFromSockaddr(&saddr, len); bound_ = true; + + return folly::unit; } -void LibevQuicAsyncUDPSocket::connect(const folly::SocketAddress& address) { +folly::Expected LibevQuicAsyncUDPSocket::connect( + const folly::SocketAddress& address) { if (fd_ == -1) { - init(address.getFamily()); + auto initResult = init(address.getFamily()); + if (initResult.hasError()) { + return initResult; + } } sockaddr_storage addrStorage; address.getAddress(&addrStorage); auto saddr = reinterpret_cast(addrStorage); if (::connect(fd_, &saddr, sizeof(saddr)) != 0) { - throw folly::AsyncSocketException( - folly::AsyncSocketException::INTERNAL_ERROR, - "error connecting UDP socket to " + address.describe(), - errno); + int errnoCopy = errno; + std::string errorMsg = "Libev connect failed to " + address.describe(); + errorMsg += ": " + folly::errnoStr(errnoCopy); + return folly::makeUnexpected(QuicError( + QuicErrorCode(TransportErrorCode::INTERNAL_ERROR), + std::move(errorMsg))); } connected_ = true; @@ -328,20 +405,26 @@ void LibevQuicAsyncUDPSocket::connect(const folly::SocketAddress& address) { memset(&saddr, 0, sizeof(saddr)); socklen_t len = sizeof(saddr); if (::getsockname(fd_, &saddr, &len) != 0) { - throw folly::AsyncSocketException( - folly::AsyncSocketException::INTERNAL_ERROR, - "error retrieving local address", - errno); + int errnoCopy = errno; + std::string errorMsg = "Libev getsockname failed after connect: " + + folly::errnoStr(errnoCopy); + // Connect succeeded, but getsockname failed. + return folly::makeUnexpected(QuicError( + QuicErrorCode(TransportErrorCode::INTERNAL_ERROR), + std::move(errorMsg))); } localAddress_.setFromSockaddr(&saddr, len); } + return folly::unit; } -void LibevQuicAsyncUDPSocket::setDFAndTurnOffPMTU() { +folly::Expected +LibevQuicAsyncUDPSocket::setDFAndTurnOffPMTU() { if (fd_ == -1) { - throw folly::AsyncSocketException( - folly::AsyncSocketException::NOT_OPEN, "socket is not initialized"); + return folly::makeUnexpected(QuicError( + QuicErrorCode(TransportErrorCode::INTERNAL_ERROR), + "socket is not initialized")); } int optname4 = 0; int optval4 = 0; @@ -355,26 +438,44 @@ void LibevQuicAsyncUDPSocket::setDFAndTurnOffPMTU() { optname6 = IPV6_MTU_DISCOVER; optval6 = IPV6_PMTUDISC_PROBE; #endif - if (optname4 && optval4 && address().getFamily() == AF_INET) { + auto familyResult = address(); // address() now returns Expected + if (familyResult.hasError()) { + return folly::makeUnexpected(familyResult.error()); + } + sa_family_t family = familyResult->getFamily(); + + if (optname4 && optval4 && family == AF_INET) { if (::setsockopt(fd_, IPPROTO_IP, optname4, &optval4, sizeof(optval4))) { - throw folly::AsyncSocketException( - folly::AsyncSocketException::NOT_OPEN, - "failed to turn off PMTU discovery (IPv4)", - errno); + int errnoCopy = errno; + std::string errorMsg = "failed to turn off PMTU discovery (IPv4): " + + folly::errnoStr(errnoCopy); + return folly::makeUnexpected(QuicError( + QuicErrorCode(TransportErrorCode::INTERNAL_ERROR), + std::move(errorMsg))); } } - if (optname6 && optval6 && address().getFamily() == AF_INET6) { + if (optname6 && optval6 && family == AF_INET6) { if (::setsockopt(fd_, IPPROTO_IPV6, optname6, &optval6, sizeof(optval6))) { - throw folly::AsyncSocketException( - folly::AsyncSocketException::NOT_OPEN, - "failed to turn off PMTU discovery (IPv6)", - errno); + int errnoCopy = errno; + std::string errorMsg = "failed to turn off PMTU discovery (IPv6): " + + folly::errnoStr(errnoCopy); + return folly::makeUnexpected(QuicError( + QuicErrorCode(TransportErrorCode::INTERNAL_ERROR), + std::move(errorMsg))); } } + // If options are not defined for the family, we succeed silently. + return folly::unit; } -void LibevQuicAsyncUDPSocket::setErrMessageCallback( +folly::Expected +LibevQuicAsyncUDPSocket::setErrMessageCallback( ErrMessageCallback* errMessageCallback) { + if (fd_ == -1) { + return folly::makeUnexpected(QuicError( + QuicErrorCode(TransportErrorCode::INTERNAL_ERROR), + "socket not initialized for setErrMessageCallback")); + } errMessageCallback_ = errMessageCallback; int optname4 = 0; int optname6 = 0; @@ -384,25 +485,36 @@ void LibevQuicAsyncUDPSocket::setErrMessageCallback( #if defined(IPV6_RECVERR) optname6 = IPV6_RECVERR; #endif + auto familyResult = address(); // address() now returns Expected + if (familyResult.hasError()) { + return folly::makeUnexpected(familyResult.error()); + } + sa_family_t family = familyResult->getFamily(); + errMessageCallback_ = errMessageCallback; int err = (errMessageCallback_ != nullptr); - if (optname4 && address().getFamily() == AF_INET && + if (optname4 && family == AF_INET && ::setsockopt(fd_, IPPROTO_IP, optname4, &err, sizeof(err))) { - throw folly::AsyncSocketException( - folly::AsyncSocketException::NOT_OPEN, - "Failed to set IP_RECVERR", - errno); + int errnoCopy = errno; + std::string errorMsg = + "Failed to set IP_RECVERR: " + folly::errnoStr(errnoCopy); + return folly::makeUnexpected(QuicError( + QuicErrorCode(TransportErrorCode::INTERNAL_ERROR), + std::move(errorMsg))); } - if (optname6 && address().getFamily() == AF_INET6 && + if (optname6 && family == AF_INET6 && ::setsockopt(fd_, IPPROTO_IPV6, optname6, &err, sizeof(err))) { - throw folly::AsyncSocketException( - folly::AsyncSocketException::NOT_OPEN, - "Failed to set IPV6_RECVERR", - errno); + int errnoCopy = errno; + std::string errorMsg = + "Failed to set IPV6_RECVERR: " + folly::errnoStr(errnoCopy); + return folly::makeUnexpected(QuicError( + QuicErrorCode(TransportErrorCode::INTERNAL_ERROR), + std::move(errorMsg))); } + return folly::unit; } -int LibevQuicAsyncUDPSocket::getGRO() { +folly::Expected LibevQuicAsyncUDPSocket::getGRO() { return -1; } @@ -452,13 +564,22 @@ int LibevQuicAsyncUDPSocket::recvmmsg( return static_cast(vlen); } -bool LibevQuicAsyncUDPSocket::setGRO(bool /* bVal */) { - return false; +folly::Expected LibevQuicAsyncUDPSocket::setGRO( + bool /* bVal */) { + // Not supported, return error + return folly::makeUnexpected(QuicError( + QuicErrorCode(TransportErrorCode::INTERNAL_ERROR), + "setGRO not supported")); } -void LibevQuicAsyncUDPSocket::applyOptions( +folly::Expected LibevQuicAsyncUDPSocket::applyOptions( const folly::SocketOptionMap& options, folly::SocketOptionKey::ApplyPos pos) { + if (fd_ == -1) { + return folly::makeUnexpected(QuicError( + QuicErrorCode(TransportErrorCode::INTERNAL_ERROR), + "socket not initialized for applyOptions")); + } for (const auto& opt : options) { if (opt.first.applyPos_ == pos) { if (::setsockopt( @@ -467,18 +588,37 @@ void LibevQuicAsyncUDPSocket::applyOptions( opt.first.optname, &opt.second, sizeof(opt.second)) != 0) { - throw folly::AsyncSocketException( - folly::AsyncSocketException::INTERNAL_ERROR, - "failed to apply socket options", - errno); + int errnoCopy = errno; + std::string errorMsg = + "failed to apply socket options: " + folly::errnoStr(errnoCopy); + return folly::makeUnexpected(QuicError( + QuicErrorCode(TransportErrorCode::INTERNAL_ERROR), + std::move(errorMsg))); } } } + return folly::unit; } -void LibevQuicAsyncUDPSocket::setFD(int fd, FDOwnership ownership) { +folly::Expected LibevQuicAsyncUDPSocket::setFD( + int fd, + FDOwnership ownership) { + // TODO: Check if fd is valid? setsockopt? + // TODO: Close existing fd_ if owned? + if (fd_ != -1 && ownership_ == FDOwnership::OWNS) { + LOG(WARNING) << "Closing existing owned FD in setFD"; + auto closeRes = close(); // Close existing owned FD + if (closeRes.hasError()) { + // Log error but continue trying to set the new FD + LOG(ERROR) << "Failed to close existing FD in setFD: " + << closeRes.error().message; + } + } + fd_ = fd; ownership_ = ownership; + bound_ = false; // Assume not bound until checked/bind called + connected_ = false; // Assume not connected // Update the watchers removeEvent(EV_READ | EV_WRITE); @@ -491,6 +631,9 @@ void LibevQuicAsyncUDPSocket::setFD(int fd, FDOwnership ownership) { if (writeCallback_) { addEvent(EV_WRITE); } + // TODO: Check if the FD is actually usable? Maybe getsockopt? + // For now, assume success if we reach here. + return folly::unit; } int LibevQuicAsyncUDPSocket::getFD() { @@ -628,12 +771,59 @@ void LibevQuicAsyncUDPSocket::sockEventsWatcherCallback( } } -void LibevQuicAsyncUDPSocket::setRcvBuf(int rcvBuf) { +folly::Expected LibevQuicAsyncUDPSocket::setRcvBuf( + int rcvBuf) { rcvBuf_ = rcvBuf; + if (fd_ != -1) { + // Apply immediately if socket exists + int value = rcvBuf_; + if (::setsockopt(fd_, SOL_SOCKET, SO_RCVBUF, &value, sizeof(value)) != 0) { + int errnoCopy = errno; + std::string errorMsg = + "failed to set SO_RCVBUF: " + folly::errnoStr(errnoCopy); + return folly::makeUnexpected(QuicError( + QuicErrorCode(TransportErrorCode::INTERNAL_ERROR), + std::move(errorMsg))); + } + } + return folly::unit; } -void LibevQuicAsyncUDPSocket::setSndBuf(int sndBuf) { +folly::Expected LibevQuicAsyncUDPSocket::setSndBuf( + int sndBuf) { sndBuf_ = sndBuf; + if (fd_ != -1) { + // Apply immediately if socket exists + int value = sndBuf_; + if (::setsockopt(fd_, SOL_SOCKET, SO_SNDBUF, &value, sizeof(value)) != 0) { + int errnoCopy = errno; + std::string errorMsg = + "failed to set SO_SNDBUF: " + folly::errnoStr(errnoCopy); + return folly::makeUnexpected(QuicError( + QuicErrorCode(TransportErrorCode::INTERNAL_ERROR), + std::move(errorMsg))); + } + } + return folly::unit; +} + +folly::Expected LibevQuicAsyncUDPSocket::setReuseAddr( + bool reuseAddr) { + reuseAddr_ = reuseAddr; + if (fd_ != -1) { + int sockOptVal = reuseAddr ? 1 : 0; + if (::setsockopt( + fd_, SOL_SOCKET, SO_REUSEADDR, &sockOptVal, sizeof(sockOptVal)) != + 0) { + int errnoCopy = errno; + std::string errorMsg = + "failed to set SO_REUSEADDR: " + folly::errnoStr(errnoCopy); + return folly::makeUnexpected(QuicError( + QuicErrorCode(TransportErrorCode::INTERNAL_ERROR), + std::move(errorMsg))); + } + } + return folly::unit; } bool LibevQuicAsyncUDPSocket::isWritableCallbackSet() const { diff --git a/quic/common/udpsocket/LibevQuicAsyncUDPSocket.h b/quic/common/udpsocket/LibevQuicAsyncUDPSocket.h index 03162103c..f215ab422 100644 --- a/quic/common/udpsocket/LibevQuicAsyncUDPSocket.h +++ b/quic/common/udpsocket/LibevQuicAsyncUDPSocket.h @@ -7,6 +7,8 @@ #pragma once +#include +#include // For QuicError #include #include #include @@ -19,18 +21,21 @@ class LibevQuicAsyncUDPSocket : public QuicAsyncUDPSocketImpl { explicit LibevQuicAsyncUDPSocket(std::shared_ptr qEvb); ~LibevQuicAsyncUDPSocket() override; + [[nodiscard]] folly::Expected init( + sa_family_t family) override; - void init(sa_family_t family) override; - - void bind(const folly::SocketAddress& address) override; + [[nodiscard]] folly::Expected bind( + const folly::SocketAddress& address) override; [[nodiscard]] bool isBound() const override; - void connect(const folly::SocketAddress& address) override; + folly::Expected connect( + const folly::SocketAddress& address) override; - void close() override; + folly::Expected close() override; void resumeRead(ReadCallback* callback) override; + // TODO: resumeRead should return Expected void pauseRead() override; @@ -96,32 +101,42 @@ class LibevQuicAsyncUDPSocket : public QuicAsyncUDPSocketImpl { // generic segmentation offload get/set // negative return value means GSO is not available - int getGSO() override; + folly::Expected getGSO() override; // generic receive offload get/set // negative return value means GRO is not available - int getGRO() override; - bool setGRO(bool bVal) override; + folly::Expected getGRO() override; + folly::Expected setGRO(bool bVal) override; // receive tos cmsgs // if true, the IPv6 Traffic Class/IPv4 Type of Service field should be // populated in OnDataAvailableParams. - void setRecvTos(bool /*recvTos*/) override { + folly::Expected setRecvTos( + bool /*recvTos*/) override { LOG(WARNING) << __func__ << " not implemented in LibevQuicAsyncUDPSocket"; + return folly::unit; // Or return error if strictness needed } - bool getRecvTos() override { - return false; + folly::Expected getRecvTos() override { + return false; // Not implemented, return default/false } - void setTosOrTrafficClass(uint8_t /*tos*/) override { + folly::Expected setTosOrTrafficClass( + uint8_t /*tos*/) override { LOG(WARNING) << __func__ << " not implemented in LibevQuicAsyncUDPSocket"; + return folly::unit; // Or return error if strictness needed } /** - * Returns the socket server is bound to + * Returns the socket address this socket is bound to and error otherwise. */ - [[nodiscard]] const folly::SocketAddress& address() const override; + [[nodiscard]] folly::Expected address() + const override; + + /** + * Returns the socket address this socket is bound to and crashes otherwise. + */ + [[nodiscard]] virtual const folly::SocketAddress& addressRef() const override; /** * Manage the eventbase driving this socket @@ -133,35 +148,35 @@ class LibevQuicAsyncUDPSocket : public QuicAsyncUDPSocketImpl { /** * Set extra control messages to send */ - void setCmsgs(const folly::SocketCmsgMap& cmsgs) override; - void appendCmsgs(const folly::SocketCmsgMap& cmsgs) override; - void setAdditionalCmsgsFunc( + folly::Expected setCmsgs( + const folly::SocketCmsgMap& cmsgs) override; + folly::Expected appendCmsgs( + const folly::SocketCmsgMap& cmsgs) override; + folly::Expected setAdditionalCmsgsFunc( folly::Function()>&& additionalCmsgsFunc) override; /* * Packet timestamping is currently not supported. */ - int getTimestamping() override { - return -1; + folly::Expected getTimestamping() override { + return -1; // Keep returning -1 for not supported } /** * Set SO_REUSEADDR flag on the socket. Default is OFF. */ - void setReuseAddr(bool /*reuseAddr*/) override { - LOG(WARNING) << __func__ << " not implemented in LibevQuicAsyncUDPSocket"; - } + folly::Expected setReuseAddr(bool reuseAddr) override; /** * Set SO_RCVBUF option on the socket, if not zero. Default is zero. */ - void setRcvBuf(int rcvBuf) override; + folly::Expected setRcvBuf(int rcvBuf) override; /** * Set SO_SNDBUF option on the socket, if not zero. Default is zero. */ - void setSndBuf(int sndBuf) override; + folly::Expected setSndBuf(int sndBuf) override; /** * Set Dont-Fragment (DF) but ignore Path MTU. * @@ -170,23 +185,25 @@ class LibevQuicAsyncUDPSocket : public QuicAsyncUDPSocketImpl { * This may be desirable for apps that has its own PMTU Discovery mechanism. * See http://man7.org/linux/man-pages/man7/ip.7.html for more info. */ - void setDFAndTurnOffPMTU() override; + folly::Expected setDFAndTurnOffPMTU() override; /** * Callback for receiving errors on the UDP sockets */ - void setErrMessageCallback( + folly::Expected setErrMessageCallback( ErrMessageCallback* /* errMessageCallback */) override; - void applyOptions( + folly::Expected applyOptions( const folly::SocketOptionMap& options, folly::SocketOptionKey::ApplyPos pos) override; /** * Set reuse port mode to call bind() on the same address multiple times */ - void setReusePort(bool /*reusePort*/) override { - LOG(FATAL) << __func__ << " not supported in LibevQuicAsyncUDPSocket"; + folly::Expected setReusePort(bool) override { + LOG(FATAL) << __func__ << " not implemented in LibevQuicAsyncUDPSocket"; + // Return success as it's just a warning, or error if strictness needed + return folly::unit; } /** @@ -195,14 +212,15 @@ class LibevQuicAsyncUDPSocket : public QuicAsyncUDPSocketImpl { * FDOwnership::SHARED. In case FD is shared, it will not be `close`d in * destructor. */ - void setFD(int fd, FDOwnership ownership) override; + folly::Expected setFD(int fd, FDOwnership ownership) + override; int getFD() override; /** * Start listening to writable events on the socket. */ - folly::Expected resumeWrite( + folly::Expected resumeWrite( WriteCallback* /* cob */) override; /** diff --git a/quic/common/udpsocket/QuicAsyncUDPSocket.h b/quic/common/udpsocket/QuicAsyncUDPSocket.h index 99765afa8..91b0860ff 100644 --- a/quic/common/udpsocket/QuicAsyncUDPSocket.h +++ b/quic/common/udpsocket/QuicAsyncUDPSocket.h @@ -9,18 +9,25 @@ #include +#include #include #include #include #include #include +#include // For folly::errnoStr #include +#include // Include for existing QuicError and QuicErrorCode #include #include #include namespace quic { +// Forward declarations are likely in QuicException.h now +// class QuicError; +// enum class TransportErrorCode : uint64_t; +// class QuicErrorCode; /** * QuicAsyncUDPSocket is an abstract class that represents an UDP socket that @@ -86,12 +93,12 @@ class QuicAsyncUDPSocket { }; virtual ~QuicAsyncUDPSocket() = default; - // Initializes underlying socket fd. This is called in bind() and connect() // internally if fd is not yet set at the time of the call. But if there is a // need to apply socket options pre-bind, one can call this function // explicitly before bind()/connect() and socket opts application. - virtual void init(sa_family_t /* family */) = 0; + [[nodiscard]] virtual folly::Expected init( + sa_family_t /* family */) = 0; /** * Bind the socket to the following address. If port is not @@ -99,7 +106,9 @@ class QuicAsyncUDPSocket { * use `address()` method above to get it after this method successfully * returns. */ - virtual void bind(const folly::SocketAddress& address) = 0; + [[nodiscard]] virtual folly::Expected bind( + const folly::SocketAddress& address) = 0; + [[nodiscard]] virtual bool isBound() const = 0; /** @@ -121,17 +130,19 @@ class QuicAsyncUDPSocket { * * Returns the result of calling the connect syscall. */ - virtual void connect(const folly::SocketAddress& /* address */) = 0; + [[nodiscard]] virtual folly::Expected connect( + const folly::SocketAddress& /* address */) = 0; /** * Stop listening on the socket. */ - virtual void close() = 0; + [[nodiscard]] virtual folly::Expected close() = 0; /** * Start reading datagrams */ virtual void resumeRead(ReadCallback* /* cb */) = 0; + // TODO: resumeRead can fail (e.g. bad socket state). Should return Expected. /** * Pause reading datagrams @@ -145,7 +156,7 @@ class QuicAsyncUDPSocket { /** * Start listening to writable events on the socket. */ - virtual folly::Expected resumeWrite( + [[nodiscard]] virtual folly::Expected resumeWrite( WriteCallback* /* cb */) { return folly::unit; } @@ -258,7 +269,7 @@ class QuicAsyncUDPSocket { * recv() result structure. */ struct RecvResult { - RecvResult() = default; + RecvResult() = default; // Default constructor for success case explicit RecvResult(NoReadReason noReadReason) : maybeNoReadReason(noReadReason) {} @@ -266,7 +277,8 @@ class QuicAsyncUDPSocket { Optional maybeNoReadReason; }; - virtual RecvResult recvmmsgNetworkData( + [[nodiscard]] virtual folly::Expected + recvmmsgNetworkData( uint64_t readBufferSize, uint16_t numPackets, NetworkData& networkData, @@ -275,25 +287,34 @@ class QuicAsyncUDPSocket { // generic segmentation offload get/set // negative return value means GSO is not available - virtual int getGSO() = 0; + [[nodiscard]] virtual folly::Expected getGSO() = 0; // generic receive offload get/set // negative return value means GRO is not available - virtual int getGRO() = 0; - virtual bool setGRO(bool /* bVal */) = 0; + [[nodiscard]] virtual folly::Expected getGRO() = 0; + [[nodiscard]] virtual folly::Expected setGRO( + bool /* bVal */) = 0; // receive tos cmsgs // if true, the IPv6 Traffic Class/IPv4 Type of Service field should be // populated in OnDataAvailableParams. - virtual void setRecvTos(bool recvTos) = 0; - virtual bool getRecvTos() = 0; + [[nodiscard]] virtual folly::Expected setRecvTos( + bool recvTos) = 0; + [[nodiscard]] virtual folly::Expected getRecvTos() = 0; - virtual void setTosOrTrafficClass(uint8_t tos) = 0; + [[nodiscard]] virtual folly::Expected + setTosOrTrafficClass(uint8_t tos) = 0; /** - * Returns the socket server is bound to + * Returns the socket address this socket is bound to and error otherwise. */ - [[nodiscard]] virtual const folly::SocketAddress& address() const = 0; + [[nodiscard]] virtual folly::Expected + address() const = 0; + + /** + * Returns the socket address this socket is bound to and crashes otherwise. + */ + [[nodiscard]] virtual const folly::SocketAddress& addressRef() const = 0; /** * Manage the eventbase driving this socket @@ -301,25 +322,27 @@ class QuicAsyncUDPSocket { virtual void attachEventBase(std::shared_ptr /* evb */) = 0; virtual void detachEventBase() = 0; [[nodiscard]] virtual std::shared_ptr getEventBase() const = 0; - /** * Set extra control messages to send */ - virtual void setCmsgs(const folly::SocketCmsgMap& /* cmsgs */) = 0; - virtual void appendCmsgs(const folly::SocketCmsgMap& /* cmsgs */) = 0; - virtual void setAdditionalCmsgsFunc( - folly::Function()>&& - /* additionalCmsgsFunc */) = 0; + [[nodiscard]] virtual folly::Expected setCmsgs( + const folly::SocketCmsgMap& /* cmsgs */) = 0; + [[nodiscard]] virtual folly::Expected appendCmsgs( + const folly::SocketCmsgMap& /* cmsgs */) = 0; + [[nodiscard]] virtual folly::Expected + setAdditionalCmsgsFunc(folly::Function()>&& + /* additionalCmsgsFunc */) = 0; /* * Packet timestamping is currentl not supported. */ - virtual int getTimestamping() = 0; + [[nodiscard]] virtual folly::Expected getTimestamping() = 0; /** * Set SO_REUSEADDR flag on the socket. Default is OFF. */ - virtual void setReuseAddr(bool reuseAddr) = 0; + [[nodiscard]] virtual folly::Expected setReuseAddr( + bool reuseAddr) = 0; /** * Set Dont-Fragment (DF) but ignore Path MTU. @@ -329,32 +352,36 @@ class QuicAsyncUDPSocket { * This may be desirable for apps that has its own PMTU Discovery mechanism. * See http://man7.org/linux/man-pages/man7/ip.7.html for more info. */ - virtual void setDFAndTurnOffPMTU() = 0; + [[nodiscard]] virtual folly::Expected + setDFAndTurnOffPMTU() = 0; /** * Callback for receiving errors on the UDP sockets */ - virtual void setErrMessageCallback( - ErrMessageCallback* /* errMessageCallback */) = 0; + [[nodiscard]] virtual folly::Expected + setErrMessageCallback(ErrMessageCallback* /* errMessageCallback */) = 0; - virtual void applyOptions( + [[nodiscard]] virtual folly::Expected applyOptions( const folly::SocketOptionMap& /* options */, folly::SocketOptionKey::ApplyPos /* pos */) = 0; /** * Set reuse port mode to call bind() on the same address multiple times */ - virtual void setReusePort(bool reusePort) = 0; + [[nodiscard]] virtual folly::Expected setReusePort( + bool reusePort) = 0; /** * Set SO_RCVBUF option on the socket, if not zero. Default is zero. */ - virtual void setRcvBuf(int rcvBuf) = 0; + [[nodiscard]] virtual folly::Expected setRcvBuf( + int rcvBuf) = 0; /** * Set SO_SNDBUF option on the socket, if not zero. Default is zero. */ - virtual void setSndBuf(int sndBuf) = 0; + [[nodiscard]] virtual folly::Expected setSndBuf( + int sndBuf) = 0; enum class FDOwnership { OWNS, SHARED }; @@ -364,12 +391,19 @@ class QuicAsyncUDPSocket { * FDOwnership::SHARED. In case FD is shared, it will not be `close`d in * destructor. */ - virtual void setFD(int /* fd */, FDOwnership /* ownership */) = 0; + [[nodiscard]] virtual folly::Expected setFD( + int /* fd */, + FDOwnership /* ownership */) = 0; virtual int getFD() = 0; - [[nodiscard]] virtual sa_family_t getLocalAddressFamily() const { - return address().getFamily(); + [[nodiscard]] virtual folly::Expected + getLocalAddressFamily() const { + auto addrResult = address(); + if (addrResult.hasError()) { + return folly::makeUnexpected(addrResult.error()); + } + return addrResult->getFamily(); } template < diff --git a/quic/common/udpsocket/QuicAsyncUDPSocketImpl.cpp b/quic/common/udpsocket/QuicAsyncUDPSocketImpl.cpp index 6c121c51c..4d9858ffe 100644 --- a/quic/common/udpsocket/QuicAsyncUDPSocketImpl.cpp +++ b/quic/common/udpsocket/QuicAsyncUDPSocketImpl.cpp @@ -5,6 +5,10 @@ * LICENSE file in the root directory of this source tree. */ +#include +#include +#include // For folly::errnoStr +#include // For QuicError, QuicErrorCode, TransportErrorCode #include namespace { @@ -13,7 +17,8 @@ constexpr socklen_t kAddrLen = sizeof(sockaddr_storage); namespace quic { -QuicAsyncUDPSocket::RecvResult QuicAsyncUDPSocketImpl::recvmmsgNetworkData( +folly::Expected +QuicAsyncUDPSocketImpl::recvmmsgNetworkData( uint64_t readBufferSize, uint16_t numPackets, NetworkData& networkData, @@ -27,9 +32,23 @@ QuicAsyncUDPSocket::RecvResult QuicAsyncUDPSocketImpl::recvmmsgNetworkData( recvmmsgStorage_.resize(numPackets); auto& msgs = recvmmsgStorage_.msgs; int flags = 0; + + // Check socket options using Expected results + auto groResult = getGRO(); + if (FOLLY_UNLIKELY(groResult.hasError())) { + return folly::makeUnexpected(groResult.error()); + } + auto timestampingResult = getTimestamping(); + if (FOLLY_UNLIKELY(timestampingResult.hasError())) { + return folly::makeUnexpected(timestampingResult.error()); + } + auto recvTosResult = getRecvTos(); + if (FOLLY_UNLIKELY(recvTosResult.hasError())) { + return folly::makeUnexpected(recvTosResult.error()); + } #if defined(FOLLY_HAVE_MSG_ERRQUEUE) || defined(_WIN32) - bool useGRO = getGRO() > 0; - bool checkCmsgs = useGRO || getTimestamping() > 0 || getRecvTos(); + bool useGRO = *groResult > 0; + bool checkCmsgs = useGRO || *timestampingResult > 0 || *recvTosResult; std::vector> @@ -55,8 +74,13 @@ QuicAsyncUDPSocket::RecvResult QuicAsyncUDPSocketImpl::recvmmsgNetworkData( } CHECK(readBuffer != nullptr); - auto* rawAddr = reinterpret_cast(&addr); - rawAddr->sa_family = address().getFamily(); + auto localAddrResult = address(); + if (FOLLY_UNLIKELY(localAddrResult.hasError())) { + return folly::makeUnexpected(localAddrResult.error()); + } + auto* rawAddr = + reinterpret_cast(&addr); // Assuming addr is large enough + rawAddr->sa_family = localAddrResult->getFamily(); msg->msg_name = rawAddr; msg->msg_namelen = kAddrLen; #if defined(FOLLY_HAVE_MSG_ERRQUEUE) || defined(_WIN32) @@ -76,9 +100,17 @@ QuicAsyncUDPSocket::RecvResult QuicAsyncUDPSocketImpl::recvmmsgNetworkData( return RecvResult(NoReadReason::RETRIABLE_ERROR); } // If we got a non-retriable error, we might have received - // a packet that we could process, however let's just quit early. + // a packet that we could process, however let's just quit early. Pause read + // might fail too. pauseRead(); - return RecvResult(NoReadReason::NONRETRIABLE_ERROR); + // Return the error from recvmmsg itself + int errnoCopy = errno; + std::string errorMsg = "recvmmsg failed: " + folly::errnoStr(errnoCopy); + return folly::makeUnexpected(QuicError( + QuicErrorCode(TransportErrorCode::INTERNAL_ERROR), + std::move(errorMsg))); + // Original code returned RecvResult(NoReadReason::NONRETRIABLE_ERROR); + // Returning the actual error seems more informative. } // process msgs (packets) returned by recvmmsg @@ -168,7 +200,7 @@ QuicAsyncUDPSocket::RecvResult QuicAsyncUDPSocketImpl::recvmmsgNetworkData( } } - return {}; + return RecvResult(); // Success case } void QuicAsyncUDPSocketImpl::RecvmmsgStorage::resize(size_t numPackets) { diff --git a/quic/common/udpsocket/QuicAsyncUDPSocketImpl.h b/quic/common/udpsocket/QuicAsyncUDPSocketImpl.h index 21f8865b3..358ed64a4 100644 --- a/quic/common/udpsocket/QuicAsyncUDPSocketImpl.h +++ b/quic/common/udpsocket/QuicAsyncUDPSocketImpl.h @@ -13,7 +13,8 @@ namespace quic { class QuicAsyncUDPSocketImpl : public QuicAsyncUDPSocket { public: - QuicAsyncUDPSocket::RecvResult recvmmsgNetworkData( + folly::Expected + recvmmsgNetworkData( uint64_t readBufferSize, uint16_t numPackets, NetworkData& networkData, diff --git a/quic/common/udpsocket/test/QuicAsyncUDPSocketMock.h b/quic/common/udpsocket/test/QuicAsyncUDPSocketMock.h index ce605130e..d2b6eb7f9 100644 --- a/quic/common/udpsocket/test/QuicAsyncUDPSocketMock.h +++ b/quic/common/udpsocket/test/QuicAsyncUDPSocketMock.h @@ -11,14 +11,19 @@ #include namespace quic::test { - class QuicAsyncUDPSocketMock : public QuicAsyncUDPSocket { public: - MOCK_METHOD((void), init, (sa_family_t)); - MOCK_METHOD((void), bind, (const folly::SocketAddress&)); + MOCK_METHOD((folly::Expected), init, (sa_family_t)); + MOCK_METHOD( + (folly::Expected), + bind, + (const folly::SocketAddress&)); MOCK_METHOD((bool), isBound, (), (const)); - MOCK_METHOD((void), connect, (const folly::SocketAddress&)); - MOCK_METHOD((void), close, ()); + MOCK_METHOD( + (folly::Expected), + connect, + (const folly::SocketAddress&)); + MOCK_METHOD((folly::Expected), close, ()); MOCK_METHOD((void), resumeRead, (ReadCallback*)); MOCK_METHOD((void), pauseRead, ()); MOCK_METHOD( @@ -57,45 +62,70 @@ class QuicAsyncUDPSocketMock : public QuicAsyncUDPSocket { recvmmsg, (struct mmsghdr*, unsigned int, unsigned int, struct timespec*)); MOCK_METHOD( - (RecvResult), + (folly::Expected), recvmmsgNetworkData, (uint64_t, uint16_t, NetworkData&, Optional&, size_t&)); - MOCK_METHOD((int), getGSO, ()); - MOCK_METHOD((int), getGRO, ()); - MOCK_METHOD((bool), setGRO, (bool)); - MOCK_METHOD((const folly::SocketAddress&), setGSO, (), (const)); + MOCK_METHOD((folly::Expected), getGSO, ()); + MOCK_METHOD((folly::Expected), getGRO, ()); + MOCK_METHOD((folly::Expected), setGRO, (bool)); + MOCK_METHOD( + (folly::Expected), + address, + (), + (const)); + MOCK_METHOD((const folly::SocketAddress&), addressRef, (), (const)); MOCK_METHOD((void), attachEventBase, (std::shared_ptr)); MOCK_METHOD((void), detachEventBase, ()); MOCK_METHOD((std::shared_ptr), getEventBase, (), (const)); - MOCK_METHOD((void), setCmsgs, (const folly::SocketCmsgMap&)); - MOCK_METHOD((void), appendCmsgs, (const folly::SocketCmsgMap&)); MOCK_METHOD( - (void), + (folly::Expected), + setCmsgs, + (const folly::SocketCmsgMap&)); + MOCK_METHOD( + (folly::Expected), + appendCmsgs, + (const folly::SocketCmsgMap&)); + MOCK_METHOD( + (folly::Expected), setAdditionalCmsgsFunc, (folly::Function()>&&)); - MOCK_METHOD((int), getTimestamping, ()); - MOCK_METHOD((void), setReuseAddr, (bool)); - MOCK_METHOD((void), setDFAndTurnOffPMTU, (bool)); - MOCK_METHOD((void), setErrMessageCallback, (ErrMessageCallback*)); + MOCK_METHOD((folly::Expected), getTimestamping, ()); + MOCK_METHOD((folly::Expected), setReuseAddr, (bool)); MOCK_METHOD( - (void), + (folly::Expected), + setDFAndTurnOffPMTU, + ()); + MOCK_METHOD( + (folly::Expected), + setErrMessageCallback, + (ErrMessageCallback*)); + MOCK_METHOD( + (folly::Expected), applyOptions, (const folly::SocketOptionMap&, folly::SocketOptionKey::ApplyPos)); - MOCK_METHOD((void), setReusePort, (bool)); - MOCK_METHOD((void), setRcvBuf, (int)); - MOCK_METHOD((void), setSndBuf, (int)); - MOCK_METHOD((void), setFD, (int, FDOwnership)); + MOCK_METHOD((folly::Expected), setReusePort, (bool)); + MOCK_METHOD((folly::Expected), setRcvBuf, (int)); + MOCK_METHOD((folly::Expected), setSndBuf, (int)); + MOCK_METHOD( + (folly::Expected), + setFD, + (int, FDOwnership)); MOCK_METHOD((int), getFD, ()); - MOCK_METHOD((const folly::SocketAddress&), address, (), (const)); - MOCK_METHOD((void), setDFAndTurnOffPMTU, ()); - MOCK_METHOD((void), setRecvTos, (bool)); - MOCK_METHOD((bool), getRecvTos, ()); - MOCK_METHOD((void), setTosOrTrafficClass, (uint8_t)); - MOCK_METHOD((sa_family_t), getLocalAddressFamily, (), (const)); + MOCK_METHOD((folly::Expected), setRecvTos, (bool)); + MOCK_METHOD((folly::Expected), getRecvTos, ()); + MOCK_METHOD( + (folly::Expected), + setTosOrTrafficClass, + (uint8_t)); + MOCK_METHOD( + (folly::Expected), + getLocalAddressFamily, + (), + (const)); }; class MockErrMessageCallback diff --git a/quic/common/udpsocket/test/QuicAsyncUDPSocketTestBase.h b/quic/common/udpsocket/test/QuicAsyncUDPSocketTestBase.h index 2c9359c9c..ef66d8513 100644 --- a/quic/common/udpsocket/test/QuicAsyncUDPSocketTestBase.h +++ b/quic/common/udpsocket/test/QuicAsyncUDPSocketTestBase.h @@ -16,7 +16,7 @@ class QuicAsyncUDPSocketTestBase : public testing::Test { void SetUp() override { udpSocket_ = T::makeQuicAsyncUDPSocket(); addr_ = folly::SocketAddress("127.0.0.1", 0); - udpSocket_->bind(addr_); + CHECK(!udpSocket_->bind(addr_).hasError()); // For QUIC, we're only interested in the shouldOnlyNotify path. EXPECT_CALL(readCb_, shouldOnlyNotify()) @@ -40,7 +40,8 @@ TYPED_TEST_SUITE_P(QuicAsyncUDPSocketTest); TYPED_TEST_P(QuicAsyncUDPSocketTest, ErrToNonExistentServer) { #ifdef FOLLY_HAVE_MSG_ERRQUEUE this->udpSocket_->resumeRead(&this->readCb_); - this->udpSocket_->setErrMessageCallback(&this->errCb_); + ASSERT_FALSE( + this->udpSocket_->setErrMessageCallback(&this->errCb_).hasError()); folly::SocketAddress addr("127.0.0.1", 10000); bool errRecvd = false; @@ -76,8 +77,9 @@ TYPED_TEST_P(QuicAsyncUDPSocketTest, ErrToNonExistentServer) { TYPED_TEST_P(QuicAsyncUDPSocketTest, TestUnsetErrCallback) { #ifdef FOLLY_HAVE_MSG_ERRQUEUE this->udpSocket_->resumeRead(&this->readCb_); - this->udpSocket_->setErrMessageCallback(&this->errCb_); - this->udpSocket_->setErrMessageCallback(nullptr); + ASSERT_FALSE( + this->udpSocket_->setErrMessageCallback(&this->errCb_).hasError()); + ASSERT_FALSE(this->udpSocket_->setErrMessageCallback(nullptr).hasError()); folly::SocketAddress addr("127.0.0.1", 10000); EXPECT_CALL(this->errCb_, errMessage_(testing::_)).Times(0); EXPECT_CALL(this->readCb_, onNotifyDataAvailable_(testing::_)).Times(0); @@ -113,7 +115,8 @@ TYPED_TEST_P(QuicAsyncUDPSocketTest, TestUnsetErrCallback) { TYPED_TEST_P(QuicAsyncUDPSocketTest, CloseInErrorCallback) { #ifdef FOLLY_HAVE_MSG_ERRQUEUE this->udpSocket_->resumeRead(&this->readCb_); - this->udpSocket_->setErrMessageCallback(&this->errCb_); + ASSERT_FALSE( + this->udpSocket_->setErrMessageCallback(&this->errCb_).hasError()); folly::SocketAddress addr("127.0.0.1", 10000); bool errRecvd = false; @@ -123,7 +126,7 @@ TYPED_TEST_P(QuicAsyncUDPSocketTest, CloseInErrorCallback) { EXPECT_CALL(this->errCb_, errMessage_(testing::_)) .WillOnce(testing::Invoke([this, &errRecvd, &evb](auto&) { errRecvd = true; - this->udpSocket_->close(); + ASSERT_FALSE(this->udpSocket_->close().hasError()); evb->terminateLoopSoon(); })); diff --git a/quic/dsr/backend/DSRPacketizer.cpp b/quic/dsr/backend/DSRPacketizer.cpp index 6f8504d2a..68c8b984d 100644 --- a/quic/dsr/backend/DSRPacketizer.cpp +++ b/quic/dsr/backend/DSRPacketizer.cpp @@ -201,11 +201,13 @@ void UdpSocketPacketGroupWriter::rollback() { } bool UdpSocketPacketGroupWriter::send(uint32_t size) { - return ioBufBatch_.write(nullptr /* no need to pass buildBuf */, size); + auto result = ioBufBatch_.write(nullptr /* no need to pass buildBuf */, size); + CHECK(!result.hasError()); + return result.value(); } void UdpSocketPacketGroupWriter::flush() { - ioBufBatch_.flush(); + CHECK(!ioBufBatch_.flush().hasError()); } BufQuicBatchResult UdpSocketPacketGroupWriter::getResult() { diff --git a/quic/dsr/backend/test/DSRPacketizerTest.cpp b/quic/dsr/backend/test/DSRPacketizerTest.cpp index cd5e04250..8e3bf3820 100644 --- a/quic/dsr/backend/test/DSRPacketizerTest.cpp +++ b/quic/dsr/backend/test/DSRPacketizerTest.cpp @@ -92,7 +92,7 @@ TEST_F(DSRPacketizerSingleWriteTest, SingleWrite) { dcid.size() /* dcid */ + 1 /* stream frame initial byte */ + 1 /* stream id */ + length /* actual data */ + aead->getCipherOverhead()); - packetGroupWriter.getIOBufQuicBatch().flush(); + ASSERT_FALSE(packetGroupWriter.getIOBufQuicBatch().flush().hasError()); EXPECT_EQ(1, packetGroupWriter.getIOBufQuicBatch().getPktSent()); } @@ -122,7 +122,7 @@ TEST_F(DSRPacketizerSingleWriteTest, NotEnoughData) { eof, folly::IOBuf::copyBuffer("Clif")); EXPECT_FALSE(ret); - packetGroupWriter.getIOBufQuicBatch().flush(); + ASSERT_FALSE(packetGroupWriter.getIOBufQuicBatch().flush().hasError()); EXPECT_EQ(0, packetGroupWriter.getIOBufQuicBatch().getPktSent()); } @@ -208,8 +208,8 @@ TEST_F(DSRMultiWriteTest, TwoRequestsWithLoss) { requests, [](const PacketizationRequest& req) { return buildRandomInputData(req.len); }); - EXPECT_EQ(2, result.packetsSent); - EXPECT_EQ(2, sentData.size()); + ASSERT_EQ(2, result.packetsSent); + ASSERT_EQ(2, sentData.size()); EXPECT_GT(sentData[0]->computeChainDataLength(), 500); EXPECT_GT(sentData[1]->computeChainDataLength(), 500); } diff --git a/quic/fizz/client/test/QuicClientTransportTest.cpp b/quic/fizz/client/test/QuicClientTransportTest.cpp index e3c413533..03207d5d0 100644 --- a/quic/fizz/client/test/QuicClientTransportTest.cpp +++ b/quic/fizz/client/test/QuicClientTransportTest.cpp @@ -1102,6 +1102,46 @@ TEST_F(QuicClientTransportTest, onNetworkSwitchReplaceAfterHandshake) { auto newSocket = std::make_unique>(qEvb_); + ON_CALL(*newSocket, setReuseAddr(_)) + .WillByDefault(testing::Return(folly::unit)); + ON_CALL(*newSocket, setAdditionalCmsgsFunc(_)) + .WillByDefault(testing::Return(folly::unit)); + ON_CALL(*newSocket, setDFAndTurnOffPMTU()) + .WillByDefault(testing::Return(folly::unit)); + ON_CALL(*newSocket, setErrMessageCallback(testing::_)) + .WillByDefault(testing::Return(folly::unit)); + ON_CALL(*newSocket, setTosOrTrafficClass(_)) + .WillByDefault(testing::Return(folly::unit)); + ON_CALL(*newSocket, init(testing::_)) + .WillByDefault(testing::Return(folly::unit)); + ON_CALL(*newSocket, applyOptions(testing::_, testing::_)) + .WillByDefault(testing::Return(folly::unit)); + ON_CALL(*newSocket, bind(testing::_)) + .WillByDefault(testing::Return(folly::unit)); + ON_CALL(*newSocket, connect(testing::_)) + .WillByDefault(testing::Return(folly::unit)); + ON_CALL(*newSocket, close()).WillByDefault(testing::Return(folly::unit)); + ON_CALL(*newSocket, resumeWrite(testing::_)) + .WillByDefault(testing::Return(folly::unit)); + ON_CALL(*newSocket, setGRO(testing::_)) + .WillByDefault(testing::Return(folly::unit)); + ON_CALL(*newSocket, setRecvTos(testing::_)) + .WillByDefault(testing::Return(folly::unit)); + ON_CALL(*newSocket, getRecvTos()).WillByDefault(testing::Return(false)); + ON_CALL(*newSocket, setCmsgs(testing::_)) + .WillByDefault(testing::Return(folly::unit)); + ON_CALL(*newSocket, appendCmsgs(testing::_)) + .WillByDefault(testing::Return(folly::unit)); + ON_CALL(*newSocket, getTimestamping()).WillByDefault(testing::Return(0)); + ON_CALL(*newSocket, setReusePort(testing::_)) + .WillByDefault(testing::Return(folly::unit)); + ON_CALL(*newSocket, setRcvBuf(testing::_)) + .WillByDefault(testing::Return(folly::unit)); + ON_CALL(*newSocket, setSndBuf(testing::_)) + .WillByDefault(testing::Return(folly::unit)); + ON_CALL(*newSocket, setFD(testing::_, testing::_)) + .WillByDefault(testing::Return(folly::unit)); + auto newSocketPtr = newSocket.get(); EXPECT_CALL(*sock, pauseRead()); EXPECT_CALL(*sock, close()); @@ -1118,6 +1158,45 @@ TEST_F(QuicClientTransportTest, onNetworkSwitchReplaceAfterHandshake) { TEST_F(QuicClientTransportTest, onNetworkSwitchReplaceNoHandshake) { auto newSocket = std::make_unique>(qEvb_); + ON_CALL(*newSocket, setReuseAddr(_)) + .WillByDefault(testing::Return(folly::unit)); + ON_CALL(*newSocket, setAdditionalCmsgsFunc(_)) + .WillByDefault(testing::Return(folly::unit)); + ON_CALL(*newSocket, setDFAndTurnOffPMTU()) + .WillByDefault(testing::Return(folly::unit)); + ON_CALL(*newSocket, setErrMessageCallback(testing::_)) + .WillByDefault(testing::Return(folly::unit)); + ON_CALL(*newSocket, setTosOrTrafficClass(_)) + .WillByDefault(testing::Return(folly::unit)); + ON_CALL(*newSocket, init(testing::_)) + .WillByDefault(testing::Return(folly::unit)); + ON_CALL(*newSocket, applyOptions(testing::_, testing::_)) + .WillByDefault(testing::Return(folly::unit)); + ON_CALL(*newSocket, bind(testing::_)) + .WillByDefault(testing::Return(folly::unit)); + ON_CALL(*newSocket, connect(testing::_)) + .WillByDefault(testing::Return(folly::unit)); + ON_CALL(*newSocket, close()).WillByDefault(testing::Return(folly::unit)); + ON_CALL(*newSocket, resumeWrite(testing::_)) + .WillByDefault(testing::Return(folly::unit)); + ON_CALL(*newSocket, setGRO(testing::_)) + .WillByDefault(testing::Return(folly::unit)); + ON_CALL(*newSocket, setRecvTos(testing::_)) + .WillByDefault(testing::Return(folly::unit)); + ON_CALL(*newSocket, getRecvTos()).WillByDefault(testing::Return(false)); + ON_CALL(*newSocket, setCmsgs(testing::_)) + .WillByDefault(testing::Return(folly::unit)); + ON_CALL(*newSocket, appendCmsgs(testing::_)) + .WillByDefault(testing::Return(folly::unit)); + ON_CALL(*newSocket, getTimestamping()).WillByDefault(testing::Return(0)); + ON_CALL(*newSocket, setReusePort(testing::_)) + .WillByDefault(testing::Return(folly::unit)); + ON_CALL(*newSocket, setRcvBuf(testing::_)) + .WillByDefault(testing::Return(folly::unit)); + ON_CALL(*newSocket, setSndBuf(testing::_)) + .WillByDefault(testing::Return(folly::unit)); + ON_CALL(*newSocket, setFD(testing::_, testing::_)) + .WillByDefault(testing::Return(folly::unit)); auto newSocketPtr = newSocket.get(); auto mockQLogger = std::make_shared(VantagePoint::Client); EXPECT_CALL(*mockQLogger, addConnectionMigrationUpdate(true)).Times(0); @@ -1174,7 +1253,7 @@ TEST_F(QuicClientTransportTest, SocketClosedDuringOnTransportReady) { return getTotalIovecLen(vec, iovec_len); ; })); - ON_CALL(*sock, address()).WillByDefault(ReturnRef(serverAddr)); + ON_CALL(*sock, address()).WillByDefault(Return(serverAddr)); client->addNewPeerAddress(serverAddr); setupCryptoLayer(); @@ -1370,6 +1449,49 @@ class QuicClientTransportHappyEyeballsTest EXPECT_EQ(client->getConn().happyEyeballsState.v4PeerAddress, serverAddrV4); setupCryptoLayer(); + + ON_CALL(*secondSock, address()).WillByDefault(testing::Return(serverAddr)); + ON_CALL(*secondSock, setAdditionalCmsgsFunc(testing::_)) + .WillByDefault(testing::Return(folly::unit)); + ON_CALL(*secondSock, getGSO).WillByDefault(testing::Return(0)); + ON_CALL(*secondSock, getGRO).WillByDefault(testing::Return(0)); + ON_CALL(*secondSock, init(testing::_)) + .WillByDefault(testing::Return(folly::unit)); + ON_CALL(*secondSock, bind(testing::_)) + .WillByDefault(testing::Return(folly::unit)); + ON_CALL(*secondSock, connect(testing::_)) + .WillByDefault(testing::Return(folly::unit)); + ON_CALL(*secondSock, close()).WillByDefault(testing::Return(folly::unit)); + ON_CALL(*secondSock, resumeWrite(testing::_)) + .WillByDefault(testing::Return(folly::unit)); + ON_CALL(*secondSock, setGRO(testing::_)) + .WillByDefault(testing::Return(folly::unit)); + ON_CALL(*secondSock, setRecvTos(testing::_)) + .WillByDefault(testing::Return(folly::unit)); + ON_CALL(*secondSock, getRecvTos()).WillByDefault(testing::Return(false)); + ON_CALL(*secondSock, setTosOrTrafficClass(testing::_)) + .WillByDefault(testing::Return(folly::unit)); + ON_CALL(*secondSock, setCmsgs(testing::_)) + .WillByDefault(testing::Return(folly::unit)); + ON_CALL(*secondSock, appendCmsgs(testing::_)) + .WillByDefault(testing::Return(folly::unit)); + ON_CALL(*secondSock, getTimestamping()).WillByDefault(testing::Return(0)); + ON_CALL(*secondSock, setReuseAddr(testing::_)) + .WillByDefault(testing::Return(folly::unit)); + ON_CALL(*secondSock, setDFAndTurnOffPMTU()) + .WillByDefault(testing::Return(folly::unit)); + ON_CALL(*secondSock, setErrMessageCallback(testing::_)) + .WillByDefault(testing::Return(folly::unit)); + ON_CALL(*secondSock, applyOptions(testing::_, testing::_)) + .WillByDefault(testing::Return(folly::unit)); + ON_CALL(*secondSock, setReusePort(testing::_)) + .WillByDefault(testing::Return(folly::unit)); + ON_CALL(*secondSock, setRcvBuf(testing::_)) + .WillByDefault(testing::Return(folly::unit)); + ON_CALL(*secondSock, setSndBuf(testing::_)) + .WillByDefault(testing::Return(folly::unit)); + ON_CALL(*secondSock, setFD(testing::_, testing::_)) + .WillByDefault(testing::Return(folly::unit)); } protected: @@ -1607,8 +1729,10 @@ class QuicClientTransportHappyEyeballsTest EXPECT_CALL(*sock, write(firstAddress, _, _)); EXPECT_CALL(*secondSock, bind(_)) - .WillOnce(Invoke( - [](const folly::SocketAddress&) { throw std::exception(); })); + .WillOnce(Invoke([](const folly::SocketAddress&) { + return folly::makeUnexpected( + QuicError(TransportErrorCode::INTERNAL_ERROR, "oopsies")); + })); client->start(&clientConnSetupCallback, &clientConnCallback); EXPECT_EQ(conn.peerAddress, firstAddress); EXPECT_EQ(conn.happyEyeballsState.secondPeerAddress, secondAddress); @@ -2396,41 +2520,6 @@ TEST_F(QuicClientTransportAfterStartTest, RetriableErrorLoopCounting) { client->invokeOnNotifyDataAvailable(*sock); } -TEST_F(QuicClientTransportAfterStartTest, ReadLoopTwice) { - auto& conn = client->getNonConstConn(); - auto mockLoopDetectorCallback = std::make_unique(); - auto rawLoopDetectorCallback = mockLoopDetectorCallback.get(); - conn.loopDetectorCallback = std::move(mockLoopDetectorCallback); - - conn.transportSettings.maxRecvBatchSize = 1; - socketReads.emplace_back(TestReadData(EBADF)); - EXPECT_CALL( - *rawLoopDetectorCallback, - onSuspiciousReadLoops(1, NoReadReason::NONRETRIABLE_ERROR)); - client->invokeOnNotifyDataAvailable(*sock); - socketReads.clear(); - - socketReads.emplace_back(TestReadData(EBADF)); - EXPECT_CALL( - *rawLoopDetectorCallback, - onSuspiciousReadLoops(2, NoReadReason::NONRETRIABLE_ERROR)); - client->invokeOnNotifyDataAvailable(*sock); -} - -TEST_F(QuicClientTransportAfterStartTest, NonretriableErrorLoopCounting) { - auto& conn = client->getNonConstConn(); - auto mockLoopDetectorCallback = std::make_unique(); - auto rawLoopDetectorCallback = mockLoopDetectorCallback.get(); - conn.loopDetectorCallback = std::move(mockLoopDetectorCallback); - - conn.transportSettings.maxRecvBatchSize = 1; - socketReads.emplace_back(TestReadData(EBADF)); - EXPECT_CALL( - *rawLoopDetectorCallback, - onSuspiciousReadLoops(1, NoReadReason::NONRETRIABLE_ERROR)); - client->invokeOnNotifyDataAvailable(*sock); -} - TEST_F(QuicClientTransportAfterStartTest, PartialReadLoopCounting) { auto streamId = client->createBidirectionalStream().value(); auto& conn = client->getNonConstConn(); @@ -2457,37 +2546,6 @@ TEST_F(QuicClientTransportAfterStartTest, PartialReadLoopCounting) { client->invokeOnNotifyDataAvailable(*sock); } -TEST_F(QuicClientTransportAfterStartTest, ReadLoopCountingRecvmmsg) { - auto& conn = client->getNonConstConn(); - auto mockLoopDetectorCallback = std::make_unique(); - auto rawLoopDetectorCallback = mockLoopDetectorCallback.get(); - conn.loopDetectorCallback = std::move(mockLoopDetectorCallback); - - conn.transportSettings.shouldUseRecvmmsgForBatchRecv = true; - conn.transportSettings.maxRecvBatchSize = 1; - EXPECT_CALL(*sock, recvmmsg(_, 1, _, nullptr)) - .WillOnce(Invoke( - [](struct mmsghdr*, unsigned int, unsigned int, struct timespec*) { - errno = EAGAIN; - return -1; - })); - EXPECT_CALL( - *rawLoopDetectorCallback, - onSuspiciousReadLoops(1, NoReadReason::RETRIABLE_ERROR)); - client->invokeOnNotifyDataAvailable(*sock); - - EXPECT_CALL(*sock, recvmmsg(_, 1, _, nullptr)) - .WillOnce(Invoke( - [](struct mmsghdr*, unsigned int, unsigned int, struct timespec*) { - errno = EBADF; - return -1; - })); - EXPECT_CALL( - *rawLoopDetectorCallback, - onSuspiciousReadLoops(2, NoReadReason::NONRETRIABLE_ERROR)); - client->invokeOnNotifyDataAvailable(*sock); -} - TEST_F(QuicClientTransportAfterStartTest, ReadStreamMultiplePackets) { StreamId streamId = client->createBidirectionalStream().value(); @@ -5447,7 +5505,50 @@ class QuicZeroRttHappyEyeballsClientTransportTest auto secondSocket = std::make_unique>(qEvb_); secondSock = secondSocket.get(); - + ON_CALL(*secondSock, address()).WillByDefault(testing::Return(serverAddr)); + ON_CALL(*secondSock, setAdditionalCmsgsFunc(testing::_)) + .WillByDefault(testing::Return(folly::unit)); + ON_CALL(*secondSock, getGSO).WillByDefault(testing::Return(0)); + ON_CALL(*secondSock, getGRO).WillByDefault(testing::Return(0)); + ON_CALL(*secondSock, init(testing::_)) + .WillByDefault(testing::Return(folly::unit)); + ON_CALL(*secondSock, bind(testing::_)) + .WillByDefault(testing::Return(folly::unit)); + ON_CALL(*secondSock, connect(testing::_)) + .WillByDefault(testing::Return(folly::unit)); + ON_CALL(*secondSock, close()).WillByDefault(testing::Return(folly::unit)); + ON_CALL(*secondSock, resumeWrite(testing::_)) + .WillByDefault(testing::Return(folly::unit)); + ON_CALL(*secondSock, setGRO(testing::_)) + .WillByDefault(testing::Return(folly::unit)); + ON_CALL(*secondSock, setRecvTos(testing::_)) + .WillByDefault(testing::Return(folly::unit)); + ON_CALL(*secondSock, getRecvTos()).WillByDefault(testing::Return(false)); + ON_CALL(*secondSock, setTosOrTrafficClass(testing::_)) + .WillByDefault(testing::Return(folly::unit)); + ON_CALL(*secondSock, setCmsgs(testing::_)) + .WillByDefault(testing::Return(folly::unit)); + ON_CALL(*secondSock, appendCmsgs(testing::_)) + .WillByDefault(testing::Return(folly::unit)); + ON_CALL(*secondSock, getTimestamping()).WillByDefault(testing::Return(0)); + ON_CALL(*secondSock, setReuseAddr(testing::_)) + .WillByDefault(testing::Return(folly::unit)); + ON_CALL(*secondSock, setDFAndTurnOffPMTU()) + .WillByDefault(testing::Return(folly::unit)); + ON_CALL(*secondSock, setAdditionalCmsgsFunc(testing::_)) + .WillByDefault(testing::Return(folly::unit)); + ON_CALL(*secondSock, setErrMessageCallback(testing::_)) + .WillByDefault(testing::Return(folly::unit)); + ON_CALL(*secondSock, applyOptions(testing::_, testing::_)) + .WillByDefault(testing::Return(folly::unit)); + ON_CALL(*secondSock, setReusePort(testing::_)) + .WillByDefault(testing::Return(folly::unit)); + ON_CALL(*secondSock, setRcvBuf(testing::_)) + .WillByDefault(testing::Return(folly::unit)); + ON_CALL(*secondSock, setSndBuf(testing::_)) + .WillByDefault(testing::Return(folly::unit)); + ON_CALL(*secondSock, setFD(testing::_, testing::_)) + .WillByDefault(testing::Return(folly::unit)); client->setHappyEyeballsEnabled(true); client->addNewPeerAddress(firstAddress); client->addNewPeerAddress(secondAddress); @@ -5988,15 +6089,16 @@ TEST(AsyncUDPSocketTest, CloseMultipleTimes) { TransportSettings transportSettings; EmptyErrMessageCallback errMessageCallback; EmptyReadCallback readCallback; - happyEyeballsSetUpSocket( - socket, - none, - folly::SocketAddress("127.0.0.1", 12345), - transportSettings, - 0, // tosValue - &errMessageCallback, - &readCallback, - folly::emptySocketOptionMap); + ASSERT_FALSE(happyEyeballsSetUpSocket( + socket, + none, + folly::SocketAddress("127.0.0.1", 12345), + transportSettings, + 0, // tosValue + &errMessageCallback, + &readCallback, + folly::emptySocketOptionMap) + .hasError()); socket.pauseRead(); socket.close(); diff --git a/quic/fizz/client/test/QuicClientTransportTestUtil.h b/quic/fizz/client/test/QuicClientTransportTestUtil.h index 674fb8a3c..271754511 100644 --- a/quic/fizz/client/test/QuicClientTransportTestUtil.h +++ b/quic/fizz/client/test/QuicClientTransportTestUtil.h @@ -440,6 +440,9 @@ class QuicClientTransportTestBase : public virtual testing::Test { std::make_unique>( qEvb_); sock = socket.get(); + EXPECT_CALL(*sock, setAdditionalCmsgsFunc(testing::_)) + .WillRepeatedly(testing::Return(folly::unit)); + EXPECT_CALL(*sock, close()).WillRepeatedly(testing::Return(folly::unit)); client = TestingQuicClientTransport::newClient( qEvb_, std::move(socket), getFizzClientContext()); @@ -454,7 +457,7 @@ class QuicClientTransportTestBase : public virtual testing::Test { connIdAlgo_ = std::make_unique(); ON_CALL(*sock, resumeRead(testing::_)) .WillByDefault(testing::SaveArg<0>(&networkReadCallback)); - ON_CALL(*sock, address()).WillByDefault(testing::ReturnRef(serverAddr)); + ON_CALL(*sock, address()).WillByDefault(testing::Return(serverAddr)); ON_CALL(*sock, recvmsg(testing::_, testing::_)) .WillByDefault(testing::Invoke([&](struct msghdr* msg, int) -> ssize_t { DCHECK_GT(msg->msg_iovlen, 0); @@ -488,6 +491,44 @@ class QuicClientTransportTestBase : public virtual testing::Test { return testDataLen; })); ON_CALL(*sock, getRecvTos()).WillByDefault(testing::Return(true)); + ON_CALL(*sock, getGSO()).WillByDefault(testing::Return(0)); + ON_CALL(*sock, getGRO()).WillByDefault(testing::Return(0)); + ON_CALL(*sock, getTimestamping()).WillByDefault(testing::Return(0)); + ON_CALL(*sock, setTosOrTrafficClass(testing::_)) + .WillByDefault(testing::Return(folly::unit)); + ON_CALL(*sock, init(testing::_)) + .WillByDefault(testing::Return(folly::unit)); + ON_CALL(*sock, bind(testing::_)) + .WillByDefault(testing::Return(folly::unit)); + ON_CALL(*sock, connect(testing::_)) + .WillByDefault(testing::Return(folly::unit)); + ON_CALL(*sock, close()).WillByDefault(testing::Return(folly::unit)); + ON_CALL(*sock, setReuseAddr(testing::_)) + .WillByDefault(testing::Return(folly::unit)); + ON_CALL(*sock, setDFAndTurnOffPMTU()) + .WillByDefault(testing::Return(folly::unit)); + ON_CALL(*sock, setErrMessageCallback(testing::_)) + .WillByDefault(testing::Return(folly::unit)); + ON_CALL(*sock, applyOptions(testing::_, testing::_)) + .WillByDefault(testing::Return(folly::unit)); + ON_CALL(*sock, setReusePort(testing::_)) + .WillByDefault(testing::Return(folly::unit)); + ON_CALL(*sock, setRcvBuf(testing::_)) + .WillByDefault(testing::Return(folly::unit)); + ON_CALL(*sock, setSndBuf(testing::_)) + .WillByDefault(testing::Return(folly::unit)); + ON_CALL(*sock, resumeWrite(testing::_)) + .WillByDefault(testing::Return(folly::unit)); + ON_CALL(*sock, setGRO(testing::_)) + .WillByDefault(testing::Return(folly::unit)); + ON_CALL(*sock, setRecvTos(testing::_)) + .WillByDefault(testing::Return(folly::unit)); + ON_CALL(*sock, setCmsgs(testing::_)) + .WillByDefault(testing::Return(folly::unit)); + ON_CALL(*sock, appendCmsgs(testing::_)) + .WillByDefault(testing::Return(folly::unit)); + ON_CALL(*sock, setFD(testing::_, testing::_)) + .WillByDefault(testing::Return(folly::unit)); EXPECT_EQ(client->getConn().selfConnectionIds.size(), 1); EXPECT_EQ( client->getConn().selfConnectionIds[0].connId, @@ -511,7 +552,7 @@ class QuicClientTransportTestBase : public virtual testing::Test { copyChain(folly::IOBuf::wrapIov(vec, iovec_len))); return getTotalIovecLen(vec, iovec_len); })); - ON_CALL(*sock, address()).WillByDefault(testing::ReturnRef(serverAddr)); + ON_CALL(*sock, address()).WillByDefault(testing::Return(serverAddr)); setupCryptoLayer(); start(); diff --git a/quic/happyeyeballs/QuicHappyEyeballsFunctions.cpp b/quic/happyeyeballs/QuicHappyEyeballsFunctions.cpp index 7233eebdf..697122e8e 100644 --- a/quic/happyeyeballs/QuicHappyEyeballsFunctions.cpp +++ b/quic/happyeyeballs/QuicHappyEyeballsFunctions.cpp @@ -93,18 +93,16 @@ void startHappyEyeballs( evb->scheduleTimeout(&connAttemptDelayTimeout, connAttempDelay); - try { - happyEyeballsSetUpSocket( - *connection.happyEyeballsState.secondSocket, - connection.localAddress, - connection.happyEyeballsState.secondPeerAddress, - connection.transportSettings, - connection.socketTos.value, - errMsgCallback, - readCallback, - options); - } catch (const std::exception&) { - // If second socket bind throws exception, give it up + auto res = happyEyeballsSetUpSocket( + *connection.happyEyeballsState.secondSocket, + connection.localAddress, + connection.happyEyeballsState.secondPeerAddress, + connection.transportSettings, + connection.socketTos.value, + errMsgCallback, + readCallback, + options); + if (res.hasError()) { connAttemptDelayTimeout.cancelTimerCallback(); connection.happyEyeballsState.finished = true; } @@ -121,7 +119,7 @@ void startHappyEyeballs( } } -void happyEyeballsSetUpSocket( +folly::Expected happyEyeballsSetUpSocket( QuicAsyncUDPSocket& socket, Optional localAddress, const folly::SocketAddress& peerAddress, @@ -131,48 +129,88 @@ void happyEyeballsSetUpSocket( QuicAsyncUDPSocket::ReadCallback* readCallback, const folly::SocketOptionMap& options) { auto sockFamily = localAddress.value_or(peerAddress).getFamily(); - socket.setReuseAddr(false); + auto result = socket.setReuseAddr(false); + if (!result) { + return folly::makeUnexpected(result.error()); + } if (transportSettings.readEcnOnIngress) { - socket.setRecvTos(true); + result = socket.setRecvTos(true); + if (!result) { + return folly::makeUnexpected(result.error()); + } } - auto initSockAndApplyOpts = [&]() { - socket.init(sockFamily); - applySocketOptions( + auto initSockAndApplyOpts = [&]() -> folly::Expected { + auto initResult = socket.init(sockFamily); + if (!initResult) { + return folly::makeUnexpected(initResult.error()); + } + auto applyResult = applySocketOptions( socket, options, sockFamily, folly::SocketOptionKey::ApplyPos::PRE_BIND); + if (!applyResult) { + return folly::makeUnexpected(applyResult.error()); + } + return folly::unit; }; if (localAddress.has_value()) { - initSockAndApplyOpts(); - socket.bind(*localAddress); + auto initResult = initSockAndApplyOpts(); + if (!initResult) { + return folly::makeUnexpected(initResult.error()); + } + result = socket.bind(*localAddress); + if (!result) { + return folly::makeUnexpected(result.error()); + } } if (transportSettings.connectUDP) { - initSockAndApplyOpts(); - socket.connect(peerAddress); + auto initResult = initSockAndApplyOpts(); + if (!initResult) { + return folly::makeUnexpected(initResult.error()); + } + result = socket.connect(peerAddress); + if (!result) { + return folly::makeUnexpected(result.error()); + } } if (!socket.isBound()) { auto addr = folly::SocketAddress( peerAddress.getFamily() == AF_INET ? "0.0.0.0" : "::", 0); - initSockAndApplyOpts(); - socket.bind(addr); + auto initResult = initSockAndApplyOpts(); + if (!initResult) { + return folly::makeUnexpected(initResult.error()); + } + result = socket.bind(addr); + if (!result) { + return folly::makeUnexpected(result.error()); + } } // This is called before applySocketOptions to allow the configured socket // options to override the ToS value from transport settings. This is // necessary for applications that currently rely on configuring DSCP through // socket options directly. - socket.setTosOrTrafficClass(socketTos); + result = socket.setTosOrTrafficClass(socketTos); + if (!result) { + return folly::makeUnexpected(result.error()); + } - applySocketOptions( + auto applyResult = applySocketOptions( socket, options, sockFamily, folly::SocketOptionKey::ApplyPos::POST_BIND); + if (!applyResult) { + return folly::makeUnexpected(applyResult.error()); + } #ifdef SO_NOSIGPIPE folly::SocketOptionKey nopipeKey = {SOL_SOCKET, SO_NOSIGPIPE}; if (!options.count(nopipeKey)) { - socket.applyOptions( + result = socket.applyOptions( {{nopipeKey, 1}}, folly::SocketOptionKey::ApplyPos::POST_BIND); + if (!result) { + return folly::makeUnexpected(result.error()); + } } #endif @@ -181,13 +219,20 @@ void happyEyeballsSetUpSocket( } // never fragment, always turn off PMTU - socket.setDFAndTurnOffPMTU(); + result = socket.setDFAndTurnOffPMTU(); + if (!result) { + return folly::makeUnexpected(result.error()); + } if (transportSettings.enableSocketErrMsgCallback) { - socket.setErrMessageCallback(errMsgCallback); + result = socket.setErrMessageCallback(errMsgCallback); + if (!result) { + return folly::makeUnexpected(result.error()); + } } socket.resumeRead(readCallback); + return folly::unit; } void happyEyeballsStartSecondSocket( @@ -217,7 +262,7 @@ void happyEyeballsOnDataReceived( connection.peerAddress = peerAddress; } connection.happyEyeballsState.secondSocket->pauseRead(); - connection.happyEyeballsState.secondSocket->close(); + (void)connection.happyEyeballsState.secondSocket->close(); connection.happyEyeballsState.secondSocket.reset(); } diff --git a/quic/happyeyeballs/QuicHappyEyeballsFunctions.h b/quic/happyeyeballs/QuicHappyEyeballsFunctions.h index 2da807a70..caca5993e 100644 --- a/quic/happyeyeballs/QuicHappyEyeballsFunctions.h +++ b/quic/happyeyeballs/QuicHappyEyeballsFunctions.h @@ -43,7 +43,7 @@ void startHappyEyeballs( QuicAsyncUDPSocket::ReadCallback* readCallback, const folly::SocketOptionMap& options); -void happyEyeballsSetUpSocket( +[[nodiscard]] folly::Expected happyEyeballsSetUpSocket( QuicAsyncUDPSocket& socket, Optional localAddress, const folly::SocketAddress& peerAddress, diff --git a/quic/loss/test/QuicLossFunctionsTest.cpp b/quic/loss/test/QuicLossFunctionsTest.cpp index ec855f961..a9193dbc6 100644 --- a/quic/loss/test/QuicLossFunctionsTest.cpp +++ b/quic/loss/test/QuicLossFunctionsTest.cpp @@ -499,6 +499,7 @@ TEST_F(QuicLossFunctionsTest, TestMarkPacketLoss) { folly::EventBase evb; auto qEvb = std::make_shared(&evb); MockAsyncUDPSocket socket(qEvb); + ON_CALL(socket, getGSO).WillByDefault(testing::Return(0)); auto conn = createConn(); EXPECT_CALL(*quicStats_, onNewQuicStream()).Times(2); auto stream1Id = @@ -591,6 +592,7 @@ TEST_F(QuicLossFunctionsTest, TestMarkPacketLossMerge) { folly::EventBase evb; auto qEvb = std::make_shared(&evb); MockAsyncUDPSocket socket(qEvb); + ON_CALL(socket, getGSO).WillByDefault(testing::Return(0)); auto conn = createConn(); EXPECT_CALL(*quicStats_, onNewQuicStream()).Times(1); auto stream1Id = @@ -651,6 +653,7 @@ TEST_F(QuicLossFunctionsTest, TestMarkPacketLossNoMerge) { folly::EventBase evb; auto qEvb = std::make_shared(&evb); MockAsyncUDPSocket socket(qEvb); + ON_CALL(socket, getGSO).WillByDefault(testing::Return(0)); auto conn = createConn(); EXPECT_CALL(*quicStats_, onNewQuicStream()).Times(1); auto stream1Id = @@ -732,6 +735,7 @@ TEST_F(QuicLossFunctionsTest, RetxBufferSortedAfterLoss) { folly::EventBase evb; auto qEvb = std::make_shared(&evb); MockAsyncUDPSocket socket(qEvb); + ON_CALL(socket, getGSO).WillByDefault(testing::Return(0)); auto conn = createConn(); auto stream = conn->streamManager->createNextBidirectionalStream().value(); auto buf1 = IOBuf::copyBuffer("Worse case scenario"); @@ -770,6 +774,7 @@ TEST_F(QuicLossFunctionsTest, TestMarkPacketLossAfterStreamReset) { folly::EventBase evb; auto qEvb = std::make_shared(&evb); MockAsyncUDPSocket socket(qEvb); + ON_CALL(socket, getGSO).WillByDefault(testing::Return(0)); auto conn = createConn(); auto stream1 = conn->streamManager->createNextBidirectionalStream().value(); auto buf = buildRandomInputData(20); @@ -965,6 +970,7 @@ TEST_F(QuicLossFunctionsTest, TestMarkRstLoss) { folly::EventBase evb; auto qEvb = std::make_shared(&evb); MockAsyncUDPSocket socket(qEvb); + ON_CALL(socket, getGSO).WillByDefault(testing::Return(0)); auto stream = conn->streamManager->createNextBidirectionalStream().value(); auto currentOffset = stream->currentWriteOffset; @@ -1066,6 +1072,7 @@ TEST_F(QuicLossFunctionsTest, TestMarkWindowUpdateLoss) { folly::EventBase evb; auto qEvb = std::make_shared(&evb); MockAsyncUDPSocket socket(qEvb); + ON_CALL(socket, getGSO).WillByDefault(testing::Return(0)); auto stream = conn->streamManager->createNextBidirectionalStream().value(); conn->streamManager->queueWindowUpdate(stream->id); @@ -1597,6 +1604,7 @@ TEST_F(QuicLossFunctionsTest, DetectPacketLossClonedPacketsCounter) { TEST_F(QuicLossFunctionsTest, TestMarkPacketLossProcessedPacket) { auto qEvb = std::make_shared(&evb); MockAsyncUDPSocket socket(qEvb); + ON_CALL(socket, getGSO).WillByDefault(testing::Return(0)); auto conn = createConn(); ASSERT_TRUE(conn->outstandings.packets.empty()); ASSERT_TRUE(conn->outstandings.clonedPacketIdentifiers.empty()); @@ -3355,6 +3363,7 @@ TEST_F(QuicLossFunctionsTest, TestMarkPacketLossRetransmissionDisabled) { folly::EventBase evb; auto qEvb = std::make_shared(&evb); MockAsyncUDPSocket socket(qEvb); + ON_CALL(socket, getGSO).WillByDefault(testing::Return(0)); auto conn = createConn(); conn->transportSettings.advertisedMaxStreamGroups = 16; @@ -3409,6 +3418,7 @@ TEST_F( folly::EventBase evb; auto qEvb = std::make_shared(&evb); MockAsyncUDPSocket socket(qEvb); + ON_CALL(socket, getGSO).WillByDefault(testing::Return(0)); auto conn = createConn(); conn->transportSettings.advertisedMaxStreamGroups = 16; @@ -3461,6 +3471,7 @@ TEST_F(QuicLossFunctionsTest, TestMarkPacketLossRetransmissionPolicyTwoGroups) { folly::EventBase evb; auto qEvb = std::make_shared(&evb); MockAsyncUDPSocket socket(qEvb); + ON_CALL(socket, getGSO).WillByDefault(testing::Return(0)); auto conn = createConn(); conn->transportSettings.advertisedMaxStreamGroups = 16; @@ -3522,6 +3533,7 @@ TEST_F( folly::EventBase evb; auto qEvb = std::make_shared(&evb); MockAsyncUDPSocket socket(qEvb); + ON_CALL(socket, getGSO).WillByDefault(testing::Return(0)); auto conn = createConn(); conn->transportSettings.advertisedMaxStreamGroups = 16; diff --git a/quic/server/QuicServerTransport.cpp b/quic/server/QuicServerTransport.cpp index 0d9ae7c89..689e26920 100644 --- a/quic/server/QuicServerTransport.cpp +++ b/quic/server/QuicServerTransport.cpp @@ -67,7 +67,9 @@ QuicServerTransport::QuicServerTransport( .setFizzServerContext(ctx_) .setCryptoFactory(std::move(cryptoFactory)) .build()); - tempConn->serverAddr = socket_->address(); + auto addrResult = socket_->address(); + CHECK(addrResult.hasValue()); + tempConn->serverAddr = addrResult.value(); serverConn_ = tempConn.get(); conn_.reset(tempConn.release()); conn_->observerContainer = wrappedObserverContainer_.getWeakPtr(); @@ -465,7 +467,7 @@ void QuicServerTransport::onCryptoEventAvailable() noexcept { VLOG(4) << "onCryptoEventAvailable() error " << ex.what() << " " << *this; closeImpl(QuicError(QuicErrorCode(ex.errorCode()), std::string(ex.what()))); } catch (const std::exception& ex) { - VLOG(4) << "read() error " << ex.what() << " " << *this; + LOG(ERROR) << "read() error " << ex.what() << " " << *this; closeImpl(QuicError( QuicErrorCode(TransportErrorCode::INTERNAL_ERROR), std::string(ex.what()))); diff --git a/quic/server/QuicServerWorker.cpp b/quic/server/QuicServerWorker.cpp index 3a45d3d15..859ca63b1 100644 --- a/quic/server/QuicServerWorker.cpp +++ b/quic/server/QuicServerWorker.cpp @@ -92,6 +92,9 @@ void QuicServerWorker::setSocket( void QuicServerWorker::bind( const folly::SocketAddress& address, FollyAsyncUDPSocketAlias::BindOptions bindOptions) { + // TODO get rid of the temporary wrapper + FollyQuicAsyncUDPSocket tmpSock( + std::make_shared(evb_.get()), *socket_); DCHECK(!supportedVersions_.empty()); CHECK(socket_); switch (setEventCallback_) { @@ -108,7 +111,7 @@ void QuicServerWorker::bind( // bind, since bind creates the fd. if (socketOptions_) { applySocketOptions( - *socket_.get(), + tmpSock, *socketOptions_, address.getFamily(), folly::SocketOptionKey::ApplyPos::PRE_BIND); @@ -119,7 +122,7 @@ void QuicServerWorker::bind( socket_->bind(address, bindOptions); if (socketOptions_) { applySocketOptions( - *socket_.get(), + tmpSock, *socketOptions_, address.getFamily(), folly::SocketOptionKey::ApplyPos::POST_BIND); @@ -146,14 +149,17 @@ void QuicServerWorker::bind( void QuicServerWorker::applyAllSocketOptions() { CHECK(socket_); + // TODO get rid of the temporary wrapper + FollyQuicAsyncUDPSocket tmpSock( + std::make_shared(evb_.get()), *socket_); if (socketOptions_) { applySocketOptions( - *socket_, + tmpSock, *socketOptions_, getAddress().getFamily(), folly::SocketOptionKey::ApplyPos::PRE_BIND); applySocketOptions( - *socket_, + tmpSock, *socketOptions_, getAddress().getFamily(), folly::SocketOptionKey::ApplyPos::POST_BIND); diff --git a/quic/server/test/QuicServerTransportTestUtil.h b/quic/server/test/QuicServerTransportTestUtil.h index 2c494b0ce..63978ed31 100644 --- a/quic/server/test/QuicServerTransportTestUtil.h +++ b/quic/server/test/QuicServerTransportTestUtil.h @@ -156,8 +156,48 @@ class QuicServerTransportTestBase : public virtual testing::Test { copyChain(folly::IOBuf::wrapIov(vec, iovec_len))); return getTotalIovecLen(vec, iovec_len); })); - EXPECT_CALL(*sock, address()) - .WillRepeatedly(testing::ReturnRef(serverAddr)); + ON_CALL(*sock, address()).WillByDefault(testing::Return(serverAddr)); + ON_CALL(*sock, setAdditionalCmsgsFunc(testing::_)) + .WillByDefault(testing::Return(folly::unit)); + ON_CALL(*sock, getGSO).WillByDefault(testing::Return(0)); + ON_CALL(*sock, getGRO).WillByDefault(testing::Return(0)); + ON_CALL(*sock, init(testing::_)) + .WillByDefault(testing::Return(folly::unit)); + ON_CALL(*sock, bind(testing::_)) + .WillByDefault(testing::Return(folly::unit)); + ON_CALL(*sock, connect(testing::_)) + .WillByDefault(testing::Return(folly::unit)); + ON_CALL(*sock, close()).WillByDefault(testing::Return(folly::unit)); + ON_CALL(*sock, resumeWrite(testing::_)) + .WillByDefault(testing::Return(folly::unit)); + ON_CALL(*sock, setGRO(testing::_)) + .WillByDefault(testing::Return(folly::unit)); + ON_CALL(*sock, setRecvTos(testing::_)) + .WillByDefault(testing::Return(folly::unit)); + ON_CALL(*sock, getRecvTos()).WillByDefault(testing::Return(false)); + ON_CALL(*sock, setTosOrTrafficClass(testing::_)) + .WillByDefault(testing::Return(folly::unit)); + ON_CALL(*sock, setCmsgs(testing::_)) + .WillByDefault(testing::Return(folly::unit)); + ON_CALL(*sock, appendCmsgs(testing::_)) + .WillByDefault(testing::Return(folly::unit)); + ON_CALL(*sock, getTimestamping()).WillByDefault(testing::Return(0)); + ON_CALL(*sock, setReuseAddr(testing::_)) + .WillByDefault(testing::Return(folly::unit)); + ON_CALL(*sock, setDFAndTurnOffPMTU()) + .WillByDefault(testing::Return(folly::unit)); + ON_CALL(*sock, setErrMessageCallback(testing::_)) + .WillByDefault(testing::Return(folly::unit)); + ON_CALL(*sock, applyOptions(testing::_, testing::_)) + .WillByDefault(testing::Return(folly::unit)); + ON_CALL(*sock, setReusePort(testing::_)) + .WillByDefault(testing::Return(folly::unit)); + ON_CALL(*sock, setRcvBuf(testing::_)) + .WillByDefault(testing::Return(folly::unit)); + ON_CALL(*sock, setSndBuf(testing::_)) + .WillByDefault(testing::Return(folly::unit)); + ON_CALL(*sock, setFD(testing::_, testing::_)) + .WillByDefault(testing::Return(folly::unit)); supportedVersions = {QuicVersion::MVFST}; serverCtx = createServerCtx(); connIdAlgo_ = std::make_unique(); diff --git a/quic/state/stream/test/StreamStateMachineTest.cpp b/quic/state/stream/test/StreamStateMachineTest.cpp index 4e7de514a..437c5cb02 100644 --- a/quic/state/stream/test/StreamStateMachineTest.cpp +++ b/quic/state/stream/test/StreamStateMachineTest.cpp @@ -163,6 +163,7 @@ TEST_F(QuicOpenStateTest, AckStream) { EventBase evb; auto qEvb = std::make_shared(&evb); auto sock = std::make_unique(qEvb); + ON_CALL(*sock, getGSO).WillByDefault(testing::Return(0)); auto buf = IOBuf::copyBuffer("hello"); writeQuicPacket( @@ -200,6 +201,7 @@ TEST_F(QuicOpenStateTest, AckStreamMulti) { EventBase evb; auto qEvb = std::make_shared(&evb); auto sock = std::make_unique(qEvb); + ON_CALL(*sock, getGSO).WillByDefault(testing::Return(0)); auto buf = IOBuf::copyBuffer("hello"); writeQuicPacket( @@ -264,6 +266,7 @@ TEST_F(QuicOpenStateTest, RetxBufferSortedAfterAck) { EventBase evb; auto qEvb = std::make_shared(&evb); quic::test::MockAsyncUDPSocket socket(qEvb); + ON_CALL(socket, getGSO).WillByDefault(testing::Return(0)); Optional serverChosenConnId = *conn->clientConnectionId; serverChosenConnId.value().data()[0] ^= 0x01; @@ -581,6 +584,7 @@ TEST_F(QuicHalfClosedRemoteStateTest, AckStream) { EventBase evb; auto qEvb = std::make_shared(&evb); auto sock = std::make_unique(qEvb); + ON_CALL(*sock, getGSO).WillByDefault(testing::Return(0)); auto buf = IOBuf::copyBuffer("hello"); writeQuicPacket(