diff --git a/quic/api/QuicSocketLite.h b/quic/api/QuicSocketLite.h index b12bd2deb..44c402e18 100644 --- a/quic/api/QuicSocketLite.h +++ b/quic/api/QuicSocketLite.h @@ -675,6 +675,16 @@ class QuicSocketLite { StreamId id, ApplicationErrorCode error) = 0; + /** + * This is used in conjunction with reliable resets. When we send data on a + * stream and want to mark which offset will constitute the reliable size in a + * future call to resetStreamReliably, we call this function. This function + * can potentially be called multiple times on a stream to advance the offset, + * but it is an error to call it after sending a reset. + */ + virtual folly::Expected + updateReliableDeliveryCheckpoint(StreamId id) = 0; + /** * Determine if transport is open and ready to read or write. * diff --git a/quic/api/QuicTransportBaseLite.cpp b/quic/api/QuicTransportBaseLite.cpp index 494076864..a534ab571 100644 --- a/quic/api/QuicTransportBaseLite.cpp +++ b/quic/api/QuicTransportBaseLite.cpp @@ -405,6 +405,23 @@ folly::Expected QuicTransportBaseLite::resetStream( return resetStreamInternal(id, errorCode); } +folly::Expected +QuicTransportBaseLite::updateReliableDeliveryCheckpoint(StreamId id) { + if (!conn_->streamManager->streamExists(id)) { + return folly::makeUnexpected(LocalErrorCode::STREAM_NOT_EXISTS); + } + auto stream = CHECK_NOTNULL(conn_->streamManager->getStream(id)); + if (stream->sendState == StreamSendState::ResetSent) { + // We already sent a reset, so there's really no reason why we should be + // doing any more checkpointing, especially since we cannot + // increase the reliable size in subsequent resets. + return folly::makeUnexpected(LocalErrorCode::INVALID_OPERATION); + } + stream->reliableResetCheckpoint = + stream->currentWriteOffset + stream->pendingWrites.chainLength(); + return folly::Unit(); +} + void QuicTransportBaseLite::cancelDeliveryCallbacksForStream(StreamId id) { cancelByteEventCallbacksForStream(ByteEvent::Type::ACK, id); } diff --git a/quic/api/QuicTransportBaseLite.h b/quic/api/QuicTransportBaseLite.h index 1cc642357..af16c3f4e 100644 --- a/quic/api/QuicTransportBaseLite.h +++ b/quic/api/QuicTransportBaseLite.h @@ -91,6 +91,9 @@ class QuicTransportBaseLite : virtual public QuicSocketLite, StreamId id, ApplicationErrorCode errorCode) override; + folly::Expected updateReliableDeliveryCheckpoint( + StreamId id) override; + /** * Invoke onCanceled on all the delivery callbacks registered for streamId. */ diff --git a/quic/api/test/MockQuicSocket.h b/quic/api/test/MockQuicSocket.h index 1ad66cb13..73071b32d 100644 --- a/quic/api/test/MockQuicSocket.h +++ b/quic/api/test/MockQuicSocket.h @@ -268,6 +268,10 @@ class MockQuicSocket : public QuicSocket { (folly::Expected), resetStream, (StreamId, ApplicationErrorCode)); + MOCK_METHOD( + (folly::Expected), + updateReliableDeliveryCheckpoint, + (StreamId)); MOCK_METHOD( (folly::Expected), maybeResetStreamFromReadError, diff --git a/quic/api/test/QuicTransportTest.cpp b/quic/api/test/QuicTransportTest.cpp index 905fd2e26..2d4715984 100644 --- a/quic/api/test/QuicTransportTest.cpp +++ b/quic/api/test/QuicTransportTest.cpp @@ -2159,6 +2159,123 @@ TEST_F(QuicTransportTest, RstStream) { *transport_->getConnectionState().streamManager, stream->id)); } +TEST_F(QuicTransportTest, CheckpointBeforeAnyWrites) { + auto streamId = transport_->createBidirectionalStream().value(); + auto streamState = + transport_->getConnectionState().streamManager->findStream(streamId); + + auto checkpointResult = + transport_->updateReliableDeliveryCheckpoint(streamId); + EXPECT_FALSE(checkpointResult.hasError()); + + EXPECT_EQ(streamState->reliableResetCheckpoint, 0); +} + +TEST_F(QuicTransportTest, CheckpointAfterWriteBuffered) { + auto streamId = transport_->createBidirectionalStream().value(); + auto streamState = + transport_->getConnectionState().streamManager->findStream(streamId); + + auto buf1 = IOBuf::create(10); + buf1->append(10); + transport_->writeChain(streamId, std::move(buf1), false); + EXPECT_EQ(streamState->pendingWrites.chainLength(), 10); + + auto checkpointResult = + transport_->updateReliableDeliveryCheckpoint(streamId); + EXPECT_FALSE(checkpointResult.hasError()); + + EXPECT_EQ(streamState->reliableResetCheckpoint, 10); +} + +TEST_F(QuicTransportTest, CheckpointAfterWriteWrittenToWire) { + auto streamId = transport_->createBidirectionalStream().value(); + auto streamState = + transport_->getConnectionState().streamManager->findStream(streamId); + + auto buf1 = IOBuf::create(10); + buf1->append(10); + transport_->writeChain(streamId, std::move(buf1), false); + EXPECT_EQ(streamState->pendingWrites.chainLength(), 10); + + EXPECT_CALL(*socket_, write(_, _, _)) + .WillOnce(testing::WithArgs<1, 2>(Invoke(getTotalIovecLen))); + loopForWrites(); + + EXPECT_TRUE(streamState->pendingWrites.empty()); + + auto checkpointResult = + transport_->updateReliableDeliveryCheckpoint(streamId); + EXPECT_FALSE(checkpointResult.hasError()); + + EXPECT_EQ(streamState->reliableResetCheckpoint, 10); +} + +TEST_F(QuicTransportTest, CheckpointAfterWritePartiallyWrittenToWire) { + auto streamId = transport_->createBidirectionalStream().value(); + auto streamState = + transport_->getConnectionState().streamManager->findStream(streamId); + + auto buf1 = IOBuf::create(10); + buf1->append(10); + transport_->writeChain(streamId, std::move(buf1), false); + EXPECT_EQ(streamState->pendingWrites.chainLength(), 10); + + EXPECT_CALL(*socket_, write(_, _, _)) + .WillOnce(testing::WithArgs<1, 2>(Invoke(getTotalIovecLen))); + loopForWrites(); + EXPECT_TRUE(streamState->pendingWrites.empty()); + + auto buf2 = IOBuf::create(5); + buf2->append(5); + transport_->writeChain(streamId, std::move(buf2), false); + EXPECT_EQ(streamState->pendingWrites.chainLength(), 5); + + auto checkpointResult = + transport_->updateReliableDeliveryCheckpoint(streamId); + EXPECT_FALSE(checkpointResult.hasError()); + + EXPECT_EQ(streamState->reliableResetCheckpoint, 15); +} + +TEST_F(QuicTransportTest, CheckpointMultipleTimes) { + auto streamId = transport_->createBidirectionalStream().value(); + auto streamState = + transport_->getConnectionState().streamManager->findStream(streamId); + + auto buf1 = IOBuf::create(10); + buf1->append(10); + transport_->writeChain(streamId, std::move(buf1), false); + auto checkpointResult1 = + transport_->updateReliableDeliveryCheckpoint(streamId); + EXPECT_FALSE(checkpointResult1.hasError()); + EXPECT_EQ(streamState->reliableResetCheckpoint, 10); + + auto buf2 = IOBuf::create(7); + buf2->append(7); + transport_->writeChain(streamId, std::move(buf2), false); + auto checkpointResult2 = + transport_->updateReliableDeliveryCheckpoint(streamId); + EXPECT_FALSE(checkpointResult2.hasError()); + EXPECT_EQ(streamState->reliableResetCheckpoint, 17); + + auto buf3 = IOBuf::create(2); + buf3->append(2); + transport_->writeChain(streamId, std::move(buf3), false); + auto checkpointResult3 = + transport_->updateReliableDeliveryCheckpoint(streamId); + EXPECT_FALSE(checkpointResult3.hasError()); + EXPECT_EQ(streamState->reliableResetCheckpoint, 19); +} + +TEST_F(QuicTransportTest, CheckpointAfterSendingReset) { + auto streamId = transport_->createBidirectionalStream().value(); + transport_->resetStream(streamId, GenericApplicationErrorCode::UNKNOWN); + auto checkpointResult = + transport_->updateReliableDeliveryCheckpoint(streamId); + EXPECT_TRUE(checkpointResult.hasError()); +} + TEST_F(QuicTransportTest, StopSending) { auto streamId = transport_->createBidirectionalStream().value(); EXPECT_CALL(*socket_, write(_, _, _)) diff --git a/quic/state/StreamData.h b/quic/state/StreamData.h index ff2972800..f0e964967 100644 --- a/quic/state/StreamData.h +++ b/quic/state/StreamData.h @@ -184,6 +184,12 @@ struct QuicStreamLike { // subsequently send a RESET_STREAM frame, we reset this value to none. Optional reliableSizeToPeer; + // When the application calls updateReliableDeliveryCheckpoint() on the + // QuicSocket, this is set to the size of the data written to the QuicSocket + // so far. This includes data buffered in the QUIC layer that hasn't yet been + // written out to the wire. + uint64_t reliableResetCheckpoint{0}; + // This is set if we get a RESET_STREAM_AT or RESET_STREAM frame from the // peer. If we get a RESET_STREAM_AT frame, we set this value to the reliable // size in that frame. If we get a RESET_STREAM frame, we set this value to 0.