diff --git a/quic/api/QuicSocket.h b/quic/api/QuicSocket.h index 69fc9c48b..665bd4497 100644 --- a/quic/api/QuicSocket.h +++ b/quic/api/QuicSocket.h @@ -806,7 +806,8 @@ class QuicSocket { Type type; // sRTT at time of event - // TODO(bschlinker): Deprecate, caller can fetch transport state if desired. + // TODO(bschlinker): Deprecate, caller can fetch transport state if + // desired. std::chrono::microseconds srtt{0us}; }; diff --git a/quic/api/QuicTransportBase.cpp b/quic/api/QuicTransportBase.cpp index 2024dd184..a70b2a624 100644 --- a/quic/api/QuicTransportBase.cpp +++ b/quic/api/QuicTransportBase.cpp @@ -2095,6 +2095,15 @@ QuicTransportBase::registerByteEventCallback( byteEventMapIt->second.end(), offset, [&](uint64_t o, const ByteEventDetail& p) { return o < p.offset; }); + if (pos != byteEventMapIt->second.begin()) { + auto prev = std::prev(pos); + if ((prev->offset == offset) && (prev->callback == cb)) { + // ByteEvent has been already registered for the same type, id, + // offset and for the same recipient, return an INVALID_OPERATION error + // to prevent duplicate registrations. + return folly::makeUnexpected(LocalErrorCode::INVALID_OPERATION); + } + } byteEventMapIt->second.emplace(pos, offset, cb); } auto stream = conn_->streamManager->getStream(id); diff --git a/quic/api/test/QuicTransportBaseTest.cpp b/quic/api/test/QuicTransportBaseTest.cpp index a50ec4b42..e7ffe5166 100644 --- a/quic/api/test/QuicTransportBaseTest.cpp +++ b/quic/api/test/QuicTransportBaseTest.cpp @@ -30,6 +30,7 @@ namespace test { constexpr uint8_t kStreamIncrement = 0x04; using ByteEvent = QuicTransportBase::ByteEvent; +using ByteEventCancellation = QuicTransportBase::ByteEventCancellation; enum class TestFrameType : uint8_t { STREAM, @@ -111,6 +112,61 @@ class TestPingCallback : public QuicSocket::PingCallback { void pingTimeout() noexcept override {} }; +class TestByteEventCallback : public QuicSocket::ByteEventCallback { + public: + using HashFn = std::function; + using ComparatorFn = std::function; + + enum class Status { REGISTERED = 1, RECEIVED = 2, CANCELLED = 3 }; + + void onByteEventRegistered(ByteEvent event) override { + EXPECT_TRUE(byteEventTracker_.find(event) == byteEventTracker_.end()); + byteEventTracker_[event] = Status::REGISTERED; + } + void onByteEvent(ByteEvent event) override { + EXPECT_TRUE(byteEventTracker_.find(event) != byteEventTracker_.end()); + byteEventTracker_[event] = Status::RECEIVED; + } + void onByteEventCanceled(ByteEventCancellation cancellation) override { + const ByteEvent& event = cancellation; + EXPECT_TRUE(byteEventTracker_.find(event) != byteEventTracker_.end()); + byteEventTracker_[event] = Status::CANCELLED; + } + + std::unordered_map + getByteEventTracker() const { + return byteEventTracker_; + } + + private: + // Custom hash and comparator functions that use only id, offset and types + // (not the srtt) + HashFn hash = [](const ByteEvent& e) { + return folly::hash::hash_combine(e.id, e.offset, e.type); + }; + ComparatorFn comparator = [](const ByteEvent& lhs, const ByteEvent& rhs) { + return ((lhs.id == rhs.id) && (lhs.offset == rhs.offset)); + }; + std::unordered_map byteEventTracker_{ + /* bucket count */ 4, + hash, + comparator}; +}; + +static auto +getByteEventMatcher(ByteEvent::Type type, StreamId id, uint64_t offset) { + return AllOf( + testing::Field(&ByteEvent::type, testing::Eq(type)), + testing::Field(&ByteEvent::id, testing::Eq(id)), + testing::Field(&ByteEvent::offset, testing::Eq(offset))); +} + +static auto getByteEventTrackerMatcher( + ByteEvent event, + TestByteEventCallback::Status status) { + return Pair(getByteEventMatcher(event.type, event.id, event.offset), status); +} + class TestQuicTransport : public QuicTransportBase, public std::enable_shared_from_this { @@ -405,6 +461,7 @@ class QuicTransportImplTest : public Test { protected: std::unique_ptr evb; NiceMock connCallback; + TestByteEventCallback byteEventCallback; std::shared_ptr transport; folly::test::MockAsyncUDPSocket* socketPtr; }; @@ -1354,6 +1411,166 @@ TEST_F(QuicTransportImplTest, DeliveryCallbackUnsetOne) { transport->close(folly::none); } +TEST_F(QuicTransportImplTest, ByteEventCallbacksManagementSingleStream) { + auto stream = transport->createBidirectionalStream().value(); + uint64_t offset1 = 10, offset2 = 20; + + ByteEvent txEvent1 = { + .id = stream, .offset = offset1, .type = ByteEvent::Type::TX}; + ByteEvent txEvent2 = { + .id = stream, .offset = offset2, .type = ByteEvent::Type::TX}; + ByteEvent ackEvent1 = { + .id = stream, .offset = offset1, .type = ByteEvent::Type::ACK}; + ByteEvent ackEvent2 = { + .id = stream, .offset = offset2, .type = ByteEvent::Type::ACK}; + + // Register 2 TX and 2 ACK events for the same stream at 2 different offsets + transport->registerTxCallback( + txEvent1.id, txEvent1.offset, &byteEventCallback); + transport->registerTxCallback( + txEvent2.id, txEvent2.offset, &byteEventCallback); + transport->registerByteEventCallback( + ByteEvent::Type::ACK, ackEvent1.id, ackEvent1.offset, &byteEventCallback); + transport->registerByteEventCallback( + ByteEvent::Type::ACK, ackEvent2.id, ackEvent2.offset, &byteEventCallback); + EXPECT_THAT( + byteEventCallback.getByteEventTracker(), + UnorderedElementsAre( + getByteEventTrackerMatcher( + txEvent1, TestByteEventCallback::Status::REGISTERED), + getByteEventTrackerMatcher( + txEvent2, TestByteEventCallback::Status::REGISTERED), + getByteEventTrackerMatcher( + ackEvent1, TestByteEventCallback::Status::REGISTERED), + getByteEventTrackerMatcher( + ackEvent2, TestByteEventCallback::Status::REGISTERED))); + + // Registering the same events a second time will result in an error. + // as double registrations are not allowed. + folly::Expected ret; + ret = transport->registerTxCallback( + txEvent1.id, txEvent1.offset, &byteEventCallback); + EXPECT_EQ(LocalErrorCode::INVALID_OPERATION, ret.error()); + ret = transport->registerTxCallback( + txEvent2.id, txEvent2.offset, &byteEventCallback); + EXPECT_EQ(LocalErrorCode::INVALID_OPERATION, ret.error()); + ret = transport->registerByteEventCallback( + ByteEvent::Type::ACK, ackEvent1.id, ackEvent1.offset, &byteEventCallback); + EXPECT_EQ(LocalErrorCode::INVALID_OPERATION, ret.error()); + ret = transport->registerByteEventCallback( + ByteEvent::Type::ACK, ackEvent2.id, ackEvent2.offset, &byteEventCallback); + EXPECT_EQ(LocalErrorCode::INVALID_OPERATION, ret.error()); + EXPECT_THAT( + byteEventCallback.getByteEventTracker(), + UnorderedElementsAre( + getByteEventTrackerMatcher( + txEvent1, TestByteEventCallback::Status::REGISTERED), + getByteEventTrackerMatcher( + txEvent2, TestByteEventCallback::Status::REGISTERED), + getByteEventTrackerMatcher( + ackEvent1, TestByteEventCallback::Status::REGISTERED), + getByteEventTrackerMatcher( + ackEvent2, TestByteEventCallback::Status::REGISTERED))); + + // On the ACK events, the transport usually sets the srtt value. This value + // should have NO EFFECT on the ByteEvent's hash and we still should be able + // to identify the previously registered byte event correctly. + ackEvent1.srtt = std::chrono::microseconds(1000); + ackEvent2.srtt = std::chrono::microseconds(2000); + + // Deliver 1 TX and 1 ACK event. Cancel the other TX anc ACK event + byteEventCallback.onByteEvent(txEvent1); + byteEventCallback.onByteEvent(ackEvent2); + byteEventCallback.onByteEventCanceled(txEvent2); + byteEventCallback.onByteEventCanceled((ByteEventCancellation)ackEvent1); + + EXPECT_THAT( + byteEventCallback.getByteEventTracker(), + UnorderedElementsAre( + getByteEventTrackerMatcher( + txEvent1, TestByteEventCallback::Status::RECEIVED), + getByteEventTrackerMatcher( + txEvent2, TestByteEventCallback::Status::CANCELLED), + getByteEventTrackerMatcher( + ackEvent1, TestByteEventCallback::Status::CANCELLED), + getByteEventTrackerMatcher( + ackEvent2, TestByteEventCallback::Status::RECEIVED))); +} + +TEST_F(QuicTransportImplTest, ByteEventCallbacksManagementDifferentStreams) { + auto stream1 = transport->createBidirectionalStream().value(); + auto stream2 = transport->createBidirectionalStream().value(); + + ByteEvent txEvent1 = { + .id = stream1, .offset = 10, .type = ByteEvent::Type::TX}; + ByteEvent txEvent2 = { + .id = stream2, .offset = 20, .type = ByteEvent::Type::TX}; + ByteEvent ackEvent1 = { + .id = stream1, .offset = 10, .type = ByteEvent::Type::ACK}; + ByteEvent ackEvent2 = { + .id = stream2, .offset = 20, .type = ByteEvent::Type::ACK}; + + EXPECT_THAT(byteEventCallback.getByteEventTracker(), IsEmpty()); + // Register 2 TX and 2 ACK events for 2 separate streams. + transport->registerTxCallback( + txEvent1.id, txEvent1.offset, &byteEventCallback); + transport->registerTxCallback( + txEvent2.id, txEvent2.offset, &byteEventCallback); + transport->registerByteEventCallback( + ByteEvent::Type::ACK, ackEvent1.id, ackEvent1.offset, &byteEventCallback); + transport->registerByteEventCallback( + ByteEvent::Type::ACK, ackEvent2.id, ackEvent2.offset, &byteEventCallback); + EXPECT_THAT( + byteEventCallback.getByteEventTracker(), + UnorderedElementsAre( + getByteEventTrackerMatcher( + txEvent1, TestByteEventCallback::Status::REGISTERED), + getByteEventTrackerMatcher( + txEvent2, TestByteEventCallback::Status::REGISTERED), + getByteEventTrackerMatcher( + ackEvent1, TestByteEventCallback::Status::REGISTERED), + getByteEventTrackerMatcher( + ackEvent2, TestByteEventCallback::Status::REGISTERED))); + + // On the ACK events, the transport usually sets the srtt value. This value + // should have NO EFFECT on the ByteEvent's hash and we should still be able + // to identify the previously registered byte event correctly. + ackEvent1.srtt = std::chrono::microseconds(1000); + ackEvent2.srtt = std::chrono::microseconds(2000); + + // Deliver the TX event for stream 1 and cancel the ACK event for stream 2 + byteEventCallback.onByteEvent(txEvent1); + byteEventCallback.onByteEventCanceled((ByteEventCancellation)ackEvent2); + + EXPECT_THAT( + byteEventCallback.getByteEventTracker(), + UnorderedElementsAre( + getByteEventTrackerMatcher( + txEvent1, TestByteEventCallback::Status::RECEIVED), + getByteEventTrackerMatcher( + txEvent2, TestByteEventCallback::Status::REGISTERED), + getByteEventTrackerMatcher( + ackEvent1, TestByteEventCallback::Status::REGISTERED), + getByteEventTrackerMatcher( + ackEvent2, TestByteEventCallback::Status::CANCELLED))); + + // Deliver the TX event for stream 2 and cancel the ACK event for stream 1 + byteEventCallback.onByteEvent(txEvent2); + byteEventCallback.onByteEventCanceled((ByteEventCancellation)ackEvent1); + + EXPECT_THAT( + byteEventCallback.getByteEventTracker(), + UnorderedElementsAre( + getByteEventTrackerMatcher( + txEvent1, TestByteEventCallback::Status::RECEIVED), + getByteEventTrackerMatcher( + txEvent2, TestByteEventCallback::Status::RECEIVED), + getByteEventTrackerMatcher( + ackEvent1, TestByteEventCallback::Status::CANCELLED), + getByteEventTrackerMatcher( + ackEvent2, TestByteEventCallback::Status::CANCELLED))); +} + TEST_F(QuicTransportImplTest, RegisterTxDeliveryCallbackLowerThanExpected) { auto stream = transport->createBidirectionalStream().value(); StrictMock txcb1; diff --git a/quic/api/test/QuicTransportTest.cpp b/quic/api/test/QuicTransportTest.cpp index b416f80eb..2e85f26ac 100644 --- a/quic/api/test/QuicTransportTest.cpp +++ b/quic/api/test/QuicTransportTest.cpp @@ -2240,17 +2240,20 @@ TEST_F(QuicTransportTest, InvokeTxCallbacksSingleByte) { Mock::VerifyAndClearExpectations(&firstByteTxCb); Mock::VerifyAndClearExpectations(&lastByteTxCb); - // even if we register pastlastByte again, it shouldn't be triggered + // Even if we register pastlastByte again, it shouldn't trigger + // onByteEventRegistered because this is a duplicate registration. EXPECT_CALL(pastlastByteTxCb, onByteEventRegistered(getTxMatcher(stream, 1))) - .Times(1); - transport_->registerTxCallback(stream, 1, &pastlastByteTxCb); + .Times(0); + auto ret = transport_->registerTxCallback(stream, 1, &pastlastByteTxCb); + EXPECT_EQ(LocalErrorCode::INVALID_OPERATION, ret.error()); Mock::VerifyAndClearExpectations(&pastlastByteTxCb); // pastlastByteTxCb::onByteEvent will never get called // cancel gets called instead - // onByteEventCanceled called twice, since added twice + // Even though we attempted to register the ByteEvent twice, it resulted in + // an error. So, onByteEventCanceled should be called only once. EXPECT_CALL(pastlastByteTxCb, onByteEventCanceled(getTxMatcher(stream, 1))) - .Times(2); + .Times(1); transport_->close(folly::none); Mock::VerifyAndClearExpectations(&pastlastByteTxCb); }