diff --git a/quic/api/Observer.h b/quic/api/Observer.h index a7a8fbfec..c4f1520c6 100644 --- a/quic/api/Observer.h +++ b/quic/api/Observer.h @@ -93,18 +93,23 @@ class InstrumentationObserver { /** * packetLossDetected() is invoked when a packet loss is detected. * + * @param socket Socket when the callback is processed. * @param packet const reference to the packet that was determined to be * lost. */ virtual void packetLossDetected( + QuicSocket*, /* socket */ const struct ObserverLossEvent& /* lossEvent */) {} /** * rttSampleGenerated() is invoked when a RTT sample is made. * - * @param packet const reference to the packet with the RTT + * @param socket Socket when the callback is processed. + * @param packet const reference to the packet with the RTT. */ - virtual void rttSampleGenerated(const PacketRTT& /* RTT sample */) {} + virtual void rttSampleGenerated( + QuicSocket*, /* socket */ + const PacketRTT& /* RTT sample */) {} }; // Container for instrumentation observers. diff --git a/quic/api/QuicTransportBase.cpp b/quic/api/QuicTransportBase.cpp index a906e12c4..d8edc2e54 100644 --- a/quic/api/QuicTransportBase.cpp +++ b/quic/api/QuicTransportBase.cpp @@ -1662,7 +1662,7 @@ void QuicTransportBase::processCallbacksAfterNetworkData() { // to call any callbacks added for observers for (const auto& callback : conn_->pendingCallbacks) { - callback(); + callback(this); } conn_->pendingCallbacks.clear(); diff --git a/quic/api/test/Mocks.h b/quic/api/test/Mocks.h index 1d33010a9..143ce7ef2 100644 --- a/quic/api/test/Mocks.h +++ b/quic/api/test/Mocks.h @@ -326,13 +326,18 @@ class MockInstrumentationObserver : public InstrumentationObserver { public: GMOCK_METHOD1_(, noexcept, , observerDetach, void(QuicSocket*)); GMOCK_METHOD1_(, noexcept, , appRateLimited, void(QuicSocket*)); - GMOCK_METHOD1_( + GMOCK_METHOD2_( , noexcept, , packetLossDetected, - void(const ObserverLossEvent&)); - GMOCK_METHOD1_(, noexcept, , rttSampleGenerated, void(const PacketRTT&)); + void(QuicSocket*, const ObserverLossEvent&)); + GMOCK_METHOD2_( + , + noexcept, + , + rttSampleGenerated, + void(QuicSocket*, const PacketRTT&)); static auto getLossPacketMatcher(bool reorderLoss, bool timeoutLoss) { return AllOf( diff --git a/quic/api/test/QuicTransportBaseTest.cpp b/quic/api/test/QuicTransportBaseTest.cpp index ca6cc259b..8ed4d34f6 100644 --- a/quic/api/test/QuicTransportBaseTest.cpp +++ b/quic/api/test/QuicTransportBaseTest.cpp @@ -3657,7 +3657,7 @@ TEST_F( } TEST_F(QuicTransportImplTest, ImplementationObserverCallbacksDeleted) { - auto noopCallback = [] {}; + auto noopCallback = [](QuicSocket*) {}; transport->transportConn->pendingCallbacks.emplace_back(noopCallback); EXPECT_EQ(1, size(transport->transportConn->pendingCallbacks)); transport->invokeProcessCallbacksAfterNetworkData(); @@ -3666,7 +3666,7 @@ TEST_F(QuicTransportImplTest, ImplementationObserverCallbacksDeleted) { TEST_F(QuicTransportImplTest, ImplementationObserverCallbacksInvoked) { uint32_t callbacksInvoked = 0; - auto countingCallback = [&]() { callbacksInvoked++; }; + auto countingCallback = [&](QuicSocket*) { callbacksInvoked++; }; for (int i = 0; i < 2; i++) { transport->transportConn->pendingCallbacks.emplace_back(countingCallback); @@ -3678,5 +3678,22 @@ TEST_F(QuicTransportImplTest, ImplementationObserverCallbacksInvoked) { EXPECT_EQ(0, size(transport->transportConn->pendingCallbacks)); } +TEST_F( + QuicTransportImplTest, + ImplementationObserverCallbacksCorrectQuicSocket) { + QuicSocket* returnedSocket = nullptr; + auto func = [&](QuicSocket* qSocket) { returnedSocket = qSocket; }; + auto ib = MockInstrumentationObserver(); + + EXPECT_EQ(0, size(transport->transportConn->pendingCallbacks)); + transport->transportConn->pendingCallbacks.emplace_back(func); + EXPECT_EQ(1, size(transport->transportConn->pendingCallbacks)); + + transport->invokeProcessCallbacksAfterNetworkData(); + EXPECT_EQ(0, size(transport->transportConn->pendingCallbacks)); + + EXPECT_EQ(transport.get(), returnedSocket); +} + } // namespace test } // namespace quic diff --git a/quic/loss/QuicLossFunctions.h b/quic/loss/QuicLossFunctions.h index 377e1f7d7..d46c81b3b 100644 --- a/quic/loss/QuicLossFunctions.h +++ b/quic/loss/QuicLossFunctions.h @@ -291,9 +291,10 @@ folly::Optional detectLossPackets( // if there are observers, enqueue a function to call it if (observerLossEvent.hasPackets()) { for (const auto& observer : conn.instrumentationObservers_) { - conn.pendingCallbacks.emplace_back([observer, observerLossEvent] { - observer->packetLossDetected(observerLossEvent); - }); + conn.pendingCallbacks.emplace_back( + [observer, observerLossEvent](QuicSocket* qSocket) { + observer->packetLossDetected(qSocket, observerLossEvent); + }); } } diff --git a/quic/loss/test/QuicLossFunctionsTest.cpp b/quic/loss/test/QuicLossFunctionsTest.cpp index aa869dfae..06f97072a 100644 --- a/quic/loss/test/QuicLossFunctionsTest.cpp +++ b/quic/loss/test/QuicLossFunctionsTest.cpp @@ -1922,16 +1922,18 @@ TEST_F(QuicLossFunctionsTest, TestReorderLossObserverCallback) { // 1, 2 and 6 are "lost" due to reodering. None lost due to timeout EXPECT_CALL( ib, - packetLossDetected(Field( - &InstrumentationObserver::ObserverLossEvent::lostPackets, - UnorderedElementsAre( - getLossPacketMatcher(true, false), - getLossPacketMatcher(true, false), - getLossPacketMatcher(true, false))))) + packetLossDetected( + nullptr, + Field( + &InstrumentationObserver::ObserverLossEvent::lostPackets, + UnorderedElementsAre( + getLossPacketMatcher(true, false), + getLossPacketMatcher(true, false), + getLossPacketMatcher(true, false))))) .Times(1); for (auto& callback : conn->pendingCallbacks) { - callback(); + callback(nullptr); } } @@ -1973,20 +1975,22 @@ TEST_F(QuicLossFunctionsTest, TestTimeoutLossObserverCallback) { // expecting all packets to be lost due to timeout EXPECT_CALL( ib, - packetLossDetected(Field( - &InstrumentationObserver::ObserverLossEvent::lostPackets, - UnorderedElementsAre( - getLossPacketMatcher(false, true), - getLossPacketMatcher(false, true), - getLossPacketMatcher(false, true), - getLossPacketMatcher(false, true), - getLossPacketMatcher(false, true), - getLossPacketMatcher(false, true), - getLossPacketMatcher(false, true))))) + packetLossDetected( + nullptr, + Field( + &InstrumentationObserver::ObserverLossEvent::lostPackets, + UnorderedElementsAre( + getLossPacketMatcher(false, true), + getLossPacketMatcher(false, true), + getLossPacketMatcher(false, true), + getLossPacketMatcher(false, true), + getLossPacketMatcher(false, true), + getLossPacketMatcher(false, true), + getLossPacketMatcher(false, true))))) .Times(1); for (auto& callback : conn->pendingCallbacks) { - callback(); + callback(nullptr); } } @@ -2034,17 +2038,19 @@ TEST_F(QuicLossFunctionsTest, TestTimeoutAndReorderLossObserverCallback) { // 7 just timed out EXPECT_CALL( ib, - packetLossDetected(Field( - &InstrumentationObserver::ObserverLossEvent::lostPackets, - UnorderedElementsAre( - getLossPacketMatcher(true, true), - getLossPacketMatcher(true, true), - getLossPacketMatcher(true, true), - getLossPacketMatcher(false, true))))) + packetLossDetected( + nullptr, + Field( + &InstrumentationObserver::ObserverLossEvent::lostPackets, + UnorderedElementsAre( + getLossPacketMatcher(true, true), + getLossPacketMatcher(true, true), + getLossPacketMatcher(true, true), + getLossPacketMatcher(false, true))))) .Times(1); for (auto& callback : conn->pendingCallbacks) { - callback(); + callback(nullptr); } } diff --git a/quic/state/AckHandlers.cpp b/quic/state/AckHandlers.cpp index e5a94013f..6e18b644a 100644 --- a/quic/state/AckHandlers.cpp +++ b/quic/state/AckHandlers.cpp @@ -147,9 +147,10 @@ void processAckFrame( InstrumentationObserver::PacketRTT packetRTT( ackReceiveTimeOrNow, rttSample, frame.ackDelay, *rPacketIt); for (const auto& observer : conn.instrumentationObservers_) { - conn.pendingCallbacks.emplace_back([observer, packetRTT] { - observer->rttSampleGenerated(packetRTT); - }); + conn.pendingCallbacks.emplace_back( + [observer, packetRTT](QuicSocket* qSocket) { + observer->rttSampleGenerated(qSocket, packetRTT); + }); } updateRtt(conn, rttSample, frame.ackDelay); } diff --git a/quic/state/StateData.h b/quic/state/StateData.h index e23ae8a59..8ffae6bdb 100644 --- a/quic/state/StateData.h +++ b/quic/state/StateData.h @@ -907,7 +907,7 @@ struct QuicConnectionStateBase : public folly::DelayedDestruction { InstrumentationObserverVec instrumentationObservers_; // queue of functions to be called in processCallbacksAfterNetworkData - std::vector> pendingCallbacks; + std::vector> pendingCallbacks; }; std::ostream& operator<<(std::ostream& os, const QuicConnectionStateBase& st); diff --git a/quic/state/test/AckHandlersTest.cpp b/quic/state/test/AckHandlersTest.cpp index f21706f48..ebe7d6ec2 100644 --- a/quic/state/test/AckHandlersTest.cpp +++ b/quic/state/test/AckHandlersTest.cpp @@ -1203,22 +1203,26 @@ TEST_P(AckHandlersTest, TestRTTPacketObserverCallback) { ackData.ackTime - packetRcvTime[ackData.endSeq]); EXPECT_CALL( ib, - rttSampleGenerated(AllOf( - Field( - &InstrumentationObserver::PacketRTT::rcvTime, ackData.ackTime), - Field(&InstrumentationObserver::PacketRTT::rttSample, rttSample), - Field( - &InstrumentationObserver::PacketRTT::ackDelay, - ackData.ackDelay), - Field( - &InstrumentationObserver::PacketRTT::metadata, + rttSampleGenerated( + nullptr, + AllOf( Field( - &quic::OutstandingPacketMetadata::inflightBytes, - ackData.endSeq + 1))))); + &InstrumentationObserver::PacketRTT::rcvTime, + ackData.ackTime), + Field( + &InstrumentationObserver::PacketRTT::rttSample, rttSample), + Field( + &InstrumentationObserver::PacketRTT::ackDelay, + ackData.ackDelay), + Field( + &InstrumentationObserver::PacketRTT::metadata, + Field( + &quic::OutstandingPacketMetadata::inflightBytes, + ackData.endSeq + 1))))); } for (auto& callback : conn.pendingCallbacks) { - callback(); + callback(nullptr); } }