diff --git a/quic/api/QuicTransportBase.cpp b/quic/api/QuicTransportBase.cpp index 89b449f9c..fc8e6b1cc 100644 --- a/quic/api/QuicTransportBase.cpp +++ b/quic/api/QuicTransportBase.cpp @@ -145,20 +145,10 @@ const folly::SocketAddress& QuicTransportBase::getLocalAddress() const { QuicTransportBase::~QuicTransportBase() { resetConnectionCallbacks(); - closeImpl( - QuicError( - QuicErrorCode(LocalErrorCode::SHUTTING_DOWN), - std::string("Closing from base destructor")), - false); - // If a drainTimeout is already scheduled, then closeNow above - // won't do anything. We have to manually clean up the socket. Timeout will be - // canceled by timer's destructor. - if (socket_) { - auto sock = std::move(socket_); - socket_ = nullptr; - sock->pauseRead(); - sock->close(); - } + // closeImpl and closeUdpSocket should have been triggered by destructor of + // derived class to ensure that observers are properly notified + DCHECK_NE(CloseState::OPEN, closeState_); + DCHECK(!socket_.get()); // should be no socket } bool QuicTransportBase::good() const { @@ -247,9 +237,11 @@ void QuicTransportBase::closeImpl( } if (getSocketObserverContainer()) { + SocketObserverInterface::CloseStartedEvent event; + event.maybeCloseReason = errorCode; getSocketObserverContainer()->invokeInterfaceMethodAllObservers( - [errorCode](auto observer, auto observed) { - observer->close(observed, errorCode); + [&event](auto observer, auto observed) { + observer->closeStarted(observed, event); }); } @@ -437,6 +429,23 @@ void QuicTransportBase::closeImpl( } } +void QuicTransportBase::closeUdpSocket() { + if (!socket_) { + return; + } + if (getSocketObserverContainer()) { + SocketObserverInterface::ClosingEvent event; // empty for now + getSocketObserverContainer()->invokeInterfaceMethodAllObservers( + [&event](auto observer, auto observed) { + observer->closing(observed, event); + }); + } + auto sock = std::move(socket_); + socket_ = nullptr; + sock->pauseRead(); + sock->close(); +} + bool QuicTransportBase::processCancelCode(const QuicError& cancelCode) { bool noError = false; switch (cancelCode.code.type()) { @@ -494,12 +503,7 @@ void QuicTransportBase::processConnectionCallbacks( } void QuicTransportBase::drainTimeoutExpired() noexcept { - if (socket_) { - auto sock = std::move(socket_); - socket_ = nullptr; - sock->pauseRead(); - sock->close(); - } + closeUdpSocket(); unbindConnection(); } diff --git a/quic/api/QuicTransportBase.h b/quic/api/QuicTransportBase.h index fe2edbc39..fb73b0270 100644 --- a/quic/api/QuicTransportBase.h +++ b/quic/api/QuicTransportBase.h @@ -738,6 +738,7 @@ class QuicTransportBase : public QuicSocket, QuicStreamPrioritiesObserver { folly::Optional error, bool drainConnection = true, bool sendCloseImmediately = true); + void closeUdpSocket(); folly::Expected pauseOrResumeRead( StreamId id, bool resume); diff --git a/quic/api/test/Mocks.h b/quic/api/test/Mocks.h index 9643fc2a6..3c808547d 100644 --- a/quic/api/test/Mocks.h +++ b/quic/api/test/Mocks.h @@ -300,9 +300,10 @@ class MockObserver : public QuicSocket::ManagedObserver { (noexcept)); MOCK_METHOD( (void), - close, - (QuicSocket*, const folly::Optional&), + closeStarted, + (QuicSocket*, const CloseStartedEvent&), (noexcept)); + MOCK_METHOD((void), closing, (QuicSocket*, const ClosingEvent&), (noexcept)); }; class MockLegacyObserver : public LegacyObserver { @@ -311,13 +312,14 @@ class MockLegacyObserver : public LegacyObserver { MOCK_METHOD((void), observerAttach, (QuicSocket*), (noexcept)); MOCK_METHOD((void), observerDetach, (QuicSocket*), (noexcept)); MOCK_METHOD((void), destroy, (QuicSocket*), (noexcept)); - MOCK_METHOD((void), evbAttach, (QuicSocket*, folly::EventBase*), (noexcept)); - MOCK_METHOD((void), evbDetach, (QuicSocket*, folly::EventBase*), (noexcept)); MOCK_METHOD( (void), - close, - (QuicSocket*, const folly::Optional&), + closeStarted, + (QuicSocket*, const CloseStartedEvent&), (noexcept)); + MOCK_METHOD((void), closing, (QuicSocket*, const ClosingEvent&), (noexcept)); + MOCK_METHOD((void), evbAttach, (QuicSocket*, folly::EventBase*), (noexcept)); + MOCK_METHOD((void), evbDetach, (QuicSocket*, folly::EventBase*), (noexcept)); MOCK_METHOD( (void), startWritingFromAppLimited, diff --git a/quic/api/test/QuicTransportBaseTest.cpp b/quic/api/test/QuicTransportBaseTest.cpp index 6b814b367..e82b84836 100644 --- a/quic/api/test/QuicTransportBaseTest.cpp +++ b/quic/api/test/QuicTransportBaseTest.cpp @@ -247,11 +247,14 @@ class TestQuicTransport ~TestQuicTransport() override { resetConnectionCallbacks(); // we need to call close in the derived class. + resetConnectionCallbacks(); closeImpl( QuicError( QuicErrorCode(LocalErrorCode::SHUTTING_DOWN), std::string("shutdown")), - false); + false /* drainConnection */); + // closeImpl may have been called earlier with drain = true, so force close. + closeUdpSocket(); } std::chrono::milliseconds getLossTimeoutRemainingTime() const { @@ -3703,7 +3706,8 @@ TEST_P(QuicTransportImplTestBase, ObserverDestroy) { transport->addObserver(cb.get()); EXPECT_THAT(transport->getObservers(), UnorderedElementsAre(cb.get())); InSequence s; - EXPECT_CALL(*cb, close(transport.get(), _)); + EXPECT_CALL(*cb, closeStarted(transport.get(), _)); + EXPECT_CALL(*cb, closing(transport.get(), _)); EXPECT_CALL(*cb, destroy(transport.get())); transport = nullptr; Mock::VerifyAndClearExpectations(cb.get()); @@ -3732,7 +3736,8 @@ TEST_P(QuicTransportImplTestBase, ObserverSharedPtrDestroy) { transport->addObserver(cb); EXPECT_THAT(transport->getObservers(), UnorderedElementsAre(cb.get())); InSequence s; - EXPECT_CALL(*cb, close(transport.get(), _)); + EXPECT_CALL(*cb, closeStarted(transport.get(), _)); + EXPECT_CALL(*cb, closing(transport.get(), _)); EXPECT_CALL(*cb, destroy(transport.get())); transport = nullptr; Mock::VerifyAndClearExpectations(cb.get()); @@ -3752,7 +3757,8 @@ TEST_P(QuicTransportImplTestBase, ObserverSharedPtrReleasedDestroy) { EXPECT_FALSE(dc.destroyed()); // should still exist InSequence s; - EXPECT_CALL(*cbRaw, close(transport.get(), _)); + EXPECT_CALL(*cbRaw, closeStarted(transport.get(), _)); + EXPECT_CALL(*cbRaw, closing(transport.get(), _)); EXPECT_CALL(*cbRaw, destroy(transport.get())); transport = nullptr; Mock::VerifyAndClearExpectations(cb.get()); @@ -3764,44 +3770,6 @@ TEST_P(QuicTransportImplTestBase, ObserverSharedPtrRemoveMissing) { EXPECT_THAT(transport->getObservers(), IsEmpty()); } -TEST_P(QuicTransportImplTestBase, ObserverCloseNoErrorThenDestroy) { - auto cb = std::make_unique>(); - EXPECT_CALL(*cb, observerAttach(transport.get())); - transport->addObserver(cb.get()); - EXPECT_THAT(transport->getObservers(), UnorderedElementsAre(cb.get())); - - const QuicError defaultError = QuicError( - GenericApplicationErrorCode::NO_ERROR, - toString(GenericApplicationErrorCode::NO_ERROR)); - EXPECT_CALL( - *cb, close(transport.get(), folly::Optional(defaultError))); - transport->close(folly::none); - Mock::VerifyAndClearExpectations(cb.get()); - InSequence s; - EXPECT_CALL(*cb, destroy(transport.get())); - transport = nullptr; - Mock::VerifyAndClearExpectations(cb.get()); -} - -TEST_P(QuicTransportImplTestBase, ObserverCloseWithErrorThenDestroy) { - auto cb = std::make_unique>(); - EXPECT_CALL(*cb, observerAttach(transport.get())); - transport->addObserver(cb.get()); - EXPECT_THAT(transport->getObservers(), UnorderedElementsAre(cb.get())); - - const auto testError = QuicError( - QuicErrorCode(LocalErrorCode::CONNECTION_RESET), - std::string("testError")); - EXPECT_CALL( - *cb, close(transport.get(), folly::Optional(testError))); - transport->close(testError); - Mock::VerifyAndClearExpectations(cb.get()); - InSequence s; - EXPECT_CALL(*cb, destroy(transport.get())); - transport = nullptr; - Mock::VerifyAndClearExpectations(cb.get()); -} - TEST_P(QuicTransportImplTestBase, ObserverDetachImmediately) { auto cb = std::make_unique>(); EXPECT_CALL(*cb, observerAttach(transport.get())); @@ -3815,12 +3783,20 @@ TEST_P(QuicTransportImplTestBase, ObserverDetachImmediately) { } TEST_P(QuicTransportImplTestBase, ObserverDetachAfterClose) { + // disable draining to ensure closing() event occurs immediately after close() + { + auto transportSettings = transport->getTransportSettings(); + transportSettings.shouldDrain = false; + transport->setTransportSettings(transportSettings); + } + auto cb = std::make_unique>(); EXPECT_CALL(*cb, observerAttach(transport.get())); transport->addObserver(cb.get()); EXPECT_THAT(transport->getObservers(), UnorderedElementsAre(cb.get())); - EXPECT_CALL(*cb, close(transport.get(), _)); + EXPECT_CALL(*cb, closeStarted(transport.get(), _)); + EXPECT_CALL(*cb, closing(transport.get(), _)); transport->close(folly::none); Mock::VerifyAndClearExpectations(cb.get()); @@ -3830,14 +3806,33 @@ TEST_P(QuicTransportImplTestBase, ObserverDetachAfterClose) { EXPECT_THAT(transport->getObservers(), IsEmpty()); } -TEST_F(QuicTransportImplTest, ObserverDetachOnCloseDuringDestroy) { +TEST_F(QuicTransportImplTest, ObserverDetachOnCloseStartedDuringDestroy) { auto cb = std::make_unique>(); EXPECT_CALL(*cb, observerAttach(transport.get())); transport->addObserver(cb.get()); EXPECT_THAT(transport->getObservers(), UnorderedElementsAre(cb.get())); InSequence s; - EXPECT_CALL(*cb, close(transport.get(), _)) + + EXPECT_CALL(*cb, closeStarted(transport.get(), _)) + .WillOnce(Invoke([&cb](auto callbackTransport, auto /* errorOpt */) { + EXPECT_TRUE(callbackTransport->removeObserver(cb.get())); + })); + EXPECT_CALL(*cb, observerDetach(transport.get())); + transport = nullptr; + Mock::VerifyAndClearExpectations(cb.get()); +} + +TEST_F(QuicTransportImplTest, ObserverDetachOnClosingDuringDestroy) { + auto cb = std::make_unique>(); + EXPECT_CALL(*cb, observerAttach(transport.get())); + transport->addObserver(cb.get()); + EXPECT_THAT(transport->getObservers(), UnorderedElementsAre(cb.get())); + + InSequence s; + + EXPECT_CALL(*cb, closeStarted(transport.get(), _)); + EXPECT_CALL(*cb, closing(transport.get(), _)) .WillOnce(Invoke([&cb](auto callbackTransport, auto /* errorOpt */) { EXPECT_TRUE(callbackTransport->removeObserver(cb.get())); })); @@ -3938,8 +3933,10 @@ TEST_P(QuicTransportImplTestBase, ObserverMultipleAttachDestroy) { transport->getObservers(), UnorderedElementsAre(cb1.get(), cb2.get())); InSequence s; - EXPECT_CALL(*cb1, close(transport.get(), _)); - EXPECT_CALL(*cb2, close(transport.get(), _)); + EXPECT_CALL(*cb1, closeStarted(transport.get(), _)); + EXPECT_CALL(*cb2, closeStarted(transport.get(), _)); + EXPECT_CALL(*cb1, closing(transport.get(), _)); + EXPECT_CALL(*cb2, closing(transport.get(), _)); EXPECT_CALL(*cb1, destroy(transport.get())); EXPECT_CALL(*cb2, destroy(transport.get())); transport = nullptr; diff --git a/quic/api/test/QuicTransportTest.cpp b/quic/api/test/QuicTransportTest.cpp index 7e0fe036d..cf5aa9137 100644 --- a/quic/api/test/QuicTransportTest.cpp +++ b/quic/api/test/QuicTransportTest.cpp @@ -405,9 +405,12 @@ TEST_F(QuicTransportTest, ObserverNotAppLimitedWithNoWritableBytes) { loopForWrites(); Mock::VerifyAndClearExpectations(cb1.get()); Mock::VerifyAndClearExpectations(cb2.get()); - EXPECT_CALL(*cb1, close(transport_.get(), _)); - EXPECT_CALL(*cb2, close(transport_.get(), _)); - EXPECT_CALL(*cb3, close(transport_.get(), _)); + EXPECT_CALL(*cb1, closeStarted(transport_.get(), _)); + EXPECT_CALL(*cb2, closeStarted(transport_.get(), _)); + EXPECT_CALL(*cb3, closeStarted(transport_.get(), _)); + EXPECT_CALL(*cb1, closing(transport_.get(), _)); + EXPECT_CALL(*cb2, closing(transport_.get(), _)); + EXPECT_CALL(*cb3, closing(transport_.get(), _)); EXPECT_CALL(*cb1, destroy(transport_.get())); EXPECT_CALL(*cb2, destroy(transport_.get())); EXPECT_CALL(*cb3, destroy(transport_.get())); @@ -451,9 +454,12 @@ TEST_F(QuicTransportTest, ObserverNotAppLimitedWithLargeBuffer) { loopForWrites(); Mock::VerifyAndClearExpectations(cb1.get()); Mock::VerifyAndClearExpectations(cb2.get()); - EXPECT_CALL(*cb1, close(transport_.get(), _)); - EXPECT_CALL(*cb2, close(transport_.get(), _)); - EXPECT_CALL(*cb3, close(transport_.get(), _)); + EXPECT_CALL(*cb1, closeStarted(transport_.get(), _)); + EXPECT_CALL(*cb2, closeStarted(transport_.get(), _)); + EXPECT_CALL(*cb3, closeStarted(transport_.get(), _)); + EXPECT_CALL(*cb1, closing(transport_.get(), _)); + EXPECT_CALL(*cb2, closing(transport_.get(), _)); + EXPECT_CALL(*cb3, closing(transport_.get(), _)); EXPECT_CALL(*cb1, destroy(transport_.get())); EXPECT_CALL(*cb2, destroy(transport_.get())); EXPECT_CALL(*cb3, destroy(transport_.get())); @@ -499,9 +505,12 @@ TEST_F(QuicTransportTest, ObserverAppLimited) { Mock::VerifyAndClearExpectations(cb1.get()); Mock::VerifyAndClearExpectations(cb2.get()); Mock::VerifyAndClearExpectations(cb3.get()); - EXPECT_CALL(*cb1, close(transport_.get(), _)); - EXPECT_CALL(*cb2, close(transport_.get(), _)); - EXPECT_CALL(*cb3, close(transport_.get(), _)); + EXPECT_CALL(*cb1, closeStarted(transport_.get(), _)); + EXPECT_CALL(*cb2, closeStarted(transport_.get(), _)); + EXPECT_CALL(*cb3, closeStarted(transport_.get(), _)); + EXPECT_CALL(*cb1, closing(transport_.get(), _)); + EXPECT_CALL(*cb2, closing(transport_.get(), _)); + EXPECT_CALL(*cb3, closing(transport_.get(), _)); EXPECT_CALL(*cb1, destroy(transport_.get())); EXPECT_CALL(*cb2, destroy(transport_.get())); EXPECT_CALL(*cb3, destroy(transport_.get())); @@ -885,7 +894,10 @@ TEST_F(QuicTransportTest, ObserverPacketsWrittenCycleCheckDetails) { loopForWrites(); invokeForAllObservers(([this](MockLegacyObserver& observer) { - EXPECT_CALL(observer, close(transport_.get(), _)); + EXPECT_CALL(observer, closeStarted(transport_.get(), _)); + })); + invokeForAllObservers(([this](MockLegacyObserver& observer) { + EXPECT_CALL(observer, closing(transport_.get(), _)); })); invokeForAllObservers(([this](MockLegacyObserver& observer) { EXPECT_CALL(observer, destroy(transport_.get())); @@ -1097,7 +1109,10 @@ TEST_F(QuicTransportTest, ObserverPacketsWrittenCheckBytesSent) { } invokeForAllObservers(([this](MockLegacyObserver& observer) { - EXPECT_CALL(observer, close(transport_.get(), _)); + EXPECT_CALL(observer, closeStarted(transport_.get(), _)); + })); + invokeForAllObservers(([this](MockLegacyObserver& observer) { + EXPECT_CALL(observer, closing(transport_.get(), _)); })); invokeForAllObservers(([this](MockLegacyObserver& observer) { EXPECT_CALL(observer, destroy(transport_.get())); @@ -1378,7 +1393,10 @@ TEST_F(QuicTransportTest, ObserverWriteEventsCheckCwndPacketsWritable) { } invokeForAllObservers(([this](MockLegacyObserver& observer) { - EXPECT_CALL(observer, close(transport_.get(), _)); + EXPECT_CALL(observer, closeStarted(transport_.get(), _)); + })); + invokeForAllObservers(([this](MockLegacyObserver& observer) { + EXPECT_CALL(observer, closing(transport_.get(), _)); })); invokeForAllObservers(([this](MockLegacyObserver& observer) { EXPECT_CALL(observer, destroy(transport_.get())); @@ -1422,8 +1440,10 @@ TEST_F(QuicTransportTest, ObserverStreamEventBidirectionalLocalOpenClose) { SocketAddress("::1", 10000), NetworkData(IOBuf::copyBuffer("fake data"), Clock::now())); - EXPECT_CALL(*cb1, close(transport_.get(), _)); - EXPECT_CALL(*cb2, close(transport_.get(), _)); + EXPECT_CALL(*cb1, closeStarted(transport_.get(), _)); + EXPECT_CALL(*cb2, closeStarted(transport_.get(), _)); + EXPECT_CALL(*cb1, closing(transport_.get(), _)); + EXPECT_CALL(*cb2, closing(transport_.get(), _)); EXPECT_CALL(*cb1, destroy(transport_.get())); EXPECT_CALL(*cb2, destroy(transport_.get())); transport_ = nullptr; @@ -1464,8 +1484,10 @@ TEST_F(QuicTransportTest, ObserverStreamEventBidirectionalRemoteOpenClose) { SocketAddress("::1", 10000), NetworkData(IOBuf::copyBuffer("fake data"), Clock::now())); - EXPECT_CALL(*cb1, close(transport_.get(), _)); - EXPECT_CALL(*cb2, close(transport_.get(), _)); + EXPECT_CALL(*cb1, closeStarted(transport_.get(), _)); + EXPECT_CALL(*cb2, closeStarted(transport_.get(), _)); + EXPECT_CALL(*cb1, closing(transport_.get(), _)); + EXPECT_CALL(*cb2, closing(transport_.get(), _)); EXPECT_CALL(*cb1, destroy(transport_.get())); EXPECT_CALL(*cb2, destroy(transport_.get())); transport_ = nullptr; @@ -1506,8 +1528,10 @@ TEST_F(QuicTransportTest, ObserverStreamEventUnidirectionalLocalOpenClose) { SocketAddress("::1", 10000), NetworkData(IOBuf::copyBuffer("fake data"), Clock::now())); - EXPECT_CALL(*cb1, close(transport_.get(), _)); - EXPECT_CALL(*cb2, close(transport_.get(), _)); + EXPECT_CALL(*cb1, closeStarted(transport_.get(), _)); + EXPECT_CALL(*cb2, closeStarted(transport_.get(), _)); + EXPECT_CALL(*cb1, closing(transport_.get(), _)); + EXPECT_CALL(*cb2, closing(transport_.get(), _)); EXPECT_CALL(*cb1, destroy(transport_.get())); EXPECT_CALL(*cb2, destroy(transport_.get())); transport_ = nullptr; @@ -1547,8 +1571,10 @@ TEST_F(QuicTransportTest, ObserverStreamEventUnidirectionalRemoteOpenClose) { SocketAddress("::1", 10000), NetworkData(IOBuf::copyBuffer("fake data"), Clock::now())); - EXPECT_CALL(*cb1, close(transport_.get(), _)); - EXPECT_CALL(*cb2, close(transport_.get(), _)); + EXPECT_CALL(*cb1, closeStarted(transport_.get(), _)); + EXPECT_CALL(*cb2, closeStarted(transport_.get(), _)); + EXPECT_CALL(*cb1, closing(transport_.get(), _)); + EXPECT_CALL(*cb2, closing(transport_.get(), _)); EXPECT_CALL(*cb1, destroy(transport_.get())); EXPECT_CALL(*cb2, destroy(transport_.get())); transport_ = nullptr; diff --git a/quic/api/test/QuicTypedTransportTest.cpp b/quic/api/test/QuicTypedTransportTest.cpp index fc5795612..54f979add 100644 --- a/quic/api/test/QuicTypedTransportTest.cpp +++ b/quic/api/test/QuicTypedTransportTest.cpp @@ -26,8 +26,8 @@ using namespace testing; namespace { using TransportTypes = testing::Types< - quic::test::QuicClientTransportAfterStartTestBase, - quic::test::QuicServerTransportAfterStartTestBase>; + quic::test::QuicClientTransportTestBase, + quic::test::QuicServerTransportTestBase>; class TransportTypeNames { public: @@ -51,14 +51,30 @@ template class QuicTypedTransportTest : public virtual testing::Test, public QuicTypedTransportTestBase { public: + ~QuicTypedTransportTest() override = default; void SetUp() override { // trigger setup of the underlying transport QuicTypedTransportTestBase::SetUp(); } }; +template +class QuicTypedTransportAfterStartTest : public QuicTypedTransportTest { + public: + ~QuicTypedTransportAfterStartTest() override = default; + void SetUp() override { + QuicTypedTransportTest::SetUp(); + QuicTypedTransportTestBase::startTransport(); + } +}; + TYPED_TEST_SUITE( - QuicTypedTransportTest, + QuicTypedTransportAfterStartTest, + ::TransportTypes, + ::TransportTypeNames); + +TYPED_TEST_SUITE( + QuicTypedTransportAfterStartTest, ::TransportTypes, ::TransportTypeNames); @@ -67,7 +83,7 @@ TYPED_TEST_SUITE( * * Currently tests mrtt, mrttNoAckDelay, lrttRaw, lrttRawAckDelay */ -TYPED_TEST(QuicTypedTransportTest, TransportInfoRttSignals) { +TYPED_TEST(QuicTypedTransportAfterStartTest, TransportInfoRttSignals) { // lambda to send and ACK a packet const auto sendAndAckPacket = [&](const auto& rttIn, const auto& ackDelayIn) { auto streamId = this->getTransport()->createBidirectionalStream().value(); @@ -116,12 +132,10 @@ TYPED_TEST(QuicTypedTransportTest, TransportInfoRttSignals) { const auto expectedMinRtt = 31ms; const auto expectedMinRttNoAckDelay = 26ms; - if constexpr (std::is_same_v< - TypeParam, - QuicClientTransportAfterStartTestBase>) { - } else if constexpr (std::is_same_v< + if constexpr (std::is_base_of_v) { + } else if constexpr (std::is_base_of_v< TypeParam, - QuicServerTransportAfterStartTestBase>) { + QuicServerTransportTestBase>) { const auto tInfo = this->getTransport()->getTransportInfo(); EXPECT_EQ(rtt, tInfo.maybeLrtt); EXPECT_EQ(ackDelay, tInfo.maybeLrttAckDelay); @@ -140,12 +154,10 @@ TYPED_TEST(QuicTypedTransportTest, TransportInfoRttSignals) { const auto expectedMinRtt = 30ms; const auto expectedMinRttNoAckDelay = 26ms; - if constexpr (std::is_same_v< - TypeParam, - QuicClientTransportAfterStartTestBase>) { - } else if constexpr (std::is_same_v< + if constexpr (std::is_base_of_v) { + } else if constexpr (std::is_base_of_v< TypeParam, - QuicServerTransportAfterStartTestBase>) { + QuicServerTransportTestBase>) { const auto tInfo = this->getTransport()->getTransportInfo(); EXPECT_EQ(rtt, tInfo.maybeLrtt); EXPECT_EQ(ackDelay, tInfo.maybeLrttAckDelay); @@ -164,12 +176,10 @@ TYPED_TEST(QuicTypedTransportTest, TransportInfoRttSignals) { const auto expectedMinRtt = 30ms; const auto expectedMinRttNoAckDelay = 22ms; - if constexpr (std::is_same_v< - TypeParam, - QuicClientTransportAfterStartTestBase>) { - } else if constexpr (std::is_same_v< + if constexpr (std::is_base_of_v) { + } else if constexpr (std::is_base_of_v< TypeParam, - QuicServerTransportAfterStartTestBase>) { + QuicServerTransportTestBase>) { const auto tInfo = this->getTransport()->getTransportInfo(); EXPECT_EQ(rtt, tInfo.maybeLrtt); EXPECT_EQ(ackDelay, tInfo.maybeLrttAckDelay); @@ -188,12 +198,10 @@ TYPED_TEST(QuicTypedTransportTest, TransportInfoRttSignals) { const auto expectedMinRtt = 30ms; const auto expectedMinRttNoAckDelay = 22ms; - if constexpr (std::is_same_v< - TypeParam, - QuicClientTransportAfterStartTestBase>) { - } else if constexpr (std::is_same_v< + if constexpr (std::is_base_of_v) { + } else if constexpr (std::is_base_of_v< TypeParam, - QuicServerTransportAfterStartTestBase>) { + QuicServerTransportTestBase>) { const auto tInfo = this->getTransport()->getTransportInfo(); EXPECT_EQ(rtt, tInfo.maybeLrtt); EXPECT_EQ(ackDelay, tInfo.maybeLrttAckDelay); @@ -212,12 +220,10 @@ TYPED_TEST(QuicTypedTransportTest, TransportInfoRttSignals) { const auto expectedMinRtt = 25ms; const auto expectedMinRttNoAckDelay = 22ms; - if constexpr (std::is_same_v< - TypeParam, - QuicClientTransportAfterStartTestBase>) { - } else if constexpr (std::is_same_v< + if constexpr (std::is_base_of_v) { + } else if constexpr (std::is_base_of_v< TypeParam, - QuicServerTransportAfterStartTestBase>) { + QuicServerTransportTestBase>) { const auto tInfo = this->getTransport()->getTransportInfo(); EXPECT_EQ(rtt, tInfo.maybeLrtt); EXPECT_EQ(ackDelay, tInfo.maybeLrttAckDelay); @@ -236,12 +242,10 @@ TYPED_TEST(QuicTypedTransportTest, TransportInfoRttSignals) { const auto expectedMinRtt = 25ms; const auto expectedMinRttNoAckDelay = 21ms; - if constexpr (std::is_same_v< - TypeParam, - QuicClientTransportAfterStartTestBase>) { - } else if constexpr (std::is_same_v< + if constexpr (std::is_base_of_v) { + } else if constexpr (std::is_base_of_v< TypeParam, - QuicServerTransportAfterStartTestBase>) { + QuicServerTransportTestBase>) { const auto tInfo = this->getTransport()->getTransportInfo(); EXPECT_EQ(rtt, tInfo.maybeLrtt); EXPECT_EQ(ackDelay, tInfo.maybeLrttAckDelay); @@ -260,12 +264,10 @@ TYPED_TEST(QuicTypedTransportTest, TransportInfoRttSignals) { const auto expectedMinRtt = 20ms; const auto expectedMinRttNoAckDelay = 20ms; - if constexpr (std::is_same_v< - TypeParam, - QuicClientTransportAfterStartTestBase>) { - } else if constexpr (std::is_same_v< + if constexpr (std::is_base_of_v) { + } else if constexpr (std::is_base_of_v< TypeParam, - QuicServerTransportAfterStartTestBase>) { + QuicServerTransportTestBase>) { const auto tInfo = this->getTransport()->getTransportInfo(); EXPECT_EQ(rtt, tInfo.maybeLrtt); EXPECT_EQ(ackDelay, tInfo.maybeLrttAckDelay); @@ -282,7 +284,7 @@ TYPED_TEST(QuicTypedTransportTest, TransportInfoRttSignals) { /** * Test case where the ACK delay is equal to the RTT sample. */ -TYPED_TEST(QuicTypedTransportTest, RttSampleAckDelayEqual) { +TYPED_TEST(QuicTypedTransportAfterStartTest, RttSampleAckDelayEqual) { // lambda to send and ACK a packet const auto sendAndAckPacket = [&](const auto& rttIn, const auto& ackDelayIn) { auto streamId = this->getTransport()->createBidirectionalStream().value(); @@ -323,12 +325,10 @@ TYPED_TEST(QuicTypedTransportTest, RttSampleAckDelayEqual) { const auto expectedMinRtt = 25ms; const auto expectedMinRttNoAckDelay = 0ms; - if constexpr (std::is_same_v< - TypeParam, - QuicClientTransportAfterStartTestBase>) { - } else if constexpr (std::is_same_v< + if constexpr (std::is_base_of_v) { + } else if constexpr (std::is_base_of_v< TypeParam, - QuicServerTransportAfterStartTestBase>) { + QuicServerTransportTestBase>) { const auto tInfo = this->getTransport()->getTransportInfo(); EXPECT_EQ(rtt, tInfo.maybeLrtt); EXPECT_EQ(ackDelay, tInfo.maybeLrttAckDelay); @@ -345,7 +345,7 @@ TYPED_TEST(QuicTypedTransportTest, RttSampleAckDelayEqual) { /** * Test case where the ACK delay is greater than the RTT sample. */ -TYPED_TEST(QuicTypedTransportTest, RttSampleAckDelayGreater) { +TYPED_TEST(QuicTypedTransportAfterStartTest, RttSampleAckDelayGreater) { // lambda to send and ACK a packet const auto sendAndAckPacket = [&](const auto& rttIn, const auto& ackDelayIn) { auto streamId = this->getTransport()->createBidirectionalStream().value(); @@ -385,12 +385,10 @@ TYPED_TEST(QuicTypedTransportTest, RttSampleAckDelayGreater) { sendAndAckPacket(rtt, ackDelay); const auto expectedMinRtt = 25ms; - if constexpr (std::is_same_v< - TypeParam, - QuicClientTransportAfterStartTestBase>) { - } else if constexpr (std::is_same_v< + if constexpr (std::is_base_of_v) { + } else if constexpr (std::is_base_of_v< TypeParam, - QuicServerTransportAfterStartTestBase>) { + QuicServerTransportTestBase>) { const auto tInfo = this->getTransport()->getTransportInfo(); EXPECT_EQ(rtt, tInfo.maybeLrtt); EXPECT_EQ(ackDelay, tInfo.maybeLrttAckDelay); @@ -410,7 +408,7 @@ TYPED_TEST(QuicTypedTransportTest, RttSampleAckDelayGreater) { * In this case, we should fallback to using system clock timestamp, and thus * should end up with a non-zero RTT. */ -TYPED_TEST(QuicTypedTransportTest, RttSampleZeroTime) { +TYPED_TEST(QuicTypedTransportAfterStartTest, RttSampleZeroTime) { // lambda to send and ACK a packet const auto sendAndAckPacket = [&](const auto& rttIn, const auto& ackDelayIn) { auto streamId = this->getTransport()->createBidirectionalStream().value(); @@ -448,12 +446,10 @@ TYPED_TEST(QuicTypedTransportTest, RttSampleZeroTime) { const auto rtt = 0ms; const auto ackDelay = 0ms; sendAndAckPacket(rtt, ackDelay); - if constexpr (std::is_same_v< - TypeParam, - QuicClientTransportAfterStartTestBase>) { - } else if constexpr (std::is_same_v< + if constexpr (std::is_base_of_v) { + } else if constexpr (std::is_base_of_v< TypeParam, - QuicServerTransportAfterStartTestBase>) { + QuicServerTransportTestBase>) { const auto tInfo = this->getTransport()->getTransportInfo(); EXPECT_LE(0ms, tInfo.maybeLrtt.value()); EXPECT_GE(500ms, tInfo.maybeLrtt.value()); @@ -471,7 +467,9 @@ TYPED_TEST(QuicTypedTransportTest, RttSampleZeroTime) { /** * Verify vector used to store ACK events has no capacity if no pkts in flight. */ -TYPED_TEST(QuicTypedTransportTest, AckEventsNoAllocatedSpaceWhenNoOutstanding) { +TYPED_TEST( + QuicTypedTransportAfterStartTest, + AckEventsNoAllocatedSpaceWhenNoOutstanding) { // clear any outstanding packets this->getNonConstConn().outstandings.reset(); @@ -503,7 +501,7 @@ TYPED_TEST(QuicTypedTransportTest, AckEventsNoAllocatedSpaceWhenNoOutstanding) { * Two packets to give opportunity for packets in flight. */ TYPED_TEST( - QuicTypedTransportTest, + QuicTypedTransportAfterStartTest, AckEventsNoAllocatedSpaceWhenNoOutstandingTwoInFlight) { // clear any outstanding packets this->getNonConstConn().outstandings.reset(); @@ -560,7 +558,7 @@ TYPED_TEST( * Two packets ACKed in reverse to give opportunity for packets in flight. */ TYPED_TEST( - QuicTypedTransportTest, + QuicTypedTransportAfterStartTest, AckEventsNoAllocatedSpaceWhenNoOutstandingTwoInFlightReverse) { // prevent packets from being marked as lost this->getNonConstConn().lossState.reorderingThreshold = 10; @@ -620,7 +618,9 @@ TYPED_TEST( /** * Verify PacketProcessor callbacks when sending a packet and its ack */ -TYPED_TEST(QuicTypedTransportTest, PacketProcessorSendSingleDataPacketWithAck) { +TYPED_TEST( + QuicTypedTransportAfterStartTest, + PacketProcessorSendSingleDataPacketWithAck) { // clear any outstanding packets this->getNonConstConn().outstandings.reset(); auto mockPacketProcessor = std::make_unique(); @@ -674,7 +674,9 @@ TYPED_TEST(QuicTypedTransportTest, PacketProcessorSendSingleDataPacketWithAck) { * Verify PacketProcessor callbacks when sending two data packets and receiving * one ack */ -TYPED_TEST(QuicTypedTransportTest, PacketProcessorSendTwoDataPacketsWithAck) { +TYPED_TEST( + QuicTypedTransportAfterStartTest, + PacketProcessorSendTwoDataPacketsWithAck) { // clear any outstanding packets this->getNonConstConn().outstandings.reset(); auto mockPacketProcessor = std::make_unique(); @@ -734,7 +736,9 @@ TYPED_TEST(QuicTypedTransportTest, PacketProcessorSendTwoDataPacketsWithAck) { this->destroyTransport(); } -TYPED_TEST(QuicTypedTransportTest, StreamAckedIntervalsDeliveryCallbacks) { +TYPED_TEST( + QuicTypedTransportAfterStartTest, + StreamAckedIntervalsDeliveryCallbacks) { // clear any outstanding packets this->getNonConstConn().outstandings.reset(); @@ -788,7 +792,7 @@ TYPED_TEST(QuicTypedTransportTest, StreamAckedIntervalsDeliveryCallbacks) { } TYPED_TEST( - QuicTypedTransportTest, + QuicTypedTransportAfterStartTest, StreamAckedIntervalsDeliveryCallbacksFinOnly) { // clear any outstanding packets this->getNonConstConn().outstandings.reset(); @@ -819,7 +823,7 @@ TYPED_TEST( } TYPED_TEST( - QuicTypedTransportTest, + QuicTypedTransportAfterStartTest, StreamAckedIntervalsDeliveryCallbacksSingleByteNoFin) { // clear any outstanding packets this->getNonConstConn().outstandings.reset(); @@ -850,6 +854,228 @@ TYPED_TEST( this->destroyTransport(); } +template +struct AckEventMatcherBuilder { + using Builder = AckEventMatcherBuilder; + Builder&& setExpectedAckedIntervals( + std::vector< + typename QuicTypedTransportTest::NewOutstandingPacketInterval> + expectedAckedIntervals) { + maybeExpectedAckedIntervals = std::move(expectedAckedIntervals); + return std::move(*this); + } + Builder&& setExpectedAckedIntervals( + std::vector::NewOutstandingPacketInterval>> + expectedAckedIntervalsOpt) { + std::vector< + typename QuicTypedTransportTest::NewOutstandingPacketInterval> + expectedAckedIntervals; + for (const auto& maybeInterval : expectedAckedIntervalsOpt) { + CHECK(maybeInterval.has_value()); + expectedAckedIntervals.push_back(maybeInterval.value()); + } + maybeExpectedAckedIntervals = std::move(expectedAckedIntervals); + return std::move(*this); + } + Builder&& setExpectedNumAckedPackets(const uint64_t expectedNumAckedPackets) { + maybeExpectedNumAckedPackets = expectedNumAckedPackets; + return std::move(*this); + } + Builder&& setAckTime(TimePoint ackTime) { + maybeAckTime = ackTime; + return std::move(*this); + } + Builder&& setAckDelay(std::chrono::microseconds ackDelay) { + maybeAckDelay = ackDelay; + return std::move(*this); + } + Builder&& setLargestAckedPacket(quic::PacketNum largestAckedPacketIn) { + maybeLargestAckedPacket = largestAckedPacketIn; + return std::move(*this); + } + Builder&& setLargestNewlyAckedPacket( + quic::PacketNum largestNewlyAckedPacketIn) { + maybeLargestNewlyAckedPacket = largestNewlyAckedPacketIn; + return std::move(*this); + } + Builder&& setRtt(const folly::Optional& rttIn) { + maybeRtt = rttIn; + CHECK(!noRtt); + return std::move(*this); + } + Builder&& setRttNoAckDelay( + const folly::Optional& rttNoAckDelayIn) { + maybeRttNoAckDelay = rttNoAckDelayIn; + CHECK(!noRtt); + CHECK(!noRttWithNoAckDelay); + return std::move(*this); + } + Builder&& setNoRtt() { + noRtt = true; + CHECK(!maybeRtt); + CHECK(!maybeRttNoAckDelay); + return std::move(*this); + } + Builder&& setNoRttWithNoAckDelay() { + noRttWithNoAckDelay = true; + CHECK(!maybeRttNoAckDelay); + return std::move(*this); + } + auto build() && { + CHECK( + noRtt || + (maybeRtt.has_value() && + (noRttWithNoAckDelay || maybeRttNoAckDelay.has_value()))); + + CHECK(maybeExpectedAckedIntervals.has_value()); + const auto& expectedAckedIntervals = *maybeExpectedAckedIntervals; + CHECK_LT(0, expectedAckedIntervals.size()); + + CHECK(maybeExpectedNumAckedPackets.has_value()); + const auto& expectedNumAckedPackets = *maybeExpectedNumAckedPackets; + + CHECK(maybeAckTime.has_value()); + const auto& ackTime = *maybeAckTime; + + CHECK(maybeAckDelay.has_value()); + const auto& ackDelay = *maybeAckDelay; + + CHECK(maybeLargestAckedPacket.has_value()); + const auto& largestAckedPacket = *maybeLargestAckedPacket; + + CHECK(maybeLargestNewlyAckedPacket.has_value()); + const auto& largestNewlyAckedPacket = *maybeLargestNewlyAckedPacket; + + // sanity check expectedNumAckedPackets and expectedAckedIntervals + // reduces potential of error in test design + { + uint64_t expectedNumAckedPacketsFromIntervals = 0; + std::vector< + typename QuicTypedTransportTest::NewOutstandingPacketInterval> + processedExpectedAckedIntervals; + + for (const auto& interval : expectedAckedIntervals) { + CHECK_LE(interval.start, interval.end); + CHECK_LE(0, interval.end); + expectedNumAckedPacketsFromIntervals += + interval.end - interval.start + 1; + + // should not overlap with existing intervals + for (const auto& processedInterval : processedExpectedAckedIntervals) { + CHECK( + processedInterval.end < interval.start || + processedInterval.start < interval.end); + } + + processedExpectedAckedIntervals.push_back(interval); + } + CHECK_EQ(expectedNumAckedPacketsFromIntervals, expectedNumAckedPackets); + } + + if constexpr (std::is_base_of_v) { + return testing::Property( + &quic::SocketObserverInterface::AcksProcessedEvent::getAckEvents, + testing::ElementsAre(testing::AllOf( + // ack time, adjusted ack time, RTT not supported for client now + testing::Field(&quic::AckEvent::ackDelay, testing::Eq(ackDelay)), + testing::Field( + &quic::AckEvent::largestAckedPacket, + testing::Eq(largestAckedPacket)), + testing::Field( + &quic::AckEvent::largestNewlyAckedPacket, + testing::Eq(largestNewlyAckedPacket)), + testing::Field( + &quic::AckEvent::ackedPackets, + testing::SizeIs(expectedNumAckedPackets))))); + } else if constexpr (std::is_base_of_v) { + return testing::Property( + &quic::SocketObserverInterface::AcksProcessedEvent::getAckEvents, + testing::ElementsAre(testing::AllOf( + testing::Field(&quic::AckEvent::ackTime, testing::Eq(ackTime)), + testing::Field( + &quic::AckEvent::adjustedAckTime, + testing::Eq(ackTime - ackDelay)), + testing::Field(&quic::AckEvent::ackDelay, testing::Eq(ackDelay)), + testing::Field( + &quic::AckEvent::largestAckedPacket, + testing::Eq(largestAckedPacket)), + testing::Field( + &quic::AckEvent::largestNewlyAckedPacket, + testing::Eq(largestNewlyAckedPacket)), + testing::Field( + &quic::AckEvent::ackedPackets, + testing::SizeIs(expectedNumAckedPackets)), + testing::Field(&quic::AckEvent::rttSample, testing::Eq(maybeRtt)), + testing::Field( + &quic::AckEvent::rttSampleNoAckDelay, + testing::Eq(maybeRttNoAckDelay))))); + } else { + FAIL(); // unhandled typed test + } + } + explicit AckEventMatcherBuilder() = default; + + folly::Optional::NewOutstandingPacketInterval>> + maybeExpectedAckedIntervals; + folly::Optional maybeExpectedNumAckedPackets; + folly::Optional maybeAckTime; + folly::Optional maybeAckDelay; + folly::Optional maybeLargestAckedPacket; + folly::Optional maybeLargestNewlyAckedPacket; + folly::Optional maybeRtt; + folly::Optional maybeRttNoAckDelay; + bool noRtt{false}; + bool noRttWithNoAckDelay{false}; +}; + +template +struct ReceivedPacketMatcherBuilder { + using Builder = ReceivedPacketMatcherBuilder; + using Obj = + quic::SocketObserverInterface::PacketsReceivedEvent::ReceivedPacket; + Builder&& setExpectedPacketReceiveTime( + const TimePoint expectedPacketReceiveTime) { + maybeExpectedPacketReceiveTime = expectedPacketReceiveTime; + return std::move(*this); + } + Builder&& setExpectedPacketNumBytes(const uint64_t expectedPacketNumBytes) { + maybeExpectedPacketNumBytes = expectedPacketNumBytes; + return std::move(*this); + } + auto build() && { + CHECK(maybeExpectedPacketReceiveTime.has_value()); + const auto& packetReceiveTime = *maybeExpectedPacketReceiveTime; + + CHECK(maybeExpectedPacketNumBytes.has_value()); + const auto& packetNumBytes = *maybeExpectedPacketNumBytes; + + if constexpr (std::is_base_of_v) { + return testing::AllOf( + // client does not currently support socket RX timestamps, so we + // expect ts >= now() at time of matcher build + testing::Field( + &Obj::packetReceiveTime, + testing::AnyOf( + testing::Eq(packetReceiveTime), + testing::Ge(TimePoint::clock::now()))), + testing::Field(&Obj::packetNumBytes, testing::Eq(packetNumBytes))); + } else if constexpr (std::is_base_of_v) { + return testing::AllOf( + testing::Field( + &Obj::packetReceiveTime, testing::Eq(packetReceiveTime)), + testing::Field(&Obj::packetNumBytes, testing::Eq(packetNumBytes))); + } else { + FAIL(); // unhandled typed test + } + } + explicit ReceivedPacketMatcherBuilder() = default; + + folly::Optional maybeExpectedPacketReceiveTime; + folly::Optional maybeExpectedPacketNumBytes; +}; + template class QuicTypedTransportTestForObservers : public QuicTypedTransportTest { public: @@ -857,235 +1083,6 @@ class QuicTypedTransportTestForObservers : public QuicTypedTransportTest { QuicTypedTransportTest::SetUp(); } - struct AckEventMatcherBuilder { - using Builder = AckEventMatcherBuilder; - Builder&& setExpectedAckedIntervals( - std::vector< - typename QuicTypedTransportTest::NewOutstandingPacketInterval> - expectedAckedIntervals) { - maybeExpectedAckedIntervals = std::move(expectedAckedIntervals); - return std::move(*this); - } - Builder&& setExpectedAckedIntervals( - std::vector::NewOutstandingPacketInterval>> - expectedAckedIntervalsOpt) { - std::vector< - typename QuicTypedTransportTest::NewOutstandingPacketInterval> - expectedAckedIntervals; - for (const auto& maybeInterval : expectedAckedIntervalsOpt) { - CHECK(maybeInterval.has_value()); - expectedAckedIntervals.push_back(maybeInterval.value()); - } - maybeExpectedAckedIntervals = std::move(expectedAckedIntervals); - return std::move(*this); - } - Builder&& setExpectedNumAckedPackets( - const uint64_t expectedNumAckedPackets) { - maybeExpectedNumAckedPackets = expectedNumAckedPackets; - return std::move(*this); - } - Builder&& setAckTime(TimePoint ackTime) { - maybeAckTime = ackTime; - return std::move(*this); - } - Builder&& setAckDelay(std::chrono::microseconds ackDelay) { - maybeAckDelay = ackDelay; - return std::move(*this); - } - Builder&& setLargestAckedPacket(quic::PacketNum largestAckedPacketIn) { - maybeLargestAckedPacket = largestAckedPacketIn; - return std::move(*this); - } - Builder&& setLargestNewlyAckedPacket( - quic::PacketNum largestNewlyAckedPacketIn) { - maybeLargestNewlyAckedPacket = largestNewlyAckedPacketIn; - return std::move(*this); - } - Builder&& setRtt(const folly::Optional& rttIn) { - maybeRtt = rttIn; - CHECK(!noRtt); - return std::move(*this); - } - Builder&& setRttNoAckDelay( - const folly::Optional& rttNoAckDelayIn) { - maybeRttNoAckDelay = rttNoAckDelayIn; - CHECK(!noRtt); - CHECK(!noRttWithNoAckDelay); - return std::move(*this); - } - Builder&& setNoRtt() { - noRtt = true; - CHECK(!maybeRtt); - CHECK(!maybeRttNoAckDelay); - return std::move(*this); - } - Builder&& setNoRttWithNoAckDelay() { - noRttWithNoAckDelay = true; - CHECK(!maybeRttNoAckDelay); - return std::move(*this); - } - auto build() && { - CHECK( - noRtt || - (maybeRtt.has_value() && - (noRttWithNoAckDelay || maybeRttNoAckDelay.has_value()))); - - CHECK(maybeExpectedAckedIntervals.has_value()); - const auto& expectedAckedIntervals = *maybeExpectedAckedIntervals; - CHECK_LT(0, expectedAckedIntervals.size()); - - CHECK(maybeExpectedNumAckedPackets.has_value()); - const auto& expectedNumAckedPackets = *maybeExpectedNumAckedPackets; - - CHECK(maybeAckTime.has_value()); - const auto& ackTime = *maybeAckTime; - - CHECK(maybeAckDelay.has_value()); - const auto& ackDelay = *maybeAckDelay; - - CHECK(maybeLargestAckedPacket.has_value()); - const auto& largestAckedPacket = *maybeLargestAckedPacket; - - CHECK(maybeLargestNewlyAckedPacket.has_value()); - const auto& largestNewlyAckedPacket = *maybeLargestNewlyAckedPacket; - - // sanity check expectedNumAckedPackets and expectedAckedIntervals - // reduces potential of error in test design - { - uint64_t expectedNumAckedPacketsFromIntervals = 0; - std::vector< - typename QuicTypedTransportTest::NewOutstandingPacketInterval> - processedExpectedAckedIntervals; - - for (const auto& interval : expectedAckedIntervals) { - CHECK_LE(interval.start, interval.end); - CHECK_LE(0, interval.end); - expectedNumAckedPacketsFromIntervals += - interval.end - interval.start + 1; - - // should not overlap with existing intervals - for (const auto& processedInterval : - processedExpectedAckedIntervals) { - CHECK( - processedInterval.end < interval.start || - processedInterval.start < interval.end); - } - - processedExpectedAckedIntervals.push_back(interval); - } - CHECK_EQ(expectedNumAckedPacketsFromIntervals, expectedNumAckedPackets); - } - - if constexpr (std::is_same_v) { - return testing::Property( - &quic::SocketObserverInterface::AcksProcessedEvent::getAckEvents, - testing::ElementsAre(testing::AllOf( - // ack time, adjusted ack time, RTT not supported for client now - testing::Field( - &quic::AckEvent::ackDelay, testing::Eq(ackDelay)), - testing::Field( - &quic::AckEvent::largestAckedPacket, - testing::Eq(largestAckedPacket)), - testing::Field( - &quic::AckEvent::largestNewlyAckedPacket, - testing::Eq(largestNewlyAckedPacket)), - testing::Field( - &quic::AckEvent::ackedPackets, - testing::SizeIs(expectedNumAckedPackets))))); - } else if constexpr (std::is_same_v< - T, - QuicServerTransportAfterStartTestBase>) { - return testing::Property( - &quic::SocketObserverInterface::AcksProcessedEvent::getAckEvents, - testing::ElementsAre(testing::AllOf( - testing::Field(&quic::AckEvent::ackTime, testing::Eq(ackTime)), - testing::Field( - &quic::AckEvent::adjustedAckTime, - testing::Eq(ackTime - ackDelay)), - testing::Field( - &quic::AckEvent::ackDelay, testing::Eq(ackDelay)), - testing::Field( - &quic::AckEvent::largestAckedPacket, - testing::Eq(largestAckedPacket)), - testing::Field( - &quic::AckEvent::largestNewlyAckedPacket, - testing::Eq(largestNewlyAckedPacket)), - testing::Field( - &quic::AckEvent::ackedPackets, - testing::SizeIs(expectedNumAckedPackets)), - testing::Field( - &quic::AckEvent::rttSample, testing::Eq(maybeRtt)), - testing::Field( - &quic::AckEvent::rttSampleNoAckDelay, - testing::Eq(maybeRttNoAckDelay))))); - } else { - FAIL(); // unhandled typed test - } - } - explicit AckEventMatcherBuilder() = default; - - folly::Optional::NewOutstandingPacketInterval>> - maybeExpectedAckedIntervals; - folly::Optional maybeExpectedNumAckedPackets; - folly::Optional maybeAckTime; - folly::Optional maybeAckDelay; - folly::Optional maybeLargestAckedPacket; - folly::Optional maybeLargestNewlyAckedPacket; - folly::Optional maybeRtt; - folly::Optional maybeRttNoAckDelay; - bool noRtt{false}; - bool noRttWithNoAckDelay{false}; - }; - - struct ReceivedPacketMatcherBuilder { - using Builder = ReceivedPacketMatcherBuilder; - using Obj = - quic::SocketObserverInterface::PacketsReceivedEvent::ReceivedPacket; - Builder&& setExpectedPacketReceiveTime( - const TimePoint expectedPacketReceiveTime) { - maybeExpectedPacketReceiveTime = expectedPacketReceiveTime; - return std::move(*this); - } - Builder&& setExpectedPacketNumBytes(const uint64_t expectedPacketNumBytes) { - maybeExpectedPacketNumBytes = expectedPacketNumBytes; - return std::move(*this); - } - auto build() && { - CHECK(maybeExpectedPacketReceiveTime.has_value()); - const auto& packetReceiveTime = *maybeExpectedPacketReceiveTime; - - CHECK(maybeExpectedPacketNumBytes.has_value()); - const auto& packetNumBytes = *maybeExpectedPacketNumBytes; - - if constexpr (std::is_same_v) { - return testing::AllOf( - // client does not currently support socket RX timestamps, so we - // expect ts >= now() at time of matcher build - testing::Field( - &Obj::packetReceiveTime, - testing::AnyOf( - testing::Eq(packetReceiveTime), - testing::Ge(TimePoint::clock::now()))), - testing::Field(&Obj::packetNumBytes, testing::Eq(packetNumBytes))); - } else if constexpr (std::is_same_v< - T, - QuicServerTransportAfterStartTestBase>) { - return testing::AllOf( - testing::Field( - &Obj::packetReceiveTime, testing::Eq(packetReceiveTime)), - testing::Field(&Obj::packetNumBytes, testing::Eq(packetNumBytes))); - } else { - FAIL(); // unhandled typed test - } - } - explicit ReceivedPacketMatcherBuilder() = default; - - folly::Optional maybeExpectedPacketReceiveTime; - folly::Optional maybeExpectedPacketNumBytes; - }; - auto getStreamEventMatcherOpt( const StreamId streamId, const StreamInitiator streamInitiator, @@ -1103,12 +1100,31 @@ class QuicTypedTransportTestForObservers : public QuicTypedTransportTest { } }; +template +class QuicTypedTransportAfterStartTestForObservers + : public QuicTypedTransportTestForObservers { + public: + ~QuicTypedTransportAfterStartTestForObservers() override = default; + void SetUp() override { + QuicTypedTransportTestForObservers::SetUp(); + QuicTypedTransportTestForObservers::startTransport(); + } +}; + TYPED_TEST_SUITE( QuicTypedTransportTestForObservers, ::TransportTypes, ::TransportTypeNames); +TYPED_TEST_SUITE( + QuicTypedTransportAfterStartTestForObservers, + ::TransportTypes, + ::TransportTypeNames); + TYPED_TEST(QuicTypedTransportTestForObservers, AttachThenDetach) { + this->startTransport(); + + InSequence s; auto transport = this->getTransport(); auto observer = std::make_unique>(); @@ -1129,8 +1145,16 @@ TYPED_TEST(QuicTypedTransportTestForObservers, AttachThenDetach) { TYPED_TEST( QuicTypedTransportTestForObservers, - CloseNoErrorThenDestroyTransport) { + CloseNoDrainNoErrorThenDestroyTransport) { auto transport = this->getTransport(); + { + auto transportSettings = transport->getTransportSettings(); + transportSettings.shouldDrain = false; + transport->setTransportSettings(transportSettings); + } + this->startTransport(); + + InSequence s; auto observer = std::make_unique>(); EXPECT_CALL(*observer, attached(transport)); @@ -1141,10 +1165,19 @@ TYPED_TEST( GenericApplicationErrorCode::NO_ERROR, toString(GenericApplicationErrorCode::NO_ERROR)); EXPECT_CALL( - *observer, close(transport, folly::Optional(defaultError))); + *observer, + closeStarted( + transport, + AllOf( + // should not be equal to an empty event + testing::Ne(SocketObserverInterface::CloseStartedEvent{ + .maybeCloseReason = folly::none}), + // should be equal to a populated event with default error + testing::Eq(SocketObserverInterface::CloseStartedEvent{ + .maybeCloseReason = defaultError})))); + EXPECT_CALL(*observer, closing(transport, _)); transport->close(folly::none); Mock::VerifyAndClearExpectations(observer.get()); - InSequence s; EXPECT_CALL(*observer, destroyed(transport, IsNull())); this->destroyTransport(); Mock::VerifyAndClearExpectations(observer.get()); @@ -1152,87 +1185,177 @@ TYPED_TEST( TYPED_TEST( QuicTypedTransportTestForObservers, - CloseWithErrorThenDestroyTransport) { + CloseNoErrorDrainEnabled_DrainThenDestroyTransport) { auto transport = this->getTransport(); + { + auto transportSettings = transport->getTransportSettings(); + transportSettings.shouldDrain = true; + transport->setTransportSettings(transportSettings); + } + this->startTransport(); + + InSequence s; auto observer = std::make_unique>(); EXPECT_CALL(*observer, attached(transport)); transport->addObserver(observer.get()); EXPECT_THAT(transport->getObservers(), UnorderedElementsAre(observer.get())); - const auto testError = QuicError( - QuicErrorCode(LocalErrorCode::CONNECTION_RESET), - std::string("testError")); - EXPECT_CALL( - *observer, close(transport, folly::Optional(testError))); - transport->close(testError); - Mock::VerifyAndClearExpectations(observer.get()); - InSequence s; - EXPECT_CALL(*observer, destroyed(transport, IsNull())); - this->destroyTransport(); - Mock::VerifyAndClearExpectations(observer.get()); -} - -TYPED_TEST(QuicTypedTransportTestForObservers, LegacyAttachThenDetach) { - auto transport = this->getTransport(); - auto observer = std::make_unique>(); - - EXPECT_CALL(*observer, observerAttach(transport)); - transport->addObserver(observer.get()); - EXPECT_THAT(transport->getObservers(), UnorderedElementsAre(observer.get())); - EXPECT_CALL(*observer, observerDetach(transport)); - EXPECT_TRUE(transport->removeObserver(observer.get())); - Mock::VerifyAndClearExpectations(observer.get()); - EXPECT_THAT(transport->getObservers(), IsEmpty()); -} - -TYPED_TEST( - QuicTypedTransportTestForObservers, - LegacyCloseNoErrorThenDestroyTransport) { - auto transport = this->getTransport(); - auto observer = std::make_unique>(); - - EXPECT_CALL(*observer, observerAttach(transport)); - transport->addObserver(observer.get()); - EXPECT_THAT(transport->getObservers(), UnorderedElementsAre(observer.get())); - const QuicError defaultError = QuicError( GenericApplicationErrorCode::NO_ERROR, toString(GenericApplicationErrorCode::NO_ERROR)); EXPECT_CALL( - *observer, close(transport, folly::Optional(defaultError))); + *observer, + closeStarted( + transport, + AllOf( + // should not be equal to an empty event + testing::Ne(SocketObserverInterface::CloseStartedEvent{ + .maybeCloseReason = folly::none}), + // should be equal to a populated event with default error + testing::Eq(SocketObserverInterface::CloseStartedEvent{ + .maybeCloseReason = defaultError})))); transport->close(folly::none); + + // wait for the drain + EXPECT_CALL(*observer, closing(transport, _)); + transport->getEventBase()->timer().scheduleTimeoutFn( + [&] { transport->getEventBase()->terminateLoopSoon(); }, + folly::chrono::ceil( + 1ms + kDrainFactor * calculatePTO(this->getConn()))); + transport->getEventBase()->loop(); Mock::VerifyAndClearExpectations(observer.get()); - InSequence s; - EXPECT_CALL(*observer, destroy(transport)); + + EXPECT_CALL(*observer, destroyed(transport, IsNull())); this->destroyTransport(); - Mock::VerifyAndClearExpectations(observer.get()); } TYPED_TEST( QuicTypedTransportTestForObservers, - LegacyCloseWithErrorThenDestroyTransport) { + CloseNoErrorDrainEnabled_DestroyTransport) { auto transport = this->getTransport(); - auto observer = std::make_unique>(); + { + auto transportSettings = transport->getTransportSettings(); + transportSettings.shouldDrain = true; + transport->setTransportSettings(transportSettings); + } + this->startTransport(); - EXPECT_CALL(*observer, observerAttach(transport)); + InSequence s; + auto observer = std::make_unique>(); + + EXPECT_CALL(*observer, attached(transport)); transport->addObserver(observer.get()); EXPECT_THAT(transport->getObservers(), UnorderedElementsAre(observer.get())); - const auto testError = QuicError( + const QuicError defaultError = QuicError( + GenericApplicationErrorCode::NO_ERROR, + toString(GenericApplicationErrorCode::NO_ERROR)); + EXPECT_CALL( + *observer, + closeStarted( + transport, + AllOf( + // should not be equal to an empty event + testing::Ne(SocketObserverInterface::CloseStartedEvent{ + .maybeCloseReason = folly::none}), + // should be equal to a populated event with default error + testing::Eq(SocketObserverInterface::CloseStartedEvent{ + .maybeCloseReason = defaultError})))); + transport->close(folly::none); + Mock::VerifyAndClearExpectations(observer.get()); + + // destroy transport without waiting for drain + EXPECT_CALL(*observer, closing(transport, _)); + EXPECT_CALL(*observer, destroyed(transport, IsNull())); + this->destroyTransport(); +} + +TYPED_TEST( + QuicTypedTransportTestForObservers, + CloseWithErrorDrainDisabled_DestroyTransport) { + auto transport = this->getTransport(); + { + auto transportSettings = transport->getTransportSettings(); + transportSettings.shouldDrain = false; + transport->setTransportSettings(transportSettings); + } + this->startTransport(); + + InSequence s; + auto observer = std::make_unique>(); + + EXPECT_CALL(*observer, attached(transport)); + transport->addObserver(observer.get()); + EXPECT_THAT(transport->getObservers(), UnorderedElementsAre(observer.get())); + + const QuicError testError = QuicError( QuicErrorCode(LocalErrorCode::CONNECTION_RESET), std::string("testError")); EXPECT_CALL( - *observer, close(transport, folly::Optional(testError))); + *observer, + closeStarted( + transport, + AllOf( + // should not be equal to an empty event + testing::Ne(SocketObserverInterface::CloseStartedEvent{ + .maybeCloseReason = folly::none}), + // should be equal to a populated event with default error + testing::Eq(SocketObserverInterface::CloseStartedEvent{ + .maybeCloseReason = testError})))); + EXPECT_CALL(*observer, closing(transport, _)); transport->close(testError); Mock::VerifyAndClearExpectations(observer.get()); - InSequence s; - EXPECT_CALL(*observer, destroy(transport)); + + EXPECT_CALL(*observer, destroyed(transport, IsNull())); this->destroyTransport(); - Mock::VerifyAndClearExpectations(observer.get()); } -TYPED_TEST(QuicTypedTransportTestForObservers, StreamEventsLocalOpenedStream) { +TYPED_TEST( + QuicTypedTransportTestForObservers, + CloseWithErrorDrainEnabled_DestroyTransport) { + auto transport = this->getTransport(); + { + auto transportSettings = transport->getTransportSettings(); + transportSettings.shouldDrain = true; + transport->setTransportSettings(transportSettings); + } + this->startTransport(); + + InSequence s; + auto observer = std::make_unique>(); + + EXPECT_CALL(*observer, attached(transport)); + transport->addObserver(observer.get()); + EXPECT_THAT(transport->getObservers(), UnorderedElementsAre(observer.get())); + + const QuicError testError = QuicError( + QuicErrorCode(LocalErrorCode::CONNECTION_RESET), + std::string("testError")); + + // because of the error, we won't wait for the drain despite it being enabled. + EXPECT_CALL( + *observer, + closeStarted( + transport, + AllOf( + // should not be equal to an empty event + testing::Ne(SocketObserverInterface::CloseStartedEvent{ + .maybeCloseReason = folly::none}), + // should be equal to a populated event with default error + testing::Eq(SocketObserverInterface::CloseStartedEvent{ + .maybeCloseReason = testError})))); + EXPECT_CALL(*observer, closing(transport, _)); + transport->close(testError); + Mock::VerifyAndClearExpectations(observer.get()); + + EXPECT_CALL(*observer, destroyed(transport, IsNull())); + this->destroyTransport(); +} + +TYPED_TEST( + QuicTypedTransportAfterStartTestForObservers, + StreamEventsLocalOpenedStream) { LegacyObserver::EventSet eventSet; eventSet.enable(SocketObserverInterface::Events::streamEvents); auto transport = this->getTransport(); @@ -1301,7 +1424,7 @@ TYPED_TEST(QuicTypedTransportTestForObservers, StreamEventsLocalOpenedStream) { } TYPED_TEST( - QuicTypedTransportTestForObservers, + QuicTypedTransportAfterStartTestForObservers, StreamEventsLocalOpenedStreamImmediateEofLocal) { LegacyObserver::EventSet eventSet; eventSet.enable(SocketObserverInterface::Events::streamEvents); @@ -1364,7 +1487,7 @@ TYPED_TEST( } TYPED_TEST( - QuicTypedTransportTestForObservers, + QuicTypedTransportAfterStartTestForObservers, StreamEventsLocalOpenedStreamImmediateEofLocalRemote) { LegacyObserver::EventSet eventSet; eventSet.enable(SocketObserverInterface::Events::streamEvents); @@ -1423,7 +1546,9 @@ TYPED_TEST( this->destroyTransport(); } -TYPED_TEST(QuicTypedTransportTestForObservers, StreamEventsPeerOpenedStream) { +TYPED_TEST( + QuicTypedTransportAfterStartTestForObservers, + StreamEventsPeerOpenedStream) { LegacyObserver::EventSet eventSet; eventSet.enable(SocketObserverInterface::Events::streamEvents); @@ -1487,7 +1612,7 @@ TYPED_TEST(QuicTypedTransportTestForObservers, StreamEventsPeerOpenedStream) { } TYPED_TEST( - QuicTypedTransportTestForObservers, + QuicTypedTransportAfterStartTestForObservers, StreamEventsPeerOpenedStreamImmediateEofRemote) { LegacyObserver::EventSet eventSet; eventSet.enable(SocketObserverInterface::Events::streamEvents); @@ -1548,7 +1673,7 @@ TYPED_TEST( } TYPED_TEST( - QuicTypedTransportTestForObservers, + QuicTypedTransportAfterStartTestForObservers, StreamEventsPeerOpenedStreamImmediateEofLocalRemote) { LegacyObserver::EventSet eventSet; eventSet.enable(SocketObserverInterface::Events::streamEvents); @@ -1603,7 +1728,7 @@ TYPED_TEST( } TYPED_TEST( - QuicTypedTransportTestForObservers, + QuicTypedTransportAfterStartTestForObservers, StreamEventsPeerOpenedStreamStopSendingPlusRstTriggersRst) { LegacyObserver::EventSet eventSet; eventSet.enable(SocketObserverInterface::Events::streamEvents); @@ -1682,7 +1807,7 @@ TYPED_TEST( } TYPED_TEST( - QuicTypedTransportTestForObservers, + QuicTypedTransportAfterStartTestForObservers, StreamEventsPeerOpenedStreamStopSendingPlusRstTriggersRstBytesInFlight) { LegacyObserver::EventSet eventSet; eventSet.enable(SocketObserverInterface::Events::streamEvents); @@ -1759,7 +1884,7 @@ TYPED_TEST( } TYPED_TEST( - QuicTypedTransportTestForObservers, + QuicTypedTransportAfterStartTestForObservers, StreamEventsPeerOpenedStreamImmediateEorStopSendingTriggersRstBytesInFlight) { LegacyObserver::EventSet eventSet; eventSet.enable(SocketObserverInterface::Events::streamEvents); @@ -1832,7 +1957,7 @@ TYPED_TEST( } TYPED_TEST( - QuicTypedTransportTestForObservers, + QuicTypedTransportAfterStartTestForObservers, WriteEventsOutstandingPacketSent) { InSequence s; @@ -2007,7 +2132,7 @@ TYPED_TEST( } TYPED_TEST( - QuicTypedTransportTestForObservers, + QuicTypedTransportAfterStartTestForObservers, WriteEventsOutstandingPacketSentWroteMoreThanCwnd) { InSequence s; @@ -2150,7 +2275,7 @@ TYPED_TEST( } TYPED_TEST( - QuicTypedTransportTestForObservers, + QuicTypedTransportAfterStartTestForObservers, WriteEventsOutstandingPacketsSentCwndLimited) { InSequence s; @@ -2388,7 +2513,7 @@ TYPED_TEST( } TYPED_TEST( - QuicTypedTransportTestForObservers, + QuicTypedTransportAfterStartTestForObservers, WriteEventsOutstandingPacketSentNoCongestionController) { InSequence s; @@ -2501,7 +2626,7 @@ TYPED_TEST( } TYPED_TEST( - QuicTypedTransportTestForObservers, + QuicTypedTransportAfterStartTestForObservers, AckEventsOutstandingPacketSentThenAckedNoAckDelay) { LegacyObserver::EventSet eventSet; eventSet.enable(SocketObserverInterface::Events::acksProcessedEvents); @@ -2542,7 +2667,7 @@ TYPED_TEST( const auto ackRecvTime = sentTime + 27ms; const auto ackDelay = 0us; const auto matcher = - typename TestFixture::AckEventMatcherBuilder() + AckEventMatcherBuilder() .setExpectedAckedIntervals({maybeWrittenPackets}) .setExpectedNumAckedPackets(1) .setAckTime(ackRecvTime) @@ -2574,7 +2699,7 @@ TYPED_TEST( } TYPED_TEST( - QuicTypedTransportTestForObservers, + QuicTypedTransportAfterStartTestForObservers, AckEventsOutstandingPacketSentThenAckedWithAckDelay) { LegacyObserver::EventSet eventSet; eventSet.enable(SocketObserverInterface::Events::acksProcessedEvents); @@ -2615,7 +2740,7 @@ TYPED_TEST( const auto ackRecvTime = sentTime + 50ms; const auto ackDelay = 5ms; const auto matcher = - typename TestFixture::AckEventMatcherBuilder() + AckEventMatcherBuilder() .setExpectedAckedIntervals({maybeWrittenPackets}) .setExpectedNumAckedPackets(1) .setAckTime(ackRecvTime) @@ -2647,7 +2772,7 @@ TYPED_TEST( } TYPED_TEST( - QuicTypedTransportTestForObservers, + QuicTypedTransportAfterStartTestForObservers, AckEventsOutstandingPacketSentThenAckedWithAckDelayEqRtt) { LegacyObserver::EventSet eventSet; eventSet.enable(SocketObserverInterface::Events::acksProcessedEvents); @@ -2688,7 +2813,7 @@ TYPED_TEST( const auto ackRecvTime = sentTime + 50ms; const auto ackDelay = ackRecvTime - sentTime; // ack delay == RTT! const auto matcher = - typename TestFixture::AckEventMatcherBuilder() + AckEventMatcherBuilder() .setExpectedAckedIntervals({maybeWrittenPackets}) .setExpectedNumAckedPackets(1) .setAckTime(ackRecvTime) @@ -2744,7 +2869,7 @@ TYPED_TEST( } TYPED_TEST( - QuicTypedTransportTestForObservers, + QuicTypedTransportAfterStartTestForObservers, AckEventsOutstandingPacketSentThenAckedWithTooLargeAckDelay) { LegacyObserver::EventSet eventSet; eventSet.enable(SocketObserverInterface::Events::acksProcessedEvents); @@ -2785,7 +2910,7 @@ TYPED_TEST( const auto ackRecvTime = sentTime + 50ms; const auto ackDelay = ackRecvTime + 1ms - sentTime; // ack delay >> RTT! const auto matcher = - typename TestFixture::AckEventMatcherBuilder() + AckEventMatcherBuilder() .setExpectedAckedIntervals({maybeWrittenPackets}) .setExpectedNumAckedPackets(1) .setAckTime(ackRecvTime) @@ -2839,7 +2964,7 @@ TYPED_TEST( } TYPED_TEST( - QuicTypedTransportTestForObservers, + QuicTypedTransportAfterStartTestForObservers, AckEventsThreeOutstandingPacketsSentThenAllAckedAtOnce) { LegacyObserver::EventSet eventSet; eventSet.enable(SocketObserverInterface::Events::acksProcessedEvents); @@ -2890,7 +3015,7 @@ TYPED_TEST( const auto ackRecvTime = sentTime + 27ms; const auto ackDelay = 5ms; const auto matcher = - typename TestFixture::AckEventMatcherBuilder() + AckEventMatcherBuilder() .setExpectedAckedIntervals( {maybeWrittenPackets1, maybeWrittenPackets2, @@ -2925,7 +3050,7 @@ TYPED_TEST( } TYPED_TEST( - QuicTypedTransportTestForObservers, + QuicTypedTransportAfterStartTestForObservers, AckEventsThreeOutstandingPacketsSentAndAckedSequentially) { LegacyObserver::EventSet eventSet; eventSet.enable(SocketObserverInterface::Events::acksProcessedEvents); @@ -2961,7 +3086,7 @@ TYPED_TEST( const auto ackRecvTime = sentTime + 27ms; const auto ackDelay = 5ms; const auto matcher = - typename TestFixture::AckEventMatcherBuilder() + AckEventMatcherBuilder() .setExpectedAckedIntervals({maybeWrittenPackets}) .setExpectedNumAckedPackets(1) .setAckTime(ackRecvTime) @@ -3001,7 +3126,7 @@ TYPED_TEST( const auto ackRecvTime = sentTime + 443ms; const auto ackDelay = 7ms; const auto matcher = - typename TestFixture::AckEventMatcherBuilder() + AckEventMatcherBuilder() .setExpectedAckedIntervals({maybeWrittenPackets}) .setExpectedNumAckedPackets(1) .setAckTime(ackRecvTime) @@ -3042,7 +3167,7 @@ TYPED_TEST( const auto ackRecvTime = sentTime + 62ms; const auto ackDelay = 3ms; const auto matcher = - typename TestFixture::AckEventMatcherBuilder() + AckEventMatcherBuilder() .setExpectedAckedIntervals({maybeWrittenPackets}) .setExpectedNumAckedPackets(1) .setAckTime(ackRecvTime) @@ -3077,7 +3202,7 @@ TYPED_TEST( } TYPED_TEST( - QuicTypedTransportTestForObservers, + QuicTypedTransportAfterStartTestForObservers, AckEventsThreeOutstandingPacketsSentThenAckedSequentially) { LegacyObserver::EventSet eventSet; eventSet.enable(SocketObserverInterface::Events::acksProcessedEvents); @@ -3129,7 +3254,7 @@ TYPED_TEST( const auto ackRecvTime = sentTime + 122ms; const auto ackDelay = 3ms; const auto matcher = - typename TestFixture::AckEventMatcherBuilder() + AckEventMatcherBuilder() .setExpectedAckedIntervals({maybeWrittenPackets1}) .setExpectedNumAckedPackets(1) .setAckTime(ackRecvTime) @@ -3166,7 +3291,7 @@ TYPED_TEST( const auto ackRecvTime = sentTime + 62ms; const auto ackDelay = 1ms; const auto matcher = - typename TestFixture::AckEventMatcherBuilder() + AckEventMatcherBuilder() .setExpectedAckedIntervals({maybeWrittenPackets2}) .setExpectedNumAckedPackets(1) .setAckTime(ackRecvTime) @@ -3203,7 +3328,7 @@ TYPED_TEST( const auto ackRecvTime = sentTime + 82ms; const auto ackDelay = 20ms; const auto matcher = - typename TestFixture::AckEventMatcherBuilder() + AckEventMatcherBuilder() .setExpectedAckedIntervals({maybeWrittenPackets3}) .setExpectedNumAckedPackets(1) .setAckTime(ackRecvTime) @@ -3238,7 +3363,7 @@ TYPED_TEST( } TYPED_TEST( - QuicTypedTransportTestForObservers, + QuicTypedTransportAfterStartTestForObservers, AckEventsThreeOutstandingPacketsSentThenFirstLastAckedSequentiallyThenSecondAcked) { LegacyObserver::EventSet eventSet; eventSet.enable(SocketObserverInterface::Events::acksProcessedEvents); @@ -3297,7 +3422,7 @@ TYPED_TEST( const auto ackRecvTime = sentTime + 20ms; const auto ackDelay = 5ms; const auto matcher = - typename TestFixture::AckEventMatcherBuilder() + AckEventMatcherBuilder() .setExpectedAckedIntervals({maybeWrittenPackets1}) .setExpectedNumAckedPackets(1) .setAckTime(ackRecvTime) @@ -3334,7 +3459,7 @@ TYPED_TEST( const auto ackRecvTime = sentTime + 11ms; const auto ackDelay = 4ms; const auto matcher = - typename TestFixture::AckEventMatcherBuilder() + AckEventMatcherBuilder() .setExpectedAckedIntervals({maybeWrittenPackets1}) .setExpectedNumAckedPackets(1) .setAckTime(ackRecvTime) @@ -3372,7 +3497,7 @@ TYPED_TEST( const auto ackRecvTime = maybeWrittenPackets3->sentTime + 11ms; const auto ackDelay = 2ms; const auto matcher = - typename TestFixture::AckEventMatcherBuilder() + AckEventMatcherBuilder() .setExpectedAckedIntervals({maybeWrittenPackets1}) .setExpectedNumAckedPackets(1) .setAckTime(ackRecvTime) @@ -3403,7 +3528,7 @@ TYPED_TEST( } TYPED_TEST( - QuicTypedTransportTestForObservers, + QuicTypedTransportAfterStartTestForObservers, AckEventsThreeOutstandingPacketsSentThenFirstLastAckedAtOnceThenSecondAcked) { LegacyObserver::EventSet eventSet; eventSet.enable(SocketObserverInterface::Events::acksProcessedEvents); @@ -3462,7 +3587,7 @@ TYPED_TEST( const auto ackRecvTime = sentTime + 20ms; const auto ackDelay = 5ms; const auto matcher = - typename TestFixture::AckEventMatcherBuilder() + AckEventMatcherBuilder() .setExpectedAckedIntervals( {maybeWrittenPackets1, maybeWrittenPackets3}) .setExpectedNumAckedPackets(2) @@ -3501,7 +3626,7 @@ TYPED_TEST( const auto ackRecvTime = maybeWrittenPackets3->sentTime + 11ms; const auto ackDelay = 2ms; const auto matcher = - typename TestFixture::AckEventMatcherBuilder() + AckEventMatcherBuilder() .setExpectedAckedIntervals({maybeWrittenPackets1}) .setExpectedNumAckedPackets(1) .setAckTime(ackRecvTime) @@ -3531,7 +3656,9 @@ TYPED_TEST( this->destroyTransport(); } -TYPED_TEST(QuicTypedTransportTestForObservers, PacketsReceivedEventsSingle) { +TYPED_TEST( + QuicTypedTransportAfterStartTestForObservers, + PacketsReceivedEventsSingle) { using Event = quic::SocketObserverInterface::PacketsReceivedEvent; LegacyObserver::EventSet eventSet; eventSet.enable(SocketObserverInterface::Events::packetsReceivedEvents); @@ -3568,11 +3695,10 @@ TYPED_TEST(QuicTypedTransportTestForObservers, PacketsReceivedEventsSingle) { testing::Field(&Event::receivedPackets, testing::SizeIs(1)), testing::Field( &Event::receivedPackets, - testing::ElementsAre( - typename TestFixture::ReceivedPacketMatcherBuilder() - .setExpectedPacketReceiveTime(pkt1RecvTime) - .setExpectedPacketNumBytes(pkt1NumBytes) - .build()))); + testing::ElementsAre(ReceivedPacketMatcherBuilder() + .setExpectedPacketReceiveTime(pkt1RecvTime) + .setExpectedPacketNumBytes(pkt1NumBytes) + .build()))); EXPECT_CALL(*obs1, packetsReceived(_, _)).Times(0); EXPECT_CALL(*obs2, packetsReceived(transport, matcher)); @@ -3595,11 +3721,10 @@ TYPED_TEST(QuicTypedTransportTestForObservers, PacketsReceivedEventsSingle) { testing::Field(&Event::receivedPackets, testing::SizeIs(1)), testing::Field( &Event::receivedPackets, - testing::ElementsAre( - typename TestFixture::ReceivedPacketMatcherBuilder() - .setExpectedPacketReceiveTime(pkt2RecvTime) - .setExpectedPacketNumBytes(pkt2NumBytes) - .build()))); + testing::ElementsAre(ReceivedPacketMatcherBuilder() + .setExpectedPacketReceiveTime(pkt2RecvTime) + .setExpectedPacketNumBytes(pkt2NumBytes) + .build()))); EXPECT_CALL(*obs1, packetsReceived(_, _)).Times(0); EXPECT_CALL(*obs2, packetsReceived(transport, matcher)); @@ -3610,12 +3735,12 @@ TYPED_TEST(QuicTypedTransportTestForObservers, PacketsReceivedEventsSingle) { this->destroyTransport(); } -TYPED_TEST(QuicTypedTransportTestForObservers, PacketsReceivedEventsMulti) { +TYPED_TEST( + QuicTypedTransportAfterStartTestForObservers, + PacketsReceivedEventsMulti) { // skip for client transport tests for now as supporting test foundation // does not properly support batch delivery - if constexpr (std::is_same_v< - TypeParam, - QuicClientTransportAfterStartTestBase>) { + if constexpr (std::is_base_of_v) { return; } @@ -3667,12 +3792,12 @@ TYPED_TEST(QuicTypedTransportTestForObservers, PacketsReceivedEventsMulti) { &Event::receivedPackets, testing::ElementsAre( // pkt1 - typename TestFixture::ReceivedPacketMatcherBuilder() + ReceivedPacketMatcherBuilder() .setExpectedPacketReceiveTime(pktBatch1RecvTime) .setExpectedPacketNumBytes(pkt1NumBytes) .build(), // pkt2 - typename TestFixture::ReceivedPacketMatcherBuilder() + ReceivedPacketMatcherBuilder() .setExpectedPacketReceiveTime(pktBatch1RecvTime) .setExpectedPacketNumBytes(pkt2NumBytes) .build()))); @@ -3710,12 +3835,12 @@ TYPED_TEST(QuicTypedTransportTestForObservers, PacketsReceivedEventsMulti) { &Event::receivedPackets, testing::ElementsAre( // pkt1 - typename TestFixture::ReceivedPacketMatcherBuilder() + ReceivedPacketMatcherBuilder() .setExpectedPacketReceiveTime(pktBatch2RecvTime) .setExpectedPacketNumBytes(pkt3NumBytes) .build(), // pkt2 - typename TestFixture::ReceivedPacketMatcherBuilder() + ReceivedPacketMatcherBuilder() .setExpectedPacketReceiveTime(pktBatch2RecvTime) .setExpectedPacketNumBytes(pkt4NumBytes) .build()))); diff --git a/quic/api/test/TestQuicTransport.h b/quic/api/test/TestQuicTransport.h index 0ef654311..fcd56c252 100644 --- a/quic/api/test/TestQuicTransport.h +++ b/quic/api/test/TestQuicTransport.h @@ -45,7 +45,9 @@ class TestQuicTransport QuicError( QuicErrorCode(LocalErrorCode::SHUTTING_DOWN), std::string("shutdown")), - false); + false /* drainConnection */); + // closeImpl may have been called earlier with drain = true, so force close. + closeUdpSocket(); } QuicVersion getVersion() { diff --git a/quic/client/QuicClientTransport.cpp b/quic/client/QuicClientTransport.cpp index b126aeba8..a246cc6fe 100644 --- a/quic/client/QuicClientTransport.cpp +++ b/quic/client/QuicClientTransport.cpp @@ -103,7 +103,9 @@ QuicClientTransport::~QuicClientTransport() { QuicError( QuicErrorCode(LocalErrorCode::SHUTTING_DOWN), std::string("Closing from client destructor")), - false); + false /* drainConnection */); + // closeImpl may have been called earlier with drain = true, so force close. + closeUdpSocket(); if (clientConn_->happyEyeballsState.secondSocket) { auto sock = std::move(clientConn_->happyEyeballsState.secondSocket); diff --git a/quic/observer/SocketObserverInterface.h b/quic/observer/SocketObserverInterface.h index 2fc7f59bd..cf73e6dfc 100644 --- a/quic/observer/SocketObserverInterface.h +++ b/quic/observer/SocketObserverInterface.h @@ -44,6 +44,38 @@ class SocketObserverInterface { }; virtual ~SocketObserverInterface() = default; + /** + * Event structures. + */ + + struct CloseStartedEvent { + // Error code provided when close() or closeNow() called. + // + // The presence of an error code does NOT indicate that a "problem" caused + // the socket to close, since an error code can be an application timeout. + folly::Optional maybeCloseReason; + + // Default equality comparator available in C++20. + // + // mvfst currently supports C++17 onwards. However, we can enable this for + // unit tests and other code that we expect to run in C++20. +#if FOLLY_CPLUSPLUS >= 202002L + friend auto operator<=>( + const CloseStartedEvent&, + const CloseStartedEvent&) = default; +#endif + }; + + struct ClosingEvent { + // Default equality comparator available in C++20. + // + // mvfst currently supports C++17 onwards. However, we can enable this for + // unit tests and other code that we expect to run in C++20. +#if FOLLY_CPLUSPLUS >= 202002L + friend auto operator<=>(const ClosingEvent&, const ClosingEvent&) = default; +#endif + }; + struct WriteEvent { [[nodiscard]] const std::deque& getOutstandingPackets() const { @@ -408,19 +440,38 @@ class SocketObserverInterface { using StreamCloseEvent = StreamEvent; /** - * close() will be invoked when the socket is being closed. - * - * If the callback handler does not unsubscribe itself upon being called, - * then it may be called multiple times (e.g., by a call to close() by - * the application, and then again when closeNow() is called on - * destruction). - * - * @param socket Socket being closed. - * @param errorOpt Error information, if connection closed due to error. + * Events. */ - virtual void close( + + /** + * closeStarted() is invoked when socket close begins. + * + * The socket may stay open for some time after this event to drain. + * + * @param socket Socket being closed. + * @param event CloseStartedEvent with details. + */ + virtual void closeStarted( QuicSocket* /* socket */, - const folly::Optional& /* errorOpt */) noexcept {} + const CloseStartedEvent& /* event */) noexcept {} + + /** + * closing() is invoked right before the transport is unbound from UDP socket. + * + * closeStarted() should have been invoked prior to this event as this event + * marks the completion of the socket being closed and the last opportunity + * to capture state from the socket. + * + * Called immediately BEFORE the transport is unbound from the UDP socket to + * be consistent with TCP sockets, for which the closing() event would mark + * the last opportunity to get information (such as TCP_INFO) from the socket. + * + * @param socket Socket being closed. + * @param event ClosingEvent with details. + */ + virtual void closing( + QuicSocket* /* socket */, + const ClosingEvent& /* event */) noexcept {} /** * evbAttach() will be invoked when a new event base is attached to this diff --git a/quic/server/QuicServerTransport.cpp b/quic/server/QuicServerTransport.cpp index ac66f5950..5e66e77d7 100644 --- a/quic/server/QuicServerTransport.cpp +++ b/quic/server/QuicServerTransport.cpp @@ -79,7 +79,9 @@ QuicServerTransport::~QuicServerTransport() { QuicError( QuicErrorCode(LocalErrorCode::SHUTTING_DOWN), std::string("Closing from server destructor")), - false); + false /* drainConnection */); + // closeImpl may have been called earlier with drain = true, so force close. + closeUdpSocket(); } QuicServerTransport::Ptr QuicServerTransport::make(