From 81756e3d13f5229b511c5b600716efe6e036abcc Mon Sep 17 00:00:00 2001 From: Matt Joras Date: Thu, 13 Aug 2020 18:27:17 -0700 Subject: [PATCH] Allow specifying error code in setReadCallback Summary: We have an API behavior where setReadCalback will issue a StopSending on behalf of the app. This is useful but has confusing semantics as it always defaults to GenericApplicationError::NO_ERROR. Instead let the error be specified as part of the API. Reviewed By: yangchi, lnicco Differential Revision: D23055196 fbshipit-source-id: 755f4122bf445016c9b5adb23c3090fc23173eb9 --- quic/api/QuicSocket.h | 12 ++++- quic/api/QuicTransportBase.cpp | 17 ++++--- quic/api/QuicTransportBase.h | 7 ++- quic/api/test/MockQuicSocket.h | 7 ++- quic/api/test/QuicTransportTest.cpp | 79 +++++++++++++++++++++++++++++ quic/server/test/QuicSocketTest.cpp | 2 +- 6 files changed, 111 insertions(+), 13 deletions(-) diff --git a/quic/api/QuicSocket.h b/quic/api/QuicSocket.h index edeff4891..45e42e014 100644 --- a/quic/api/QuicSocket.h +++ b/quic/api/QuicSocket.h @@ -439,11 +439,19 @@ class QuicSocket { /** * Set the read callback for the given stream. Note that read callback is * expected to be set all the time. Removing read callback indicates that - * stream is no longer intended to be read again. + * stream is no longer intended to be read again. This will issue a + * StopSending if cb is being set to nullptr after previously being not + * nullptr. The err parameter is used to control the error sent in the + * StopSending. By default when cb is nullptr this function will cause the + * transport to send a StopSending frame with + * GenericApplicationErrorCode::NO_ERROR. If err is specified to be + * folly::none, no StopSending will be sent. */ virtual folly::Expected setReadCallback( StreamId id, - ReadCallback* cb) = 0; + ReadCallback* cb, + folly::Optional err = + GenericApplicationErrorCode::NO_ERROR) = 0; /** * Convenience function that sets the read callbacks of all streams to be diff --git a/quic/api/QuicTransportBase.cpp b/quic/api/QuicTransportBase.cpp index 20459eb54..42ebd83d3 100644 --- a/quic/api/QuicTransportBase.cpp +++ b/quic/api/QuicTransportBase.cpp @@ -650,7 +650,8 @@ QuicTransportBase::setStreamFlowControlWindow( folly::Expected QuicTransportBase::setReadCallback( StreamId id, - ReadCallback* cb) { + ReadCallback* cb, + folly::Optional err) { if (isSendingStream(conn_->nodeType, id)) { return folly::makeUnexpected(LocalErrorCode::INVALID_OPERATION); } @@ -660,12 +661,15 @@ folly::Expected QuicTransportBase::setReadCallback( if (!conn_->streamManager->streamExists(id)) { return folly::makeUnexpected(LocalErrorCode::STREAM_NOT_EXISTS); } - return setReadCallbackInternal(id, cb); + return setReadCallbackInternal(id, cb, err); } void QuicTransportBase::unsetAllReadCallbacks() { for (auto& streamCallbackPair : readCallbacks_) { - setReadCallbackInternal(streamCallbackPair.first, nullptr); + setReadCallbackInternal( + streamCallbackPair.first, + nullptr, + GenericApplicationErrorCode::NO_ERROR); } } @@ -685,7 +689,8 @@ void QuicTransportBase::unsetAllDeliveryCallbacks() { folly::Expected QuicTransportBase::setReadCallbackInternal( StreamId id, - ReadCallback* cb) noexcept { + ReadCallback* cb, + folly::Optional err) noexcept { VLOG(4) << "Setting setReadCallback for stream=" << id << " cb=" << cb << " " << *this; auto readCbIt = readCallbacks_.find(id); @@ -702,8 +707,8 @@ QuicTransportBase::setReadCallbackInternal( return folly::makeUnexpected(LocalErrorCode::INVALID_OPERATION); } else { readCb = cb; - if (readCb == nullptr) { - return stopSending(id, GenericApplicationErrorCode::NO_ERROR); + if (readCb == nullptr && err) { + return stopSending(id, err.value()); } } updateReadLooper(); diff --git a/quic/api/QuicTransportBase.h b/quic/api/QuicTransportBase.h index dc32fdd07..f7254b940 100644 --- a/quic/api/QuicTransportBase.h +++ b/quic/api/QuicTransportBase.h @@ -114,7 +114,9 @@ class QuicTransportBase : public QuicSocket { folly::Expected setReadCallback( StreamId id, - ReadCallback* cb) override; + ReadCallback* cb, + folly::Optional err = + GenericApplicationErrorCode::NO_ERROR) override; void unsetAllReadCallbacks() override; void unsetAllPeekCallbacks() override; void unsetAllDeliveryCallbacks() override; @@ -610,7 +612,8 @@ class QuicTransportBase : public QuicSocket { void checkForClosedStream(); folly::Expected setReadCallbackInternal( StreamId id, - ReadCallback* cb) noexcept; + ReadCallback* cb, + folly::Optional err) noexcept; folly::Expected setPeekCallbackInternal( StreamId id, PeekCallback* cb) noexcept; diff --git a/quic/api/test/MockQuicSocket.h b/quic/api/test/MockQuicSocket.h index 7722e39bc..1ed529800 100644 --- a/quic/api/test/MockQuicSocket.h +++ b/quic/api/test/MockQuicSocket.h @@ -84,9 +84,12 @@ class MockQuicSocket : public QuicSocket { folly::Expected(StreamId, uint64_t)); MOCK_METHOD1(setTransportSettings, void(TransportSettings)); MOCK_CONST_METHOD0(isPartiallyReliableTransport, bool()); - MOCK_METHOD2( + MOCK_METHOD3( setReadCallback, - folly::Expected(StreamId, ReadCallback*)); + folly::Expected( + StreamId, + ReadCallback*, + folly::Optional err)); MOCK_METHOD1(setConnectionCallback, void(ConnectionCallback*)); void setEarlyDataAppParamsFunctions( folly::Function&, const Buf&) diff --git a/quic/api/test/QuicTransportTest.cpp b/quic/api/test/QuicTransportTest.cpp index cd1bd0b79..6641dc852 100644 --- a/quic/api/test/QuicTransportTest.cpp +++ b/quic/api/test/QuicTransportTest.cpp @@ -888,6 +888,85 @@ TEST_F(QuicTransportTest, StopSending) { EXPECT_TRUE(foundStopSending); } +TEST_F(QuicTransportTest, StopSendingReadCallbackDefault) { + auto streamId = transport_->createBidirectionalStream().value(); + NiceMock readCb; + EXPECT_CALL(*socket_, write(_, _)).WillOnce(Invoke(bufLength)); + transport_->setReadCallback(streamId, &readCb); + transport_->setReadCallback(streamId, nullptr); + loopForWrites(); + EXPECT_EQ(1, transport_->getConnectionState().outstandings.packets.size()); + auto packet = + getLastOutstandingPacket( + transport_->getConnectionState(), PacketNumberSpace::AppData) + ->packet; + EXPECT_EQ(1, packet.frames.size()); + bool foundStopSending = false; + for (auto& frame : packet.frames) { + const QuicSimpleFrame* simpleFrame = frame.asQuicSimpleFrame(); + if (!simpleFrame) { + continue; + } + const StopSendingFrame* stopSending = simpleFrame->asStopSendingFrame(); + if (!stopSending) { + continue; + } + EXPECT_EQ(streamId, stopSending->streamId); + EXPECT_EQ(GenericApplicationErrorCode::NO_ERROR, stopSending->errorCode); + foundStopSending = true; + } + EXPECT_TRUE(foundStopSending); +} + +TEST_F(QuicTransportTest, StopSendingReadCallback) { + auto streamId = transport_->createBidirectionalStream().value(); + NiceMock readCb; + EXPECT_CALL(*socket_, write(_, _)).WillOnce(Invoke(bufLength)); + transport_->setReadCallback(streamId, &readCb); + transport_->setReadCallback( + streamId, nullptr, GenericApplicationErrorCode::UNKNOWN); + loopForWrites(); + EXPECT_EQ(1, transport_->getConnectionState().outstandings.packets.size()); + auto packet = + getLastOutstandingPacket( + transport_->getConnectionState(), PacketNumberSpace::AppData) + ->packet; + EXPECT_EQ(1, packet.frames.size()); + bool foundStopSending = false; + for (auto& frame : packet.frames) { + const QuicSimpleFrame* simpleFrame = frame.asQuicSimpleFrame(); + if (!simpleFrame) { + continue; + } + const StopSendingFrame* stopSending = simpleFrame->asStopSendingFrame(); + if (!stopSending) { + continue; + } + EXPECT_EQ(streamId, stopSending->streamId); + EXPECT_EQ(GenericApplicationErrorCode::UNKNOWN, stopSending->errorCode); + foundStopSending = true; + } + EXPECT_TRUE(foundStopSending); +} + +TEST_F(QuicTransportTest, StopSendingReadCallbackNone) { + auto streamId = transport_->createBidirectionalStream().value(); + NiceMock readCb; + transport_->setReadCallback(streamId, &readCb); + transport_->setReadCallback(streamId, nullptr, folly::none); + loopForWrites(); + EXPECT_EQ(0, transport_->getConnectionState().outstandings.packets.size()); +} + +TEST_F(QuicTransportTest, NoStopSendingReadCallback) { + auto streamId = transport_->createBidirectionalStream().value(); + NiceMock readCb; + transport_->setReadCallback(streamId, &readCb); + loopForWrites(); + EXPECT_EQ(0, transport_->getConnectionState().outstandings.packets.size()); + transport_->setReadCallback(streamId, nullptr, folly::none); +} + TEST_F(QuicTransportTest, SendPathChallenge) { EXPECT_CALL(*socket_, write(_, _)).WillOnce(Invoke(bufLength)); auto& conn = transport_->getConnectionState(); diff --git a/quic/server/test/QuicSocketTest.cpp b/quic/server/test/QuicSocketTest.cpp index 306f218b3..e37068057 100644 --- a/quic/server/test/QuicSocketTest.cpp +++ b/quic/server/test/QuicSocketTest.cpp @@ -27,7 +27,7 @@ class QuicSocketTest : public Test { } void openStream(StreamId) { - EXPECT_CALL(*socket_, setReadCallback(3, &handler_)); + EXPECT_CALL(*socket_, setReadCallback(3, &handler_, _)); socket_->cb_->onNewBidirectionalStream(3); }