diff --git a/quic/api/QuicSocket.h b/quic/api/QuicSocket.h index aa79e219a..2edb82f3b 100644 --- a/quic/api/QuicSocket.h +++ b/quic/api/QuicSocket.h @@ -977,6 +977,15 @@ class QuicSocket { */ virtual void cancelByteEventCallbacks(const ByteEvent::Type type) = 0; + /** + * Reset or send a stop sending on all non-control streams. Leaves the + * connection otherwise unmodified. Note this will also trigger the + * onStreamWriteError and readError callbacks immediately. + */ + virtual void resetNonControlStreams( + ApplicationErrorCode error, + folly::StringPiece errorMsg) = 0; + /** * Get the number of pending byte events for the given stream. */ diff --git a/quic/api/QuicTransportBase.cpp b/quic/api/QuicTransportBase.cpp index d6131d84b..bbfe0fe0e 100644 --- a/quic/api/QuicTransportBase.cpp +++ b/quic/api/QuicTransportBase.cpp @@ -2673,6 +2673,42 @@ void QuicTransportBase::cancelAllAppCallbacks( } } +void QuicTransportBase::resetNonControlStreams( + ApplicationErrorCode error, + folly::StringPiece errorMsg) { + std::vector nonControlStreamIds; + nonControlStreamIds.reserve(conn_->streamManager->streamCount()); + conn_->streamManager->streamStateForEach( + [&nonControlStreamIds](const auto& stream) { + if (!stream.isControl) { + nonControlStreamIds.push_back(stream.id); + } + }); + for (auto id : nonControlStreamIds) { + if (isSendingStream(conn_->nodeType, id) || isBidirectionalStream(id)) { + auto writeCallbackIt = pendingWriteCallbacks_.find(id); + if (writeCallbackIt != pendingWriteCallbacks_.end()) { + writeCallbackIt->second->onStreamWriteError(id, {error, errorMsg}); + } + if (conn_->partialReliabilityEnabled) { + dataRejectedCallbacks_.erase(id); + } + resetStream(id, error); + } + if (isReceivingStream(conn_->nodeType, id) || isBidirectionalStream(id)) { + auto readCallbackIt = readCallbacks_.find(id); + if (readCallbackIt != readCallbacks_.end()) { + readCallbackIt->second.readCb->readError(id, {error, errorMsg}); + } + if (conn_->partialReliabilityEnabled) { + dataExpiredCallbacks_.erase(id); + } + peekCallbacks_.erase(id); + stopSending(id, error); + } + } +} + void QuicTransportBase::addObserver(Observer* observer) { observers_->push_back(CHECK_NOTNULL(observer)); observer->observerAttach(this); diff --git a/quic/api/QuicTransportBase.h b/quic/api/QuicTransportBase.h index e9dc0ca8b..c9517a7b6 100644 --- a/quic/api/QuicTransportBase.h +++ b/quic/api/QuicTransportBase.h @@ -386,6 +386,15 @@ class QuicTransportBase : public QuicSocket { */ void cancelByteEventCallbacks(const ByteEvent::Type type) override; + /** + * Reset or send a stop sending on all non-control streams. Leaves the + * connection otherwise unmodified. Note this will also trigger the + * onStreamWriteError and readError callbacks immediately. + */ + void resetNonControlStreams( + ApplicationErrorCode error, + folly::StringPiece errorMsg) override; + /** * Get the number of pending byte events for the given stream. */ diff --git a/quic/api/test/MockQuicSocket.h b/quic/api/test/MockQuicSocket.h index 2de4409ac..cb6e15046 100644 --- a/quic/api/test/MockQuicSocket.h +++ b/quic/api/test/MockQuicSocket.h @@ -284,5 +284,8 @@ class MockQuicSocket : public QuicSocket { MOCK_METHOD1(addObserver, void(Observer*)); MOCK_METHOD1(removeObserver, bool(Observer*)); MOCK_CONST_METHOD0(getObservers, const ObserverVec&()); + MOCK_METHOD2( + resetNonControlStreams, + void(ApplicationErrorCode, folly::StringPiece)); }; } // namespace quic diff --git a/quic/api/test/QuicTransportBaseTest.cpp b/quic/api/test/QuicTransportBaseTest.cpp index a43f1bbbf..2ac302e1f 100644 --- a/quic/api/test/QuicTransportBaseTest.cpp +++ b/quic/api/test/QuicTransportBaseTest.cpp @@ -2250,6 +2250,37 @@ TEST_F(QuicTransportImplTest, ResetStreamUnsetWriteCallback) { evb->loopOnce(); } +TEST_F(QuicTransportImplTest, ResetAllNonControlStreams) { + auto stream1 = transport->createBidirectionalStream().value(); + ASSERT_FALSE(transport->setControlStream(stream1)); + NiceMock wcb1; + NiceMock rcb1; + EXPECT_CALL(wcb1, onStreamWriteError(stream1, _)).Times(0); + EXPECT_CALL(rcb1, readError(stream1, _)).Times(0); + transport->notifyPendingWriteOnStream(stream1, &wcb1); + transport->setReadCallback(stream1, &rcb1); + + auto stream2 = transport->createBidirectionalStream().value(); + NiceMock wcb2; + NiceMock rcb2; + EXPECT_CALL(wcb2, onStreamWriteError(stream2, _)).Times(1); + EXPECT_CALL(rcb2, readError(stream2, _)).Times(1); + transport->notifyPendingWriteOnStream(stream2, &wcb2); + transport->setReadCallback(stream2, &rcb2); + + auto stream3 = transport->createUnidirectionalStream().value(); + NiceMock wcb3; + transport->notifyPendingWriteOnStream(stream3, &wcb3); + EXPECT_CALL(wcb3, onStreamWriteError(stream3, _)).Times(1); + + transport->resetNonControlStreams( + GenericApplicationErrorCode::UNKNOWN, "bye bye"); + evb->loopOnce(); + + // Have to manually unset the read callbacks so they aren't use-after-freed. + transport->unsetAllReadCallbacks(); +} + TEST_F(QuicTransportImplTest, DestroyWithoutClosing) { EXPECT_CALL(connCallback, onConnectionError(_)).Times(0); EXPECT_CALL(connCallback, onConnectionEnd()).Times(0);