diff --git a/quic/client/QuicClientTransport.cpp b/quic/client/QuicClientTransport.cpp index 4e9e8ef67..90861c93f 100644 --- a/quic/client/QuicClientTransport.cpp +++ b/quic/client/QuicClientTransport.cpp @@ -28,6 +28,12 @@ namespace fsp = folly::portability::sockets; +#ifndef MSG_WAITFORONE +#define RECVMMSG_FLAGS 0 +#else +#define RECVMMSG_FLAGS MSG_WAITFORONE +#endif + namespace quic { QuicClientTransport::QuicClientTransport( @@ -970,6 +976,137 @@ bool QuicClientTransport::shouldOnlyNotify() { return conn_->transportSettings.shouldRecvBatch; } +void QuicClientTransport::recvMsg( + folly::AsyncUDPSocket& sock, + uint64_t readBufferSize, + int numPackets, + NetworkData& networkData, + folly::Optional& server, + size_t& totalData) { + for (int packetNum = 0; packetNum < numPackets; ++packetNum) { + // We create 1 buffer per packet so that it is not shared, this enables + // us to decrypt in place. If the fizz decrypt api could decrypt in-place + // even if shared, then we could allocate one giant IOBuf here. + Buf readBuffer = folly::IOBuf::create(readBufferSize); + struct iovec vec {}; + vec.iov_base = readBuffer->writableData(); + vec.iov_len = readBufferSize; + + sockaddr* rawAddr{nullptr}; + struct sockaddr_storage addrStorage {}; + socklen_t addrLen{sizeof(addrStorage)}; + if (!server) { + rawAddr = reinterpret_cast(&addrStorage); + rawAddr->sa_family = sock.address().getFamily(); + } + + int flags = 0; + int gro = -1; + struct msghdr msg {}; + msg.msg_name = rawAddr; + msg.msg_namelen = size_t(addrLen); + msg.msg_iov = &vec; + msg.msg_iovlen = 1; +#ifdef FOLLY_HAVE_MSG_ERRQUEUE + char control[CMSG_SPACE(sizeof(uint16_t))] = {}; + bool useGRO = sock.getGRO() > 0; + + if (useGRO) { + msg.msg_control = control; + msg.msg_controllen = sizeof(control); + + // we need to consider MSG_TRUNC too + flags |= MSG_TRUNC; + } +#endif + + ssize_t ret = sock.recvmsg(&msg, flags); + if (ret < 0) { + if (errno == EAGAIN || errno == EWOULDBLOCK) { + // If we got a retriable error, let us continue. + if (conn_->loopDetectorCallback) { + conn_->readDebugState.noReadReason = NoReadReason::RETRIABLE_ERROR; + } + break; + } + // If we got a non-retriable error, we might have received + // a packet that we could process, however let's just quit early. + sock.pauseRead(); + if (conn_->loopDetectorCallback) { + conn_->readDebugState.noReadReason = NoReadReason::NONRETRIABLE_ERROR; + } + return onReadError(folly::AsyncSocketException( + folly::AsyncSocketException::INTERNAL_ERROR, + "::recvmsg() failed", + errno)); + } else if (ret == 0) { + break; + } +#ifdef FOLLY_HAVE_MSG_ERRQUEUE + if (useGRO) { + for (struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg); cmsg != nullptr; + cmsg = CMSG_NXTHDR(&msg, cmsg)) { + if (cmsg->cmsg_level == SOL_UDP && cmsg->cmsg_type == UDP_GRO) { + gro = *((uint16_t*)CMSG_DATA(cmsg)); + break; + } + } + + // truncated + if ((size_t)ret > readBufferSize) { + ret = readBufferSize; + if (gro > 0) { + ret = ret - ret % gro; + } + } + } +#endif + size_t bytesRead = size_t(ret); + totalData += bytesRead; + if (!server) { + server = folly::SocketAddress(); + server->setFromSockaddr(rawAddr, addrLen); + } + VLOG(10) << "Got data from socket peer=" << *server << " len=" << bytesRead; + readBuffer->append(bytesRead); + if (gro > 0) { + size_t len = bytesRead; + size_t remaining = len; + size_t offset = 0; + size_t totalNumPackets = + networkData.packets.size() + ((len + gro - 1) / gro); + networkData.packets.reserve(totalNumPackets); + while (remaining) { + if (static_cast(remaining) > gro) { + auto tmp = readBuffer->cloneOne(); + // start at offset + tmp->trimStart(offset); + // the actual len is len - offset now + // leave gro bytes + tmp->trimEnd(len - offset - gro); + DCHECK_EQ(tmp->length(), gro); + + offset += gro; + remaining -= gro; + networkData.packets.emplace_back(std::move(tmp)); + } else { + // do not clone the last packet + // start at offset, use all the remaining data + readBuffer->trimStart(offset); + DCHECK_EQ(readBuffer->length(), remaining); + remaining = 0; + networkData.packets.emplace_back(std::move(readBuffer)); + } + } + } else { + networkData.packets.emplace_back(std::move(readBuffer)); + } + if (conn_->qLogger) { + conn_->qLogger->addDatagramReceived(bytesRead); + } + } +} + void QuicClientTransport::recvMmsg( folly::AsyncUDPSocket& sock, uint64_t readBufferSize, @@ -1025,7 +1162,8 @@ void QuicClientTransport::recvMmsg( #endif } - int numMsgsRecvd = sock.recvmmsg(msgs.data(), numPackets, flags, nullptr); + int numMsgsRecvd = + sock.recvmmsg(msgs.data(), numPackets, RECVMMSG_FLAGS | flags, nullptr); if (numMsgsRecvd < 0) { if (errno == EAGAIN || errno == EWOULDBLOCK) { // Exit, socket will notify us again when socket is readable. @@ -1144,8 +1282,12 @@ void QuicClientTransport::onNotifyDataAvailable( size_t totalData = 0; folly::Optional server; - recvmmsgStorage_.resize(numPackets); - recvMmsg(sock, readBufferSize, numPackets, networkData, server, totalData); + if (conn_->transportSettings.shouldUseRecvmmsgForBatchRecv) { + recvmmsgStorage_.resize(numPackets); + recvMmsg(sock, readBufferSize, numPackets, networkData, server, totalData); + } else { + recvMsg(sock, readBufferSize, numPackets, networkData, server, totalData); + } if (networkData.packets.empty()) { // recvMmsg and recvMsg might have already set the reason and counter diff --git a/quic/fizz/client/test/QuicClientTransportTest.cpp b/quic/fizz/client/test/QuicClientTransportTest.cpp index 11c3b9ef2..381c83679 100644 --- a/quic/fizz/client/test/QuicClientTransportTest.cpp +++ b/quic/fizz/client/test/QuicClientTransportTest.cpp @@ -1319,41 +1319,30 @@ class QuicClientTransportTest : public Test { ON_CALL(*sock, resumeRead(_)) .WillByDefault(SaveArg<0>(&networkReadCallback)); ON_CALL(*sock, address()).WillByDefault(ReturnRef(serverAddr)); - ON_CALL(*sock, recvmmsg(_, _, _, _)) - .WillByDefault(Invoke( - [&](struct mmsghdr* mmsg, auto numPackets, auto, auto) -> ssize_t { - VLOG(4) << "socketreads size " << socketReads.size(); - struct msghdr* msg; - if (socketReads.empty()) { - errno = EAGAIN; - return -1; - } - auto len = std::min(socketReads.size(), numPackets); - ssize_t i; - auto srItr = socketReads.begin(); - for (i = 0; i < len; i++, mmsg++) { - if (srItr->err) { - errno = *srItr->err; - socketReads.pop_front(); - break; - } - msg = &mmsg->msg_hdr; - auto testData = std::move(srItr->data); - testData->coalesce(); - size_t testDataLen = testData->length(); - CHECK_EQ(msg->msg_iovlen, 1); - CHECK(msg->msg_iov[0].iov_base != nullptr); - memcpy(msg->msg_iov[0].iov_base, testData->data(), testDataLen); - mmsg->msg_len = testDataLen; - if (msg->msg_name) { - socklen_t msg_len = srItr->addr.getAddress( - static_cast(msg->msg_name)); - msg->msg_namelen = msg_len; - } - srItr = socketReads.erase(srItr); - } - return i == 0 ? -1 : i; - })); + ON_CALL(*sock, recvmsg(_, _)) + .WillByDefault(Invoke([&](struct msghdr* msg, int) -> ssize_t { + DCHECK_GT(msg->msg_iovlen, 0); + if (socketReads.empty()) { + errno = EAGAIN; + return -1; + } + if (socketReads[0].err) { + errno = *socketReads[0].err; + return -1; + } + auto testData = std::move(socketReads[0].data); + testData->coalesce(); + size_t testDataLen = testData->length(); + memcpy( + msg->msg_iov[0].iov_base, testData->data(), testData->length()); + if (msg->msg_name) { + socklen_t msg_len = socketReads[0].addr.getAddress( + static_cast(msg->msg_name)); + msg->msg_namelen = msg_len; + } + socketReads.pop_front(); + return testDataLen; + })); EXPECT_EQ(client->getConn().selfConnectionIds.size(), 1); EXPECT_EQ( client->getConn().selfConnectionIds[0].connId, @@ -2968,8 +2957,9 @@ TEST_F(QuicClientTransportAfterStartTest, ReadLoopCountingRecvmmsg) { auto mockLoopDetectorCallback = std::make_unique(); auto rawLoopDetectorCallback = mockLoopDetectorCallback.get(); conn.loopDetectorCallback = std::move(mockLoopDetectorCallback); - conn.transportSettings.maxRecvBatchSize = 1; + conn.transportSettings.shouldUseRecvmmsgForBatchRecv = true; + conn.transportSettings.maxRecvBatchSize = 1; EXPECT_CALL(*sock, recvmmsg(_, 1, _, nullptr)) .WillOnce(Invoke( [](struct mmsghdr*, unsigned int, unsigned int, struct timespec*) { @@ -3102,76 +3092,6 @@ TEST_F( client->close(folly::none); } -TEST_F( - QuicClientTransportAfterStartTest, - ReadStreamMultiplePacketsGreaterThanBatch) { - StreamId streamId = client->createBidirectionalStream().value(); - - uint32_t batchSize = 2; - client->getNonConstConn().transportSettings.maxRecvBatchSize = batchSize; - client->setReadCallback(streamId, &readCb); - - auto data = IOBuf::copyBuffer("hello"); - auto expected = data->clone(); - expected->prependChain(data->clone()); - expected->prependChain(data->clone()); - IOBuf result; - InSequence s; - EXPECT_CALL(readCb, readAvailable(streamId)).WillOnce(Invoke([&](auto) { - auto readData = client->read(streamId, 1000); - result.prependChain(readData->first->clone()); - })); - EXPECT_CALL(readCb, readAvailable(streamId)).WillRepeatedly(Invoke([&](auto) { - auto readData = client->read(streamId, 1000); - result.prependChain(readData->first->clone()); - eventbase_->terminateLoopSoon(); - })); - auto packet1 = packetToBuf(createStreamPacket( - *serverChosenConnId /* src */, - *originalConnId /* dest */, - appDataPacketNum++, - streamId, - *data, - 0 /* cipherOverhead */, - 0 /* largestAcked */, - folly::none /* longHeaderOverride */, - false /* eof */)); - auto packet2 = packetToBuf(createStreamPacket( - *serverChosenConnId /* src */, - *originalConnId /* dest */, - appDataPacketNum++, - streamId, - *data, - 0 /* cipherOverhead */, - 0 /* largestAcked */, - folly::none /* longHeaderOverride */, - false /* eof */, - folly::none /* shortHeaderOverride */, - data->length() /* offset */)); - auto packet3 = packetToBuf(createStreamPacket( - *serverChosenConnId /* src */, - *originalConnId /* dest */, - appDataPacketNum++, - streamId, - *data, - 0 /* cipherOverhead */, - 0 /* largestAcked */, - folly::none /* longHeaderOverride */, - true /* eof */, - folly::none /* shortHeaderOverride */, - data->length() * 2 /* offset */)); - - socketReads.emplace_back(TestReadData(packet1->coalesce(), serverAddr)); - socketReads.emplace_back(TestReadData(packet2->coalesce(), serverAddr)); - deliverData(packet3->coalesce()); - EXPECT_EQ(socketReads.size(), 1); - client->invokeOnNotifyDataAvailable(*sock); - EXPECT_EQ(socketReads.size(), 0); - eventbase_->loopForever(); - EXPECT_TRUE(IOBufEqualTo()(expected.get(), &result)); - client->close(folly::none); -} - TEST_F( QuicClientTransportAfterStartTest, ReadStreamMultiplePacketsWithNonRetriableError) { diff --git a/quic/state/TransportSettings.h b/quic/state/TransportSettings.h index 0b75d2a6c..c015cfe28 100644 --- a/quic/state/TransportSettings.h +++ b/quic/state/TransportSettings.h @@ -161,6 +161,8 @@ struct TransportSettings { size_t maxRecvBatchSize{5}; // Whether or not we should recv data in a batch. bool shouldRecvBatch{false}; + // Whether or not use recvmmsg when shouldRecvBatch is true. + bool shouldUseRecvmmsgForBatchRecv{false}; // Config struct for BBR BbrConfig bbrConfig; // A packet is considered loss when a packet that's sent later by at least