diff --git a/quic/api/QuicSocket.h b/quic/api/QuicSocket.h index 665bd4497..f7829b002 100644 --- a/quic/api/QuicSocket.h +++ b/quic/api/QuicSocket.h @@ -614,7 +614,17 @@ class QuicSocket { virtual void onDataAvailable( StreamId id, const folly::Range& peekData) noexcept = 0; + + /** + * Called from the transport layer during peek time when there is an error + * on the stream. + */ + virtual void peekError( + StreamId id, + std::pair> + error) noexcept = 0; }; + virtual folly::Expected setPeekCallback( StreamId id, PeekCallback* cb) = 0; diff --git a/quic/api/QuicTransportBase.cpp b/quic/api/QuicTransportBase.cpp index 24162c5a0..9fbfc8e96 100644 --- a/quic/api/QuicTransportBase.cpp +++ b/quic/api/QuicTransportBase.cpp @@ -989,7 +989,13 @@ void QuicTransportBase::invokePeekDataAndCallbacks() { } auto peekCb = callback->second.peekCb; auto stream = conn_->streamManager->getStream(streamId); - if (peekCb && !stream->streamReadError && stream->hasPeekableData()) { + if (peekCb && stream->streamReadError) { + VLOG(10) << "invoking peek error callbacks on stream=" << streamId << " " + << *this; + peekCb->peekError( + streamId, std::make_pair(*stream->streamReadError, folly::none)); + } else if ( + peekCb && !stream->streamReadError && stream->hasPeekableData()) { VLOG(10) << "invoking peek callbacks on stream=" << streamId << " " << *this; diff --git a/quic/api/test/Mocks.h b/quic/api/test/Mocks.h index 0d2bdae98..09dc17b66 100644 --- a/quic/api/test/Mocks.h +++ b/quic/api/test/Mocks.h @@ -63,6 +63,14 @@ class MockPeekCallback : public QuicSocket::PeekCallback { , onDataAvailable, void(StreamId, const folly::Range&)); + GMOCK_METHOD2_( + , + noexcept, + , + peekError, + void( + StreamId, + std::pair>)); }; class MockWriteCallback : public QuicSocket::WriteCallback { diff --git a/quic/api/test/QuicTransportBaseTest.cpp b/quic/api/test/QuicTransportBaseTest.cpp index 734df2a90..a1701518c 100644 --- a/quic/api/test/QuicTransportBaseTest.cpp +++ b/quic/api/test/QuicTransportBaseTest.cpp @@ -339,6 +339,8 @@ class TestQuicTransport stream->streamReadError = ex; conn_->streamManager->updateReadableStreams(*stream); conn_->streamManager->updatePeekableStreams(*stream); + // peekableStreams is updated to contain streams with streamReadError + updatePeekLooper(); updateReadLooper(); } @@ -2691,6 +2693,24 @@ TEST_F(QuicTransportImplTest, PeekCallbackDataAvailable) { transport.reset(); } +TEST_F(QuicTransportImplTest, PeekError) { + auto stream1 = transport->createBidirectionalStream().value(); + + NiceMock peekCb1; + transport->setPeekCallback(stream1, &peekCb1); + + transport->addDataToStream( + stream1, StreamBuffer(folly::IOBuf::copyBuffer("actual stream data"), 0)); + transport->addStreamReadError(stream1, LocalErrorCode::STREAM_CLOSED); + + EXPECT_CALL( + peekCb1, peekError(stream1, IsError(LocalErrorCode::STREAM_CLOSED))); + + transport->driveReadCallbacks(); + + transport.reset(); +} + TEST_F(QuicTransportImplTest, PeekCallbackUnsetAll) { auto stream1 = transport->createBidirectionalStream().value(); auto stream2 = transport->createBidirectionalStream().value(); @@ -3057,9 +3077,9 @@ TEST_F(QuicTransportImplTest, UpdatePeekableListWithStreamErrorTest) { transport->addStreamReadError(streamId, LocalErrorCode::NO_ERROR); - // streamId is removed from the list after the call - // because there is an error on the stream. - EXPECT_EQ(0, conn->streamManager->peekableStreams().count(streamId)); + // peekableStreams is updated to allow stream with streamReadError. + // So the streamId shall be in the list + EXPECT_EQ(1, conn->streamManager->peekableStreams().count(streamId)); } TEST_F(QuicTransportImplTest, SuccessfulPing) { diff --git a/quic/state/QuicStreamManager.cpp b/quic/state/QuicStreamManager.cpp index 6875558b4..f20b1935d 100644 --- a/quic/state/QuicStreamManager.cpp +++ b/quic/state/QuicStreamManager.cpp @@ -528,7 +528,9 @@ void QuicStreamManager::updateWritableStreams(QuicStreamState& stream) { } void QuicStreamManager::updatePeekableStreams(QuicStreamState& stream) { - if (stream.hasPeekableData() && !stream.streamReadError.has_value()) { + // In the PeekCallback, the API peekError() is added, so change the condition + // and allow streamReadError in the peekableStreams + if (stream.hasPeekableData() || stream.streamReadError.has_value()) { peekableStreams_.emplace(stream.id); } else { peekableStreams_.erase(stream.id);