diff --git a/quic/api/Observer.h b/quic/api/Observer.h index 0ba331ced..218017a01 100644 --- a/quic/api/Observer.h +++ b/quic/api/Observer.h @@ -11,6 +11,7 @@ #include #include #include +#include namespace folly { class EventBase; @@ -46,6 +47,7 @@ class Observer { bool pmtuEvents{false}; bool rttSamples{false}; bool knobFrameEvents{false}; + bool streamEvents{false}; virtual void enableAllEvents() { evbEvents = true; @@ -56,6 +58,7 @@ class Observer { spuriousLossEvents = true; pmtuEvents = true; knobFrameEvents = true; + streamEvents = true; } /** @@ -235,6 +238,23 @@ class Observer { const quic::KnobFrame knobFrame; }; + struct StreamEvent { + StreamEvent( + const StreamId id, + StreamInitiator initiator, + StreamDirectionality directionality) + : streamId(id), + streamInitiator(initiator), + streamDirectionality(directionality) {} + + const StreamId streamId; + const StreamInitiator streamInitiator; + const StreamDirectionality streamDirectionality; + }; + + using StreamOpenEvent = StreamEvent; + using StreamCloseEvent = StreamEvent; + /** * observerAttach() will be invoked when an observer is added. * @@ -413,6 +433,26 @@ class Observer { QuicSocket*, /* socket */ const KnobFrameEvent& /* event */) {} + /** + * streamOpened() is invoked when a new stream is opened. + * + * @param socket Socket associated with the event. + * @param event Event containing details. + */ + virtual void streamOpened( + QuicSocket*, /* socket */ + const StreamOpenEvent& /* event */) {} + + /** + * streamClosed() is invoked when a stream is closed. + * + * @param socket Socket associated with the event. + * @param event Event containing details. + */ + virtual void streamClosed( + QuicSocket*, /* socket */ + const StreamCloseEvent& /* event */) {} + protected: // observer configuration; cannot be changed post instantiation const Config observerConfig_; diff --git a/quic/api/QuicSocket.h b/quic/api/QuicSocket.h index 35071a788..6117526d5 100644 --- a/quic/api/QuicSocket.h +++ b/quic/api/QuicSocket.h @@ -18,6 +18,7 @@ #include #include #include +#include #include #include @@ -782,6 +783,11 @@ class QuicSocket { */ virtual bool isServerStream(StreamId stream) noexcept = 0; + /** + * Returns initiator (self or peer) of a stream by ID. + */ + virtual StreamInitiator getStreamInitiator(StreamId stream) noexcept = 0; + /** * Returns whether a stream ID represents a unidirectional stream. */ @@ -792,6 +798,12 @@ class QuicSocket { */ virtual bool isBidirectionalStream(StreamId stream) noexcept = 0; + /** + * Returns directionality (unidirectional or bidirectional) of a stream by ID. + */ + virtual StreamDirectionality getStreamDirectionality( + StreamId stream) noexcept = 0; + /** * Callback class for receiving write readiness notifications */ diff --git a/quic/api/QuicTransportBase.cpp b/quic/api/QuicTransportBase.cpp index 8d0e8da67..1b84b69e5 100644 --- a/quic/api/QuicTransportBase.cpp +++ b/quic/api/QuicTransportBase.cpp @@ -1506,13 +1506,22 @@ void QuicTransportBase::handleNewStreamCallbacks( streamStorage = conn_->streamManager->consumeNewPeerStreams(std::move(streamStorage)); - const auto& newPeerStreams = streamStorage; - for (const auto& stream : newPeerStreams) { + const auto& newPeerStreamIds = streamStorage; + for (const auto& streamId : newPeerStreamIds) { CHECK_NOTNULL(connCallback_.get()); - if (isBidirectionalStream(stream)) { - connCallback_->onNewBidirectionalStream(stream); + if (isBidirectionalStream(streamId)) { + connCallback_->onNewBidirectionalStream(streamId); } else { - connCallback_->onNewUnidirectionalStream(stream); + connCallback_->onNewUnidirectionalStream(streamId); + } + const Observer::StreamOpenEvent streamEvent( + streamId, + getStreamInitiator(streamId), + getStreamDirectionality(streamId)); + for (const auto& cb : *observers_) { + if (cb->getConfig().streamEvents) { + cb->streamOpened(this, streamEvent); + } } if (closeState_ != CloseState::OPEN) { @@ -1818,7 +1827,17 @@ QuicTransportBase::createStreamInternal(bool bidirectional) { streamResult = conn_->streamManager->createNextUnidirectionalStream(); } if (streamResult) { - return streamResult.value()->id; + const StreamId streamId = streamResult.value()->id; + const Observer::StreamOpenEvent streamEvent( + streamId, + getStreamInitiator(streamId), + getStreamDirectionality(streamId)); + for (const auto& cb : *observers_) { + if (cb->getConfig().streamEvents) { + cb->streamOpened(this, streamEvent); + } + } + return streamId; } else { return folly::makeUnexpected(streamResult.error()); } @@ -1834,10 +1853,6 @@ QuicTransportBase::createUnidirectionalStream(bool /*replaySafe*/) { return createStreamInternal(false); } -bool QuicTransportBase::isUnidirectionalStream(StreamId stream) noexcept { - return quic::isUnidirectionalStream(stream); -} - bool QuicTransportBase::isClientStream(StreamId stream) noexcept { return quic::isClientStream(stream); } @@ -1846,10 +1861,24 @@ bool QuicTransportBase::isServerStream(StreamId stream) noexcept { return quic::isServerStream(stream); } +StreamInitiator QuicTransportBase::getStreamInitiator( + StreamId stream) noexcept { + return quic::getStreamInitiator(conn_->nodeType, stream); +} + +bool QuicTransportBase::isUnidirectionalStream(StreamId stream) noexcept { + return quic::isUnidirectionalStream(stream); +} + bool QuicTransportBase::isBidirectionalStream(StreamId stream) noexcept { return quic::isBidirectionalStream(stream); } +StreamDirectionality QuicTransportBase::getStreamDirectionality( + StreamId stream) noexcept { + return quic::getStreamDirectionality(stream); +} + folly::Expected QuicTransportBase::notifyPendingWriteOnConnection(WriteCallback* wcb) { if (closeState_ != CloseState::OPEN) { @@ -2306,6 +2335,17 @@ void QuicTransportBase::checkForClosedStream() { } auto itr = conn_->streamManager->closedStreams().begin(); while (itr != conn_->streamManager->closedStreams().end()) { + const auto& streamId = *itr; + const Observer::StreamCloseEvent streamEvent( + streamId, + getStreamInitiator(streamId), + getStreamDirectionality(streamId)); + for (const auto& cb : *observers_) { + if (cb->getConfig().streamEvents) { + cb->streamClosed(this, streamEvent); + } + } + // We may be in an active read cb when we close the stream auto readCbIt = readCallbacks_.find(*itr); if (readCbIt != readCallbacks_.end() && diff --git a/quic/api/QuicTransportBase.h b/quic/api/QuicTransportBase.h index ffc662f2e..00ece1595 100644 --- a/quic/api/QuicTransportBase.h +++ b/quic/api/QuicTransportBase.h @@ -160,8 +160,11 @@ class QuicTransportBase : public QuicSocket, QuicStreamPrioritiesObserver { uint64_t getNumOpenableUnidirectionalStreams() const override; bool isClientStream(StreamId stream) noexcept override; bool isServerStream(StreamId stream) noexcept override; + StreamInitiator getStreamInitiator(StreamId stream) noexcept override; bool isUnidirectionalStream(StreamId stream) noexcept override; bool isBidirectionalStream(StreamId stream) noexcept override; + StreamDirectionality getStreamDirectionality( + StreamId stream) noexcept override; folly::Expected notifyPendingWriteOnStream( StreamId id, diff --git a/quic/api/test/MockQuicSocket.h b/quic/api/test/MockQuicSocket.h index d77e9331a..092afb6d1 100644 --- a/quic/api/test/MockQuicSocket.h +++ b/quic/api/test/MockQuicSocket.h @@ -160,8 +160,15 @@ class MockQuicSocket : public QuicSocket { MOCK_CONST_METHOD0(getNumOpenableUnidirectionalStreams, uint64_t()); GMOCK_METHOD1_(, noexcept, , isClientStream, bool(StreamId)); GMOCK_METHOD1_(, noexcept, , isServerStream, bool(StreamId)); + GMOCK_METHOD1_(, noexcept, , getStreamInitiator, StreamInitiator(StreamId)); GMOCK_METHOD1_(, noexcept, , isBidirectionalStream, bool(StreamId)); GMOCK_METHOD1_(, noexcept, , isUnidirectionalStream, bool(StreamId)); + GMOCK_METHOD1_( + , + noexcept, + , + getStreamDirectionality, + StreamDirectionality(StreamId)); MOCK_METHOD1( notifyPendingWriteOnConnection, folly::Expected(WriteCallback*)); diff --git a/quic/api/test/Mocks.h b/quic/api/test/Mocks.h index e9f5c146b..794c36d05 100644 --- a/quic/api/test/Mocks.h +++ b/quic/api/test/Mocks.h @@ -397,6 +397,18 @@ class MockObserver : public Observer { , knobFrameReceived, void(QuicSocket*, const KnobFrameEvent&)); + GMOCK_METHOD2_( + , + noexcept, + , + streamOpened, + void(QuicSocket*, const StreamOpenEvent&)); + GMOCK_METHOD2_( + , + noexcept, + , + streamClosed, + void(QuicSocket*, const StreamCloseEvent&)); static auto getLossPacketNum(PacketNum packetNum) { return testing::Field( @@ -419,6 +431,17 @@ class MockObserver : public Observer { testing::Field( &Observer::LostPacket::packet, getLossPacketNum(packetNum))); } + + static auto getStreamEventMatcher( + const StreamId id, + StreamInitiator initiator, + StreamDirectionality directionality) { + return AllOf( + testing::Field(&StreamEvent::streamId, testing::Eq(id)), + testing::Field(&StreamEvent::streamInitiator, testing::Eq(initiator)), + testing::Field( + &StreamEvent::streamDirectionality, testing::Eq(directionality))); + } }; inline std::ostream& operator<<(std::ostream& os, const MockQuicTransport&) { diff --git a/quic/api/test/QuicTransportBaseTest.cpp b/quic/api/test/QuicTransportBaseTest.cpp index 3c3a19df3..5bd22946b 100644 --- a/quic/api/test/QuicTransportBaseTest.cpp +++ b/quic/api/test/QuicTransportBaseTest.cpp @@ -3006,6 +3006,20 @@ TEST_F(QuicTransportImplTest, IsBidirectionalStream) { EXPECT_TRUE(transport->isBidirectionalStream(stream)); } +TEST_F(QuicTransportImplTest, GetStreamDirectionalityUnidirectional) { + auto stream = transport->createUnidirectionalStream().value(); + EXPECT_EQ( + StreamDirectionality::Unidirectional, + transport->getStreamDirectionality(stream)); +} + +TEST_F(QuicTransportImplTest, GetStreamDirectionalityBidirectional) { + auto stream = transport->createBidirectionalStream().value(); + EXPECT_EQ( + StreamDirectionality::Bidirectional, + transport->getStreamDirectionality(stream)); +} + TEST_F(QuicTransportImplTest, PeekCallbackDataAvailable) { auto stream1 = transport->createBidirectionalStream().value(); auto stream2 = transport->createBidirectionalStream().value(); diff --git a/quic/api/test/QuicTransportTest.cpp b/quic/api/test/QuicTransportTest.cpp index f03d360d0..c09976f7f 100644 --- a/quic/api/test/QuicTransportTest.cpp +++ b/quic/api/test/QuicTransportTest.cpp @@ -457,6 +457,211 @@ TEST_F(QuicTransportTest, AppLimitedWithObservers) { Mock::VerifyAndClearExpectations(cb2.get()); } +TEST_F(QuicTransportTest, ObserverStreamEventBidirectionalLocalOpenClose) { + Observer::Config configWithStreamEvents = {}; + configWithStreamEvents.streamEvents = true; + auto cb1 = std::make_unique>(configWithStreamEvents); + auto cb2 = std::make_unique>(Observer::Config()); + EXPECT_CALL(*cb1, observerAttach(transport_.get())); + transport_->addObserver(cb1.get()); + EXPECT_CALL(*cb2, observerAttach(transport_.get())); + transport_->addObserver(cb2.get()); + EXPECT_THAT( + transport_->getObservers(), UnorderedElementsAre(cb1.get(), cb2.get())); + + const auto id = 0x01; + const auto streamEventMatcher = MockObserver::getStreamEventMatcher( + id, StreamInitiator::Local, StreamDirectionality::Bidirectional); + + EXPECT_CALL(*cb1, streamOpened(transport_.get(), streamEventMatcher)); + EXPECT_EQ(id, transport_->createBidirectionalStream().value()); + + EXPECT_EQ( + StreamDirectionality::Bidirectional, + transport_->getStreamDirectionality(id)); + EXPECT_EQ(StreamInitiator::Local, transport_->getStreamInitiator(id)); + + EXPECT_CALL(*cb1, streamClosed(transport_.get(), streamEventMatcher)); + auto stream = CHECK_NOTNULL( + transport_->getConnectionState().streamManager->getStream(id)); + stream->sendState = StreamSendState::Closed; + stream->recvState = StreamRecvState::Closed; + transport_->getConnectionState().streamManager->addClosed(id); + transport_->onNetworkData( + SocketAddress("::1", 10000), + NetworkData(IOBuf::copyBuffer("fake data"), Clock::now())); + + EXPECT_CALL(*cb1, close(transport_.get(), _)); + EXPECT_CALL(*cb2, close(transport_.get(), _)); + EXPECT_CALL(*cb1, destroy(transport_.get())); + EXPECT_CALL(*cb2, destroy(transport_.get())); + transport_ = nullptr; +} + +TEST_F(QuicTransportTest, ObserverStreamEventBidirectionalRemoteOpenClose) { + Observer::Config configWithStreamEvents = {}; + configWithStreamEvents.streamEvents = true; + auto cb1 = std::make_unique>(configWithStreamEvents); + auto cb2 = std::make_unique>(Observer::Config()); + EXPECT_CALL(*cb1, observerAttach(transport_.get())); + transport_->addObserver(cb1.get()); + EXPECT_CALL(*cb2, observerAttach(transport_.get())); + transport_->addObserver(cb2.get()); + EXPECT_THAT( + transport_->getObservers(), UnorderedElementsAre(cb1.get(), cb2.get())); + + const auto id = 0x00; + const auto streamEventMatcher = MockObserver::getStreamEventMatcher( + id, StreamInitiator::Remote, StreamDirectionality::Bidirectional); + + EXPECT_CALL(*cb1, streamOpened(transport_.get(), streamEventMatcher)); + auto stream = CHECK_NOTNULL( + transport_->getConnectionState().streamManager->getStream(id)); + EXPECT_THAT(stream, NotNull()); + + EXPECT_EQ( + StreamDirectionality::Bidirectional, + transport_->getStreamDirectionality(id)); + EXPECT_EQ(StreamInitiator::Remote, transport_->getStreamInitiator(id)); + + EXPECT_CALL(*cb1, streamClosed(transport_.get(), streamEventMatcher)); + stream->sendState = StreamSendState::Closed; + stream->recvState = StreamRecvState::Closed; + transport_->getConnectionState().streamManager->addClosed(id); + transport_->onNetworkData( + SocketAddress("::1", 10000), + NetworkData(IOBuf::copyBuffer("fake data"), Clock::now())); + + EXPECT_CALL(*cb1, close(transport_.get(), _)); + EXPECT_CALL(*cb2, close(transport_.get(), _)); + EXPECT_CALL(*cb1, destroy(transport_.get())); + EXPECT_CALL(*cb2, destroy(transport_.get())); + transport_ = nullptr; +} + +TEST_F(QuicTransportTest, ObserverStreamEventUnidirectionalLocalOpenClose) { + Observer::Config configWithStreamEvents = {}; + configWithStreamEvents.streamEvents = true; + auto cb1 = std::make_unique>(configWithStreamEvents); + auto cb2 = std::make_unique>(Observer::Config()); + EXPECT_CALL(*cb1, observerAttach(transport_.get())); + transport_->addObserver(cb1.get()); + EXPECT_CALL(*cb2, observerAttach(transport_.get())); + transport_->addObserver(cb2.get()); + EXPECT_THAT( + transport_->getObservers(), UnorderedElementsAre(cb1.get(), cb2.get())); + + const auto id = 0x03; + const auto streamEventMatcher = MockObserver::getStreamEventMatcher( + id, StreamInitiator::Local, StreamDirectionality::Unidirectional); + + EXPECT_CALL(*cb1, streamOpened(transport_.get(), streamEventMatcher)); + EXPECT_EQ(id, transport_->createUnidirectionalStream().value()); + + EXPECT_EQ( + StreamDirectionality::Unidirectional, + transport_->getStreamDirectionality(id)); + EXPECT_EQ(StreamInitiator::Local, transport_->getStreamInitiator(id)); + + EXPECT_CALL(*cb1, streamClosed(transport_.get(), streamEventMatcher)); + auto stream = CHECK_NOTNULL( + transport_->getConnectionState().streamManager->getStream(id)); + stream->sendState = StreamSendState::Closed; + stream->recvState = StreamRecvState::Closed; + transport_->getConnectionState().streamManager->addClosed(id); + transport_->onNetworkData( + SocketAddress("::1", 10000), + NetworkData(IOBuf::copyBuffer("fake data"), Clock::now())); + + EXPECT_CALL(*cb1, close(transport_.get(), _)); + EXPECT_CALL(*cb2, close(transport_.get(), _)); + EXPECT_CALL(*cb1, destroy(transport_.get())); + EXPECT_CALL(*cb2, destroy(transport_.get())); + transport_ = nullptr; +} + +TEST_F(QuicTransportTest, ObserverStreamEventUnidirectionalRemoteOpenClose) { + Observer::Config configWithStreamEvents = {}; + configWithStreamEvents.streamEvents = true; + auto cb1 = std::make_unique>(configWithStreamEvents); + auto cb2 = std::make_unique>(Observer::Config()); + EXPECT_CALL(*cb1, observerAttach(transport_.get())); + transport_->addObserver(cb1.get()); + EXPECT_CALL(*cb2, observerAttach(transport_.get())); + transport_->addObserver(cb2.get()); + EXPECT_THAT( + transport_->getObservers(), UnorderedElementsAre(cb1.get(), cb2.get())); + + const auto id = 0x02; + const auto streamEventMatcher = MockObserver::getStreamEventMatcher( + id, StreamInitiator::Remote, StreamDirectionality::Unidirectional); + + EXPECT_CALL(*cb1, streamOpened(transport_.get(), streamEventMatcher)); + auto stream = CHECK_NOTNULL( + transport_->getConnectionState().streamManager->getStream(id)); + + EXPECT_EQ( + StreamDirectionality::Unidirectional, + transport_->getStreamDirectionality(id)); + EXPECT_EQ(StreamInitiator::Remote, transport_->getStreamInitiator(id)); + + EXPECT_CALL(*cb1, streamClosed(transport_.get(), streamEventMatcher)); + stream->sendState = StreamSendState::Closed; + stream->recvState = StreamRecvState::Closed; + transport_->getConnectionState().streamManager->addClosed(id); + transport_->onNetworkData( + SocketAddress("::1", 10000), + NetworkData(IOBuf::copyBuffer("fake data"), Clock::now())); + + EXPECT_CALL(*cb1, close(transport_.get(), _)); + EXPECT_CALL(*cb2, close(transport_.get(), _)); + EXPECT_CALL(*cb1, destroy(transport_.get())); + EXPECT_CALL(*cb2, destroy(transport_.get())); + transport_ = nullptr; +} + +TEST_F(QuicTransportTest, StreamBidirectionalLocal) { + const auto id = 0x01; + EXPECT_EQ(id, transport_->createBidirectionalStream().value()); + + EXPECT_EQ( + StreamDirectionality::Bidirectional, + transport_->getStreamDirectionality(id)); + EXPECT_EQ(StreamInitiator::Local, transport_->getStreamInitiator(id)); +} + +TEST_F(QuicTransportTest, StreamBidirectionalRemote) { + const auto id = 0x00; + // trigger tracking of new remote stream via getStream() + CHECK_NOTNULL(transport_->getConnectionState().streamManager->getStream(id)); + + EXPECT_EQ( + StreamDirectionality::Bidirectional, + transport_->getStreamDirectionality(id)); + EXPECT_EQ(StreamInitiator::Remote, transport_->getStreamInitiator(id)); +} + +TEST_F(QuicTransportTest, StreamUnidirectionalLocal) { + const auto id = 0x03; + EXPECT_EQ(id, transport_->createUnidirectionalStream().value()); + + EXPECT_EQ( + StreamDirectionality::Unidirectional, + transport_->getStreamDirectionality(id)); + EXPECT_EQ(StreamInitiator::Local, transport_->getStreamInitiator(id)); +} + +TEST_F(QuicTransportTest, StreamUnidirectionalRemote) { + const auto id = 0x02; + // trigger tracking of new remote stream via getStream() + CHECK_NOTNULL(transport_->getConnectionState().streamManager->getStream(id)); + + EXPECT_EQ( + StreamDirectionality::Unidirectional, + transport_->getStreamDirectionality(id)); + EXPECT_EQ(StreamInitiator::Remote, transport_->getStreamInitiator(id)); +} + TEST_F(QuicTransportTest, WriteSmall) { // Testing writing a small buffer that could be fit in a single packet auto stream = transport_->createBidirectionalStream().value(); diff --git a/quic/state/QuicStreamManager.cpp b/quic/state/QuicStreamManager.cpp index 7470263ca..c705b533f 100644 --- a/quic/state/QuicStreamManager.cpp +++ b/quic/state/QuicStreamManager.cpp @@ -6,9 +6,9 @@ * */ -#include "quic/state/QuicStreamManager.h" - +#include #include +#include namespace quic { diff --git a/quic/state/QuicStreamUtilities.cpp b/quic/state/QuicStreamUtilities.cpp index e23a93895..bc671982c 100644 --- a/quic/state/QuicStreamUtilities.cpp +++ b/quic/state/QuicStreamUtilities.cpp @@ -7,7 +7,6 @@ */ #include -#include namespace quic { @@ -27,6 +26,11 @@ bool isBidirectionalStream(StreamId stream) { return !isUnidirectionalStream(stream); } +StreamDirectionality getStreamDirectionality(StreamId stream) { + return isUnidirectionalStream(stream) ? StreamDirectionality::Unidirectional + : StreamDirectionality::Bidirectional; +} + bool isSendingStream(QuicNodeType nodeType, StreamId stream) { return isUnidirectionalStream(stream) && ((nodeType == QuicNodeType::Client && isClientStream(stream)) || @@ -48,4 +52,10 @@ bool isRemoteStream(QuicNodeType nodeType, StreamId stream) { return (nodeType == QuicNodeType::Client && isServerStream(stream)) || (nodeType == QuicNodeType::Server && isClientStream(stream)); } + +StreamInitiator getStreamInitiator(QuicNodeType nodeType, StreamId stream) { + return isLocalStream(nodeType, stream) ? StreamInitiator::Local + : StreamInitiator::Remote; +} + } // namespace quic diff --git a/quic/state/QuicStreamUtilities.h b/quic/state/QuicStreamUtilities.h index a936f4faa..380ab3ebf 100644 --- a/quic/state/QuicStreamUtilities.h +++ b/quic/state/QuicStreamUtilities.h @@ -6,10 +6,17 @@ * */ -#include +#pragma once + +#include +#include namespace quic { +enum class StreamInitiator : uint8_t { Local, Remote }; + +enum class StreamDirectionality : uint8_t { Unidirectional, Bidirectional }; + /** * Returns whether the given StreamId identifies a client stream. */ @@ -30,6 +37,11 @@ bool isUnidirectionalStream(StreamId stream); */ bool isBidirectionalStream(StreamId stream); +/** + * Returns directionality (unidirectional or bidirectional) of a stream by ID. + */ +StreamDirectionality getStreamDirectionality(StreamId stream); + /** * Returns whether the given QuicNodeType and StreamId indicate a sending * stream, i.e., a stream which only sends data. Note that a bidirectional @@ -56,4 +68,9 @@ bool isLocalStream(QuicNodeType nodeType, StreamId stream); */ bool isRemoteStream(QuicNodeType nodeType, StreamId stream); +/** + * Returns initiator (local or remote) of a stream by ID. + */ +StreamInitiator getStreamInitiator(QuicNodeType nodeType, StreamId stream); + } // namespace quic diff --git a/quic/state/test/QuicStreamFunctionsTest.cpp b/quic/state/test/QuicStreamFunctionsTest.cpp index 8263ceb88..04db9982c 100644 --- a/quic/state/test/QuicStreamFunctionsTest.cpp +++ b/quic/state/test/QuicStreamFunctionsTest.cpp @@ -8,6 +8,7 @@ #include #include +#include "quic/state/QuicStreamUtilities.h" #include #include @@ -1103,6 +1104,18 @@ TEST_F(QuicStreamFunctionsTest, IsBidirectionalStream) { EXPECT_FALSE(isBidirectionalStream(0xff)); } +TEST_F(QuicStreamFunctionsTest, GetStreamDirectionality) { + EXPECT_EQ(StreamDirectionality::Bidirectional, getStreamDirectionality(0x01)); + EXPECT_EQ(StreamDirectionality::Bidirectional, getStreamDirectionality(0xf0)); + EXPECT_EQ(StreamDirectionality::Bidirectional, getStreamDirectionality(0xf1)); + EXPECT_EQ( + StreamDirectionality::Unidirectional, getStreamDirectionality(0x02)); + EXPECT_EQ( + StreamDirectionality::Unidirectional, getStreamDirectionality(0x03)); + EXPECT_EQ( + StreamDirectionality::Unidirectional, getStreamDirectionality(0xff)); +} + TEST_F(QuicStreamFunctionsTest, IsSendingStream) { QuicClientConnectionState clientState( FizzClientQuicHandshakeContext::Builder().build()); @@ -1161,6 +1174,174 @@ TEST_F(QuicStreamFunctionsTest, IsReceivingStream) { EXPECT_TRUE(isReceivingStream(nodeType, id)); } +TEST_F(QuicStreamFunctionsTest, GetStreamInitiatorBidirectional) { + const auto clientStream1Id = + conn.streamManager->createNextBidirectionalStream().value()->id; + EXPECT_EQ(conn.streamManager->streamCount(), 1); + const auto clientStream2Id = + conn.streamManager->createNextBidirectionalStream().value()->id; + EXPECT_EQ(conn.streamManager->streamCount(), 2); + EXPECT_EQ(clientStream1Id, 0x00); + EXPECT_EQ(clientStream2Id, 0x04); + + const auto serverStream1Id = + CHECK_NOTNULL(conn.streamManager->getStream(clientStream1Id + 1))->id; + const auto serverStream2Id = + CHECK_NOTNULL(conn.streamManager->getStream(clientStream2Id + 1))->id; + + EXPECT_EQ( + StreamDirectionality::Bidirectional, + getStreamDirectionality(clientStream1Id)); + EXPECT_EQ( + StreamDirectionality::Bidirectional, + getStreamDirectionality(clientStream2Id)); + EXPECT_EQ( + StreamDirectionality::Bidirectional, + getStreamDirectionality(serverStream1Id)); + EXPECT_EQ( + StreamDirectionality::Bidirectional, + getStreamDirectionality(serverStream2Id)); + + EXPECT_EQ( + StreamInitiator::Local, + getStreamInitiator(conn.nodeType, clientStream1Id)); + EXPECT_EQ( + StreamInitiator::Local, + getStreamInitiator(conn.nodeType, clientStream2Id)); + EXPECT_EQ( + StreamInitiator::Remote, + getStreamInitiator(conn.nodeType, serverStream1Id)); + EXPECT_EQ( + StreamInitiator::Remote, + getStreamInitiator(conn.nodeType, serverStream2Id)); +} + +TEST_F(QuicServerStreamFunctionsTest, GetStreamInitiatorBidirectional) { + const auto serverStream1Id = + conn.streamManager->createNextBidirectionalStream().value()->id; + EXPECT_EQ(conn.streamManager->streamCount(), 1); + const auto serverStream2Id = + conn.streamManager->createNextBidirectionalStream().value()->id; + EXPECT_EQ(conn.streamManager->streamCount(), 2); + EXPECT_EQ(serverStream1Id, 0x01); + EXPECT_EQ(serverStream2Id, 0x05); + + const auto clientStream1Id = + CHECK_NOTNULL(conn.streamManager->getStream(serverStream1Id - 1))->id; + const auto clientStream2Id = + CHECK_NOTNULL(conn.streamManager->getStream(serverStream2Id - 1))->id; + + EXPECT_EQ( + StreamDirectionality::Bidirectional, + getStreamDirectionality(serverStream1Id)); + EXPECT_EQ( + StreamDirectionality::Bidirectional, + getStreamDirectionality(serverStream2Id)); + EXPECT_EQ( + StreamDirectionality::Bidirectional, + getStreamDirectionality(clientStream1Id)); + EXPECT_EQ( + StreamDirectionality::Bidirectional, + getStreamDirectionality(clientStream2Id)); + + EXPECT_EQ( + StreamInitiator::Local, + getStreamInitiator(conn.nodeType, serverStream1Id)); + EXPECT_EQ( + StreamInitiator::Local, + getStreamInitiator(conn.nodeType, serverStream2Id)); + EXPECT_EQ( + StreamInitiator::Remote, + getStreamInitiator(conn.nodeType, clientStream1Id)); + EXPECT_EQ( + StreamInitiator::Remote, + getStreamInitiator(conn.nodeType, clientStream2Id)); +} + +TEST_F(QuicStreamFunctionsTest, GetStreamInitiatorUnidirectional) { + const auto clientStream1Id = + conn.streamManager->createNextUnidirectionalStream().value()->id; + EXPECT_EQ(conn.streamManager->streamCount(), 1); + const auto clientStream2Id = + conn.streamManager->createNextUnidirectionalStream().value()->id; + EXPECT_EQ(conn.streamManager->streamCount(), 2); + EXPECT_EQ(clientStream1Id, 0x02); + EXPECT_EQ(clientStream2Id, 0x06); + + const auto serverStream1Id = + CHECK_NOTNULL(conn.streamManager->getStream(clientStream1Id + 1))->id; + const auto serverStream2Id = + CHECK_NOTNULL(conn.streamManager->getStream(clientStream2Id + 1))->id; + + EXPECT_EQ( + StreamDirectionality::Unidirectional, + getStreamDirectionality(clientStream1Id)); + EXPECT_EQ( + StreamDirectionality::Unidirectional, + getStreamDirectionality(clientStream2Id)); + EXPECT_EQ( + StreamDirectionality::Unidirectional, + getStreamDirectionality(serverStream1Id)); + EXPECT_EQ( + StreamDirectionality::Unidirectional, + getStreamDirectionality(serverStream2Id)); + + EXPECT_EQ( + StreamInitiator::Local, + getStreamInitiator(conn.nodeType, clientStream1Id)); + EXPECT_EQ( + StreamInitiator::Local, + getStreamInitiator(conn.nodeType, clientStream2Id)); + EXPECT_EQ( + StreamInitiator::Remote, + getStreamInitiator(conn.nodeType, serverStream1Id)); + EXPECT_EQ( + StreamInitiator::Remote, + getStreamInitiator(conn.nodeType, serverStream2Id)); +} + +TEST_F(QuicServerStreamFunctionsTest, GetStreamInitiatorUnidirectional) { + const auto serverStream1Id = + conn.streamManager->createNextUnidirectionalStream().value()->id; + EXPECT_EQ(conn.streamManager->streamCount(), 1); + const auto serverStream2Id = + conn.streamManager->createNextUnidirectionalStream().value()->id; + EXPECT_EQ(conn.streamManager->streamCount(), 2); + EXPECT_EQ(serverStream1Id, 0x03); + EXPECT_EQ(serverStream2Id, 0x07); + + const auto clientStream1Id = + CHECK_NOTNULL(conn.streamManager->getStream(serverStream1Id - 1))->id; + const auto clientStream2Id = + CHECK_NOTNULL(conn.streamManager->getStream(serverStream2Id - 1))->id; + + EXPECT_EQ( + StreamDirectionality::Unidirectional, + getStreamDirectionality(serverStream1Id)); + EXPECT_EQ( + StreamDirectionality::Unidirectional, + getStreamDirectionality(serverStream2Id)); + EXPECT_EQ( + StreamDirectionality::Unidirectional, + getStreamDirectionality(clientStream1Id)); + EXPECT_EQ( + StreamDirectionality::Unidirectional, + getStreamDirectionality(clientStream2Id)); + + EXPECT_EQ( + StreamInitiator::Local, + getStreamInitiator(conn.nodeType, serverStream1Id)); + EXPECT_EQ( + StreamInitiator::Local, + getStreamInitiator(conn.nodeType, serverStream2Id)); + EXPECT_EQ( + StreamInitiator::Remote, + getStreamInitiator(conn.nodeType, clientStream1Id)); + EXPECT_EQ( + StreamInitiator::Remote, + getStreamInitiator(conn.nodeType, clientStream2Id)); +} + TEST_F(QuicStreamFunctionsTest, HasReadableDataNoData) { auto stream = conn.streamManager->createNextBidirectionalStream().value(); auto buf1 = IOBuf::copyBuffer("just");