diff --git a/quic/api/QuicTransportBaseLite.cpp b/quic/api/QuicTransportBaseLite.cpp index 6e27bea92..2b10481d8 100644 --- a/quic/api/QuicTransportBaseLite.cpp +++ b/quic/api/QuicTransportBaseLite.cpp @@ -753,7 +753,7 @@ QuicTransportBaseLite::setReadCallback( if (isSendingStream(conn_->nodeType, id)) { return folly::makeUnexpected(LocalErrorCode::INVALID_OPERATION); } - if (closeState_ != CloseState::OPEN) { + if (cb != nullptr && closeState_ != CloseState::OPEN) { return folly::makeUnexpected(LocalErrorCode::CONNECTION_CLOSED); } if (!conn_->streamManager->streamExists(id)) { @@ -2224,20 +2224,28 @@ void QuicTransportBaseLite::cancelAllAppCallbacks( // structure of read callbacks. // TODO: this approach will make the app unable to setReadCallback to // nullptr during the loop. Need to fix that. - // TODO: setReadCallback to nullptr closes the stream, so the app - // may just do that... auto readCallbacksCopy = readCallbacks_; for (auto& cb : readCallbacksCopy) { - readCallbacks_.erase(cb.first); - if (cb.second.readCb) { - auto stream = CHECK_NOTNULL(conn_->streamManager->getStream(cb.first)); + auto streamId = cb.first; + auto it = readCallbacks_.find(streamId); + if (it == readCallbacks_.end()) { + // An earlier call to readError removed the stream from readCallbacks + // May not be possible? + continue; + } + if (it->second.readCb) { + auto stream = CHECK_NOTNULL(conn_->streamManager->getStream(streamId)); if (!stream->groupId) { - cb.second.readCb->readError(cb.first, err); + it->second.readCb->readError(streamId, err); } else { - cb.second.readCb->readErrorWithGroup(cb.first, *stream->groupId, err); + it->second.readCb->readErrorWithGroup(streamId, *stream->groupId, err); } } + readCallbacks_.erase(it); } + // TODO: what if a call to readError installs a new read callback? + LOG_IF(ERROR, !readCallbacks_.empty()) + << readCallbacks_.size() << " read callbacks remaining to be cleared"; VLOG(4) << "Clearing datagram callback"; datagramCallback_ = nullptr; diff --git a/quic/api/test/BUCK b/quic/api/test/BUCK index 11f7556f2..c8a4fc10c 100644 --- a/quic/api/test/BUCK +++ b/quic/api/test/BUCK @@ -48,6 +48,7 @@ cpp_unittest( "//folly/io:iobuf", "//quic:constants", "//quic/api:transport", + "//quic/api:transport_helpers", "//quic/common:buf_util", "//quic/common/events:highres_quic_timer", "//quic/common/test:test_utils", @@ -98,7 +99,7 @@ cpp_unittest( deps = [ ":mocks", "//folly:range", - "//quic/api:transport", + "//quic/api:transport_helpers", "//quic/common/events:folly_eventbase", "//quic/common/test:test_utils", "//quic/common/testutil:mock_async_udp_socket", @@ -119,7 +120,7 @@ cpp_unittest( deps = [ ":mocks", "//folly/portability:gtest", - "//quic/api:transport", + "//quic/api:transport_helpers", "//quic/client:state_and_handshake", "//quic/codec:pktbuilder", "//quic/codec/test:mocks", @@ -140,7 +141,7 @@ cpp_unittest( "IoBufQuicBatchTest.cpp", ], deps = [ - "//quic/api:transport", + "//quic/api:transport_helpers", "//quic/client:state_and_handshake", "//quic/common/events:folly_eventbase", "//quic/common/test:test_utils", @@ -233,6 +234,7 @@ mvfst_cpp_library( ], exported_deps = [ "//quic/api:transport", + "//quic/api:transport_helpers", "//quic/common/test:test_utils", "//quic/dsr/frontend:write_functions", "//quic/fizz/server/handshake:fizz_server_handshake", diff --git a/quic/api/test/QuicTransportBaseTest.cpp b/quic/api/test/QuicTransportBaseTest.cpp index b05515acc..773ae1a75 100644 --- a/quic/api/test/QuicTransportBaseTest.cpp +++ b/quic/api/test/QuicTransportBaseTest.cpp @@ -3094,12 +3094,15 @@ TEST_P(QuicTransportImplTestBase, TestGracefulCloseWithNoActiveStream) { TEST_P(QuicTransportImplTestBase, TestImmediateClose) { auto stream = transport->createBidirectionalStream().value(); + auto stream2 = transport->createBidirectionalStream().value(); NiceMock wcb; NiceMock wcbConn; NiceMock rcb; + NiceMock rcb2; NiceMock pcb; NiceMock deliveryCb; NiceMock txCb; + uint8_t resetCount = 0; EXPECT_CALL( wcb, onStreamWriteError( @@ -3107,8 +3110,21 @@ TEST_P(QuicTransportImplTestBase, TestImmediateClose) { EXPECT_CALL( wcbConn, onConnectionWriteError(IsAppError(GenericApplicationErrorCode::UNKNOWN))); - EXPECT_CALL( - rcb, readError(stream, IsAppError(GenericApplicationErrorCode::UNKNOWN))); + // The first stream to get a reset will clear the other read callback, so only + // one will receive a reset. + ON_CALL( + rcb, readError(stream, IsAppError(GenericApplicationErrorCode::UNKNOWN))) + .WillByDefault(InvokeWithoutArgs([this, stream2, &resetCount] { + transport->setReadCallback(stream2, nullptr); + resetCount++; + })); + ON_CALL( + rcb2, + readError(stream2, IsAppError(GenericApplicationErrorCode::UNKNOWN))) + .WillByDefault(InvokeWithoutArgs([this, stream, &resetCount] { + transport->setReadCallback(stream, nullptr); + resetCount++; + })); EXPECT_CALL( pcb, peekError(stream, IsAppError(GenericApplicationErrorCode::UNKNOWN))); EXPECT_CALL(deliveryCb, onCanceled(stream, _)); @@ -3120,6 +3136,7 @@ TEST_P(QuicTransportImplTestBase, TestImmediateClose) { transport->notifyPendingWriteOnConnection(&wcbConn); transport->notifyPendingWriteOnStream(stream, &wcb); transport->setReadCallback(stream, &rcb); + transport->setReadCallback(stream2, &rcb2); transport->setPeekCallback(stream, &pcb); EXPECT_CALL(*socketPtr, write(_, _, _)) .WillRepeatedly(SetErrnoAndReturn(EAGAIN, -1)); @@ -3151,6 +3168,7 @@ TEST_P(QuicTransportImplTestBase, TestImmediateClose) { EXPECT_EQ( transport->transportConn->streamManager->getStream(stream), nullptr); qEvb->loopOnce(); + EXPECT_EQ(resetCount, 1); } TEST_P(QuicTransportImplTestBase, ResetStreamUnsetWriteCallback) {