diff --git a/quic/api/QuicSocket.h b/quic/api/QuicSocket.h index 0fd73cf2f..d348b69c6 100644 --- a/quic/api/QuicSocket.h +++ b/quic/api/QuicSocket.h @@ -1124,17 +1124,24 @@ class QuicSocket { * Invoked if the ping times out */ virtual void pingTimeout() noexcept = 0; + + /** + * Invoked when a ping is received + */ + virtual void onPing() noexcept = 0; }; + /** + * Set the ping callback + */ + virtual folly::Expected setPingCallback( + PingCallback* cb) = 0; + /** * Send a ping to the peer. When the ping is acknowledged by the peer or * times out, the transport will invoke the callback. - * - * If 'callback' is nullptr, or pingTimeout is 0, no callback is scheduled. */ - virtual void sendPing( - PingCallback* callback, - std::chrono::milliseconds pingTimeout) = 0; + virtual void sendPing(std::chrono::milliseconds pingTimeout) = 0; /** * Get information on the state of the quic connection. Should only be used diff --git a/quic/api/QuicTransportBase.cpp b/quic/api/QuicTransportBase.cpp index 70e241c23..8180320d0 100644 --- a/quic/api/QuicTransportBase.cpp +++ b/quic/api/QuicTransportBase.cpp @@ -1400,7 +1400,16 @@ folly:: } } -void QuicTransportBase::handlePingCallback() { +void QuicTransportBase::handlePingCallbacks() { + if (conn_->pendingEvents.notifyPingReceived && pingCallback_ != nullptr) { + conn_->pendingEvents.notifyPingReceived = false; + runOnEvbAsync([](auto self) { + if (self->pingCallback_) { + self->pingCallback_->onPing(); + } + }); + } + if (!conn_->pendingEvents.cancelPingTimeout) { return; // nothing to cancel } @@ -1411,7 +1420,11 @@ void QuicTransportBase::handlePingCallback() { } pingTimeout_.cancelTimeout(); if (pingCallback_ != nullptr) { - runOnEvbAsync([](auto self) { self->pingCallback_->pingAcknowledged(); }); + runOnEvbAsync([](auto self) { + if (self->pingCallback_) { + self->pingCallback_->pingAcknowledged(); + } + }); } conn_->pendingEvents.cancelPingTimeout = false; } @@ -1704,7 +1717,7 @@ void QuicTransportBase::processCallbacksAfterNetworkData() { } conn_->pendingCallbacks.clear(); - handlePingCallback(); + handlePingCallbacks(); if (closeState_ != CloseState::OPEN) { return; } @@ -2433,9 +2446,19 @@ void QuicTransportBase::checkForClosedStream() { } } -void QuicTransportBase::sendPing( - PingCallback* callback, - std::chrono::milliseconds pingTimeout) { +folly::Expected QuicTransportBase::setPingCallback( + PingCallback* cb) { + if (closeState_ != CloseState::OPEN) { + return folly::makeUnexpected(LocalErrorCode::CONNECTION_CLOSED); + } + VLOG(4) << "Setting ping callback " + << " cb=" << cb << " " << *this; + + pingCallback_ = cb; + return folly::unit; +} + +void QuicTransportBase::sendPing(std::chrono::milliseconds pingTimeout) { /* Step 0: Connection should not be closed */ if (closeState_ == CloseState::CLOSED) { return; @@ -2446,8 +2469,8 @@ void QuicTransportBase::sendPing( updateWriteLooper(true); // Step 2: Schedule the timeout on event base - if (callback && pingTimeout != 0ms) { - schedulePingTimeout(callback, pingTimeout); + if (pingCallback_ && pingTimeout != 0ms) { + schedulePingTimeout(pingCallback_, pingTimeout); } } @@ -2499,7 +2522,11 @@ void QuicTransportBase::pingTimeoutExpired() noexcept { if (pingCallback_ == nullptr) { return; } - runOnEvbAsync([](auto self) { self->pingCallback_->pingTimeout(); }); + runOnEvbAsync([](auto self) { + if (self->pingCallback_) { + self->pingCallback_->pingTimeout(); + } + }); } void QuicTransportBase::pathValidationTimeoutExpired() noexcept { diff --git a/quic/api/QuicTransportBase.h b/quic/api/QuicTransportBase.h index 15c56b670..702f140dd 100644 --- a/quic/api/QuicTransportBase.h +++ b/quic/api/QuicTransportBase.h @@ -208,8 +208,10 @@ class QuicTransportBase : public QuicSocket, QuicStreamPrioritiesObserver { StreamId id, QuicErrorCode error) override; - void sendPing(PingCallback* callback, std::chrono::milliseconds pingTimeout) - override; + folly::Expected setPingCallback( + PingCallback* cb) override; + + void sendPing(std::chrono::milliseconds pingTimeout) override; const QuicConnectionStateBase* getState() const override { return conn_.get(); @@ -702,7 +704,7 @@ class QuicTransportBase : public QuicSocket, QuicStreamPrioritiesObserver { void updateReadLooper(); void updatePeekLooper(); void updateWriteLooper(bool thisIteration); - void handlePingCallback(); + void handlePingCallbacks(); void handleKnobCallbacks(); void handleAckEventCallbacks(); void handleCancelByteEventCallbacks(); diff --git a/quic/api/test/MockQuicSocket.h b/quic/api/test/MockQuicSocket.h index 09b984acb..8ce95d975 100644 --- a/quic/api/test/MockQuicSocket.h +++ b/quic/api/test/MockQuicSocket.h @@ -244,7 +244,10 @@ class MockQuicSocket : public QuicSocket { MOCK_METHOD2( maybeResetStreamFromReadError, folly::Expected(StreamId, QuicErrorCode)); - MOCK_METHOD2(sendPing, void(PingCallback*, std::chrono::milliseconds)); + MOCK_METHOD1( + setPingCallback, + folly::Expected(PingCallback*)); + MOCK_METHOD1(sendPing, void(std::chrono::milliseconds)); MOCK_CONST_METHOD0(getState, const QuicConnectionStateBase*()); MOCK_METHOD0(isDetachable, bool()); MOCK_METHOD1(attachEventBase, void(folly::EventBase*)); diff --git a/quic/api/test/QuicTransportBaseTest.cpp b/quic/api/test/QuicTransportBaseTest.cpp index 0e666d4b3..35eee6cf0 100644 --- a/quic/api/test/QuicTransportBaseTest.cpp +++ b/quic/api/test/QuicTransportBaseTest.cpp @@ -131,6 +131,7 @@ class TestPingCallback : public QuicSocket::PingCallback { public: void pingAcknowledged() noexcept override {} void pingTimeout() noexcept override {} + void onPing() noexcept override {} }; class TestByteEventCallback : public QuicSocket::ByteEventCallback { @@ -310,18 +311,16 @@ class TestQuicTransport ackTimeout_.timeoutExpired(); } - void invokeSendPing( - quic::QuicSocket::PingCallback* cb, - std::chrono::milliseconds interval) { - sendPing(cb, interval); + void invokeSendPing(std::chrono::milliseconds interval) { + sendPing(interval); } void invokeCancelPingTimeout() { pingTimeout_.cancelTimeout(); } - void invokeHandlePingCallback() { - handlePingCallback(); + void invokeHandlePingCallbacks() { + handlePingCallbacks(); } void invokeHandleKnobCallbacks() { @@ -3457,11 +3456,12 @@ TEST_F(QuicTransportImplTest, SuccessfulPing) { auto conn = transport->transportConn; std::chrono::milliseconds interval(10); TestPingCallback pingCallback; - transport->invokeSendPing(&pingCallback, interval); + transport->setPingCallback(&pingCallback); + transport->invokeSendPing(interval); EXPECT_EQ(transport->isPingTimeoutScheduled(), true); EXPECT_EQ(conn->pendingEvents.cancelPingTimeout, false); conn->pendingEvents.cancelPingTimeout = true; - transport->invokeHandlePingCallback(); + transport->invokeHandlePingCallbacks(); evb->loopOnce(); EXPECT_EQ(transport->isPingTimeoutScheduled(), false); EXPECT_EQ(conn->pendingEvents.cancelPingTimeout, false); @@ -3471,12 +3471,13 @@ TEST_F(QuicTransportImplTest, FailedPing) { auto conn = transport->transportConn; std::chrono::milliseconds interval(10); TestPingCallback pingCallback; - transport->invokeSendPing(&pingCallback, interval); + transport->setPingCallback(&pingCallback); + transport->invokeSendPing(interval); EXPECT_EQ(transport->isPingTimeoutScheduled(), true); EXPECT_EQ(conn->pendingEvents.cancelPingTimeout, false); conn->pendingEvents.cancelPingTimeout = true; transport->invokeCancelPingTimeout(); - transport->invokeHandlePingCallback(); + transport->invokeHandlePingCallbacks(); EXPECT_EQ(conn->pendingEvents.cancelPingTimeout, false); } diff --git a/quic/client/QuicClientTransport.cpp b/quic/client/QuicClientTransport.cpp index 87877cb14..f4dfe5433 100644 --- a/quic/client/QuicClientTransport.cpp +++ b/quic/client/QuicClientTransport.cpp @@ -566,6 +566,7 @@ void QuicClientTransport::processPacketData( // Ping isn't retransmittable. But we would like to ack them early. // So, make Ping frames count towards ack policy pktHasRetransmittableData = true; + conn_->pendingEvents.notifyPingReceived = true; break; case QuicFrame::Type::PaddingFrame: break; diff --git a/quic/server/state/ServerStateMachine.cpp b/quic/server/state/ServerStateMachine.cpp index cb990b7f8..6a7ff1926 100644 --- a/quic/server/state/ServerStateMachine.cpp +++ b/quic/server/state/ServerStateMachine.cpp @@ -1149,6 +1149,7 @@ void onServerReadDataFromOpen( // Ping isn't retransmittable data. But we would like to ack them // early. pktHasRetransmittableData = true; + conn.pendingEvents.notifyPingReceived = true; break; case QuicFrame::Type::PaddingFrame: break; diff --git a/quic/state/StateData.h b/quic/state/StateData.h index 5998fd76f..e4f8d85d4 100644 --- a/quic/state/StateData.h +++ b/quic/state/StateData.h @@ -560,6 +560,8 @@ struct QuicConnectionStateBase : public folly::DelayedDestruction { bool cancelPingTimeout{false}; + bool notifyPingReceived{false}; + // close transport when the next packet number reaches kMaxPacketNum bool closeTransport{false};