diff --git a/quic/api/QuicSocket.h b/quic/api/QuicSocket.h index 07c278435..ec2feba18 100644 --- a/quic/api/QuicSocket.h +++ b/quic/api/QuicSocket.h @@ -1274,6 +1274,22 @@ class QuicSocket { return false; } + /** + * Adds an observer. + * + * If the observer is already added, this is a no-op. + * + * @param observer Observer to add. + * @return Whether the observer was added (fails if no list). + */ + bool addObserver(std::shared_ptr observer) { + if (auto list = getSocketObserverContainer()) { + list->addObserver(std::move(observer)); + return true; + } + return false; + } + /** * Removes an observer. * @@ -1287,6 +1303,19 @@ class QuicSocket { return false; } + /** + * Removes an observer. + * + * @param observer Observer to remove. + * @return Whether the observer was found and removed. + */ + bool removeObserver(std::shared_ptr observer) { + if (auto list = getSocketObserverContainer()) { + return list->removeObserver(std::move(observer)); + } + return false; + } + /** * Get number of observers. * diff --git a/quic/api/test/QuicSocketTest.cpp b/quic/api/test/QuicSocketTest.cpp new file mode 100644 index 000000000..b37be8c7a --- /dev/null +++ b/quic/api/test/QuicSocketTest.cpp @@ -0,0 +1,86 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include + +using namespace quic; +using namespace testing; + +class QuicSocketTest : public Test { + public: + void SetUp() override { + socket_ = std::make_shared(); + } + + protected: + std::shared_ptr socket_; +}; + +TEST_F(QuicSocketTest, ObserverAddRemoveNoContainer) { + auto obs1 = std::make_shared>(); + EXPECT_CALL(*socket_, getSocketObserverContainer()).WillOnce(Return(nullptr)); + EXPECT_FALSE(socket_->addObserver(obs1)); + + auto obs2 = std::make_shared>(); + EXPECT_CALL(*socket_, getSocketObserverContainer()).WillOnce(Return(nullptr)); + EXPECT_FALSE(socket_->removeObserver(obs1)); +} + +TEST_F(QuicSocketTest, ObserverAddRemoveWithContainer) { + auto observerContainer = + std::make_shared(socket_.get()); + + InSequence s; + + auto obs1 = std::make_unique>(); + EXPECT_CALL(*socket_, getSocketObserverContainer()) + .WillOnce(Return(observerContainer.get())); + EXPECT_CALL(*obs1, observerAttach(socket_.get())); + EXPECT_TRUE(socket_->addObserver(obs1.get())); + + EXPECT_EQ(1, observerContainer->numObservers()); + + auto obs2 = std::make_unique>(); + EXPECT_CALL(*socket_, getSocketObserverContainer()) + .WillOnce(Return(observerContainer.get())); + EXPECT_CALL(*obs1, observerDetach(socket_.get())); + EXPECT_TRUE(socket_->removeObserver(obs1.get())); +} + +TEST_F(QuicSocketTest, ObserverSharedPtrAddRemoveNoContainer) { + auto obs1 = std::make_shared>(); + EXPECT_CALL(*socket_, getSocketObserverContainer()).WillOnce(Return(nullptr)); + EXPECT_FALSE(socket_->addObserver(obs1)); + + auto obs2 = std::make_shared>(); + EXPECT_CALL(*socket_, getSocketObserverContainer()).WillOnce(Return(nullptr)); + EXPECT_FALSE(socket_->removeObserver(obs1)); +} + +TEST_F(QuicSocketTest, ObserverSharedPtrAddRemoveWithContainer) { + auto observerContainer = + std::make_shared(socket_.get()); + + InSequence s; + + auto obs1 = std::make_shared>(); + EXPECT_CALL(*socket_, getSocketObserverContainer()) + .WillOnce(Return(observerContainer.get())); + EXPECT_CALL(*obs1, observerAttach(socket_.get())); + EXPECT_TRUE(socket_->addObserver(obs1)); + + EXPECT_EQ(1, observerContainer->numObservers()); + + auto obs2 = std::make_shared>(); + EXPECT_CALL(*socket_, getSocketObserverContainer()) + .WillOnce(Return(observerContainer.get())); + EXPECT_CALL(*obs1, observerDetach(socket_.get())); + EXPECT_TRUE(socket_->removeObserver(obs1)); +} diff --git a/quic/api/test/QuicTransportBaseTest.cpp b/quic/api/test/QuicTransportBaseTest.cpp index 3ecbda504..6b814b367 100644 --- a/quic/api/test/QuicTransportBaseTest.cpp +++ b/quic/api/test/QuicTransportBaseTest.cpp @@ -3686,7 +3686,7 @@ TEST_P(QuicTransportImplTestBase, StreamWriteCallbackUnregister) { evb->loopOnce(); } -TEST_P(QuicTransportImplTestBase, ObserverAttachRemove) { +TEST_P(QuicTransportImplTestBase, ObserverRemove) { auto cb = std::make_unique>(); EXPECT_CALL(*cb, observerAttach(transport.get())); transport->addObserver(cb.get()); @@ -3697,63 +3697,7 @@ TEST_P(QuicTransportImplTestBase, ObserverAttachRemove) { EXPECT_THAT(transport->getObservers(), IsEmpty()); } -TEST_P(QuicTransportImplTestBase, ObserverAttachRemoveMultiple) { - auto cb1 = std::make_unique>(); - EXPECT_CALL(*cb1, observerAttach(transport.get())); - transport->addObserver(cb1.get()); - Mock::VerifyAndClearExpectations(cb1.get()); - EXPECT_THAT(transport->getObservers(), UnorderedElementsAre(cb1.get())); - - auto cb2 = std::make_unique>(); - EXPECT_CALL(*cb2, observerAttach(transport.get())); - transport->addObserver(cb2.get()); - Mock::VerifyAndClearExpectations(cb2.get()); - EXPECT_THAT( - transport->getObservers(), UnorderedElementsAre(cb1.get(), cb2.get())); - - EXPECT_CALL(*cb1, observerDetach(transport.get())); - EXPECT_TRUE(transport->removeObserver(cb1.get())); - Mock::VerifyAndClearExpectations(cb1.get()); - EXPECT_THAT(transport->getObservers(), UnorderedElementsAre(cb2.get())); - - EXPECT_CALL(*cb2, observerDetach(transport.get())); - EXPECT_TRUE(transport->removeObserver(cb2.get())); - Mock::VerifyAndClearExpectations(cb2.get()); - EXPECT_THAT(transport->getObservers(), IsEmpty()); -} - -TEST_P(QuicTransportImplTestBase, ObserverAttachRemoveMultipleReverse) { - auto cb1 = std::make_unique>(); - EXPECT_CALL(*cb1, observerAttach(transport.get())); - transport->addObserver(cb1.get()); - Mock::VerifyAndClearExpectations(cb1.get()); - EXPECT_THAT(transport->getObservers(), UnorderedElementsAre(cb1.get())); - - auto cb2 = std::make_unique>(); - EXPECT_CALL(*cb2, observerAttach(transport.get())); - transport->addObserver(cb2.get()); - Mock::VerifyAndClearExpectations(cb2.get()); - EXPECT_THAT( - transport->getObservers(), UnorderedElementsAre(cb1.get(), cb2.get())); - - EXPECT_CALL(*cb2, observerDetach(transport.get())); - EXPECT_TRUE(transport->removeObserver(cb2.get())); - Mock::VerifyAndClearExpectations(cb2.get()); - EXPECT_THAT(transport->getObservers(), UnorderedElementsAre(cb1.get())); - - EXPECT_CALL(*cb1, observerDetach(transport.get())); - EXPECT_TRUE(transport->removeObserver(cb1.get())); - Mock::VerifyAndClearExpectations(cb1.get()); - EXPECT_THAT(transport->getObservers(), IsEmpty()); -} - -TEST_P(QuicTransportImplTestBase, ObserverRemoveMissing) { - auto cb = std::make_unique>(); - EXPECT_FALSE(transport->removeObserver(cb.get())); - EXPECT_THAT(transport->getObservers(), IsEmpty()); -} - -TEST_P(QuicTransportImplTestBase, ObserverDestroyTransport) { +TEST_P(QuicTransportImplTestBase, ObserverDestroy) { auto cb = std::make_unique>(); EXPECT_CALL(*cb, observerAttach(transport.get())); transport->addObserver(cb.get()); @@ -3765,7 +3709,62 @@ TEST_P(QuicTransportImplTestBase, ObserverDestroyTransport) { Mock::VerifyAndClearExpectations(cb.get()); } -TEST_P(QuicTransportImplTestBase, ObserverCloseNoErrorThenDestroyTransport) { +TEST_P(QuicTransportImplTestBase, ObserverRemoveMissing) { + auto cb = std::make_unique>(); + EXPECT_FALSE(transport->removeObserver(cb.get())); + EXPECT_THAT(transport->getObservers(), IsEmpty()); +} + +TEST_P(QuicTransportImplTestBase, ObserverSharedPtrRemove) { + auto cb = std::make_shared>(); + EXPECT_CALL(*cb, observerAttach(transport.get())); + transport->addObserver(cb); + EXPECT_THAT(transport->getObservers(), UnorderedElementsAre(cb.get())); + EXPECT_CALL(*cb, observerDetach(transport.get())); + EXPECT_TRUE(transport->removeObserver(cb)); + Mock::VerifyAndClearExpectations(cb.get()); + EXPECT_THAT(transport->getObservers(), IsEmpty()); +} + +TEST_P(QuicTransportImplTestBase, ObserverSharedPtrDestroy) { + auto cb = std::make_shared>(); + EXPECT_CALL(*cb, observerAttach(transport.get())); + transport->addObserver(cb); + EXPECT_THAT(transport->getObservers(), UnorderedElementsAre(cb.get())); + InSequence s; + EXPECT_CALL(*cb, close(transport.get(), _)); + EXPECT_CALL(*cb, destroy(transport.get())); + transport = nullptr; + Mock::VerifyAndClearExpectations(cb.get()); +} + +TEST_P(QuicTransportImplTestBase, ObserverSharedPtrReleasedDestroy) { + auto cb = std::make_shared>(); + EXPECT_CALL(*cb, observerAttach(transport.get())); + transport->addObserver(cb); + EXPECT_THAT(transport->getObservers(), UnorderedElementsAre(cb.get())); + + // now that observer is attached, we release shared_ptr but keep raw ptr + // since the container holds shared_ptr too, observer should not be destroyed + MockLegacyObserver::Safety dc(*cb.get()); + auto cbRaw = cb.get(); + cb = nullptr; + EXPECT_FALSE(dc.destroyed()); // should still exist + + InSequence s; + EXPECT_CALL(*cbRaw, close(transport.get(), _)); + EXPECT_CALL(*cbRaw, destroy(transport.get())); + transport = nullptr; + Mock::VerifyAndClearExpectations(cb.get()); +} + +TEST_P(QuicTransportImplTestBase, ObserverSharedPtrRemoveMissing) { + auto cb = std::make_shared>(); + EXPECT_FALSE(transport->removeObserver(cb.get())); + EXPECT_THAT(transport->getObservers(), IsEmpty()); +} + +TEST_P(QuicTransportImplTestBase, ObserverCloseNoErrorThenDestroy) { auto cb = std::make_unique>(); EXPECT_CALL(*cb, observerAttach(transport.get())); transport->addObserver(cb.get()); @@ -3784,7 +3783,7 @@ TEST_P(QuicTransportImplTestBase, ObserverCloseNoErrorThenDestroyTransport) { Mock::VerifyAndClearExpectations(cb.get()); } -TEST_P(QuicTransportImplTestBase, ObserverCloseWithErrorThenDestroyTransport) { +TEST_P(QuicTransportImplTestBase, ObserverCloseWithErrorThenDestroy) { auto cb = std::make_unique>(); EXPECT_CALL(*cb, observerAttach(transport.get())); transport->addObserver(cb.get()); @@ -3803,7 +3802,7 @@ TEST_P(QuicTransportImplTestBase, ObserverCloseWithErrorThenDestroyTransport) { Mock::VerifyAndClearExpectations(cb.get()); } -TEST_P(QuicTransportImplTestBase, ObserverDetachObserverImmediately) { +TEST_P(QuicTransportImplTestBase, ObserverDetachImmediately) { auto cb = std::make_unique>(); EXPECT_CALL(*cb, observerAttach(transport.get())); transport->addObserver(cb.get()); @@ -3815,7 +3814,7 @@ TEST_P(QuicTransportImplTestBase, ObserverDetachObserverImmediately) { EXPECT_THAT(transport->getObservers(), IsEmpty()); } -TEST_P(QuicTransportImplTestBase, ObserverDetachObserverAfterTransportClose) { +TEST_P(QuicTransportImplTestBase, ObserverDetachAfterClose) { auto cb = std::make_unique>(); EXPECT_CALL(*cb, observerAttach(transport.get())); transport->addObserver(cb.get()); @@ -3831,9 +3830,7 @@ TEST_P(QuicTransportImplTestBase, ObserverDetachObserverAfterTransportClose) { EXPECT_THAT(transport->getObservers(), IsEmpty()); } -TEST_F( - QuicTransportImplTest, - ObserverDetachObserverOnCloseDuringTransportDestroy) { +TEST_F(QuicTransportImplTest, ObserverDetachOnCloseDuringDestroy) { auto cb = std::make_unique>(); EXPECT_CALL(*cb, observerAttach(transport.get())); transport->addObserver(cb.get()); @@ -3876,7 +3873,59 @@ TEST_P(QuicTransportImplTestBase, ObserverMultipleAttachRemove) { transport = nullptr; } -TEST_P(QuicTransportImplTestBase, ObserverMultipleAttachDestroyTransport) { +TEST_P(QuicTransportImplTestBase, ObserverSharedPtrMultipleAttachRemove) { + auto cb1 = std::make_shared>(); + EXPECT_CALL(*cb1, observerAttach(transport.get())); + transport->addObserver(cb1); + Mock::VerifyAndClearExpectations(cb1.get()); + EXPECT_THAT(transport->getObservers(), UnorderedElementsAre(cb1.get())); + + auto cb2 = std::make_shared>(); + EXPECT_CALL(*cb2, observerAttach(transport.get())); + transport->addObserver(cb2); + Mock::VerifyAndClearExpectations(cb2.get()); + EXPECT_THAT( + transport->getObservers(), UnorderedElementsAre(cb1.get(), cb2.get())); + + EXPECT_CALL(*cb1, observerDetach(transport.get())); + EXPECT_TRUE(transport->removeObserver(cb1)); + Mock::VerifyAndClearExpectations(cb1.get()); + EXPECT_THAT(transport->getObservers(), UnorderedElementsAre(cb2.get())); + + EXPECT_CALL(*cb2, observerDetach(transport.get())); + EXPECT_TRUE(transport->removeObserver(cb2)); + Mock::VerifyAndClearExpectations(cb2.get()); + EXPECT_THAT(transport->getObservers(), IsEmpty()); +} + +TEST_P(QuicTransportImplTestBase, ObserverMultipleAttachRemoveReverse) { + auto cb1 = std::make_unique>(); + EXPECT_CALL(*cb1, observerAttach(transport.get())); + transport->addObserver(cb1.get()); + Mock::VerifyAndClearExpectations(cb1.get()); + EXPECT_THAT(transport->getObservers(), UnorderedElementsAre(cb1.get())); + + auto cb2 = std::make_unique>(); + EXPECT_CALL(*cb2, observerAttach(transport.get())); + transport->addObserver(cb2.get()); + Mock::VerifyAndClearExpectations(cb2.get()); + EXPECT_THAT( + transport->getObservers(), UnorderedElementsAre(cb1.get(), cb2.get())); + + EXPECT_CALL(*cb2, observerDetach(transport.get())); + EXPECT_TRUE(transport->removeObserver(cb2.get())); + Mock::VerifyAndClearExpectations(cb1.get()); + Mock::VerifyAndClearExpectations(cb2.get()); + EXPECT_THAT(transport->getObservers(), UnorderedElementsAre(cb1.get())); + + EXPECT_CALL(*cb1, observerDetach(transport.get())); + EXPECT_TRUE(transport->removeObserver(cb1.get())); + Mock::VerifyAndClearExpectations(cb1.get()); + Mock::VerifyAndClearExpectations(cb2.get()); + EXPECT_THAT(transport->getObservers(), IsEmpty()); +} + +TEST_P(QuicTransportImplTestBase, ObserverMultipleAttachDestroy) { auto cb1 = std::make_unique>(); EXPECT_CALL(*cb1, observerAttach(transport.get())); transport->addObserver(cb1.get());