1
0
mirror of https://github.com/facebookincubator/mvfst.git synced 2025-08-08 09:42:06 +03:00

Use folly::ObserverContainer for socket observer [4/x]

Summary:
This diff is part of larger change to switch to using `folly::ObserverContainer` (introduced in D27062840).

This diff:
- Changes `LegacyObserver` to inherit from the new `Observer` class by adding a compatibility layer. This compatibility layer enables existing observers to continue to be supported.
- Changes generation of observer events so that they are routed through `ObserverContainer::invokeInterfaceMethod`
- Temporarily removes some of the reentrancy protection that previously existed, as it was not being applied consistently — some events had reentrancy protection, some did not. This will be reintroduced in the next diff.
- Improves some unit tests for observers during the transition process.

Differential Revision: D35268271

fbshipit-source-id: 5731c8a9aa8da8a2da1dd23d093e5f2e1a692653
This commit is contained in:
Brandon Schlinker
2022-04-18 04:19:31 -07:00
committed by Facebook GitHub Bot
parent f20eb76865
commit 744ae1f78e
25 changed files with 1266 additions and 1351 deletions

View File

@@ -7,129 +7,15 @@
#pragma once #pragma once
#include <quic/QuicException.h> #include <quic/observer/SocketObserverContainer.h>
#include <quic/common/SmallVec.h>
#include <quic/d6d/Types.h>
#include <quic/observer/SocketObserverInterface.h>
#include <quic/state/AckEvent.h>
#include <quic/state/OutstandingPacket.h>
#include <quic/state/QuicStreamUtilities.h>
namespace folly {
class EventBase;
}
namespace quic { namespace quic {
class QuicSocket;
/** /**
* ===== Observer API ===== * Legacy observer of socket events.
*
* TODO(bschlinker): Complete depreciation.
*/ */
using LegacyObserver = SocketObserverContainer::LegacyObserver;
/**
* Observer of socket events.
*/
class LegacyObserver : public SocketObserverInterface {
public:
/**
* Observer configuration.
*
* Specifies events observer wants to receive.
*/
struct Config {
virtual ~Config() = default;
// following flags enable support for various callbacks.
// observer and socket lifecycle callbacks are always enabled.
bool evbEvents{false};
bool packetsWrittenEvents{false};
bool appRateLimitedEvents{false};
bool lossEvents{false};
bool spuriousLossEvents{false};
bool pmtuEvents{false};
bool rttSamples{false};
bool knobFrameEvents{false};
bool streamEvents{false};
bool acksProcessedEvents{false};
virtual void enableAllEvents() {
evbEvents = true;
packetsWrittenEvents = true;
appRateLimitedEvents = true;
rttSamples = true;
lossEvents = true;
spuriousLossEvents = true;
pmtuEvents = true;
knobFrameEvents = true;
streamEvents = true;
acksProcessedEvents = true;
}
/**
* Returns a config where all events are enabled.
*/
static Config getConfigAllEventsEnabled() {
Config config = {};
config.enableAllEvents();
return config;
}
};
/**
* Constructor for observer, uses default config (all callbacks disabled).
*/
LegacyObserver() : LegacyObserver(Config()) {}
/**
* Constructor for observer.
*
* @param config Config, defaults to auxilary instrumentaton disabled.
*/
explicit LegacyObserver(const Config& observerConfig)
: observerConfig_(observerConfig) {}
~LegacyObserver() override = default;
/**
* Returns observers configuration.
*/
const Config& getConfig() {
return observerConfig_;
}
/**
* observerAttach() will be invoked when an observer is added.
*
* @param socket Socket where observer was installed.
*/
virtual void observerAttach(QuicSocket* /* socket */) noexcept {}
/**
* observerDetach() will be invoked if the observer is uninstalled prior
* to socket destruction.
*
* No further callbacks will be invoked after observerDetach().
*
* @param socket Socket where observer was uninstalled.
*/
virtual void observerDetach(QuicSocket* /* socket */) noexcept {}
/**
* destroy() will be invoked when the QuicSocket's destructor is invoked.
*
* No further callbacks will be invoked after destroy().
*
* @param socket Socket being destroyed.
*/
virtual void destroy(QuicSocket* /* socket */) noexcept {}
protected:
// observer configuration; cannot be changed post instantiation
const Config observerConfig_;
};
// Container for instrumentation observers.
// Avoids heap allocation for up to 2 observers being installed.
using ObserverVec = SmallVec<LegacyObserver*, 2>;
} // namespace quic } // namespace quic

View File

@@ -1181,34 +1181,6 @@ class QuicSocket {
*/ */
virtual void setCongestionControl(CongestionControlType type) = 0; virtual void setCongestionControl(CongestionControlType type) = 0;
/**
* Adds an observer.
*
* Observers can tie their lifetime to aspects of this socket's /
* lifetime and perform inspection at various states.
*
* This enables instrumentation to be added without changing / interfering
* with how the application uses the socket.
*
* @param observer Observer to add (implements Observer).
*/
virtual void addObserver(LegacyObserver* observer) = 0;
/**
* Removes an observer.
*
* @param observer Observer to remove.
* @return Whether observer found and removed from list.
*/
virtual bool removeObserver(LegacyObserver* observer) = 0;
/**
* Returns installed observers.
*
* @return Reference to const vector with installed observers.
*/
FOLLY_NODISCARD virtual const ObserverVec& getObservers() const = 0;
using Observer = SocketObserverContainer::Observer; using Observer = SocketObserverContainer::Observer;
using ManagedObserver = SocketObserverContainer::ManagedObserver; using ManagedObserver = SocketObserverContainer::ManagedObserver;
@@ -1253,6 +1225,18 @@ class QuicSocket {
return 0; return 0;
} }
/**
* Returns list of attached observers.
*
* @return List of observers.
*/
std::vector<Observer*> getObservers() {
if (auto list = getSocketObserverContainer()) {
return list->getObservers();
}
return {};
}
/** /**
* Returns list of attached observers that are of type T. * Returns list of attached observers that are of type T.
* *

View File

@@ -159,9 +159,6 @@ QuicTransportBase::~QuicTransportBase() {
sock->pauseRead(); sock->pauseRead();
sock->close(); sock->close();
} }
for (const auto& cb : *observers_) {
cb->destroy(this);
}
} }
bool QuicTransportBase::good() const { bool QuicTransportBase::good() const {
@@ -249,14 +246,8 @@ void QuicTransportBase::closeImpl(
return; return;
} }
// legacy observer support if (getSocketObserverContainer()) {
for (const auto& cb : *observers_) { getSocketObserverContainer()->invokeInterfaceMethodAllObservers(
cb->close(this, errorCode);
}
// new observer support
if (auto list = getSocketObserverContainer()) {
list->invokeInterfaceMethodAllObservers(
[errorCode](auto observer, auto observed) { [errorCode](auto observer, auto observed) {
observer->close(observed, errorCode); observer->close(observed, errorCode);
}); });
@@ -1493,15 +1484,18 @@ void QuicTransportBase::processCallbacksAfterWriteData() {
void QuicTransportBase::handleKnobCallbacks() { void QuicTransportBase::handleKnobCallbacks() {
for (auto& knobFrame : conn_->pendingEvents.knobs) { for (auto& knobFrame : conn_->pendingEvents.knobs) {
if (knobFrame.knobSpace != kDefaultQuicTransportKnobSpace) { if (knobFrame.knobSpace != kDefaultQuicTransportKnobSpace) {
for (const auto& cb : *observers_) { if (getSocketObserverContainer() &&
if (cb->getConfig().knobFrameEvents) { getSocketObserverContainer()
cb->knobFrameReceived( ->hasObserversForEvent<
this, SocketObserverInterface::Events::knobFrameEvents>()) {
quic::SocketObserverInterface::KnobFrameEvent( getSocketObserverContainer()
Clock::now(), knobFrame)); ->invokeInterfaceMethod<
} SocketObserverInterface::Events::knobFrameEvents>(
[event = quic::SocketObserverInterface::KnobFrameEvent(
Clock::now(), knobFrame)](auto observer, auto observed) {
observer->knobFrameReceived(observed, event);
});
} }
connCallback_->onKnob( connCallback_->onKnob(
knobFrame.knobSpace, knobFrame.id, std::move(knobFrame.blob)); knobFrame.knobSpace, knobFrame.id, std::move(knobFrame.blob));
} else { } else {
@@ -1518,14 +1512,19 @@ void QuicTransportBase::handleAckEventCallbacks() {
return; // nothing to do return; // nothing to do
} }
const auto event = if (getSocketObserverContainer() &&
quic::SocketObserverInterface::AcksProcessedEvent::Builder() getSocketObserverContainer()
.setAckEvents(lastProcessedAckEvents) ->hasObserversForEvent<
.build(); SocketObserverInterface::Events::acksProcessedEvents>()) {
for (const auto& cb : *observers_) { getSocketObserverContainer()
if (cb->getConfig().acksProcessedEvents) { ->invokeInterfaceMethod<
cb->acksProcessed(this, event); SocketObserverInterface::Events::acksProcessedEvents>(
} [event =
quic::SocketObserverInterface::AcksProcessedEvent::Builder()
.setAckEvents(lastProcessedAckEvents)
.build()](auto observer, auto observed) {
observer->acksProcessed(observed, event);
});
} }
} }
@@ -1553,14 +1552,21 @@ void QuicTransportBase::handleNewStreamCallbacks(
} else { } else {
connCallback_->onNewUnidirectionalStream(streamId); connCallback_->onNewUnidirectionalStream(streamId);
} }
const SocketObserverInterface::StreamOpenEvent streamEvent(
streamId, if (getSocketObserverContainer() &&
getStreamInitiator(streamId), getSocketObserverContainer()
getStreamDirectionality(streamId)); ->hasObserversForEvent<
for (const auto& cb : *observers_) { SocketObserverInterface::Events::streamEvents>()) {
if (cb->getConfig().streamEvents) { getSocketObserverContainer()
cb->streamOpened(this, streamEvent); ->invokeInterfaceMethod<
} SocketObserverInterface::Events::streamEvents>(
[event = SocketObserverInterface::StreamOpenEvent(
streamId,
getStreamInitiator(streamId),
getStreamDirectionality(streamId))](
auto observer, auto observed) {
observer->streamOpened(observed, event);
});
} }
if (closeState_ != CloseState::OPEN) { if (closeState_ != CloseState::OPEN) {
@@ -1717,12 +1723,6 @@ void QuicTransportBase::processCallbacksAfterNetworkData() {
return; return;
} }
// to call any callbacks added for observers
for (const auto& callback : conn_->pendingCallbacks) {
callback(this);
}
conn_->pendingCallbacks.clear();
handlePingCallbacks(); handlePingCallbacks();
if (closeState_ != CloseState::OPEN) { if (closeState_ != CloseState::OPEN) {
return; return;
@@ -1895,15 +1895,23 @@ QuicTransportBase::createStreamInternal(bool bidirectional) {
} }
if (streamResult) { if (streamResult) {
const StreamId streamId = streamResult.value()->id; const StreamId streamId = streamResult.value()->id;
const SocketObserverInterface::StreamOpenEvent streamEvent(
streamId, if (getSocketObserverContainer() &&
getStreamInitiator(streamId), getSocketObserverContainer()
getStreamDirectionality(streamId)); ->hasObserversForEvent<
for (const auto& cb : *observers_) { SocketObserverInterface::Events::streamEvents>()) {
if (cb->getConfig().streamEvents) { getSocketObserverContainer()
cb->streamOpened(this, streamEvent); ->invokeInterfaceMethod<
} SocketObserverInterface::Events::streamEvents>(
[event = SocketObserverInterface::StreamOpenEvent(
streamId,
getStreamInitiator(streamId),
getStreamDirectionality(streamId))](
auto observer, auto observed) {
observer->streamOpened(observed, event);
});
} }
return streamId; return streamId;
} else { } else {
return folly::makeUnexpected(streamResult.error()); return folly::makeUnexpected(streamResult.error());
@@ -2403,14 +2411,21 @@ void QuicTransportBase::checkForClosedStream() {
auto itr = conn_->streamManager->closedStreams().begin(); auto itr = conn_->streamManager->closedStreams().begin();
while (itr != conn_->streamManager->closedStreams().end()) { while (itr != conn_->streamManager->closedStreams().end()) {
const auto& streamId = *itr; const auto& streamId = *itr;
const SocketObserverInterface::StreamCloseEvent streamEvent(
streamId, if (getSocketObserverContainer() &&
getStreamInitiator(streamId), getSocketObserverContainer()
getStreamDirectionality(streamId)); ->hasObserversForEvent<
for (const auto& cb : *observers_) { SocketObserverInterface::Events::streamEvents>()) {
if (cb->getConfig().streamEvents) { getSocketObserverContainer()
cb->streamClosed(this, streamEvent); ->invokeInterfaceMethod<
} SocketObserverInterface::Events::streamEvents>(
[event = SocketObserverInterface::StreamCloseEvent(
streamId,
getStreamInitiator(streamId),
getStreamDirectionality(streamId))](
auto observer, auto observed) {
observer->streamClosed(observed, event);
});
} }
// We may be in an active read cb when we close the stream // We may be in an active read cb when we close the stream
@@ -2838,30 +2853,6 @@ void QuicTransportBase::resetNonControlStreams(
} }
} }
void QuicTransportBase::addObserver(LegacyObserver* observer) {
// adding the same observer multiple times is not allowed
CHECK(
std::find(observers_->begin(), observers_->end(), observer) ==
observers_->end());
observers_->push_back(CHECK_NOTNULL(observer));
observer->observerAttach(this);
}
bool QuicTransportBase::removeObserver(LegacyObserver* observer) {
auto it = std::find(observers_->begin(), observers_->end(), observer);
if (it == observers_->end()) {
return false;
}
observer->observerDetach(this);
observers_->erase(it);
return true;
}
const ObserverVec& QuicTransportBase::getObservers() const {
return *observers_;
}
QuicConnectionStats QuicTransportBase::getConnectionsStats() const { QuicConnectionStats QuicTransportBase::getConnectionsStats() const {
QuicConnectionStats connStats; QuicConnectionStats connStats;
if (!conn_) { if (!conn_) {
@@ -3337,10 +3328,15 @@ void QuicTransportBase::attachEventBase(folly::EventBase* evb) {
updatePeekLooper(); updatePeekLooper();
updateWriteLooper(false); updateWriteLooper(false);
for (const auto& cb : *observers_) { if (getSocketObserverContainer() &&
if (cb->getConfig().evbEvents) { getSocketObserverContainer()
cb->evbAttach(this, evb_); ->hasObserversForEvent<
} SocketObserverInterface::Events::evbEvents>()) {
getSocketObserverContainer()
->invokeInterfaceMethod<SocketObserverInterface::Events::evbEvents>(
[evb](auto observer, auto observed) {
observer->evbAttach(observed, evb);
});
} }
} }
@@ -3362,11 +3358,17 @@ void QuicTransportBase::detachEventBase() {
peekLooper_->detachEventBase(); peekLooper_->detachEventBase();
writeLooper_->detachEventBase(); writeLooper_->detachEventBase();
for (const auto& cb : *observers_) { if (getSocketObserverContainer() &&
if (cb->getConfig().evbEvents) { getSocketObserverContainer()
cb->evbDetach(this, evb_); ->hasObserversForEvent<
} SocketObserverInterface::Events::evbEvents>()) {
getSocketObserverContainer()
->invokeInterfaceMethod<SocketObserverInterface::Events::evbEvents>(
[evb = evb_.load()](auto observer, auto observed) {
observer->evbDetach(observed, evb);
});
} }
evb_ = nullptr; evb_ = nullptr;
} }
@@ -3557,26 +3559,33 @@ QuicSocket::WriteResult QuicTransportBase::setDSRPacketizationRequestSender(
} }
void QuicTransportBase::notifyStartWritingFromAppRateLimited() { void QuicTransportBase::notifyStartWritingFromAppRateLimited() {
const auto event = if (getSocketObserverContainer() &&
SocketObserverInterface::AppLimitedEvent::Builder() getSocketObserverContainer()
.setOutstandingPackets(conn_->outstandings.packets) ->hasObserversForEvent<
.setWriteCount(conn_->writeCount) SocketObserverInterface::Events::appRateLimitedEvents>()) {
.setLastPacketSentTime(conn_->lossState.maybeLastPacketSentTime) getSocketObserverContainer()
.setCwndInBytes( ->invokeInterfaceMethod<
conn_->congestionController SocketObserverInterface::Events::appRateLimitedEvents>(
? folly::Optional<uint64_t>( [event = SocketObserverInterface::AppLimitedEvent::Builder()
conn_->congestionController->getCongestionWindow()) .setOutstandingPackets(conn_->outstandings.packets)
: folly::none) .setWriteCount(conn_->writeCount)
.setWritableBytes( .setLastPacketSentTime(
conn_->congestionController conn_->lossState.maybeLastPacketSentTime)
? folly::Optional<uint64_t>( .setCwndInBytes(
conn_->congestionController->getWritableBytes()) conn_->congestionController
: folly::none) ? folly::Optional<uint64_t>(
.build(); conn_->congestionController
for (const auto& cb : *observers_) { ->getCongestionWindow())
if (cb->getConfig().appRateLimitedEvents) { : folly::none)
cb->startWritingFromAppLimited(this, event); .setWritableBytes(
} conn_->congestionController
? folly::Optional<uint64_t>(
conn_->congestionController
->getWritableBytes())
: folly::none)
.build()](auto observer, auto observed) {
observer->startWritingFromAppLimited(observed, event);
});
} }
} }
@@ -3584,53 +3593,68 @@ void QuicTransportBase::notifyPacketsWritten(
uint64_t numPacketsWritten, uint64_t numPacketsWritten,
uint64_t numAckElicitingPacketsWritten, uint64_t numAckElicitingPacketsWritten,
uint64_t numBytesWritten) { uint64_t numBytesWritten) {
const auto event = if (getSocketObserverContainer() &&
SocketObserverInterface::PacketsWrittenEvent::Builder() getSocketObserverContainer()
.setOutstandingPackets(conn_->outstandings.packets) ->hasObserversForEvent<
.setWriteCount(conn_->writeCount) SocketObserverInterface::Events::packetsWrittenEvents>()) {
.setLastPacketSentTime(conn_->lossState.maybeLastPacketSentTime) getSocketObserverContainer()
.setCwndInBytes( ->invokeInterfaceMethod<
conn_->congestionController SocketObserverInterface::Events::packetsWrittenEvents>(
? folly::Optional<uint64_t>( [event = SocketObserverInterface::PacketsWrittenEvent::Builder()
conn_->congestionController->getCongestionWindow()) .setOutstandingPackets(conn_->outstandings.packets)
: folly::none) .setWriteCount(conn_->writeCount)
.setWritableBytes( .setLastPacketSentTime(
conn_->congestionController conn_->lossState.maybeLastPacketSentTime)
? folly::Optional<uint64_t>( .setCwndInBytes(
conn_->congestionController->getWritableBytes()) conn_->congestionController
: folly::none) ? folly::Optional<uint64_t>(
.setNumPacketsWritten(numPacketsWritten) conn_->congestionController
.setNumAckElicitingPacketsWritten(numAckElicitingPacketsWritten) ->getCongestionWindow())
.setNumBytesWritten(numBytesWritten) : folly::none)
.build(); .setWritableBytes(
for (const auto& cb : *observers_) { conn_->congestionController
if (cb->getConfig().packetsWrittenEvents) { ? folly::Optional<uint64_t>(
cb->packetsWritten(this, event); conn_->congestionController
} ->getWritableBytes())
: folly::none)
.setNumPacketsWritten(numPacketsWritten)
.setNumAckElicitingPacketsWritten(
numAckElicitingPacketsWritten)
.setNumBytesWritten(numBytesWritten)
.build()](auto observer, auto observed) {
observer->packetsWritten(observed, event);
});
} }
} }
void QuicTransportBase::notifyAppRateLimited() { void QuicTransportBase::notifyAppRateLimited() {
const auto event = if (getSocketObserverContainer() &&
SocketObserverInterface::AppLimitedEvent::Builder() getSocketObserverContainer()
.setOutstandingPackets(conn_->outstandings.packets) ->hasObserversForEvent<
.setWriteCount(conn_->writeCount) SocketObserverInterface::Events::appRateLimitedEvents>()) {
.setLastPacketSentTime(conn_->lossState.maybeLastPacketSentTime) getSocketObserverContainer()
.setCwndInBytes( ->invokeInterfaceMethod<
conn_->congestionController SocketObserverInterface::Events::appRateLimitedEvents>(
? folly::Optional<uint64_t>( [event = SocketObserverInterface::AppLimitedEvent::Builder()
conn_->congestionController->getCongestionWindow()) .setOutstandingPackets(conn_->outstandings.packets)
: folly::none) .setWriteCount(conn_->writeCount)
.setWritableBytes( .setLastPacketSentTime(
conn_->congestionController conn_->lossState.maybeLastPacketSentTime)
? folly::Optional<uint64_t>( .setCwndInBytes(
conn_->congestionController->getWritableBytes()) conn_->congestionController
: folly::none) ? folly::Optional<uint64_t>(
.build(); conn_->congestionController
for (const auto& cb : *observers_) { ->getCongestionWindow())
if (cb->getConfig().appRateLimitedEvents) { : folly::none)
cb->appRateLimited(this, event); .setWritableBytes(
} conn_->congestionController
? folly::Optional<uint64_t>(
conn_->congestionController
->getWritableBytes())
: folly::none)
.build()](auto observer, auto observed) {
observer->appRateLimited(observed, event);
});
} }
} }

View File

@@ -641,37 +641,6 @@ class QuicTransportBase : public QuicSocket, QuicStreamPrioritiesObserver {
virtual void cancelAllAppCallbacks(const QuicError& error) noexcept; virtual void cancelAllAppCallbacks(const QuicError& error) noexcept;
using QuicSocket::addObserver;
using QuicSocket::removeObserver;
/**
* Adds an observer.
*
* Observers can tie their lifetime to aspects of this socket's lifecycle /
* lifetime and perform inspection at various states.
*
* This enables instrumentation to be added without changing / interfering
* with how the application uses the socket.
*
* @param observer Observer to add (implements Observer).
*/
void addObserver(LegacyObserver* observer) override;
/**
* Removes an observer.
*
* @param observer Observer to remove.
* @return Whether observer found and removed from list.
*/
bool removeObserver(LegacyObserver* observer) override;
/**
* Returns installed observers.
*
* @return Reference to const vector with installed observers.
*/
FOLLY_NODISCARD const ObserverVec& getObservers() const override;
FOLLY_NODISCARD QuicConnectionStats getConnectionsStats() const override; FOLLY_NODISCARD QuicConnectionStats getConnectionsStats() const override;
/** /**
@@ -948,10 +917,6 @@ class QuicTransportBase : public QuicSocket, QuicStreamPrioritiesObserver {
folly::Optional<std::string> exceptionCloseWhat_; folly::Optional<std::string> exceptionCloseWhat_;
// Observers
std::shared_ptr<ObserverVec> observers_{std::make_shared<ObserverVec>()};
uint64_t qlogRefcnt_{0}; uint64_t qlogRefcnt_{0};
// Priority level threshold for background streams // Priority level threshold for background streams

View File

@@ -17,6 +17,8 @@ class MockQuicSocket : public QuicSocket {
public: public:
using SharedBuf = std::shared_ptr<folly::IOBuf>; using SharedBuf = std::shared_ptr<folly::IOBuf>;
MockQuicSocket() = default;
MockQuicSocket( MockQuicSocket(
folly::EventBase* /*eventBase*/, folly::EventBase* /*eventBase*/,
ConnectionSetupCallback* setupCb, ConnectionSetupCallback* setupCb,
@@ -301,16 +303,13 @@ class MockQuicSocket : public QuicSocket {
MOCK_METHOD(void, setCongestionControl, (CongestionControlType)); MOCK_METHOD(void, setCongestionControl, (CongestionControlType));
ConnectionSetupCallback* setupCb_; ConnectionSetupCallback* setupCb_{nullptr};
ConnectionCallback* connCb_; ConnectionCallback* connCb_{nullptr};
folly::Function<bool(const folly::Optional<std::string>&, const Buf&)> folly::Function<bool(const folly::Optional<std::string>&, const Buf&)>
earlyDataAppParamsValidator_; earlyDataAppParamsValidator_;
folly::Function<Buf()> earlyDataAppParamsGetter_; folly::Function<Buf()> earlyDataAppParamsGetter_;
MOCK_METHOD(void, addObserver, (LegacyObserver*));
MOCK_METHOD(bool, removeObserver, (LegacyObserver*));
MOCK_METHOD(const ObserverVec&, getObservers, (), (const));
MOCK_METHOD( MOCK_METHOD(
void, void,
resetNonControlStreams, resetNonControlStreams,
@@ -335,5 +334,10 @@ class MockQuicSocket : public QuicSocket {
(folly::Expected<std::vector<Buf>, LocalErrorCode>), (folly::Expected<std::vector<Buf>, LocalErrorCode>),
readDatagramBufs, readDatagramBufs,
(size_t)); (size_t));
MOCK_METHOD(
SocketObserverContainer*,
getSocketObserverContainer,
(),
(const));
}; };
} // namespace quic } // namespace quic

View File

@@ -277,9 +277,7 @@ class MockObserver : public QuicSocket::ManagedObserver {
class MockLegacyObserver : public LegacyObserver { class MockLegacyObserver : public LegacyObserver {
public: public:
MockLegacyObserver() : LegacyObserver(LegacyObserver::Config()) {} using LegacyObserver::LegacyObserver;
explicit MockLegacyObserver(const LegacyObserver::Config& observerConfig)
: LegacyObserver(observerConfig) {}
MOCK_METHOD((void), observerAttach, (QuicSocket*), (noexcept)); MOCK_METHOD((void), observerAttach, (QuicSocket*), (noexcept));
MOCK_METHOD((void), observerDetach, (QuicSocket*), (noexcept)); MOCK_METHOD((void), observerDetach, (QuicSocket*), (noexcept));
MOCK_METHOD((void), destroy, (QuicSocket*), (noexcept)); MOCK_METHOD((void), destroy, (QuicSocket*), (noexcept));

View File

@@ -200,13 +200,15 @@ class TestQuicTransport
std::unique_ptr<folly::AsyncUDPSocket> socket, std::unique_ptr<folly::AsyncUDPSocket> socket,
ConnectionSetupCallback* connSetupCb, ConnectionSetupCallback* connSetupCb,
ConnectionCallback* connCb) ConnectionCallback* connCb)
: QuicTransportBase(evb, std::move(socket)) { : QuicTransportBase(evb, std::move(socket)),
observerContainer_(std::make_shared<SocketObserverContainer>(this)) {
setConnectionSetupCallback(connSetupCb); setConnectionSetupCallback(connSetupCb);
setConnectionCallback(connCb); setConnectionCallback(connCb);
auto conn = std::make_unique<QuicServerConnectionState>( auto conn = std::make_unique<QuicServerConnectionState>(
FizzServerQuicHandshakeContext::Builder().build()); FizzServerQuicHandshakeContext::Builder().build());
conn->clientConnectionId = ConnectionId({10, 9, 8, 7}); conn->clientConnectionId = ConnectionId({10, 9, 8, 7});
conn->version = QuicVersion::MVFST; conn->version = QuicVersion::MVFST;
conn->observerContainer = observerContainer_;
transportConn = conn.get(); transportConn = conn.get();
conn_.reset(conn.release()); conn_.reset(conn.release());
aead = test::createNoOpAead(); aead = test::createNoOpAead();
@@ -490,12 +492,24 @@ class TestQuicTransport
conn_->datagramState.maxReadBufferSize = 10; conn_->datagramState.maxReadBufferSize = 10;
} }
SocketObserverContainer* getSocketObserverContainer() const override {
return observerContainer_.get();
}
QuicServerConnectionState* transportConn; QuicServerConnectionState* transportConn;
std::unique_ptr<Aead> aead; std::unique_ptr<Aead> aead;
std::unique_ptr<PacketNumberCipher> headerCipher; std::unique_ptr<PacketNumberCipher> headerCipher;
std::unique_ptr<ConnectionIdAlgo> connIdAlgo_; std::unique_ptr<ConnectionIdAlgo> connIdAlgo_;
bool transportClosed{false}; bool transportClosed{false};
PacketNum packetNum_{0}; PacketNum packetNum_{0};
// Container of observers for the socket / transport.
//
// This member MUST be last in the list of members to ensure it is destroyed
// first, before any other members are destroyed. This ensures that observers
// can inspect any socket / transport state available through public methods
// when destruction of the transport begins.
const std::shared_ptr<SocketObserverContainer> observerContainer_;
}; };
class QuicTransportImplTest : public Test { class QuicTransportImplTest : public Test {
@@ -3481,13 +3495,15 @@ TEST_F(QuicTransportImplTest, FailedPing) {
TEST_F(QuicTransportImplTest, HandleKnobCallbacks) { TEST_F(QuicTransportImplTest, HandleKnobCallbacks) {
auto conn = transport->transportConn; auto conn = transport->transportConn;
// attach an observer to the socket LegacyObserver::EventSet eventSet;
LegacyObserver::Config config = {}; eventSet.enable(SocketObserverInterface::Events::knobFrameEvents);
config.knobFrameEvents = true;
auto cb = std::make_unique<StrictMock<MockLegacyObserver>>(config); auto obs1 = std::make_unique<NiceMock<MockLegacyObserver>>();
EXPECT_CALL(*cb, observerAttach(transport.get())); auto obs2 = std::make_unique<NiceMock<MockLegacyObserver>>(eventSet);
transport->addObserver(cb.get()); auto obs3 = std::make_unique<NiceMock<MockLegacyObserver>>(eventSet);
Mock::VerifyAndClearExpectations(cb.get()); transport->addObserver(obs1.get());
transport->addObserver(obs2.get());
transport->addObserver(obs3.get());
// set test knob frame // set test knob frame
uint64_t knobSpace = 0xfaceb00c; uint64_t knobSpace = 0xfaceb00c;
@@ -3501,15 +3517,17 @@ TEST_F(QuicTransportImplTest, HandleKnobCallbacks) {
EXPECT_CALL(connCallback, onKnobMock(knobSpace, knobId, _)) EXPECT_CALL(connCallback, onKnobMock(knobSpace, knobId, _))
.WillOnce(Invoke([](Unused, Unused, Unused) { /* do nothing */ })); .WillOnce(Invoke([](Unused, Unused, Unused) { /* do nothing */ }));
EXPECT_CALL(*cb, knobFrameReceived(transport.get(), _)).Times(1); EXPECT_CALL(*obs1, knobFrameReceived(transport.get(), _)).Times(0);
EXPECT_CALL(*obs2, knobFrameReceived(transport.get(), _)).Times(1);
EXPECT_CALL(*obs3, knobFrameReceived(transport.get(), _)).Times(1);
transport->invokeHandleKnobCallbacks(); transport->invokeHandleKnobCallbacks();
evb->loopOnce(); evb->loopOnce();
EXPECT_EQ(conn->pendingEvents.knobs.size(), 0); EXPECT_EQ(conn->pendingEvents.knobs.size(), 0);
// detach the observer from the socket // detach the observer from the socket
EXPECT_CALL(*cb, observerDetach(transport.get())); EXPECT_TRUE(transport->removeObserver(obs1.get()));
EXPECT_TRUE(transport->removeObserver(cb.get())); EXPECT_TRUE(transport->removeObserver(obs2.get()));
Mock::VerifyAndClearExpectations(cb.get()); EXPECT_TRUE(transport->removeObserver(obs3.get()));
} }
TEST_F(QuicTransportImplTest, StreamWriteCallbackUnregister) { TEST_F(QuicTransportImplTest, StreamWriteCallbackUnregister) {
@@ -3759,78 +3777,51 @@ TEST_F(QuicTransportImplTest, ObserverMultipleAttachDestroyTransport) {
} }
TEST_F(QuicTransportImplTest, ObserverDetachAndAttachEvb) { TEST_F(QuicTransportImplTest, ObserverDetachAndAttachEvb) {
LegacyObserver::Config config = {}; LegacyObserver::EventSet eventSet;
config.evbEvents = true; eventSet.enable(SocketObserverInterface::Events::evbEvents);
auto cb = std::make_unique<StrictMock<MockLegacyObserver>>(config);
auto obs1 = std::make_unique<NiceMock<MockLegacyObserver>>();
auto obs2 = std::make_unique<NiceMock<MockLegacyObserver>>(eventSet);
auto obs3 = std::make_unique<NiceMock<MockLegacyObserver>>(eventSet);
transport->addObserver(obs1.get());
transport->addObserver(obs2.get());
transport->addObserver(obs3.get());
// check the current event base and create a new one
EXPECT_EQ(evb.get(), transport->getEventBase());
folly::EventBase evb2; folly::EventBase evb2;
EXPECT_CALL(*cb, observerAttach(transport.get())); // Detach the event base evb
transport->addObserver(cb.get()); EXPECT_CALL(*obs1, evbDetach(transport.get(), evb.get())).Times(0);
Mock::VerifyAndClearExpectations(cb.get()); EXPECT_CALL(*obs2, evbDetach(transport.get(), evb.get())).Times(1);
EXPECT_CALL(*obs3, evbDetach(transport.get(), evb.get())).Times(1);
// Detach the event base evb and attach a new event base evb2
EXPECT_CALL(*cb, evbDetach(transport.get(), evb.get()));
transport->detachEventBase(); transport->detachEventBase();
EXPECT_EQ(nullptr, transport->getEventBase()); EXPECT_EQ(nullptr, transport->getEventBase());
Mock::VerifyAndClearExpectations(cb.get());
EXPECT_CALL(*cb, evbAttach(transport.get(), &evb2)); // Attach a new event base evb2
EXPECT_CALL(*obs1, evbAttach(transport.get(), &evb2)).Times(0);
EXPECT_CALL(*obs2, evbAttach(transport.get(), &evb2)).Times(1);
EXPECT_CALL(*obs3, evbAttach(transport.get(), &evb2)).Times(1);
transport->attachEventBase(&evb2); transport->attachEventBase(&evb2);
EXPECT_EQ(&evb2, transport->getEventBase()); EXPECT_EQ(&evb2, transport->getEventBase());
Mock::VerifyAndClearExpectations(cb.get());
// Detach the event base evb and re-attach the old event base evb // Detach the event base evb2
EXPECT_CALL(*cb, evbDetach(transport.get(), &evb2)); EXPECT_CALL(*obs1, evbDetach(transport.get(), &evb2)).Times(0);
EXPECT_CALL(*obs2, evbDetach(transport.get(), &evb2)).Times(1);
EXPECT_CALL(*obs3, evbDetach(transport.get(), &evb2)).Times(1);
transport->detachEventBase(); transport->detachEventBase();
EXPECT_EQ(nullptr, transport->getEventBase()); EXPECT_EQ(nullptr, transport->getEventBase());
Mock::VerifyAndClearExpectations(cb.get());
EXPECT_CALL(*cb, evbAttach(transport.get(), evb.get())); // Attach the original event base evb
EXPECT_CALL(*obs1, evbAttach(transport.get(), evb.get())).Times(0);
EXPECT_CALL(*obs2, evbAttach(transport.get(), evb.get())).Times(1);
EXPECT_CALL(*obs3, evbAttach(transport.get(), evb.get())).Times(1);
transport->attachEventBase(evb.get()); transport->attachEventBase(evb.get());
EXPECT_EQ(evb.get(), transport->getEventBase()); EXPECT_EQ(evb.get(), transport->getEventBase());
Mock::VerifyAndClearExpectations(cb.get());
EXPECT_CALL(*cb, observerDetach(transport.get())); EXPECT_TRUE(transport->removeObserver(obs1.get()));
EXPECT_TRUE(transport->removeObserver(cb.get())); EXPECT_TRUE(transport->removeObserver(obs2.get()));
Mock::VerifyAndClearExpectations(cb.get()); EXPECT_TRUE(transport->removeObserver(obs3.get()));
}
TEST_F(QuicTransportImplTest, ImplementationObserverCallbacksDeleted) {
auto noopCallback = [](QuicSocket*) {};
transport->transportConn->pendingCallbacks.emplace_back(noopCallback);
EXPECT_EQ(1, size(transport->transportConn->pendingCallbacks));
transport->invokeProcessCallbacksAfterNetworkData();
EXPECT_EQ(0, size(transport->transportConn->pendingCallbacks));
}
TEST_F(QuicTransportImplTest, ImplementationObserverCallbacksInvoked) {
uint32_t callbacksInvoked = 0;
auto countingCallback = [&](QuicSocket*) { callbacksInvoked++; };
for (int i = 0; i < 2; i++) {
transport->transportConn->pendingCallbacks.emplace_back(countingCallback);
}
EXPECT_EQ(2, size(transport->transportConn->pendingCallbacks));
transport->invokeProcessCallbacksAfterNetworkData();
EXPECT_EQ(2, callbacksInvoked);
EXPECT_EQ(0, size(transport->transportConn->pendingCallbacks));
}
TEST_F(
QuicTransportImplTest,
ImplementationObserverCallbacksCorrectQuicSocket) {
QuicSocket* returnedSocket = nullptr;
auto func = [&](QuicSocket* qSocket) { returnedSocket = qSocket; };
EXPECT_EQ(0, size(transport->transportConn->pendingCallbacks));
transport->transportConn->pendingCallbacks.emplace_back(func);
EXPECT_EQ(1, size(transport->transportConn->pendingCallbacks));
transport->invokeProcessCallbacksAfterNetworkData();
EXPECT_EQ(0, size(transport->transportConn->pendingCallbacks));
EXPECT_EQ(transport.get(), returnedSocket);
} }
TEST_F(QuicTransportImplTest, GetConnectionStatsSmoke) { TEST_F(QuicTransportImplTest, GetConnectionStatsSmoke) {

View File

@@ -376,13 +376,13 @@ TEST_F(QuicTransportTest, ObserverNotAppLimitedWithNoWritableBytes) {
return 0; return 0;
})); }));
LegacyObserver::Config config = {}; LegacyObserver::EventSet eventSet;
config.packetsWrittenEvents = true; eventSet.enable(
config.appRateLimitedEvents = true; SocketObserverInterface::Events::packetsWrittenEvents,
auto cb1 = std::make_unique<StrictMock<MockLegacyObserver>>(config); SocketObserverInterface::Events::appRateLimitedEvents);
auto cb2 = std::make_unique<StrictMock<MockLegacyObserver>>(config); auto cb1 = std::make_unique<StrictMock<MockLegacyObserver>>(eventSet);
auto cb3 = std::make_unique<StrictMock<MockLegacyObserver>>( auto cb2 = std::make_unique<StrictMock<MockLegacyObserver>>(eventSet);
LegacyObserver::Config()); auto cb3 = std::make_unique<StrictMock<MockLegacyObserver>>();
EXPECT_CALL(*cb1, observerAttach(transport_.get())); EXPECT_CALL(*cb1, observerAttach(transport_.get()));
EXPECT_CALL(*cb2, observerAttach(transport_.get())); EXPECT_CALL(*cb2, observerAttach(transport_.get()));
@@ -422,13 +422,13 @@ TEST_F(QuicTransportTest, ObserverNotAppLimitedWithLargeBuffer) {
EXPECT_CALL(*rawCongestionController, getWritableBytes()) EXPECT_CALL(*rawCongestionController, getWritableBytes())
.WillRepeatedly(Return(5000)); .WillRepeatedly(Return(5000));
LegacyObserver::Config config = {}; LegacyObserver::EventSet eventSet;
config.packetsWrittenEvents = true; eventSet.enable(
config.appRateLimitedEvents = true; SocketObserverInterface::Events::packetsWrittenEvents,
auto cb1 = std::make_unique<StrictMock<MockLegacyObserver>>(config); SocketObserverInterface::Events::appRateLimitedEvents);
auto cb2 = std::make_unique<StrictMock<MockLegacyObserver>>(config); auto cb1 = std::make_unique<StrictMock<MockLegacyObserver>>(eventSet);
auto cb3 = std::make_unique<StrictMock<MockLegacyObserver>>( auto cb2 = std::make_unique<StrictMock<MockLegacyObserver>>(eventSet);
LegacyObserver::Config()); auto cb3 = std::make_unique<StrictMock<MockLegacyObserver>>();
EXPECT_CALL(*cb1, observerAttach(transport_.get())); EXPECT_CALL(*cb1, observerAttach(transport_.get()));
EXPECT_CALL(*cb2, observerAttach(transport_.get())); EXPECT_CALL(*cb2, observerAttach(transport_.get()));
@@ -460,13 +460,14 @@ TEST_F(QuicTransportTest, ObserverNotAppLimitedWithLargeBuffer) {
} }
TEST_F(QuicTransportTest, ObserverAppLimited) { TEST_F(QuicTransportTest, ObserverAppLimited) {
LegacyObserver::Config config = {}; LegacyObserver::EventSet eventSet;
config.packetsWrittenEvents = true; eventSet.enable(
config.appRateLimitedEvents = true; SocketObserverInterface::Events::packetsWrittenEvents,
auto cb1 = std::make_unique<StrictMock<MockLegacyObserver>>(config); SocketObserverInterface::Events::appRateLimitedEvents);
auto cb2 = std::make_unique<StrictMock<MockLegacyObserver>>(config); auto cb1 = std::make_unique<StrictMock<MockLegacyObserver>>(eventSet);
auto cb3 = std::make_unique<StrictMock<MockLegacyObserver>>( auto cb2 = std::make_unique<StrictMock<MockLegacyObserver>>(eventSet);
LegacyObserver::Config()); auto cb3 = std::make_unique<StrictMock<MockLegacyObserver>>();
EXPECT_CALL(*cb1, observerAttach(transport_.get())); EXPECT_CALL(*cb1, observerAttach(transport_.get()));
EXPECT_CALL(*cb2, observerAttach(transport_.get())); EXPECT_CALL(*cb2, observerAttach(transport_.get()));
EXPECT_CALL(*cb3, observerAttach(transport_.get())); EXPECT_CALL(*cb3, observerAttach(transport_.get()));
@@ -508,14 +509,14 @@ TEST_F(QuicTransportTest, ObserverAppLimited) {
TEST_F(QuicTransportTest, ObserverPacketsWrittenCycleCheckDetails) { TEST_F(QuicTransportTest, ObserverPacketsWrittenCycleCheckDetails) {
InSequence s; InSequence s;
LegacyObserver::EventSet eventSet;
eventSet.enable(
SocketObserverInterface::Events::packetsWrittenEvents,
SocketObserverInterface::Events::appRateLimitedEvents);
auto cb1 = std::make_unique<StrictMock<MockLegacyObserver>>(eventSet);
auto cb2 = std::make_unique<StrictMock<MockLegacyObserver>>(eventSet);
auto cb3 = std::make_unique<StrictMock<MockLegacyObserver>>();
LegacyObserver::Config config = {};
config.packetsWrittenEvents = true;
config.appRateLimitedEvents = true;
auto cb1 = std::make_unique<StrictMock<MockLegacyObserver>>(config);
auto cb2 = std::make_unique<StrictMock<MockLegacyObserver>>(config);
auto cb3 = std::make_unique<StrictMock<MockLegacyObserver>>(
LegacyObserver::Config());
const auto invokeForAllObservers = const auto invokeForAllObservers =
[&cb1, &cb2, &cb3](const std::function<void(MockLegacyObserver&)>& fn) { [&cb1, &cb2, &cb3](const std::function<void(MockLegacyObserver&)>& fn) {
fn(*cb1); fn(*cb1);
@@ -893,13 +894,12 @@ TEST_F(QuicTransportTest, ObserverPacketsWrittenCycleCheckDetails) {
TEST_F(QuicTransportTest, ObserverPacketsWrittenCheckBytesSent) { TEST_F(QuicTransportTest, ObserverPacketsWrittenCheckBytesSent) {
InSequence s; InSequence s;
LegacyObserver::EventSet eventSet;
eventSet.enable(SocketObserverInterface::Events::packetsWrittenEvents);
auto cb1 = std::make_unique<StrictMock<MockLegacyObserver>>(eventSet);
auto cb2 = std::make_unique<StrictMock<MockLegacyObserver>>(eventSet);
auto cb3 = std::make_unique<StrictMock<MockLegacyObserver>>();
LegacyObserver::Config config = {};
config.packetsWrittenEvents = true;
auto cb1 = std::make_unique<StrictMock<MockLegacyObserver>>(config);
auto cb2 = std::make_unique<StrictMock<MockLegacyObserver>>(config);
auto cb3 = std::make_unique<StrictMock<MockLegacyObserver>>(
LegacyObserver::Config());
const auto invokeForAllObservers = const auto invokeForAllObservers =
[&cb1, &cb2, &cb3](const std::function<void(MockLegacyObserver&)>& fn) { [&cb1, &cb2, &cb3](const std::function<void(MockLegacyObserver&)>& fn) {
fn(*cb1); fn(*cb1);
@@ -1106,14 +1106,14 @@ TEST_F(QuicTransportTest, ObserverPacketsWrittenCheckBytesSent) {
TEST_F(QuicTransportTest, ObserverWriteEventsCheckCwndPacketsWritable) { TEST_F(QuicTransportTest, ObserverWriteEventsCheckCwndPacketsWritable) {
InSequence s; InSequence s;
LegacyObserver::EventSet eventSet;
eventSet.enable(
SocketObserverInterface::Events::packetsWrittenEvents,
SocketObserverInterface::Events::appRateLimitedEvents);
auto cb1 = std::make_unique<StrictMock<MockLegacyObserver>>(eventSet);
auto cb2 = std::make_unique<StrictMock<MockLegacyObserver>>(eventSet);
auto cb3 = std::make_unique<StrictMock<MockLegacyObserver>>();
LegacyObserver::Config config = {};
config.packetsWrittenEvents = true;
config.appRateLimitedEvents = true;
auto cb1 = std::make_unique<StrictMock<MockLegacyObserver>>(config);
auto cb2 = std::make_unique<StrictMock<MockLegacyObserver>>(config);
auto cb3 = std::make_unique<StrictMock<MockLegacyObserver>>(
LegacyObserver::Config());
const auto invokeForAllObservers = const auto invokeForAllObservers =
[&cb1, &cb2, &cb3](const std::function<void(MockLegacyObserver&)>& fn) { [&cb1, &cb2, &cb3](const std::function<void(MockLegacyObserver&)>& fn) {
fn(*cb1); fn(*cb1);
@@ -1386,12 +1386,11 @@ TEST_F(QuicTransportTest, ObserverWriteEventsCheckCwndPacketsWritable) {
} }
TEST_F(QuicTransportTest, ObserverStreamEventBidirectionalLocalOpenClose) { TEST_F(QuicTransportTest, ObserverStreamEventBidirectionalLocalOpenClose) {
LegacyObserver::Config configWithStreamEvents = {}; LegacyObserver::EventSet eventSet;
configWithStreamEvents.streamEvents = true; eventSet.enable(SocketObserverInterface::Events::streamEvents);
auto cb1 = auto cb1 = std::make_unique<StrictMock<MockLegacyObserver>>(eventSet);
std::make_unique<StrictMock<MockLegacyObserver>>(configWithStreamEvents); auto cb2 = std::make_unique<StrictMock<MockLegacyObserver>>();
auto cb2 = std::make_unique<StrictMock<MockLegacyObserver>>(
LegacyObserver::Config());
EXPECT_CALL(*cb1, observerAttach(transport_.get())); EXPECT_CALL(*cb1, observerAttach(transport_.get()));
transport_->addObserver(cb1.get()); transport_->addObserver(cb1.get());
EXPECT_CALL(*cb2, observerAttach(transport_.get())); EXPECT_CALL(*cb2, observerAttach(transport_.get()));
@@ -1429,12 +1428,11 @@ TEST_F(QuicTransportTest, ObserverStreamEventBidirectionalLocalOpenClose) {
} }
TEST_F(QuicTransportTest, ObserverStreamEventBidirectionalRemoteOpenClose) { TEST_F(QuicTransportTest, ObserverStreamEventBidirectionalRemoteOpenClose) {
LegacyObserver::Config configWithStreamEvents = {}; LegacyObserver::EventSet eventSet;
configWithStreamEvents.streamEvents = true; eventSet.enable(SocketObserverInterface::Events::streamEvents);
auto cb1 = auto cb1 = std::make_unique<StrictMock<MockLegacyObserver>>(eventSet);
std::make_unique<StrictMock<MockLegacyObserver>>(configWithStreamEvents); auto cb2 = std::make_unique<StrictMock<MockLegacyObserver>>();
auto cb2 = std::make_unique<StrictMock<MockLegacyObserver>>(
LegacyObserver::Config());
EXPECT_CALL(*cb1, observerAttach(transport_.get())); EXPECT_CALL(*cb1, observerAttach(transport_.get()));
transport_->addObserver(cb1.get()); transport_->addObserver(cb1.get());
EXPECT_CALL(*cb2, observerAttach(transport_.get())); EXPECT_CALL(*cb2, observerAttach(transport_.get()));
@@ -1472,12 +1470,11 @@ TEST_F(QuicTransportTest, ObserverStreamEventBidirectionalRemoteOpenClose) {
} }
TEST_F(QuicTransportTest, ObserverStreamEventUnidirectionalLocalOpenClose) { TEST_F(QuicTransportTest, ObserverStreamEventUnidirectionalLocalOpenClose) {
LegacyObserver::Config configWithStreamEvents = {}; LegacyObserver::EventSet eventSet;
configWithStreamEvents.streamEvents = true; eventSet.enable(SocketObserverInterface::Events::streamEvents);
auto cb1 = auto cb1 = std::make_unique<StrictMock<MockLegacyObserver>>(eventSet);
std::make_unique<StrictMock<MockLegacyObserver>>(configWithStreamEvents); auto cb2 = std::make_unique<StrictMock<MockLegacyObserver>>();
auto cb2 = std::make_unique<StrictMock<MockLegacyObserver>>(
LegacyObserver::Config());
EXPECT_CALL(*cb1, observerAttach(transport_.get())); EXPECT_CALL(*cb1, observerAttach(transport_.get()));
transport_->addObserver(cb1.get()); transport_->addObserver(cb1.get());
EXPECT_CALL(*cb2, observerAttach(transport_.get())); EXPECT_CALL(*cb2, observerAttach(transport_.get()));
@@ -1515,12 +1512,11 @@ TEST_F(QuicTransportTest, ObserverStreamEventUnidirectionalLocalOpenClose) {
} }
TEST_F(QuicTransportTest, ObserverStreamEventUnidirectionalRemoteOpenClose) { TEST_F(QuicTransportTest, ObserverStreamEventUnidirectionalRemoteOpenClose) {
LegacyObserver::Config configWithStreamEvents = {}; LegacyObserver::EventSet eventSet;
configWithStreamEvents.streamEvents = true; eventSet.enable(SocketObserverInterface::Events::streamEvents);
auto cb1 = auto cb1 = std::make_unique<StrictMock<MockLegacyObserver>>(eventSet);
std::make_unique<StrictMock<MockLegacyObserver>>(configWithStreamEvents); auto cb2 = std::make_unique<StrictMock<MockLegacyObserver>>();
auto cb2 = std::make_unique<StrictMock<MockLegacyObserver>>(
LegacyObserver::Config());
EXPECT_CALL(*cb1, observerAttach(transport_.get())); EXPECT_CALL(*cb1, observerAttach(transport_.get()));
transport_->addObserver(cb1.get()); transport_->addObserver(cb1.get());
EXPECT_CALL(*cb2, observerAttach(transport_.get())); EXPECT_CALL(*cb2, observerAttach(transport_.get()));

File diff suppressed because it is too large Load Diff

View File

@@ -23,7 +23,8 @@ class TestQuicTransport
std::unique_ptr<folly::AsyncUDPSocket> socket, std::unique_ptr<folly::AsyncUDPSocket> socket,
ConnectionSetupCallback* connSetupCb, ConnectionSetupCallback* connSetupCb,
ConnectionCallback* connCb) ConnectionCallback* connCb)
: QuicTransportBase(evb, std::move(socket)) { : QuicTransportBase(evb, std::move(socket)),
observerContainer_(std::make_shared<SocketObserverContainer>(this)) {
setConnectionSetupCallback(connSetupCb); setConnectionSetupCallback(connSetupCb);
setConnectionCallback(connCb); setConnectionCallback(connCb);
conn_.reset(new QuicServerConnectionState( conn_.reset(new QuicServerConnectionState(
@@ -31,6 +32,7 @@ class TestQuicTransport
conn_->clientConnectionId = ConnectionId({9, 8, 7, 6}); conn_->clientConnectionId = ConnectionId({9, 8, 7, 6});
conn_->serverConnectionId = ConnectionId({1, 2, 3, 4}); conn_->serverConnectionId = ConnectionId({1, 2, 3, 4});
conn_->version = QuicVersion::MVFST; conn_->version = QuicVersion::MVFST;
conn_->observerContainer = observerContainer_;
aead = test::createNoOpAead(); aead = test::createNoOpAead();
headerCipher = test::createNoOpHeaderCipher(); headerCipher = test::createNoOpHeaderCipher();
} }
@@ -157,9 +159,21 @@ class TestQuicTransport
transportReadyNotified_ = transportReadyNotified; transportReadyNotified_ = transportReadyNotified;
} }
SocketObserverContainer* getSocketObserverContainer() const override {
return observerContainer_.get();
}
std::unique_ptr<Aead> aead; std::unique_ptr<Aead> aead;
std::unique_ptr<PacketNumberCipher> headerCipher; std::unique_ptr<PacketNumberCipher> headerCipher;
bool closed{false}; bool closed{false};
// Container of observers for the socket / transport.
//
// This member MUST be last in the list of members to ensure it is destroyed
// first, before any other members are destroyed. This ensures that observers
// can inspect any socket / transport state available through public methods
// when destruction of the transport begins.
const std::shared_ptr<SocketObserverContainer> observerContainer_;
}; };
} // namespace quic } // namespace quic

View File

@@ -66,7 +66,7 @@ QuicClientTransport::QuicClientTransport(
std::make_unique<QuicClientConnectionState>(std::move(handshakeFactory)); std::make_unique<QuicClientConnectionState>(std::move(handshakeFactory));
clientConn_ = tempConn.get(); clientConn_ = tempConn.get();
conn_.reset(tempConn.release()); conn_.reset(tempConn.release());
conn_->observers = observers_; conn_->observerContainer = observerContainer_;
auto srcConnId = connectionIdSize > 0 auto srcConnId = connectionIdSize > 0
? ConnectionId::createRandom(connectionIdSize) ? ConnectionId::createRandom(connectionIdSize)

View File

@@ -28,6 +28,7 @@ std::unique_ptr<QuicClientConnectionState> undoAllClientStateForRetry(
// across stateless retry. // across stateless retry.
auto newConn = std::make_unique<QuicClientConnectionState>( auto newConn = std::make_unique<QuicClientConnectionState>(
std::move(conn->handshakeFactory)); std::move(conn->handshakeFactory));
newConn->observerContainer = conn->observerContainer;
newConn->qLogger = conn->qLogger; newConn->qLogger = conn->qLogger;
newConn->clientConnectionId = conn->clientConnectionId; newConn->clientConnectionId = conn->clientConnectionId;
newConn->initialDestinationConnectionId = newConn->initialDestinationConnectionId =

View File

@@ -5,6 +5,7 @@
* LICENSE file in the root directory of this source tree. * LICENSE file in the root directory of this source tree.
*/ */
#include <quic/api/test/MockQuicSocket.h>
#include <quic/client/handshake/CachedServerTransportParameters.h> #include <quic/client/handshake/CachedServerTransportParameters.h>
#include <quic/client/handshake/ClientHandshake.h> #include <quic/client/handshake/ClientHandshake.h>
#include <quic/client/state/ClientStateMachine.h> #include <quic/client/state/ClientStateMachine.h>
@@ -104,6 +105,34 @@ TEST_F(ClientStateMachineTest, PreserveHappyeyabllsDuringUndo) {
EXPECT_NE(nullptr, newConn->happyEyeballsState.secondSocket); EXPECT_NE(nullptr, newConn->happyEyeballsState.secondSocket);
} }
TEST_F(ClientStateMachineTest, PreserveObserverContainer) {
auto socket = std::make_shared<MockQuicSocket>();
const auto observerContainer =
std::make_shared<SocketObserverContainer>(socket.get());
SocketObserverContainer::ManagedObserver obs;
observerContainer->addObserver(&obs);
client_->clientConnectionId = ConnectionId::createRandom(8);
client_->observerContainer = observerContainer;
EXPECT_EQ(1, client_->observerContainer->numObservers());
EXPECT_THAT(
client_->observerContainer->findObservers(), UnorderedElementsAre(&obs));
auto newConn = undoAllClientStateForRetry(std::move(client_));
EXPECT_EQ(newConn->observerContainer, observerContainer);
EXPECT_EQ(1, newConn->observerContainer->numObservers());
EXPECT_THAT(
newConn->observerContainer->findObservers(), UnorderedElementsAre(&obs));
}
TEST_F(ClientStateMachineTest, PreserveObserverContainerNullptr) {
client_->clientConnectionId = ConnectionId::createRandom(8);
client_->observerContainer = nullptr;
auto newConn = undoAllClientStateForRetry(std::move(client_));
EXPECT_THAT(newConn->observerContainer, IsNull());
}
TEST_F(ClientStateMachineTest, TestProcessMaxDatagramSizeBelowMin) { TEST_F(ClientStateMachineTest, TestProcessMaxDatagramSizeBelowMin) {
QuicClientConnectionState clientConn( QuicClientConnectionState clientConn(
FizzClientQuicHandshakeContext::Builder().build()); FizzClientQuicHandshakeContext::Builder().build());

View File

@@ -6,6 +6,7 @@
*/ */
#include <quic/d6d/QuicD6DStateFunctions.h> #include <quic/d6d/QuicD6DStateFunctions.h>
#include <quic/observer/SocketObserverContainer.h>
namespace quic { namespace quic {
@@ -23,26 +24,24 @@ static TimePoint reportUpperBound(QuicConnectionStateBase& conn) {
auto& d6d = conn.d6d; auto& d6d = conn.d6d;
const auto lastProbeSize = d6d.lastProbe->packetSize; const auto lastProbeSize = d6d.lastProbe->packetSize;
const auto now = Clock::now(); const auto now = Clock::now();
QUIC_STATS(conn.statsCallback, onConnectionPMTUUpperBoundDetected); QUIC_STATS(conn.statsCallback, onConnectionPMTUUpperBoundDetected);
if (conn.observers->size() > 0) { if (conn.getSocketObserverContainer() &&
SocketObserverInterface::PMTUUpperBoundEvent upperBoundEvent( conn.getSocketObserverContainer()
now, ->hasObserversForEvent<
std::chrono::duration_cast<std::chrono::microseconds>( SocketObserverInterface::Events::pmtuEvents>()) {
now - d6d.meta.timeLastNonSearchState), conn.getSocketObserverContainer()
d6d.meta.lastNonSearchState, ->invokeInterfaceMethod<SocketObserverInterface::Events::pmtuEvents>(
lastProbeSize, [event = SocketObserverInterface::PMTUUpperBoundEvent(
d6d.meta.totalTxedProbes, now,
conn.transportSettings.d6dConfig.raiserType); std::chrono::duration_cast<std::chrono::microseconds>(
// enqueue a function for every observer to invoke callback now - d6d.meta.timeLastNonSearchState),
for (const auto& observer : *(conn.observers)) { d6d.meta.lastNonSearchState,
conn.pendingCallbacks.emplace_back( lastProbeSize,
[observer, upperBoundEvent](QuicSocket* qSocket) { d6d.meta.totalTxedProbes,
if (observer->getConfig().pmtuEvents) { conn.transportSettings.d6dConfig.raiserType)](
observer->pmtuUpperBoundDetected(qSocket, upperBoundEvent); auto observer, auto observed) {
} observer->pmtuUpperBoundDetected(observed, event);
}); });
}
} }
return now; return now;
} }
@@ -57,28 +56,24 @@ static TimePoint reportBlackhole(
QUIC_STATS(conn.statsCallback, onConnectionPMTUBlackholeDetected); QUIC_STATS(conn.statsCallback, onConnectionPMTUBlackholeDetected);
auto& d6d = conn.d6d; auto& d6d = conn.d6d;
const auto now = Clock::now(); const auto now = Clock::now();
if (conn.observers->size() > 0) { if (conn.observerContainer &&
SocketObserverInterface::PMTUBlackholeEvent blackholeEvent( conn.observerContainer->hasObserversForEvent<
now, SocketObserverInterface::Events::pmtuEvents>()) {
std::chrono::duration_cast<std::chrono::microseconds>( conn.observerContainer
now - d6d.meta.timeLastNonSearchState), ->invokeInterfaceMethod<SocketObserverInterface::Events::pmtuEvents>(
d6d.meta.lastNonSearchState, [event = SocketObserverInterface::PMTUBlackholeEvent(
d6d.state, now,
conn.udpSendPacketLen, std::chrono::duration_cast<std::chrono::microseconds>(
d6d.lastProbe->packetSize, now - d6d.meta.timeLastNonSearchState),
d6d.thresholdCounter->getWindow(), d6d.meta.lastNonSearchState,
d6d.thresholdCounter->getThreshold(), d6d.state,
packet); conn.udpSendPacketLen,
d6d.lastProbe->packetSize,
// If there are observers, enqueue a function to invoke callback d6d.thresholdCounter->getWindow(),
for (const auto& observer : *(conn.observers)) { d6d.thresholdCounter->getThreshold(),
conn.pendingCallbacks.emplace_back( packet)](auto observer, auto observed) {
[observer, blackholeEvent](QuicSocket* qSocket) { observer->pmtuBlackholeDetected(observed, event);
if (observer->getConfig().pmtuEvents) { });
observer->pmtuBlackholeDetected(qSocket, blackholeEvent);
}
});
}
} }
return now; return now;
} }

View File

@@ -8,6 +8,7 @@
#include <gmock/gmock.h> #include <gmock/gmock.h>
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <quic/api/test/MockQuicSocket.h>
#include <quic/api/test/Mocks.h> #include <quic/api/test/Mocks.h>
#include <quic/common/test/TestUtils.h> #include <quic/common/test/TestUtils.h>
#include <quic/d6d/QuicD6DStateFunctions.h> #include <quic/d6d/QuicD6DStateFunctions.h>
@@ -185,16 +186,21 @@ TEST_F(QuicD6DStateFunctionsTest, D6DProbeTimeoutExpiredOneInError) {
} }
TEST_F(QuicD6DStateFunctionsTest, D6DProbeAckedInBase) { TEST_F(QuicD6DStateFunctionsTest, D6DProbeAckedInBase) {
auto socket = std::make_shared<MockQuicSocket>();
const auto observerContainer =
std::make_shared<SocketObserverContainer>(socket.get());
QuicConnectionStateBase conn(QuicNodeType::Server); QuicConnectionStateBase conn(QuicNodeType::Server);
conn.observerContainer = observerContainer;
LegacyObserver::EventSet eventSet;
eventSet.enable(SocketObserverInterface::Events::pmtuEvents);
auto obs1 = std::make_unique<NiceMock<MockLegacyObserver>>(eventSet);
observerContainer->addObserver(obs1.get());
const uint16_t expectPMTU = 1400; const uint16_t expectPMTU = 1400;
auto& d6d = conn.d6d; auto& d6d = conn.d6d;
auto now = Clock::now(); const auto now = Clock::now();
LegacyObserver::Config config = {};
config.pmtuEvents = true;
auto observer = std::make_unique<StrictMock<MockLegacyObserver>>(config);
auto observers = std::make_shared<ObserverVec>();
observers->emplace_back(observer.get());
conn.observers = observers;
d6d.state = D6DMachineState::BASE; d6d.state = D6DMachineState::BASE;
d6d.outstandingProbes = 1; d6d.outstandingProbes = 1;
d6d.currentProbeSize = d6d.basePMTU; d6d.currentProbeSize = d6d.basePMTU;
@@ -220,29 +226,34 @@ TEST_F(QuicD6DStateFunctionsTest, D6DProbeAckedInBase) {
EXPECT_CALL(*mockRaiser, raiseProbeSize(d6d.currentProbeSize)) EXPECT_CALL(*mockRaiser, raiseProbeSize(d6d.currentProbeSize))
.Times(1) .Times(1)
.WillOnce(Return(expectPMTU)); .WillOnce(Return(expectPMTU));
EXPECT_CALL(*observer, pmtuUpperBoundDetected(_, _)).Times(0); EXPECT_CALL(*obs1, pmtuUpperBoundDetected(_, _)).Times(0);
onD6DLastProbeAcked(conn); onD6DLastProbeAcked(conn);
for (auto& callback : conn.pendingCallbacks) {
callback(nullptr);
}
EXPECT_EQ(d6d.state, D6DMachineState::SEARCHING); EXPECT_EQ(d6d.state, D6DMachineState::SEARCHING);
EXPECT_EQ(d6d.currentProbeSize, expectPMTU); EXPECT_EQ(d6d.currentProbeSize, expectPMTU);
EXPECT_EQ(conn.udpSendPacketLen, d6d.basePMTU); EXPECT_EQ(conn.udpSendPacketLen, d6d.basePMTU);
EXPECT_EQ(d6d.meta.lastNonSearchState, D6DMachineState::BASE); EXPECT_EQ(d6d.meta.lastNonSearchState, D6DMachineState::BASE);
EXPECT_GE(d6d.meta.timeLastNonSearchState, now); EXPECT_GE(d6d.meta.timeLastNonSearchState, now);
observerContainer->removeObserver(obs1.get());
} }
TEST_F(QuicD6DStateFunctionsTest, D6DProbeAckedInSearchingOne) { TEST_F(QuicD6DStateFunctionsTest, D6DProbeAckedInSearchingOne) {
auto socket = std::make_shared<MockQuicSocket>();
const auto observerContainer =
std::make_shared<SocketObserverContainer>(socket.get());
QuicConnectionStateBase conn(QuicNodeType::Server); QuicConnectionStateBase conn(QuicNodeType::Server);
conn.observerContainer = observerContainer;
LegacyObserver::EventSet eventSet;
eventSet.enable(SocketObserverInterface::Events::pmtuEvents);
auto obs1 = std::make_unique<NiceMock<MockLegacyObserver>>(eventSet);
observerContainer->addObserver(obs1.get());
const uint16_t expectPMTU = 1400; const uint16_t expectPMTU = 1400;
auto& d6d = conn.d6d; auto& d6d = conn.d6d;
auto now = Clock::now(); const auto now = Clock::now();
LegacyObserver::Config config = {};
config.pmtuEvents = true;
auto observer = std::make_unique<StrictMock<MockLegacyObserver>>(config);
auto observers = std::make_shared<ObserverVec>();
observers->emplace_back(observer.get());
conn.observers = observers;
d6d.state = D6DMachineState::SEARCHING; d6d.state = D6DMachineState::SEARCHING;
d6d.outstandingProbes = 1; d6d.outstandingProbes = 1;
conn.udpSendPacketLen = 1250; conn.udpSendPacketLen = 1250;
@@ -269,29 +280,34 @@ TEST_F(QuicD6DStateFunctionsTest, D6DProbeAckedInSearchingOne) {
EXPECT_CALL(*mockRaiser, raiseProbeSize(d6d.currentProbeSize)) EXPECT_CALL(*mockRaiser, raiseProbeSize(d6d.currentProbeSize))
.Times(1) .Times(1)
.WillOnce(Return(expectPMTU)); .WillOnce(Return(expectPMTU));
EXPECT_CALL(*observer, pmtuUpperBoundDetected(_, _)).Times(0); EXPECT_CALL(*obs1, pmtuUpperBoundDetected(_, _)).Times(0);
onD6DLastProbeAcked(conn); onD6DLastProbeAcked(conn);
for (auto& callback : conn.pendingCallbacks) {
callback(nullptr);
}
EXPECT_EQ(d6d.state, D6DMachineState::SEARCHING); EXPECT_EQ(d6d.state, D6DMachineState::SEARCHING);
EXPECT_EQ(d6d.currentProbeSize, expectPMTU); EXPECT_EQ(d6d.currentProbeSize, expectPMTU);
EXPECT_EQ(conn.udpSendPacketLen, 1300); EXPECT_EQ(conn.udpSendPacketLen, 1300);
EXPECT_EQ(d6d.meta.lastNonSearchState, D6DMachineState::BASE); EXPECT_EQ(d6d.meta.lastNonSearchState, D6DMachineState::BASE);
EXPECT_EQ(d6d.meta.timeLastNonSearchState, now); EXPECT_EQ(d6d.meta.timeLastNonSearchState, now);
observerContainer->removeObserver(obs1.get());
} }
TEST_F(QuicD6DStateFunctionsTest, D6DProbeAckedInSearchingMax) { TEST_F(QuicD6DStateFunctionsTest, D6DProbeAckedInSearchingMax) {
auto socket = std::make_shared<MockQuicSocket>();
const auto observerContainer =
std::make_shared<SocketObserverContainer>(socket.get());
QuicConnectionStateBase conn(QuicNodeType::Server); QuicConnectionStateBase conn(QuicNodeType::Server);
conn.observerContainer = observerContainer;
LegacyObserver::EventSet eventSet;
eventSet.enable(SocketObserverInterface::Events::pmtuEvents);
auto obs1 = std::make_unique<NiceMock<MockLegacyObserver>>(eventSet);
observerContainer->addObserver(obs1.get());
const uint16_t oversize = 1500; const uint16_t oversize = 1500;
auto& d6d = conn.d6d; auto& d6d = conn.d6d;
auto now = Clock::now(); const auto now = Clock::now();
LegacyObserver::Config config = {};
config.pmtuEvents = true;
auto observer = std::make_unique<StrictMock<MockLegacyObserver>>(config);
auto observers = std::make_shared<ObserverVec>();
observers->emplace_back(observer.get());
conn.observers = observers;
d6d.state = D6DMachineState::SEARCHING; d6d.state = D6DMachineState::SEARCHING;
d6d.outstandingProbes = 3; d6d.outstandingProbes = 3;
conn.udpSendPacketLen = 1400; conn.udpSendPacketLen = 1400;
@@ -319,7 +335,7 @@ TEST_F(QuicD6DStateFunctionsTest, D6DProbeAckedInSearchingMax) {
EXPECT_CALL(*mockRaiser, raiseProbeSize(d6d.currentProbeSize)) EXPECT_CALL(*mockRaiser, raiseProbeSize(d6d.currentProbeSize))
.Times(1) .Times(1)
.WillOnce(Return(oversize)); .WillOnce(Return(oversize));
EXPECT_CALL(*observer, pmtuUpperBoundDetected(_, _)) EXPECT_CALL(*obs1, pmtuUpperBoundDetected(_, _))
.Times(1) .Times(1)
.WillOnce(Invoke( .WillOnce(Invoke(
[&](QuicSocket* /* qSocket */, [&](QuicSocket* /* qSocket */,
@@ -333,26 +349,31 @@ TEST_F(QuicD6DStateFunctionsTest, D6DProbeAckedInSearchingMax) {
ProbeSizeRaiserType::ConstantStep, event.probeSizeRaiserType); ProbeSizeRaiserType::ConstantStep, event.probeSizeRaiserType);
})); }));
onD6DLastProbeAcked(conn); onD6DLastProbeAcked(conn);
for (auto& callback : conn.pendingCallbacks) {
callback(nullptr);
}
EXPECT_EQ(d6d.state, D6DMachineState::SEARCH_COMPLETE); EXPECT_EQ(d6d.state, D6DMachineState::SEARCH_COMPLETE);
EXPECT_EQ(d6d.currentProbeSize, 1450); EXPECT_EQ(d6d.currentProbeSize, 1450);
EXPECT_EQ(conn.udpSendPacketLen, 1450); EXPECT_EQ(conn.udpSendPacketLen, 1450);
EXPECT_EQ(d6d.meta.lastNonSearchState, D6DMachineState::BASE); EXPECT_EQ(d6d.meta.lastNonSearchState, D6DMachineState::BASE);
EXPECT_EQ(d6d.meta.timeLastNonSearchState, now); EXPECT_EQ(d6d.meta.timeLastNonSearchState, now);
observerContainer->removeObserver(obs1.get());
} }
TEST_F(QuicD6DStateFunctionsTest, D6DProbeAckedInError) { TEST_F(QuicD6DStateFunctionsTest, D6DProbeAckedInError) {
auto socket = std::make_shared<MockQuicSocket>();
const auto observerContainer =
std::make_shared<SocketObserverContainer>(socket.get());
QuicConnectionStateBase conn(QuicNodeType::Server); QuicConnectionStateBase conn(QuicNodeType::Server);
conn.observerContainer = observerContainer;
LegacyObserver::EventSet eventSet;
eventSet.enable(SocketObserverInterface::Events::pmtuEvents);
auto obs1 = std::make_unique<NiceMock<MockLegacyObserver>>(eventSet);
observerContainer->addObserver(obs1.get());
auto& d6d = conn.d6d; auto& d6d = conn.d6d;
auto now = Clock::now(); const auto now = Clock::now();
LegacyObserver::Config config = {};
config.pmtuEvents = true;
auto observer = std::make_unique<StrictMock<MockLegacyObserver>>(config);
auto observers = std::make_shared<ObserverVec>();
observers->emplace_back(observer.get());
conn.observers = observers;
d6d.state = D6DMachineState::ERROR; d6d.state = D6DMachineState::ERROR;
d6d.outstandingProbes = 3; d6d.outstandingProbes = 3;
conn.udpSendPacketLen = d6d.basePMTU; conn.udpSendPacketLen = d6d.basePMTU;
@@ -379,28 +400,33 @@ TEST_F(QuicD6DStateFunctionsTest, D6DProbeAckedInError) {
EXPECT_CALL(*mockRaiser, raiseProbeSize(d6d.currentProbeSize)) EXPECT_CALL(*mockRaiser, raiseProbeSize(d6d.currentProbeSize))
.Times(1) .Times(1)
.WillOnce(Return(1300)); // Won't be used .WillOnce(Return(1300)); // Won't be used
EXPECT_CALL(*observer, pmtuUpperBoundDetected(_, _)).Times(0); EXPECT_CALL(*obs1, pmtuUpperBoundDetected(_, _)).Times(0);
onD6DLastProbeAcked(conn); onD6DLastProbeAcked(conn);
for (auto& callback : conn.pendingCallbacks) {
callback(nullptr);
}
EXPECT_EQ(d6d.state, D6DMachineState::BASE); EXPECT_EQ(d6d.state, D6DMachineState::BASE);
EXPECT_EQ(d6d.currentProbeSize, d6d.basePMTU); EXPECT_EQ(d6d.currentProbeSize, d6d.basePMTU);
EXPECT_EQ(conn.udpSendPacketLen, d6d.basePMTU); EXPECT_EQ(conn.udpSendPacketLen, d6d.basePMTU);
EXPECT_EQ(d6d.meta.lastNonSearchState, D6DMachineState::ERROR); EXPECT_EQ(d6d.meta.lastNonSearchState, D6DMachineState::ERROR);
EXPECT_GE(d6d.meta.timeLastNonSearchState, now); EXPECT_GE(d6d.meta.timeLastNonSearchState, now);
observerContainer->removeObserver(obs1.get());
} }
TEST_F(QuicD6DStateFunctionsTest, BlackholeInSearching) { TEST_F(QuicD6DStateFunctionsTest, BlackholeInSearching) {
auto socket = std::make_shared<MockQuicSocket>();
const auto observerContainer =
std::make_shared<SocketObserverContainer>(socket.get());
QuicConnectionStateBase conn(QuicNodeType::Server); QuicConnectionStateBase conn(QuicNodeType::Server);
conn.observerContainer = observerContainer;
LegacyObserver::EventSet eventSet;
eventSet.enable(SocketObserverInterface::Events::pmtuEvents);
auto obs1 = std::make_unique<NiceMock<MockLegacyObserver>>(eventSet);
observerContainer->addObserver(obs1.get());
auto& d6d = conn.d6d; auto& d6d = conn.d6d;
auto now = Clock::now(); const auto now = Clock::now();
LegacyObserver::Config config = {};
config.pmtuEvents = true;
auto observer = std::make_unique<StrictMock<MockLegacyObserver>>(config);
auto observers = std::make_shared<ObserverVec>();
observers->emplace_back(observer.get());
conn.observers = observers;
d6d.state = D6DMachineState::SEARCHING; d6d.state = D6DMachineState::SEARCHING;
d6d.outstandingProbes = 2; d6d.outstandingProbes = 2;
conn.udpSendPacketLen = d6d.basePMTU + 20; conn.udpSendPacketLen = d6d.basePMTU + 20;
@@ -441,7 +467,7 @@ TEST_F(QuicD6DStateFunctionsTest, BlackholeInSearching) {
std::chrono::microseconds(kDefaultD6DBlackholeDetectionWindow).count(), std::chrono::microseconds(kDefaultD6DBlackholeDetectionWindow).count(),
1); // Threshold of 1 will cause window to be set to 0 1); // Threshold of 1 will cause window to be set to 0
EXPECT_CALL(*observer, pmtuBlackholeDetected(_, _)) EXPECT_CALL(*obs1, pmtuBlackholeDetected(_, _))
.Times(1) .Times(1)
.WillOnce( .WillOnce(
Invoke([&](QuicSocket* /* qSocket */, Invoke([&](QuicSocket* /* qSocket */,
@@ -456,29 +482,32 @@ TEST_F(QuicD6DStateFunctionsTest, BlackholeInSearching) {
EXPECT_EQ( EXPECT_EQ(
d6d.basePMTU + 20, event.triggeringPacketMetadata.encodedSize); d6d.basePMTU + 20, event.triggeringPacketMetadata.encodedSize);
})); }));
detectPMTUBlackhole(conn, lostPacket); detectPMTUBlackhole(conn, lostPacket);
for (auto& callback : conn.pendingCallbacks) {
callback(nullptr);
}
EXPECT_EQ(d6d.state, D6DMachineState::BASE); EXPECT_EQ(d6d.state, D6DMachineState::BASE);
EXPECT_EQ(d6d.currentProbeSize, d6d.basePMTU); EXPECT_EQ(d6d.currentProbeSize, d6d.basePMTU);
EXPECT_EQ(conn.udpSendPacketLen, d6d.basePMTU); EXPECT_EQ(conn.udpSendPacketLen, d6d.basePMTU);
EXPECT_EQ(d6d.meta.lastNonSearchState, D6DMachineState::BASE); EXPECT_EQ(d6d.meta.lastNonSearchState, D6DMachineState::BASE);
EXPECT_GE(d6d.meta.timeLastNonSearchState, now); EXPECT_GE(d6d.meta.timeLastNonSearchState, now);
observerContainer->removeObserver(obs1.get());
} }
TEST_F(QuicD6DStateFunctionsTest, BlackholeInSearchComplete) { TEST_F(QuicD6DStateFunctionsTest, BlackholeInSearchComplete) {
auto socket = std::make_shared<MockQuicSocket>();
const auto observerContainer =
std::make_shared<SocketObserverContainer>(socket.get());
QuicConnectionStateBase conn(QuicNodeType::Server); QuicConnectionStateBase conn(QuicNodeType::Server);
conn.observerContainer = observerContainer;
LegacyObserver::EventSet eventSet;
eventSet.enable(SocketObserverInterface::Events::pmtuEvents);
auto obs1 = std::make_unique<NiceMock<MockLegacyObserver>>(eventSet);
observerContainer->addObserver(obs1.get());
auto& d6d = conn.d6d; auto& d6d = conn.d6d;
auto now = Clock::now(); auto now = Clock::now();
LegacyObserver::Config config = {};
config.pmtuEvents = true;
auto observer = std::make_unique<StrictMock<MockLegacyObserver>>(config);
auto observers = std::make_shared<ObserverVec>();
observers->emplace_back(observer.get());
conn.observers = observers;
d6d.state = D6DMachineState::SEARCH_COMPLETE; d6d.state = D6DMachineState::SEARCH_COMPLETE;
conn.udpSendPacketLen = d6d.basePMTU + 20; conn.udpSendPacketLen = d6d.basePMTU + 20;
d6d.currentProbeSize = d6d.basePMTU + 20; d6d.currentProbeSize = d6d.basePMTU + 20;
@@ -519,12 +548,12 @@ TEST_F(QuicD6DStateFunctionsTest, BlackholeInSearchComplete) {
std::chrono::microseconds(kDefaultD6DBlackholeDetectionWindow).count(), std::chrono::microseconds(kDefaultD6DBlackholeDetectionWindow).count(),
1); // Threshold of 1 will cause window to be set to 0 1); // Threshold of 1 will cause window to be set to 0
EXPECT_CALL(*observer, pmtuBlackholeDetected(_, _)) EXPECT_CALL(*obs1, pmtuBlackholeDetected(_, _))
.Times(1) .Times(1)
.WillOnce( .WillOnce(
Invoke([&](QuicSocket* /* qSocket */, Invoke([&](QuicSocket* /* qSocket */,
const SocketObserverInterface::PMTUBlackholeEvent& event) { const SocketObserverInterface::PMTUBlackholeEvent& event) {
EXPECT_EQ(d6d.meta.timeLastNonSearchState, event.blackholeTime); EXPECT_LE(d6d.meta.timeLastNonSearchState, event.blackholeTime);
EXPECT_EQ(D6DMachineState::BASE, event.lastNonSearchState); EXPECT_EQ(D6DMachineState::BASE, event.lastNonSearchState);
EXPECT_EQ(D6DMachineState::SEARCH_COMPLETE, event.currentState); EXPECT_EQ(D6DMachineState::SEARCH_COMPLETE, event.currentState);
EXPECT_EQ(d6d.basePMTU + 20, event.udpSendPacketLen); EXPECT_EQ(d6d.basePMTU + 20, event.udpSendPacketLen);
@@ -534,29 +563,32 @@ TEST_F(QuicD6DStateFunctionsTest, BlackholeInSearchComplete) {
EXPECT_EQ( EXPECT_EQ(
d6d.basePMTU + 20, event.triggeringPacketMetadata.encodedSize); d6d.basePMTU + 20, event.triggeringPacketMetadata.encodedSize);
})); }));
detectPMTUBlackhole(conn, lostPacket); detectPMTUBlackhole(conn, lostPacket);
for (auto& callback : conn.pendingCallbacks) {
callback(nullptr);
}
EXPECT_EQ(d6d.state, D6DMachineState::BASE); EXPECT_EQ(d6d.state, D6DMachineState::BASE);
EXPECT_EQ(d6d.currentProbeSize, d6d.basePMTU); EXPECT_EQ(d6d.currentProbeSize, d6d.basePMTU);
EXPECT_EQ(conn.udpSendPacketLen, d6d.basePMTU); EXPECT_EQ(conn.udpSendPacketLen, d6d.basePMTU);
EXPECT_EQ(d6d.meta.lastNonSearchState, D6DMachineState::SEARCH_COMPLETE); EXPECT_EQ(d6d.meta.lastNonSearchState, D6DMachineState::SEARCH_COMPLETE);
EXPECT_GE(d6d.meta.timeLastNonSearchState, now); EXPECT_GE(d6d.meta.timeLastNonSearchState, now);
observerContainer->removeObserver(obs1.get());
} }
TEST_F(QuicD6DStateFunctionsTest, ReachMaxPMTU) { TEST_F(QuicD6DStateFunctionsTest, ReachMaxPMTU) {
auto socket = std::make_shared<MockQuicSocket>();
const auto observerContainer =
std::make_shared<SocketObserverContainer>(socket.get());
QuicConnectionStateBase conn(QuicNodeType::Server); QuicConnectionStateBase conn(QuicNodeType::Server);
conn.observerContainer = observerContainer;
LegacyObserver::EventSet eventSet;
eventSet.enable(SocketObserverInterface::Events::pmtuEvents);
auto obs1 = std::make_unique<NiceMock<MockLegacyObserver>>(eventSet);
observerContainer->addObserver(obs1.get());
auto& d6d = conn.d6d; auto& d6d = conn.d6d;
auto now = Clock::now(); const auto now = Clock::now();
LegacyObserver::Config config = {};
config.pmtuEvents = true;
auto observer = std::make_unique<StrictMock<MockLegacyObserver>>(config);
auto observers = std::make_shared<ObserverVec>();
observers->emplace_back(observer.get());
conn.observers = observers;
d6d.state = D6DMachineState::SEARCHING; d6d.state = D6DMachineState::SEARCHING;
d6d.maxPMTU = 1452; d6d.maxPMTU = 1452;
d6d.outstandingProbes = 1; d6d.outstandingProbes = 1;
@@ -591,6 +623,8 @@ TEST_F(QuicD6DStateFunctionsTest, ReachMaxPMTU) {
EXPECT_EQ(conn.udpSendPacketLen, 1442); EXPECT_EQ(conn.udpSendPacketLen, 1442);
EXPECT_EQ(d6d.meta.lastNonSearchState, D6DMachineState::BASE); EXPECT_EQ(d6d.meta.lastNonSearchState, D6DMachineState::BASE);
EXPECT_EQ(d6d.meta.timeLastNonSearchState, now); EXPECT_EQ(d6d.meta.timeLastNonSearchState, now);
observerContainer->removeObserver(obs1.get());
} }
TEST_F( TEST_F(

View File

@@ -214,7 +214,9 @@ folly::Optional<CongestionController::LossEvent> detectLossPackets(
<< " " << conn; << " " << conn;
CongestionController::LossEvent lossEvent(lossTime); CongestionController::LossEvent lossEvent(lossTime);
folly::Optional<SocketObserverInterface::LossEvent> observerLossEvent; folly::Optional<SocketObserverInterface::LossEvent> observerLossEvent;
if (!conn.observers->empty()) { if (conn.observerContainer &&
conn.observerContainer->hasObserversForEvent<
SocketObserverInterface::Events::lossEvents>()) {
observerLossEvent.emplace(lossTime); observerLossEvent.emplace(lossTime);
} }
// Note that time based loss detection is also within the same PNSpace. // Note that time based loss detection is also within the same PNSpace.
@@ -313,16 +315,16 @@ folly::Optional<CongestionController::LossEvent> detectLossPackets(
iter++; iter++;
} // while (iter != conn.outstandings.packets.end()) { } // while (iter != conn.outstandings.packets.end()) {
// if there are observers, enqueue a function to call it // notify observers
if (observerLossEvent && observerLossEvent->hasPackets()) { if (observerLossEvent && observerLossEvent->hasPackets() &&
for (const auto& observer : *(conn.observers)) { conn.observerContainer &&
conn.pendingCallbacks.emplace_back( conn.observerContainer->hasObserversForEvent<
[observer, observerLossEvent](QuicSocket* qSocket) { SocketObserverInterface::Events::lossEvents>()) {
if (observer->getConfig().lossEvents) { conn.observerContainer
observer->packetLossDetected(qSocket, *observerLossEvent); ->invokeInterfaceMethod<SocketObserverInterface::Events::lossEvents>(
} [observerLossEvent](auto observer, auto observed) {
}); observer->packetLossDetected(observed, *observerLossEvent);
} });
} }
auto earliest = getFirstOutstandingPacket(conn, pnSpace); auto earliest = getFirstOutstandingPacket(conn, pnSpace);

View File

@@ -13,6 +13,7 @@
#include <folly/io/async/test/MockAsyncUDPSocket.h> #include <folly/io/async/test/MockAsyncUDPSocket.h>
#include <folly/io/async/test/MockTimeoutManager.h> #include <folly/io/async/test/MockTimeoutManager.h>
#include <quic/api/QuicTransportFunctions.h> #include <quic/api/QuicTransportFunctions.h>
#include <quic/api/test/MockQuicSocket.h>
#include <quic/api/test/Mocks.h> #include <quic/api/test/Mocks.h>
#include <quic/client/state/ClientStateMachine.h> #include <quic/client/state/ClientStateMachine.h>
#include <quic/codec/DefaultConnectionIdAlgo.h> #include <quic/codec/DefaultConnectionIdAlgo.h>
@@ -84,6 +85,7 @@ class QuicLossFunctionsTest : public TestWithParam<PacketNumberSpace> {
headerCipher = createNoOpHeaderCipher(); headerCipher = createNoOpHeaderCipher();
quicStats_ = std::make_unique<MockQuicStats>(); quicStats_ = std::make_unique<MockQuicStats>();
connIdAlgo_ = std::make_unique<DefaultConnectionIdAlgo>(); connIdAlgo_ = std::make_unique<DefaultConnectionIdAlgo>();
socket_ = std::make_unique<MockQuicSocket>();
} }
PacketNum sendPacket( PacketNum sendPacket(
@@ -122,7 +124,8 @@ class QuicLossFunctionsTest : public TestWithParam<PacketNumberSpace> {
conn->serverConnectionId = *connIdAlgo_->encodeConnectionId(params); conn->serverConnectionId = *connIdAlgo_->encodeConnectionId(params);
// for canSetLossTimerForAppData() // for canSetLossTimerForAppData()
conn->oneRttWriteCipher = createNoOpAead(); conn->oneRttWriteCipher = createNoOpAead();
conn->observers = std::make_shared<ObserverVec>(); conn->observerContainer =
std::make_shared<SocketObserverContainer>(socket_.get());
return conn; return conn;
} }
@@ -160,6 +163,7 @@ class QuicLossFunctionsTest : public TestWithParam<PacketNumberSpace> {
MockLossTimeout timeout; MockLossTimeout timeout;
std::unique_ptr<MockQuicStats> quicStats_; std::unique_ptr<MockQuicStats> quicStats_;
std::unique_ptr<ConnectionIdAlgo> connIdAlgo_; std::unique_ptr<ConnectionIdAlgo> connIdAlgo_;
std::unique_ptr<MockQuicSocket> socket_;
auto getLossPacketMatcher( auto getLossPacketMatcher(
PacketNum packetNum, PacketNum packetNum,
@@ -2019,17 +2023,15 @@ TEST_F(QuicLossFunctionsTest, PersistentCongestionNoPTO) {
folly::none, currentTime + 1s, currentTime + 8s, ack)); folly::none, currentTime + 1s, currentTime + 8s, ack));
} }
TEST_F(QuicLossFunctionsTest, TestReorderLossObserverCallback) { TEST_F(QuicLossFunctionsTest, ObserverLossEventReorder) {
auto observers = std::make_shared<ObserverVec>();
LegacyObserver::Config config = {};
config.lossEvents = true;
auto ib = MockLegacyObserver(config);
auto conn = createConn(); auto conn = createConn();
// Register 1 observer
observers->emplace_back(&ib);
conn->observers = observers;
auto noopLossVisitor = [](auto&, auto&, bool) {};
LegacyObserver::EventSet eventSet;
eventSet.enable(SocketObserverInterface::Events::lossEvents);
auto obs1 = std::make_unique<NiceMock<MockLegacyObserver>>(eventSet);
conn->observerContainer->addObserver(obs1.get());
// send 7 packets
PacketNum largestSent = 0; PacketNum largestSent = 0;
for (int i = 0; i < 7; ++i) { for (int i = 0; i < 7; ++i) {
largestSent = largestSent =
@@ -2050,56 +2052,67 @@ TEST_F(QuicLossFunctionsTest, TestReorderLossObserverCallback) {
conn->transportSettings.timeReorderingThreshDivisor = 1.0; conn->transportSettings.timeReorderingThreshDivisor = 1.0;
TimePoint checkTime = TimePoint(200ms); TimePoint checkTime = TimePoint(200ms);
detectLossPackets(
*conn,
largestSent + 1,
noopLossVisitor,
checkTime,
PacketNumberSpace::AppData);
// expecting 1 callback to be stacked
EXPECT_EQ(1, size(conn->pendingCallbacks));
// Out of 1, 2, 3, 4, 5, 6, 7 -- we deleted (acked) 3,4,5. // Out of 1, 2, 3, 4, 5, 6, 7 -- we deleted (acked) 3,4,5.
// 1, 2 and 6 are "lost" due to reodering. None lost due to timeout // 1, 2 and 6 are "lost" due to reodering. None lost due to timeout
EXPECT_THAT(
conn->outstandings.packets,
UnorderedElementsAre(
getOutstandingPacketMatcher(1, true, false),
getOutstandingPacketMatcher(2, true, false),
getOutstandingPacketMatcher(6, true, false),
getOutstandingPacketMatcher(7, false, false)));
EXPECT_CALL( EXPECT_CALL(
ib, *obs1,
packetLossDetected( packetLossDetected(
nullptr, socket_.get(),
Field( Field(
&SocketObserverInterface::LossEvent::lostPackets, &SocketObserverInterface::LossEvent::lostPackets,
UnorderedElementsAre( UnorderedElementsAre(
getLossPacketMatcher(1, true, false), getLossPacketMatcher(
getLossPacketMatcher(2, true, false), 1 /* packetNum */,
getLossPacketMatcher(6, true, false))))) true /* lossByReorder */,
false /* lossByTimeout */),
getLossPacketMatcher(
2 /* packetNum */,
true /* lossByReorder */,
false /* lossByTimeout */),
getLossPacketMatcher(
6 /* packetNum */,
true /* lossByReorder */,
false /* lossByTimeout */)))))
.Times(1); .Times(1);
detectLossPackets(
*conn,
largestSent + 1,
[](auto&, auto&, bool) {},
checkTime,
PacketNumberSpace::AppData);
EXPECT_THAT(
conn->outstandings.packets,
UnorderedElementsAre(
getOutstandingPacketMatcher(
1 /* packetNum */,
true /* lossByReorder */,
false /* lossByTimeout */),
getOutstandingPacketMatcher(
2 /* packetNum */,
true /* lossByReorder */,
false /* lossByTimeout */),
getOutstandingPacketMatcher(
6 /* packetNum */,
true /* lossByReorder */,
false /* lossByTimeout */),
getOutstandingPacketMatcher(
7 /* packetNum */,
false /* lossByReorder */,
false /* lossByTimeout */)));
for (auto& callback : conn->pendingCallbacks) { conn->observerContainer->removeObserver(obs1.get());
callback(nullptr);
}
} }
TEST_F(QuicLossFunctionsTest, TestTimeoutLossObserverCallback) { TEST_F(QuicLossFunctionsTest, ObserverLossEventTimeout) {
auto observers = std::make_shared<ObserverVec>();
LegacyObserver::Config config = {};
config.lossEvents = true;
auto ib = MockLegacyObserver(config);
auto conn = createConn(); auto conn = createConn();
// Register 1 observer
observers->emplace_back(&ib);
conn->observers = observers;
auto noopLossVisitor = [](auto&, auto&, bool) {};
PacketNum largestSent = 0; LegacyObserver::EventSet eventSet;
eventSet.enable(SocketObserverInterface::Events::lossEvents);
auto obs1 = std::make_unique<NiceMock<MockLegacyObserver>>(eventSet);
conn->observerContainer->addObserver(obs1.get());
// send 7 packets // send 7 packets
PacketNum largestSent = 0;
for (int i = 0; i < 7; ++i) { for (int i = 0; i < 7; ++i) {
largestSent = largestSent =
sendPacket(*conn, TimePoint(i * 10ms), folly::none, PacketType::OneRtt); sendPacket(*conn, TimePoint(i * 10ms), folly::none, PacketType::OneRtt);
@@ -2115,60 +2128,93 @@ TEST_F(QuicLossFunctionsTest, TestTimeoutLossObserverCallback) {
conn->transportSettings.timeReorderingThreshDivisor = 1.0; conn->transportSettings.timeReorderingThreshDivisor = 1.0;
TimePoint checkTime = TimePoint(500ms); TimePoint checkTime = TimePoint(500ms);
detectLossPackets( // expect all packets to be lost due to timeout
*conn,
largestSent + 1,
noopLossVisitor,
checkTime,
PacketNumberSpace::AppData);
// expecting 1 callback to be stacked
EXPECT_EQ(1, size(conn->pendingCallbacks));
// expecting all packets to be lost due to timeout
EXPECT_THAT(
conn->outstandings.packets,
UnorderedElementsAre(
getOutstandingPacketMatcher(1, false, true),
getOutstandingPacketMatcher(2, false, true),
getOutstandingPacketMatcher(3, false, true),
getOutstandingPacketMatcher(4, false, true),
getOutstandingPacketMatcher(5, false, true),
getOutstandingPacketMatcher(6, false, true),
getOutstandingPacketMatcher(7, false, true)));
EXPECT_CALL( EXPECT_CALL(
ib, *obs1,
packetLossDetected( packetLossDetected(
nullptr, socket_.get(),
Field( Field(
&SocketObserverInterface::LossEvent::lostPackets, &SocketObserverInterface::LossEvent::lostPackets,
UnorderedElementsAre( UnorderedElementsAre(
getLossPacketMatcher(1, false, true), getLossPacketMatcher(
getLossPacketMatcher(2, false, true), 1 /* packetNum */,
getLossPacketMatcher(3, false, true), false /* lossByReorder */,
getLossPacketMatcher(4, false, true), true /* lossByTimeout */),
getLossPacketMatcher(5, false, true), getLossPacketMatcher(
getLossPacketMatcher(6, false, true), 2 /* packetNum */,
getLossPacketMatcher(7, false, true))))) false /* lossByReorder */,
true /* lossByTimeout */),
getLossPacketMatcher(
3 /* packetNum */,
false /* lossByReorder */,
true /* lossByTimeout */),
getLossPacketMatcher(
4 /* packetNum */,
false /* lossByReorder */,
true /* lossByTimeout */),
getLossPacketMatcher(
5 /* packetNum */,
false /* lossByReorder */,
true /* lossByTimeout */),
getLossPacketMatcher(
6 /* packetNum */,
false /* lossByReorder */,
true /* lossByTimeout */),
getLossPacketMatcher(
7 /* packetNum */,
false /* lossByReorder */,
true /* lossByTimeout */)))))
.Times(1); .Times(1);
detectLossPackets(
*conn,
largestSent + 1,
[](auto&, auto&, bool) {},
checkTime,
PacketNumberSpace::AppData);
EXPECT_THAT(
conn->outstandings.packets,
UnorderedElementsAre(
getOutstandingPacketMatcher(
1 /* packetNum */,
false /* lossByReorder */,
true /* lossByTimeout */),
getOutstandingPacketMatcher(
2 /* packetNum */,
false /* lossByReorder */,
true /* lossByTimeout */),
getOutstandingPacketMatcher(
3 /* packetNum */,
false /* lossByReorder */,
true /* lossByTimeout */),
getOutstandingPacketMatcher(
4 /* packetNum */,
false /* lossByReorder */,
true /* lossByTimeout */),
getOutstandingPacketMatcher(
5 /* packetNum */,
false /* lossByReorder */,
true /* lossByTimeout */),
getOutstandingPacketMatcher(
6 /* packetNum */,
false /* lossByReorder */,
true /* lossByTimeout */),
getOutstandingPacketMatcher(
7 /* packetNum */,
false /* lossByReorder */,
true /* lossByTimeout */)));
for (auto& callback : conn->pendingCallbacks) { conn->observerContainer->removeObserver(obs1.get());
callback(nullptr);
}
} }
TEST_F(QuicLossFunctionsTest, TestTimeoutAndReorderLossObserverCallback) { TEST_F(QuicLossFunctionsTest, ObserverLossEventTimeoutAndReorder) {
auto observers = std::make_shared<ObserverVec>();
LegacyObserver::Config config = {};
config.lossEvents = true;
auto ib = MockLegacyObserver(config);
auto conn = createConn(); auto conn = createConn();
// Register 1 observer
observers->emplace_back(&ib);
conn->observers = observers;
auto noopLossVisitor = [](auto&, auto&, bool) {};
LegacyObserver::EventSet eventSet;
eventSet.enable(SocketObserverInterface::Events::lossEvents);
auto obs1 = std::make_unique<NiceMock<MockLegacyObserver>>(eventSet);
conn->observerContainer->addObserver(obs1.get());
// send 7 packets
PacketNum largestSent = 0; PacketNum largestSent = 0;
for (int i = 0; i < 7; ++i) { for (int i = 0; i < 7; ++i) {
largestSent = largestSent =
@@ -2191,79 +2237,60 @@ TEST_F(QuicLossFunctionsTest, TestTimeoutAndReorderLossObserverCallback) {
conn->transportSettings.timeReorderingThreshDivisor = 1.0; conn->transportSettings.timeReorderingThreshDivisor = 1.0;
TimePoint checkTime = TimePoint(500ms); TimePoint checkTime = TimePoint(500ms);
detectLossPackets(
*conn,
largestSent + 1,
noopLossVisitor,
checkTime,
PacketNumberSpace::AppData);
// expecting 1 callback to be stacked
EXPECT_EQ(1, size(conn->pendingCallbacks));
// Out of 1, 2, 3, 4, 5, 6, 7 -- we deleted (acked) 3,4,5. // Out of 1, 2, 3, 4, 5, 6, 7 -- we deleted (acked) 3,4,5.
// 1, 2, 6 are lost due to reodering and timeout. // 1, 2, 6 are lost due to reodering and timeout.
// 7 just timed out // 7 just timed out
EXPECT_THAT(
conn->outstandings.packets,
UnorderedElementsAre(
getOutstandingPacketMatcher(1, true, true),
getOutstandingPacketMatcher(2, true, true),
getOutstandingPacketMatcher(6, true, true),
getOutstandingPacketMatcher(7, false, true)));
EXPECT_CALL( EXPECT_CALL(
ib, *obs1,
packetLossDetected( packetLossDetected(
nullptr, socket_.get(),
Field( Field(
&SocketObserverInterface::LossEvent::lostPackets, &SocketObserverInterface::LossEvent::lostPackets,
UnorderedElementsAre( UnorderedElementsAre(
getLossPacketMatcher(1, true, true), getLossPacketMatcher(
getLossPacketMatcher(2, true, true), 1 /* packetNum */,
getLossPacketMatcher(6, true, true), true /* lossByReorder */,
getLossPacketMatcher(7, false, true))))) true /* lossByTimeout */),
getLossPacketMatcher(
2 /* packetNum */,
true /* lossByReorder */,
true /* lossByTimeout */),
getLossPacketMatcher(
6 /* packetNum */,
true /* lossByReorder */,
true /* lossByTimeout */),
getLossPacketMatcher(
7 /* packetNum */,
false /* lossByReorder */,
true /* lossByTimeout */)))))
.Times(1); .Times(1);
for (auto& callback : conn->pendingCallbacks) {
callback(nullptr);
}
}
TEST_F(QuicLossFunctionsTest, TestNoObserverCallback) {
auto conn = createConn();
auto noopLossVisitor = [](auto&, auto&, bool) {};
PacketNum largestSent = 0;
for (int i = 0; i < 7; ++i) {
largestSent =
sendPacket(*conn, TimePoint(i * 10ms), folly::none, PacketType::OneRtt);
}
// Some packets are already acked
conn->outstandings.packets.erase(
getFirstOutstandingPacket(*conn, PacketNumberSpace::AppData) + 2,
getFirstOutstandingPacket(*conn, PacketNumberSpace::AppData) + 5);
// setting a low reorder threshold
conn->lossState.reorderingThreshold = 1;
// setting time out parameters lower than the time at which detectLossPackets
// is called to make sure all packets timeout
conn->lossState.srtt = 400ms;
conn->lossState.lrtt = 350ms;
conn->transportSettings.timeReorderingThreshDividend = 1.0;
conn->transportSettings.timeReorderingThreshDivisor = 1.0;
TimePoint checkTime = TimePoint(500ms);
detectLossPackets( detectLossPackets(
*conn, *conn,
largestSent + 1, largestSent + 1,
noopLossVisitor, [](auto&, auto&, bool) {},
checkTime, checkTime,
PacketNumberSpace::AppData); PacketNumberSpace::AppData);
EXPECT_THAT(
conn->outstandings.packets,
UnorderedElementsAre(
getOutstandingPacketMatcher(
1 /* packetNum */,
true /* lossByReorder */,
true /* lossByTimeout */),
getOutstandingPacketMatcher(
2 /* packetNum */,
true /* lossByReorder */,
true /* lossByTimeout */),
getOutstandingPacketMatcher(
6 /* packetNum */,
true /* lossByReorder */,
true /* lossByTimeout */),
getOutstandingPacketMatcher(
7 /* packetNum */,
false /* lossByReorder */,
true /* lossByTimeout */)));
// expecting 0 callbacks to be queued conn->observerContainer->removeObserver(obs1.get());
EXPECT_EQ(0, size(conn->pendingCallbacks));
} }
TEST_F(QuicLossFunctionsTest, TotalPacketsMarkedLostByReordering) { TEST_F(QuicLossFunctionsTest, TotalPacketsMarkedLostByReordering) {

View File

@@ -14,14 +14,79 @@ namespace quic {
class QuicSocket; class QuicSocket;
using SocketObserverContainerBaseT = folly::ObserverContainer< using SocketObserverContainerBaseT = folly::ObserverContainer<
SocketObserverInterfaceTransitional, SocketObserverInterface,
QuicSocket, QuicSocket,
folly::ObserverContainerBasePolicyDefault< folly::ObserverContainerBasePolicyDefault<
SocketObserverInterfaceTransitional::Events /* EventEnum */, SocketObserverInterface::Events /* EventEnum */,
32 /* BitsetSize (max number of interface events) */>>; 32 /* BitsetSize (max number of interface events) */>>;
class SocketObserverContainer : public SocketObserverContainerBaseT { class SocketObserverContainer : public SocketObserverContainerBaseT {
public:
using SocketObserverContainerBaseT::SocketObserverContainerBaseT; using SocketObserverContainerBaseT::SocketObserverContainerBaseT;
/**
* Legacy observer for use during transition to folly::ObserverList.
*/
class LegacyObserver : public Observer {
public:
using EventSet = typename Observer::EventSet;
using EventSetBuilder = typename Observer::EventSetBuilder;
using Observer::Observer;
~LegacyObserver() override = default;
/**
* observerAttach() will be invoked when an observer is added.
*
* @param socket Socket where observer was installed.
*/
virtual void observerAttach(QuicSocket* /* socket */) noexcept {}
/**
* observerDetach() will be invoked if the observer is uninstalled prior
* to socket destruction.
*
* No further callbacks will be invoked after observerDetach().
*
* @param socket Socket where observer was uninstalled.
*/
virtual void observerDetach(QuicSocket* /* socket */) noexcept {}
/**
* destroy() will be invoked when the QuicSocket's destructor is invoked.
*
* No further callbacks will be invoked after destroy().
*
* @param socket Socket being destroyed.
*/
virtual void destroy(QuicSocket* /* socket */) noexcept {}
private:
void attached(QuicSocket* obj) noexcept override {
observerAttach(obj);
}
void detached(QuicSocket* obj) noexcept override {
observerDetach(obj);
}
void destroyed(QuicSocket* obj, DestroyContext* /* ctx */) noexcept
override {
destroy(obj);
}
void addedToObserverContainer(
ObserverContainerBase* list) noexcept override {
CHECK(list->getObject());
}
void removedFromObserverContainer(
ObserverContainerBase* list) noexcept override {
CHECK(list->getObject());
}
void movedToObserverContainer(
ObserverContainerBase* oldList,
ObserverContainerBase* newList) noexcept override {
CHECK(oldList->getObject());
CHECK(newList->getObject());
}
};
}; };
} // namespace quic } // namespace quic

View File

@@ -22,39 +22,25 @@ class EventBase;
namespace quic { namespace quic {
class QuicSocket; class QuicSocket;
class QuicTransportBase;
/**
* A temporary interface for use during transition to folly::ObserverContainer.
*
* At the moment, the new observer only supports events in this interface.
*/
class SocketObserverInterfaceTransitional {
public:
enum class Events {};
SocketObserverInterfaceTransitional() = default;
virtual ~SocketObserverInterfaceTransitional() = default;
/**
* close() will be invoked when the socket is being closed.
*
* If the callback handler does not unsubscribe itself upon being called,
* then it may be called multiple times (e.g., by a call to close() by
* the application, and then again when closeNow() is called on
* destruction).
*
* @param socket Socket being closed.
* @param errorOpt Error information, if connection closed due to error.
*/
virtual void close(
QuicSocket* /* socket */,
const folly::Optional<QuicError>& /* errorOpt */) noexcept {}
};
/** /**
* Observer of socket events. * Observer of socket events.
*/ */
class SocketObserverInterface : public SocketObserverInterfaceTransitional { class SocketObserverInterface {
public: public:
enum class Events {
evbEvents = 1,
packetsWrittenEvents = 2,
appRateLimitedEvents = 3,
rttSamples = 4,
lossEvents = 5,
spuriousLossEvents = 6,
pmtuEvents = 7,
knobFrameEvents = 8,
streamEvents = 9,
acksProcessedEvents = 10,
};
virtual ~SocketObserverInterface() = default; virtual ~SocketObserverInterface() = default;
struct WriteEvent { struct WriteEvent {
@@ -109,13 +95,17 @@ class SocketObserverInterface : public SocketObserverInterfaceTransitional {
}; };
// Do not support copy or move given that outstanding packets is a ref. // Do not support copy or move given that outstanding packets is a ref.
WriteEvent(const WriteEvent&) = delete;
WriteEvent(WriteEvent&&) = delete; WriteEvent(WriteEvent&&) = delete;
WriteEvent& operator=(const WriteEvent&) = delete; WriteEvent& operator=(const WriteEvent&) = delete;
WriteEvent& operator=(WriteEvent&& rhs) = delete; WriteEvent& operator=(WriteEvent&& rhs) = delete;
// Use builder to construct. // Use builder to construct.
explicit WriteEvent(const BuilderFields& builderFields); explicit WriteEvent(const BuilderFields& builderFields);
protected:
// Allow QuicTransportBase to use the copy constructor for enqueuing
friend class QuicTransportBase;
WriteEvent(const WriteEvent&) = default;
}; };
struct AppLimitedEvent : public WriteEvent { struct AppLimitedEvent : public WriteEvent {
@@ -351,6 +341,21 @@ class SocketObserverInterface : public SocketObserverInterfaceTransitional {
using StreamOpenEvent = StreamEvent; using StreamOpenEvent = StreamEvent;
using StreamCloseEvent = StreamEvent; using StreamCloseEvent = StreamEvent;
/**
* close() will be invoked when the socket is being closed.
*
* If the callback handler does not unsubscribe itself upon being called,
* then it may be called multiple times (e.g., by a call to close() by
* the application, and then again when closeNow() is called on
* destruction).
*
* @param socket Socket being closed.
* @param errorOpt Error information, if connection closed due to error.
*/
virtual void close(
QuicSocket* /* socket */,
const folly::Optional<QuicError>& /* errorOpt */) noexcept {}
/** /**
* evbAttach() will be invoked when a new event base is attached to this * evbAttach() will be invoked when a new event base is attached to this
* socket. This will be called from the new event base's thread. * socket. This will be called from the new event base's thread.

View File

@@ -60,7 +60,7 @@ QuicServerTransport::QuicServerTransport(
tempConn->serverAddr = socket_->address(); tempConn->serverAddr = socket_->address();
serverConn_ = tempConn.get(); serverConn_ = tempConn.get();
conn_.reset(tempConn.release()); conn_.reset(tempConn.release());
conn_->observers = observers_; conn_->observerContainer = observerContainer_;
setConnectionSetupCallback(connSetupCb); setConnectionSetupCallback(connSetupCb);
setConnectionCallback(connStreamsCb); setConnectionCallback(connStreamsCb);
@@ -626,10 +626,16 @@ void QuicServerTransport::maybeStartD6DProbing() {
// valuable // valuable
conn_->pendingEvents.d6d.sendProbeDelay = kDefaultD6DKickStartDelay; conn_->pendingEvents.d6d.sendProbeDelay = kDefaultD6DKickStartDelay;
QUIC_STATS(conn_->statsCallback, onConnectionD6DStarted); QUIC_STATS(conn_->statsCallback, onConnectionD6DStarted);
for (const auto& cb : *(conn_->observers)) {
if (cb->getConfig().pmtuEvents) { if (getSocketObserverContainer() &&
cb->pmtuProbingStarted(this); getSocketObserverContainer()
} ->hasObserversForEvent<
SocketObserverInterface::Events::pmtuEvents>()) {
getSocketObserverContainer()
->invokeInterfaceMethod<SocketObserverInterface::Events::pmtuEvents>(
[](auto observer, auto observed) {
observer->pmtuProbingStarted(observed);
});
} }
} }
} }

View File

@@ -4189,20 +4189,31 @@ TEST_P(
} }
TEST_P(QuicServerTransportHandshakeTest, TestD6DStartCallback) { TEST_P(QuicServerTransportHandshakeTest, TestD6DStartCallback) {
LegacyObserver::Config config = {}; LegacyObserver::EventSet eventSet;
config.pmtuEvents = true; eventSet.enable(SocketObserverInterface::Events::pmtuEvents);
auto observer = std::make_unique<MockLegacyObserver>(config);
server->addObserver(observer.get()); auto obs1 = std::make_unique<MockLegacyObserver>();
auto obs2 = std::make_unique<MockLegacyObserver>(eventSet);
auto obs3 = std::make_unique<MockLegacyObserver>(eventSet);
server->addObserver(obs1.get());
server->addObserver(obs2.get());
server->addObserver(obs3.get());
// Set oneRttReader so that maybeStartD6DPriobing passes its check // Set oneRttReader so that maybeStartD6DPriobing passes its check
auto codec = std::make_unique<QuicReadCodec>(QuicNodeType::Server); auto codec = std::make_unique<QuicReadCodec>(QuicNodeType::Server);
codec->setOneRttReadCipher(createNoOpAead()); codec->setOneRttReadCipher(createNoOpAead());
server->getNonConstConn().readCodec = std::move(codec); server->getNonConstConn().readCodec = std::move(codec);
// And the state too // And the state too
server->getNonConstConn().d6d.state = D6DMachineState::BASE; server->getNonConstConn().d6d.state = D6DMachineState::BASE;
EXPECT_CALL(*observer, pmtuProbingStarted(_)).Times(1); EXPECT_CALL(*obs1, pmtuProbingStarted(_)).Times(0); // not enabled
EXPECT_CALL(*obs2, pmtuProbingStarted(_)).Times(1);
EXPECT_CALL(*obs3, pmtuProbingStarted(_)).Times(1);
// CHLO should be enough to trigger probing // CHLO should be enough to trigger probing
recvClientHello(); recvClientHello();
server->removeObserver(observer.get());
server->removeObserver(obs1.get());
server->removeObserver(obs2.get());
server->removeObserver(obs3.get());
} }
TEST_F(QuicUnencryptedServerTransportTest, DuplicateOneRttWriteCipher) { TEST_F(QuicUnencryptedServerTransportTest, DuplicateOneRttWriteCipher) {

View File

@@ -76,7 +76,9 @@ AckEvent processAckFrame(
folly::Optional<LegacyObserver::SpuriousLossEvent> spuriousLossEvent; folly::Optional<LegacyObserver::SpuriousLossEvent> spuriousLossEvent;
// Used for debug only. // Used for debug only.
const auto originalPacketCount = conn.outstandings.packetCount; const auto originalPacketCount = conn.outstandings.packetCount;
if (conn.observers->size() > 0) { if (conn.observerContainer &&
conn.observerContainer->hasObserversForEvent<
SocketObserverInterface::Events::spuriousLossEvents>()) {
spuriousLossEvent.emplace(ackReceiveTime); spuriousLossEvent.emplace(ackReceiveTime);
} }
auto ackBlockIt = frame.ackBlocks.cbegin(); auto ackBlockIt = frame.ackBlocks.cbegin();
@@ -198,14 +200,17 @@ AckEvent processAckFrame(
ackReceiveTimeOrNow - rPacketIt->metadata.time); ackReceiveTimeOrNow - rPacketIt->metadata.time);
if (rttSample != rttSample.zero()) { if (rttSample != rttSample.zero()) {
// notify observers // notify observers
const SocketObserverInterface::PacketRTT packetRTT( if (conn.observerContainer &&
ackReceiveTimeOrNow, rttSample, frame.ackDelay, *rPacketIt); conn.observerContainer->hasObserversForEvent<
for (const auto& observer : *(conn.observers)) { SocketObserverInterface::Events::rttSamples>()) {
conn.pendingCallbacks.emplace_back( conn.observerContainer->invokeInterfaceMethod<
[observer, packetRTT](QuicSocket* qSocket) { SocketObserverInterface::Events::rttSamples>(
if (observer->getConfig().rttSamples) { [event = SocketObserverInterface::PacketRTT(
observer->rttSampleGenerated(qSocket, packetRTT); ackReceiveTimeOrNow,
} rttSample,
frame.ackDelay,
*rPacketIt)](auto observer, auto observed) {
observer->rttSampleGenerated(observed, event);
}); });
} }
@@ -461,16 +466,17 @@ AckEvent processAckFrame(
ack.ccState = conn.congestionController->getState(); ack.ccState = conn.congestionController->getState();
} }
clearOldOutstandingPackets(conn, ackReceiveTime, pnSpace); clearOldOutstandingPackets(conn, ackReceiveTime, pnSpace);
if (spuriousLossEvent && spuriousLossEvent->hasPackets()) {
for (const auto& observer : *(conn.observers)) { if (spuriousLossEvent && conn.observerContainer &&
conn.pendingCallbacks.emplace_back( conn.observerContainer->hasObserversForEvent<
[observer, spuriousLossEvent](QuicSocket* qSocket) { SocketObserverInterface::Events::spuriousLossEvents>()) {
if (observer->getConfig().spuriousLossEvents) { conn.observerContainer->invokeInterfaceMethod<
observer->spuriousLossDetected(qSocket, *spuriousLossEvent); SocketObserverInterface::Events::spuriousLossEvents>(
} [spuriousLossEvent](auto observer, auto observed) {
}); observer->spuriousLossDetected(observed, *spuriousLossEvent);
} });
} }
return ack; return ack;
} }

View File

@@ -307,9 +307,7 @@ struct ReadDatagram {
struct QuicConnectionStateBase : public folly::DelayedDestruction { struct QuicConnectionStateBase : public folly::DelayedDestruction {
virtual ~QuicConnectionStateBase() override = default; virtual ~QuicConnectionStateBase() override = default;
explicit QuicConnectionStateBase(QuicNodeType type) : nodeType(type) { explicit QuicConnectionStateBase(QuicNodeType type) : nodeType(type) {}
observers = std::make_shared<ObserverVec>();
}
// Accessor to output buffer for continuous memory GSO writes // Accessor to output buffer for continuous memory GSO writes
BufAccessor* bufAccessor{nullptr}; BufAccessor* bufAccessor{nullptr};
@@ -666,11 +664,15 @@ struct QuicConnectionStateBase : public folly::DelayedDestruction {
*/ */
bool retireAndSwitchPeerConnectionIds(); bool retireAndSwitchPeerConnectionIds();
// queue of functions to be called in processCallbacksAfterNetworkData // SocketObserverContainer
std::vector<std::function<void(QuicSocket*)>> pendingCallbacks; std::shared_ptr<SocketObserverContainer> observerContainer;
// Vector of Observers that are attached to this socket. /**
std::shared_ptr<const ObserverVec> observers; * Returns the SocketObserverContainer or nullptr if not available.
*/
SocketObserverContainer* getSocketObserverContainer() const {
return observerContainer.get();
}
// Recent ACK events, for use in processCallbacksAfterNetworkData. // Recent ACK events, for use in processCallbacksAfterNetworkData.
// Holds the ACK events generated during the last round of ACK processing. // Holds the ACK events generated during the last round of ACK processing.

View File

@@ -9,6 +9,7 @@
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <quic/QuicConstants.h> #include <quic/QuicConstants.h>
#include <quic/api/test/MockQuicSocket.h>
#include <quic/api/test/Mocks.h> #include <quic/api/test/Mocks.h>
#include <quic/common/test/TestUtils.h> #include <quic/common/test/TestUtils.h>
#include <quic/fizz/server/handshake/FizzServerQuicHandshakeContext.h> #include <quic/fizz/server/handshake/FizzServerQuicHandshakeContext.h>
@@ -2315,21 +2316,23 @@ TEST_P(AckHandlersTest, ImplictAckEventCreation) {
ackTime); ackTime);
} }
TEST_P(AckHandlersTest, TestRTTPacketObserverCallback) { TEST_P(AckHandlersTest, ObserverRttSample) {
auto socket = std::make_shared<MockQuicSocket>();
const auto observerContainer =
std::make_shared<SocketObserverContainer>(socket.get());
QuicServerConnectionState conn( QuicServerConnectionState conn(
FizzServerQuicHandshakeContext::Builder().build()); FizzServerQuicHandshakeContext::Builder().build());
auto mockCongestionController = std::make_unique<MockCongestionController>(); auto mockCongestionController = std::make_unique<MockCongestionController>();
conn.congestionController = std::move(mockCongestionController); conn.congestionController = std::move(mockCongestionController);
conn.observerContainer = observerContainer;
// Register 1 observer LegacyObserver::EventSet eventSet;
LegacyObserver::Config config = {}; eventSet.enable(SocketObserverInterface::Events::rttSamples);
config.rttSamples = true; auto obs1 = std::make_unique<NiceMock<MockLegacyObserver>>(eventSet);
auto ib = MockLegacyObserver(config); observerContainer->addObserver(obs1.get());
auto observers = std::make_shared<ObserverVec>();
observers->emplace_back(&ib);
conn.observers = observers;
// send packet numbers 0 -> 29
PacketNum packetNum = 0; PacketNum packetNum = 0;
StreamId streamid = 0; StreamId streamid = 0;
TimePoint sentTime; TimePoint sentTime;
@@ -2360,13 +2363,13 @@ TEST_P(AckHandlersTest, TestRTTPacketObserverCallback) {
packetNum++; packetNum++;
} }
struct ackPacketData { struct AckFrameWithTestData {
PacketNum startSeq, endSeq; PacketNum startSeq, endSeq;
std::chrono::milliseconds ackDelay; std::chrono::milliseconds ackDelay;
TimePoint ackTime; TimePoint ackTime;
ReadAckFrame ackFrame; ReadAckFrame ackFrame;
explicit ackPacketData( explicit AckFrameWithTestData(
PacketNum startSeqIn, PacketNum startSeqIn,
PacketNum endSeqIn, PacketNum endSeqIn,
std::chrono::milliseconds ackDelayIn) std::chrono::milliseconds ackDelayIn)
@@ -2381,44 +2384,24 @@ TEST_P(AckHandlersTest, TestRTTPacketObserverCallback) {
}; };
// See each emplace as the ACK Block [X, Y] with size (Y-X+1) // See each emplace as the ACK Block [X, Y] with size (Y-X+1)
std::vector<ackPacketData> ackVec; std::vector<AckFrameWithTestData> ackVec;
// Sequential test // Sequential test
ackVec.emplace_back(0, 5, 4ms); // +1 callback ackVec.emplace_back(0, 5, 4ms);
ackVec.emplace_back(6, 10, 5ms); // +1 ackVec.emplace_back(6, 10, 5ms);
ackVec.emplace_back(11, 15, 6ms); // +1 ackVec.emplace_back(11, 15, 6ms);
// Out-of-order test // Out-of-order test
// ackVec.emplace_back(18, 18, 0ms);
// Its important to check the if ackVec.emplace_back(16, 17, 2ms);
// largestAcked - currentPacketNum > reorderingThreshold (currently 3) ackVec.emplace_back(19, 29, 12ms);
// else it can trigger Observer::packetLossDetected
// and increase the number of callbacks
ackVec.emplace_back(18, 18, 0ms); // +1
ackVec.emplace_back(16, 17, 2ms); // +1
ackVec.emplace_back(19, 29, 12ms); // +1 = 6 callbacks
// 0 pending callbacks
EXPECT_EQ(0, size(conn.pendingCallbacks));
for (const auto& ackData : ackVec) {
processAckFrame(
conn,
GetParam(),
ackData.ackFrame,
[](const auto&, const auto&, const auto&) {},
[](auto&, auto&, bool) {},
ackData.ackTime);
}
// see above
EXPECT_EQ(6, size(conn.pendingCallbacks));
// Setup expectations, then process the actual ACKs
for (const auto& ackData : ackVec) { for (const auto& ackData : ackVec) {
auto rttSample = std::chrono::ceil<std::chrono::microseconds>( auto rttSample = std::chrono::ceil<std::chrono::microseconds>(
ackData.ackTime - packetRcvTime[ackData.endSeq]); ackData.ackTime - packetRcvTime[ackData.endSeq]);
EXPECT_CALL( EXPECT_CALL(
ib, *obs1,
rttSampleGenerated( rttSampleGenerated(
nullptr, socket.get(),
AllOf( AllOf(
Field( Field(
&SocketObserverInterface::PacketRTT::rcvTime, &SocketObserverInterface::PacketRTT::rcvTime,
@@ -2434,29 +2417,38 @@ TEST_P(AckHandlersTest, TestRTTPacketObserverCallback) {
&quic::OutstandingPacketMetadata::inflightBytes, &quic::OutstandingPacketMetadata::inflightBytes,
ackData.endSeq + 1))))); ackData.endSeq + 1)))));
} }
for (const auto& ackData : ackVec) {
for (auto& callback : conn.pendingCallbacks) { processAckFrame(
callback(nullptr); conn,
GetParam(),
ackData.ackFrame,
[](const auto&, const auto&, const auto&) {},
[](auto&, auto&, bool) {},
ackData.ackTime);
} }
observerContainer->removeObserver(obs1.get());
} }
TEST_P(AckHandlersTest, TestSpuriousObserverReorder) { TEST_P(AckHandlersTest, ObserverSpuriousLostEventReorderThreshold) {
auto socket = std::make_shared<MockQuicSocket>();
const auto observerContainer =
std::make_shared<SocketObserverContainer>(socket.get());
QuicServerConnectionState conn( QuicServerConnectionState conn(
FizzServerQuicHandshakeContext::Builder().build()); FizzServerQuicHandshakeContext::Builder().build());
auto mockCongestionController = std::make_unique<MockCongestionController>(); auto mockCongestionController = std::make_unique<MockCongestionController>();
conn.congestionController = std::move(mockCongestionController); conn.congestionController = std::move(mockCongestionController);
conn.observerContainer = observerContainer;
// Register 1 observer LegacyObserver::EventSet eventSet;
LegacyObserver::Config config = {}; eventSet.enable(
config.spuriousLossEvents = true; SocketObserverInterface::Events::lossEvents,
config.lossEvents = true; SocketObserverInterface::Events::spuriousLossEvents);
auto ib = MockLegacyObserver(config); auto obs1 = std::make_unique<NiceMock<MockLegacyObserver>>(eventSet);
observerContainer->addObserver(obs1.get());
auto observers = std::make_shared<ObserverVec>();
observers->emplace_back(&ib);
conn.observers = observers;
auto noopLossVisitor = [](auto&, auto&, bool) {};
// send 10 packets
TimePoint startTime = Clock::now(); TimePoint startTime = Clock::now();
emplacePackets(conn, 10, startTime, GetParam()); emplacePackets(conn, 10, startTime, GetParam());
@@ -2475,15 +2467,11 @@ TEST_P(AckHandlersTest, TestSpuriousObserverReorder) {
conn.transportSettings.timeReorderingThreshDivisor = 1.0; conn.transportSettings.timeReorderingThreshDivisor = 1.0;
TimePoint checkTime = startTime + 20ms; TimePoint checkTime = startTime + 20ms;
detectLossPackets(conn, 4, noopLossVisitor, checkTime, GetParam()); // expect packets to be marked lost on call to detectLostPackets
// expecting 1 callback to be stacked
EXPECT_EQ(1, size(conn.pendingCallbacks));
EXPECT_CALL( EXPECT_CALL(
ib, *obs1,
packetLossDetected( packetLossDetected(
nullptr, socket.get(),
Field( Field(
&SocketObserverInterface::LossEvent::lostPackets, &SocketObserverInterface::LossEvent::lostPackets,
UnorderedElementsAre( UnorderedElementsAre(
@@ -2491,8 +2479,21 @@ TEST_P(AckHandlersTest, TestSpuriousObserverReorder) {
MockLegacyObserver::getLossPacketMatcher(1, true, false), MockLegacyObserver::getLossPacketMatcher(1, true, false),
MockLegacyObserver::getLossPacketMatcher(2, true, false))))) MockLegacyObserver::getLossPacketMatcher(2, true, false)))))
.Times(1); .Times(1);
detectLossPackets(
conn, 4, [](auto&, auto&, bool) {}, checkTime, GetParam());
// Here we receive the spurious loss packets in a late ack // now we get acks for packets marked lost, triggering spuriousLossDetected
EXPECT_CALL(
*obs1,
spuriousLossDetected(
socket.get(),
Field(
&SocketObserverInterface::SpuriousLossEvent::spuriousPackets,
UnorderedElementsAre(
MockLegacyObserver::getLossPacketMatcher(0, true, false),
MockLegacyObserver::getLossPacketMatcher(1, true, false),
MockLegacyObserver::getLossPacketMatcher(2, true, false)))))
.Times(1);
{ {
ReadAckFrame ackFrame; ReadAckFrame ackFrame;
ackFrame.largestAcked = 2; ackFrame.largestAcked = 2;
@@ -2507,43 +2508,28 @@ TEST_P(AckHandlersTest, TestSpuriousObserverReorder) {
startTime + 30ms); startTime + 30ms);
} }
// Spurious loss observer call added observerContainer->removeObserver(obs1.get());
EXPECT_EQ(2, size(conn.pendingCallbacks));
EXPECT_CALL(
ib,
spuriousLossDetected(
nullptr,
Field(
&SocketObserverInterface::SpuriousLossEvent::spuriousPackets,
UnorderedElementsAre(
MockLegacyObserver::getLossPacketMatcher(0, true, false),
MockLegacyObserver::getLossPacketMatcher(1, true, false),
MockLegacyObserver::getLossPacketMatcher(2, true, false)))))
.Times(1);
for (auto& callback : conn.pendingCallbacks) {
callback(nullptr);
}
} }
TEST_P(AckHandlersTest, TestSpuriousObserverTimeout) { TEST_P(AckHandlersTest, ObserverSpuriousLostEventTimeout) {
auto socket = std::make_shared<MockQuicSocket>();
const auto observerContainer =
std::make_shared<SocketObserverContainer>(socket.get());
QuicServerConnectionState conn( QuicServerConnectionState conn(
FizzServerQuicHandshakeContext::Builder().build()); FizzServerQuicHandshakeContext::Builder().build());
auto mockCongestionController = std::make_unique<MockCongestionController>(); auto mockCongestionController = std::make_unique<MockCongestionController>();
conn.congestionController = std::move(mockCongestionController); conn.congestionController = std::move(mockCongestionController);
conn.observerContainer = observerContainer;
// Register 1 observer LegacyObserver::EventSet eventSet;
LegacyObserver::Config config = {}; eventSet.enable(
config.spuriousLossEvents = true; SocketObserverInterface::Events::lossEvents,
config.lossEvents = true; SocketObserverInterface::Events::spuriousLossEvents);
auto ib = MockLegacyObserver(config); auto obs1 = std::make_unique<NiceMock<MockLegacyObserver>>(eventSet);
observerContainer->addObserver(obs1.get());
auto observers = std::make_shared<ObserverVec>();
observers->emplace_back(&ib);
conn.observers = observers;
auto noopLossVisitor = [](auto&, auto&, bool) {};
// send 10 packets
TimePoint startTime = Clock::now(); TimePoint startTime = Clock::now();
emplacePackets(conn, 10, startTime, GetParam()); emplacePackets(conn, 10, startTime, GetParam());
@@ -2562,15 +2548,11 @@ TEST_P(AckHandlersTest, TestSpuriousObserverTimeout) {
conn.transportSettings.timeReorderingThreshDivisor = 1.0; conn.transportSettings.timeReorderingThreshDivisor = 1.0;
TimePoint checkTime = startTime + 500ms; TimePoint checkTime = startTime + 500ms;
detectLossPackets(conn, 10, noopLossVisitor, checkTime, GetParam()); // expect packets to be marked lost on call to detectLostPackets
// expecting 1 callback to be stacked
EXPECT_EQ(1, size(conn.pendingCallbacks));
EXPECT_CALL( EXPECT_CALL(
ib, *obs1,
packetLossDetected( packetLossDetected(
nullptr, socket.get(),
Field( Field(
&SocketObserverInterface::LossEvent::lostPackets, &SocketObserverInterface::LossEvent::lostPackets,
UnorderedElementsAre( UnorderedElementsAre(
@@ -2580,8 +2562,23 @@ TEST_P(AckHandlersTest, TestSpuriousObserverTimeout) {
MockLegacyObserver::getLossPacketMatcher(8, false, true), MockLegacyObserver::getLossPacketMatcher(8, false, true),
MockLegacyObserver::getLossPacketMatcher(9, false, true))))) MockLegacyObserver::getLossPacketMatcher(9, false, true)))))
.Times(1); .Times(1);
detectLossPackets(
conn, 10, [](auto&, auto&, bool) {}, checkTime, GetParam());
// Here we receive the spurious loss packets in a late ack // now we get acks for packets marked lost, triggering spuriousLossDetected
EXPECT_CALL(
*obs1,
spuriousLossDetected(
socket.get(),
Field(
&SocketObserverInterface::SpuriousLossEvent::spuriousPackets,
UnorderedElementsAre(
MockLegacyObserver::getLossPacketMatcher(5, false, true),
MockLegacyObserver::getLossPacketMatcher(6, false, true),
MockLegacyObserver::getLossPacketMatcher(7, false, true),
MockLegacyObserver::getLossPacketMatcher(8, false, true),
MockLegacyObserver::getLossPacketMatcher(9, false, true)))))
.Times(1);
{ {
ReadAckFrame ackFrame; ReadAckFrame ackFrame;
ackFrame.largestAcked = 9; ackFrame.largestAcked = 9;
@@ -2596,26 +2593,7 @@ TEST_P(AckHandlersTest, TestSpuriousObserverTimeout) {
startTime + 510ms); startTime + 510ms);
} }
// Spurious loss observer call added observerContainer->removeObserver(obs1.get());
EXPECT_EQ(2, size(conn.pendingCallbacks));
EXPECT_CALL(
ib,
spuriousLossDetected(
nullptr,
Field(
&SocketObserverInterface::SpuriousLossEvent::spuriousPackets,
UnorderedElementsAre(
MockLegacyObserver::getLossPacketMatcher(5, false, true),
MockLegacyObserver::getLossPacketMatcher(6, false, true),
MockLegacyObserver::getLossPacketMatcher(7, false, true),
MockLegacyObserver::getLossPacketMatcher(8, false, true),
MockLegacyObserver::getLossPacketMatcher(9, false, true)))))
.Times(1);
for (auto& callback : conn.pendingCallbacks) {
callback(nullptr);
}
} }
TEST_P(AckHandlersTest, SubMicrosecondRTT) { TEST_P(AckHandlersTest, SubMicrosecondRTT) {

View File

@@ -141,8 +141,8 @@ ProbeSizeRaiserType parseRaiserType(uint32_t type) {
class TPerfObserver : public LegacyObserver { class TPerfObserver : public LegacyObserver {
public: public:
explicit TPerfObserver(const LegacyObserver::Config& config) using LegacyObserver::LegacyObserver;
: LegacyObserver(config) {}
void appRateLimited( void appRateLimited(
QuicSocket* /* socket */, QuicSocket* /* socket */,
const quic::SocketObserverInterface:: const quic::SocketObserverInterface::
@@ -202,12 +202,13 @@ class TPerfAcceptObserver : public AcceptObserver {
TPerfAcceptObserver() { TPerfAcceptObserver() {
// Create an observer config, only enabling events we are interested in // Create an observer config, only enabling events we are interested in
// receiving. // receiving.
LegacyObserver::Config config = {}; LegacyObserver::EventSet eventSet;
config.appRateLimitedEvents = true; eventSet.enable(
config.pmtuEvents = true; SocketObserverInterface::Events::appRateLimitedEvents,
config.rttSamples = true; SocketObserverInterface::Events::pmtuEvents,
config.lossEvents = true; SocketObserverInterface::Events::rttSamples,
tperfObserver_ = std::make_unique<TPerfObserver>(config); SocketObserverInterface::Events::lossEvents);
tperfObserver_ = std::make_unique<TPerfObserver>(eventSet);
} }
void accept(QuicTransportBase* transport) noexcept override { void accept(QuicTransportBase* transport) noexcept override {