diff --git a/quic/api/QuicSocket.h b/quic/api/QuicSocket.h index 1a69b270f..363b0f4b7 100644 --- a/quic/api/QuicSocket.h +++ b/quic/api/QuicSocket.h @@ -404,6 +404,14 @@ class QuicSocket { */ virtual void cancelDeliveryCallbacksForStream(StreamId streamId) = 0; + /** + * Invoke onCanceled on all the delivery callbacks registered for streamId for + * offsets lower than the offset provided. + */ + virtual void cancelDeliveryCallbacksForStream( + StreamId streamId, + uint64_t offset) = 0; + /** * Pause/Resume read callback being triggered when data is available. */ diff --git a/quic/api/QuicTransportBase.cpp b/quic/api/QuicTransportBase.cpp index 4434a4047..889fa66c7 100644 --- a/quic/api/QuicTransportBase.cpp +++ b/quic/api/QuicTransportBase.cpp @@ -895,6 +895,12 @@ QuicTransportBase::sendDataExpired(StreamId id, uint64_t offset) { return folly::makeUnexpected(LocalErrorCode::STREAM_NOT_EXISTS); } auto newOffset = advanceMinimumRetransmittableOffset(stream, offset); + + // Invoke any delivery callbacks that are set for any offset below newOffset. + if (newOffset) { + cancelDeliveryCallbacksForStream(id, *newOffset); + } + updateWriteLooper(true); return folly::makeExpected(newOffset); } @@ -954,6 +960,14 @@ void QuicTransportBase::invokeDataRejectedCallbacks() { auto dataRejectedCb = callbackData->second.dataRejectedCb; auto stream = CHECK_NOTNULL(conn_->streamManager->getStream(streamId)); + + // Invoke any delivery callbacks that are set for any offset below newly set + // minimumRetransmittableOffset. + if (!stream->streamReadError) { + cancelDeliveryCallbacksForStream( + streamId, stream->minimumRetransmittableOffset); + } + if (dataRejectedCb && !stream->streamReadError) { VLOG(10) << "invoking data rejected callback on stream=" << streamId << " " << *this; @@ -1050,6 +1064,46 @@ void QuicTransportBase::cancelDeliveryCallbacksForStream(StreamId streamId) { deliveryCallbacks_.erase(deliveryCallbackIter); } +void QuicTransportBase::cancelDeliveryCallbacksForStream( + StreamId streamId, + uint64_t offset) { + if (isReceivingStream(conn_->nodeType, streamId)) { + return; + } + + auto deliveryCallbackIter = deliveryCallbacks_.find(streamId); + if (deliveryCallbackIter == deliveryCallbacks_.end()) { + conn_->streamManager->removeDeliverable(streamId); + return; + } + + // Callbacks are kept sorted by offset, so we can just walk the queue and + // invoke those with offset below provided offset. + while (!deliveryCallbackIter->second.empty()) { + auto deliveryCallback = deliveryCallbackIter->second.front(); + auto& cbOffset = deliveryCallback.first; + if (cbOffset < offset) { + deliveryCallbackIter->second.pop_front(); + deliveryCallback.second->onCanceled(streamId, cbOffset); + if (closeState_ != CloseState::OPEN) { + // socket got closed - we can't use deliveryCallbackIter anymore, + // closeImpl should take care of delivering callbacks that are left in + // deliveryCallbackIter->second + return; + } + } else { + // Only larger or equal offsets left, exit the loop. + break; + } + } + + // Clean up state for this stream if no callbacks left to invoke. + if (deliveryCallbackIter->second.empty()) { + conn_->streamManager->removeDeliverable(streamId); + deliveryCallbacks_.erase(deliveryCallbackIter); + } +} + folly::Expected, LocalErrorCode> QuicTransportBase::read( StreamId id, size_t maxLen) { diff --git a/quic/api/QuicTransportBase.h b/quic/api/QuicTransportBase.h index 382eb14c5..ad3199283 100644 --- a/quic/api/QuicTransportBase.h +++ b/quic/api/QuicTransportBase.h @@ -312,6 +312,13 @@ class QuicTransportBase : public QuicSocket { */ void cancelDeliveryCallbacksForStream(StreamId streamId) override; + /** + * Invoke onCanceled on all the delivery callbacks registered for streamId for + * offsets lower than the offset provided. + */ + void cancelDeliveryCallbacksForStream(StreamId streamId, uint64_t offset) + override; + // Timeout functions class LossTimeout : public folly::HHWheelTimer::Callback { public: diff --git a/quic/api/test/MockQuicSocket.h b/quic/api/test/MockQuicSocket.h index f3dbbea02..78ec51cca 100644 --- a/quic/api/test/MockQuicSocket.h +++ b/quic/api/test/MockQuicSocket.h @@ -68,6 +68,9 @@ class MockQuicSocket : public QuicSocket { MOCK_METHOD0(unsetAllPeekCallbacks, void()); MOCK_METHOD0(unsetAllDeliveryCallbacks, void()); MOCK_METHOD1(cancelDeliveryCallbacksForStream, void(StreamId)); + MOCK_METHOD2( + cancelDeliveryCallbacksForStream, + void(StreamId, uint64_t offset)); MOCK_METHOD1( setConnectionFlowControlWindow, folly::Expected(uint64_t)); diff --git a/quic/api/test/QuicTransportBaseTest.cpp b/quic/api/test/QuicTransportBaseTest.cpp index fedd74cdc..9f86fea19 100644 --- a/quic/api/test/QuicTransportBaseTest.cpp +++ b/quic/api/test/QuicTransportBaseTest.cpp @@ -1097,6 +1097,62 @@ TEST_F(QuicTransportImplTest, DeliveryCallbackUnsetOne) { transport->close(folly::none); } +TEST_F(QuicTransportImplTest, DeliveryCallbackOnSendDataExpire) { + InSequence enforceOrder; + + transport->transportConn->partialReliabilityEnabled = true; + + auto stream1 = transport->createBidirectionalStream().value(); + auto stream2 = transport->createBidirectionalStream().value(); + MockDeliveryCallback dcb1; + MockDeliveryCallback dcb2; + + transport->registerDeliveryCallback(stream1, 10, &dcb1); + transport->registerDeliveryCallback(stream2, 20, &dcb2); + + EXPECT_CALL(dcb1, onCanceled(_, _)); + EXPECT_CALL(dcb2, onCanceled(_, _)).Times(0); + + auto res = transport->sendDataExpired(stream1, 11); + EXPECT_EQ(res.hasError(), false); + + Mock::VerifyAndClearExpectations(&dcb1); + Mock::VerifyAndClearExpectations(&dcb2); + + EXPECT_CALL(dcb1, onCanceled(_, _)).Times(0); + EXPECT_CALL(dcb2, onCanceled(_, _)); + + transport->close(folly::none); +} + +TEST_F(QuicTransportImplTest, DeliveryCallbackOnSendDataExpireCallbacksLeft) { + InSequence enforceOrder; + + transport->transportConn->partialReliabilityEnabled = true; + + auto stream1 = transport->createBidirectionalStream().value(); + auto stream2 = transport->createBidirectionalStream().value(); + MockDeliveryCallback dcb1; + MockDeliveryCallback dcb2; + + transport->registerDeliveryCallback(stream1, 10, &dcb1); + transport->registerDeliveryCallback(stream1, 20, &dcb1); + transport->registerDeliveryCallback(stream2, 20, &dcb2); + + EXPECT_CALL(dcb1, onCanceled(_, _)); + EXPECT_CALL(dcb2, onCanceled(_, _)).Times(0); + + auto res = transport->sendDataExpired(stream1, 11); + EXPECT_EQ(res.hasError(), false); + Mock::VerifyAndClearExpectations(&dcb1); + Mock::VerifyAndClearExpectations(&dcb2); + + EXPECT_CALL(dcb2, onCanceled(_, _)); + EXPECT_CALL(dcb1, onCanceled(_, _)).Times(1); + + transport->close(folly::none); +} + TEST_F(QuicTransportImplTest, RegisterDeliveryCallbackLowerThanExpected) { auto stream = transport->createBidirectionalStream().value(); MockDeliveryCallback dcb1; @@ -2166,6 +2222,98 @@ TEST_F(QuicTransportImplTest, DataRejecteddCallbackDataAvailable) { transport.reset(); } +TEST_F(QuicTransportImplTest, DataRejecteddCallbackWithDeliveryCallbacks) { + InSequence enforceOrder; + + transport->transportConn->partialReliabilityEnabled = true; + + auto stream1 = transport->createBidirectionalStream().value(); + auto stream2 = transport->createBidirectionalStream().value(); + + MockDeliveryCallback dcb1; + MockDeliveryCallback dcb2; + MockDataRejectedCallback dataRejectedCb1; + MockDataRejectedCallback dataRejectedCb2; + + transport->registerDeliveryCallback(stream1, 10, &dcb1); + transport->registerDeliveryCallback(stream2, 20, &dcb2); + + transport->setDataRejectedCallback(stream1, &dataRejectedCb1); + transport->setDataRejectedCallback(stream2, &dataRejectedCb2); + + EXPECT_CALL(dcb1, onCanceled(stream1, 10)).Times(1); + EXPECT_CALL(dcb2, onCanceled(_, _)).Times(0); + EXPECT_CALL(dataRejectedCb1, onDataRejected(stream1, 15)); + transport->addMinStreamDataFrameToStream( + MinStreamDataFrame(stream1, kDefaultStreamWindowSize, 15)); + + Mock::VerifyAndClearExpectations(&dcb1); + Mock::VerifyAndClearExpectations(&dcb2); + Mock::VerifyAndClearExpectations(&dataRejectedCb1); + + EXPECT_CALL(dcb1, onCanceled(_, _)).Times(0); + EXPECT_CALL(dcb2, onCanceled(stream2, 20)).Times(1); + EXPECT_CALL(dataRejectedCb2, onDataRejected(stream2, 23)); + transport->addMinStreamDataFrameToStream( + MinStreamDataFrame(stream2, kDefaultStreamWindowSize, 23)); + + Mock::VerifyAndClearExpectations(&dcb1); + Mock::VerifyAndClearExpectations(&dcb2); + Mock::VerifyAndClearExpectations(&dataRejectedCb2); + + EXPECT_CALL(dcb1, onCanceled(_, _)).Times(0); + EXPECT_CALL(dcb2, onCanceled(_, _)).Times(0); + transport->close(folly::none); +} + +TEST_F( + QuicTransportImplTest, + DataRejecteddCallbackWithDeliveryCallbacksSomeLeft) { + InSequence enforceOrder; + + transport->transportConn->partialReliabilityEnabled = true; + + auto stream1 = transport->createBidirectionalStream().value(); + auto stream2 = transport->createBidirectionalStream().value(); + + MockDeliveryCallback dcb1; + MockDeliveryCallback dcb2; + MockDataRejectedCallback dataRejectedCb1; + MockDataRejectedCallback dataRejectedCb2; + + transport->registerDeliveryCallback(stream1, 10, &dcb1); + transport->registerDeliveryCallback(stream1, 25, &dcb1); + transport->registerDeliveryCallback(stream2, 20, &dcb2); + transport->registerDeliveryCallback(stream2, 29, &dcb2); + + transport->setDataRejectedCallback(stream1, &dataRejectedCb1); + transport->setDataRejectedCallback(stream2, &dataRejectedCb2); + + EXPECT_CALL(dcb1, onCanceled(stream1, 10)).Times(1); + EXPECT_CALL(dcb2, onCanceled(_, _)).Times(0); + EXPECT_CALL(dataRejectedCb1, onDataRejected(stream1, 15)); + transport->addMinStreamDataFrameToStream( + MinStreamDataFrame(stream1, kDefaultStreamWindowSize, 15)); + + Mock::VerifyAndClearExpectations(&dcb1); + Mock::VerifyAndClearExpectations(&dcb2); + Mock::VerifyAndClearExpectations(&dataRejectedCb1); + + EXPECT_CALL(dcb1, onCanceled(_, _)).Times(0); + EXPECT_CALL(dcb2, onCanceled(stream2, 20)).Times(1); + EXPECT_CALL(dataRejectedCb2, onDataRejected(stream2, 23)); + transport->addMinStreamDataFrameToStream( + MinStreamDataFrame(stream2, kDefaultStreamWindowSize, 23)); + + Mock::VerifyAndClearExpectations(&dcb1); + Mock::VerifyAndClearExpectations(&dcb2); + Mock::VerifyAndClearExpectations(&dataRejectedCb2); + + EXPECT_CALL(dcb2, onCanceled(stream2, 29)).Times(1); + EXPECT_CALL(dcb1, onCanceled(stream1, 25)).Times(1); + transport->close(folly::none); +} + TEST_F(QuicTransportImplTest, DataRejectedCallbackChangeCallback) { transport->transportConn->partialReliabilityEnabled = true;