diff --git a/quic/api/QuicSocket.h b/quic/api/QuicSocket.h index 4f6fa34ac..e26ae81d5 100644 --- a/quic/api/QuicSocket.h +++ b/quic/api/QuicSocket.h @@ -764,6 +764,9 @@ class QuicSocket { virtual folly::Expected notifyPendingWriteOnStream(StreamId id, WriteCallback* wcb) = 0; + virtual folly::Expected + unregisterStreamWriteCallback(StreamId) = 0; + /** * Callback class for receiving ack notifications */ diff --git a/quic/api/QuicTransportBase.cpp b/quic/api/QuicTransportBase.cpp index 5c0332f68..7195ef570 100644 --- a/quic/api/QuicTransportBase.cpp +++ b/quic/api/QuicTransportBase.cpp @@ -1746,6 +1746,18 @@ QuicTransportBase::notifyPendingWriteOnConnection(WriteCallback* wcb) { return folly::unit; } +folly::Expected +QuicTransportBase::unregisterStreamWriteCallback(StreamId id) { + if (!conn_->streamManager->streamExists(id)) { + return folly::makeUnexpected(LocalErrorCode::STREAM_NOT_EXISTS); + } + if (pendingWriteCallbacks_.find(id) == pendingWriteCallbacks_.end()) { + return folly::makeUnexpected(LocalErrorCode::INVALID_OPERATION); + } + pendingWriteCallbacks_.erase(id); + return folly::unit; +} + folly::Expected QuicTransportBase::notifyPendingWriteOnStream(StreamId id, WriteCallback* wcb) { if (isReceivingStream(conn_->nodeType, id)) { diff --git a/quic/api/QuicTransportBase.h b/quic/api/QuicTransportBase.h index c259577f5..8f7d13a7d 100644 --- a/quic/api/QuicTransportBase.h +++ b/quic/api/QuicTransportBase.h @@ -183,6 +183,9 @@ class QuicTransportBase : public QuicSocket { folly::Expected notifyPendingWriteOnConnection( WriteCallback* wcb) override; + folly::Expected unregisterStreamWriteCallback( + StreamId id) override; + WriteResult writeChain( StreamId id, Buf data, diff --git a/quic/api/test/MockQuicSocket.h b/quic/api/test/MockQuicSocket.h index 0b13eebe0..4a5e8bba9 100644 --- a/quic/api/test/MockQuicSocket.h +++ b/quic/api/test/MockQuicSocket.h @@ -137,6 +137,9 @@ class MockQuicSocket : public QuicSocket { MOCK_METHOD2( notifyPendingWriteOnStream, folly::Expected(StreamId, WriteCallback*)); + MOCK_METHOD1( + unregisterStreamWriteCallback, + folly::Expected(StreamId)); folly::Expected writeChain( StreamId id, Buf data, diff --git a/quic/api/test/QuicTransportBaseTest.cpp b/quic/api/test/QuicTransportBaseTest.cpp index ccee390b7..914ef9c7f 100644 --- a/quic/api/test/QuicTransportBaseTest.cpp +++ b/quic/api/test/QuicTransportBaseTest.cpp @@ -2517,5 +2517,39 @@ TEST_F(QuicTransportImplTest, FailedPing) { EXPECT_EQ(conn->pendingEvents.cancelPingTimeout, false); } +TEST_F(QuicTransportImplTest, StreamWriteCallbackUnregister) { + auto stream = transport->createBidirectionalStream().value(); + // Unset before set + EXPECT_FALSE(transport->unregisterStreamWriteCallback(stream)); + + // Set + auto wcb = std::make_unique(); + EXPECT_CALL(*wcb, onStreamWriteReady(stream, _)).Times(1); + auto result = transport->notifyPendingWriteOnStream(stream, wcb.get()); + EXPECT_TRUE(result); + evb->loopOnce(); + + // Set then unset + EXPECT_CALL(*wcb, onStreamWriteReady(stream, _)).Times(0); + result = transport->notifyPendingWriteOnStream(stream, wcb.get()); + EXPECT_TRUE(result); + EXPECT_TRUE(transport->unregisterStreamWriteCallback(stream)); + evb->loopOnce(); + + // Set, close, unset + result = transport->notifyPendingWriteOnStream(stream, wcb.get()); + EXPECT_TRUE(result); + MockReadCallback rcb; + transport->setReadCallback(stream, &rcb); + // ReadCallback kills WriteCallback + EXPECT_CALL(rcb, readError(stream, _)) + .WillOnce(Invoke([&](StreamId stream, auto) { + EXPECT_TRUE(transport->unregisterStreamWriteCallback(stream)); + wcb.reset(); + })); + transport->close(folly::none); + evb->loopOnce(); +} + } // namespace test } // namespace quic