1
0
mirror of https://github.com/facebookincubator/mvfst.git synced 2025-08-08 09:42:06 +03:00

Observer stream events

Summary: Provide observers with visibility into stream events triggered by local and peer (e.g., new stream opened locally or by peer).

Reviewed By: mjoras

Differential Revision: D31886978

fbshipit-source-id: 7556fef0f336bd0f190b4474f1a7b0120aae6ef1
This commit is contained in:
Brandon Schlinker
2021-12-03 11:15:58 -08:00
committed by Facebook GitHub Bot
parent 7b40794598
commit d89e8d344e
12 changed files with 566 additions and 14 deletions

View File

@@ -11,6 +11,7 @@
#include <quic/common/SmallVec.h> #include <quic/common/SmallVec.h>
#include <quic/d6d/Types.h> #include <quic/d6d/Types.h>
#include <quic/state/OutstandingPacket.h> #include <quic/state/OutstandingPacket.h>
#include <quic/state/QuicStreamUtilities.h>
namespace folly { namespace folly {
class EventBase; class EventBase;
@@ -46,6 +47,7 @@ class Observer {
bool pmtuEvents{false}; bool pmtuEvents{false};
bool rttSamples{false}; bool rttSamples{false};
bool knobFrameEvents{false}; bool knobFrameEvents{false};
bool streamEvents{false};
virtual void enableAllEvents() { virtual void enableAllEvents() {
evbEvents = true; evbEvents = true;
@@ -56,6 +58,7 @@ class Observer {
spuriousLossEvents = true; spuriousLossEvents = true;
pmtuEvents = true; pmtuEvents = true;
knobFrameEvents = true; knobFrameEvents = true;
streamEvents = true;
} }
/** /**
@@ -235,6 +238,23 @@ class Observer {
const quic::KnobFrame knobFrame; const quic::KnobFrame knobFrame;
}; };
struct StreamEvent {
StreamEvent(
const StreamId id,
StreamInitiator initiator,
StreamDirectionality directionality)
: streamId(id),
streamInitiator(initiator),
streamDirectionality(directionality) {}
const StreamId streamId;
const StreamInitiator streamInitiator;
const StreamDirectionality streamDirectionality;
};
using StreamOpenEvent = StreamEvent;
using StreamCloseEvent = StreamEvent;
/** /**
* observerAttach() will be invoked when an observer is added. * observerAttach() will be invoked when an observer is added.
* *
@@ -413,6 +433,26 @@ class Observer {
QuicSocket*, /* socket */ QuicSocket*, /* socket */
const KnobFrameEvent& /* event */) {} const KnobFrameEvent& /* event */) {}
/**
* streamOpened() is invoked when a new stream is opened.
*
* @param socket Socket associated with the event.
* @param event Event containing details.
*/
virtual void streamOpened(
QuicSocket*, /* socket */
const StreamOpenEvent& /* event */) {}
/**
* streamClosed() is invoked when a stream is closed.
*
* @param socket Socket associated with the event.
* @param event Event containing details.
*/
virtual void streamClosed(
QuicSocket*, /* socket */
const StreamCloseEvent& /* event */) {}
protected: protected:
// observer configuration; cannot be changed post instantiation // observer configuration; cannot be changed post instantiation
const Config observerConfig_; const Config observerConfig_;

View File

@@ -18,6 +18,7 @@
#include <quic/common/SmallVec.h> #include <quic/common/SmallVec.h>
#include <quic/state/QuicConnectionStats.h> #include <quic/state/QuicConnectionStats.h>
#include <quic/state/QuicPriorityQueue.h> #include <quic/state/QuicPriorityQueue.h>
#include <quic/state/QuicStreamUtilities.h>
#include <quic/state/StateData.h> #include <quic/state/StateData.h>
#include <chrono> #include <chrono>
@@ -782,6 +783,11 @@ class QuicSocket {
*/ */
virtual bool isServerStream(StreamId stream) noexcept = 0; virtual bool isServerStream(StreamId stream) noexcept = 0;
/**
* Returns initiator (self or peer) of a stream by ID.
*/
virtual StreamInitiator getStreamInitiator(StreamId stream) noexcept = 0;
/** /**
* Returns whether a stream ID represents a unidirectional stream. * Returns whether a stream ID represents a unidirectional stream.
*/ */
@@ -792,6 +798,12 @@ class QuicSocket {
*/ */
virtual bool isBidirectionalStream(StreamId stream) noexcept = 0; virtual bool isBidirectionalStream(StreamId stream) noexcept = 0;
/**
* Returns directionality (unidirectional or bidirectional) of a stream by ID.
*/
virtual StreamDirectionality getStreamDirectionality(
StreamId stream) noexcept = 0;
/** /**
* Callback class for receiving write readiness notifications * Callback class for receiving write readiness notifications
*/ */

View File

@@ -1506,13 +1506,22 @@ void QuicTransportBase::handleNewStreamCallbacks(
streamStorage = streamStorage =
conn_->streamManager->consumeNewPeerStreams(std::move(streamStorage)); conn_->streamManager->consumeNewPeerStreams(std::move(streamStorage));
const auto& newPeerStreams = streamStorage; const auto& newPeerStreamIds = streamStorage;
for (const auto& stream : newPeerStreams) { for (const auto& streamId : newPeerStreamIds) {
CHECK_NOTNULL(connCallback_.get()); CHECK_NOTNULL(connCallback_.get());
if (isBidirectionalStream(stream)) { if (isBidirectionalStream(streamId)) {
connCallback_->onNewBidirectionalStream(stream); connCallback_->onNewBidirectionalStream(streamId);
} else { } else {
connCallback_->onNewUnidirectionalStream(stream); connCallback_->onNewUnidirectionalStream(streamId);
}
const Observer::StreamOpenEvent streamEvent(
streamId,
getStreamInitiator(streamId),
getStreamDirectionality(streamId));
for (const auto& cb : *observers_) {
if (cb->getConfig().streamEvents) {
cb->streamOpened(this, streamEvent);
}
} }
if (closeState_ != CloseState::OPEN) { if (closeState_ != CloseState::OPEN) {
@@ -1818,7 +1827,17 @@ QuicTransportBase::createStreamInternal(bool bidirectional) {
streamResult = conn_->streamManager->createNextUnidirectionalStream(); streamResult = conn_->streamManager->createNextUnidirectionalStream();
} }
if (streamResult) { if (streamResult) {
return streamResult.value()->id; const StreamId streamId = streamResult.value()->id;
const Observer::StreamOpenEvent streamEvent(
streamId,
getStreamInitiator(streamId),
getStreamDirectionality(streamId));
for (const auto& cb : *observers_) {
if (cb->getConfig().streamEvents) {
cb->streamOpened(this, streamEvent);
}
}
return streamId;
} else { } else {
return folly::makeUnexpected(streamResult.error()); return folly::makeUnexpected(streamResult.error());
} }
@@ -1834,10 +1853,6 @@ QuicTransportBase::createUnidirectionalStream(bool /*replaySafe*/) {
return createStreamInternal(false); return createStreamInternal(false);
} }
bool QuicTransportBase::isUnidirectionalStream(StreamId stream) noexcept {
return quic::isUnidirectionalStream(stream);
}
bool QuicTransportBase::isClientStream(StreamId stream) noexcept { bool QuicTransportBase::isClientStream(StreamId stream) noexcept {
return quic::isClientStream(stream); return quic::isClientStream(stream);
} }
@@ -1846,10 +1861,24 @@ bool QuicTransportBase::isServerStream(StreamId stream) noexcept {
return quic::isServerStream(stream); return quic::isServerStream(stream);
} }
StreamInitiator QuicTransportBase::getStreamInitiator(
StreamId stream) noexcept {
return quic::getStreamInitiator(conn_->nodeType, stream);
}
bool QuicTransportBase::isUnidirectionalStream(StreamId stream) noexcept {
return quic::isUnidirectionalStream(stream);
}
bool QuicTransportBase::isBidirectionalStream(StreamId stream) noexcept { bool QuicTransportBase::isBidirectionalStream(StreamId stream) noexcept {
return quic::isBidirectionalStream(stream); return quic::isBidirectionalStream(stream);
} }
StreamDirectionality QuicTransportBase::getStreamDirectionality(
StreamId stream) noexcept {
return quic::getStreamDirectionality(stream);
}
folly::Expected<folly::Unit, LocalErrorCode> folly::Expected<folly::Unit, LocalErrorCode>
QuicTransportBase::notifyPendingWriteOnConnection(WriteCallback* wcb) { QuicTransportBase::notifyPendingWriteOnConnection(WriteCallback* wcb) {
if (closeState_ != CloseState::OPEN) { if (closeState_ != CloseState::OPEN) {
@@ -2306,6 +2335,17 @@ void QuicTransportBase::checkForClosedStream() {
} }
auto itr = conn_->streamManager->closedStreams().begin(); auto itr = conn_->streamManager->closedStreams().begin();
while (itr != conn_->streamManager->closedStreams().end()) { while (itr != conn_->streamManager->closedStreams().end()) {
const auto& streamId = *itr;
const Observer::StreamCloseEvent streamEvent(
streamId,
getStreamInitiator(streamId),
getStreamDirectionality(streamId));
for (const auto& cb : *observers_) {
if (cb->getConfig().streamEvents) {
cb->streamClosed(this, streamEvent);
}
}
// We may be in an active read cb when we close the stream // We may be in an active read cb when we close the stream
auto readCbIt = readCallbacks_.find(*itr); auto readCbIt = readCallbacks_.find(*itr);
if (readCbIt != readCallbacks_.end() && if (readCbIt != readCallbacks_.end() &&

View File

@@ -160,8 +160,11 @@ class QuicTransportBase : public QuicSocket, QuicStreamPrioritiesObserver {
uint64_t getNumOpenableUnidirectionalStreams() const override; uint64_t getNumOpenableUnidirectionalStreams() const override;
bool isClientStream(StreamId stream) noexcept override; bool isClientStream(StreamId stream) noexcept override;
bool isServerStream(StreamId stream) noexcept override; bool isServerStream(StreamId stream) noexcept override;
StreamInitiator getStreamInitiator(StreamId stream) noexcept override;
bool isUnidirectionalStream(StreamId stream) noexcept override; bool isUnidirectionalStream(StreamId stream) noexcept override;
bool isBidirectionalStream(StreamId stream) noexcept override; bool isBidirectionalStream(StreamId stream) noexcept override;
StreamDirectionality getStreamDirectionality(
StreamId stream) noexcept override;
folly::Expected<folly::Unit, LocalErrorCode> notifyPendingWriteOnStream( folly::Expected<folly::Unit, LocalErrorCode> notifyPendingWriteOnStream(
StreamId id, StreamId id,

View File

@@ -160,8 +160,15 @@ class MockQuicSocket : public QuicSocket {
MOCK_CONST_METHOD0(getNumOpenableUnidirectionalStreams, uint64_t()); MOCK_CONST_METHOD0(getNumOpenableUnidirectionalStreams, uint64_t());
GMOCK_METHOD1_(, noexcept, , isClientStream, bool(StreamId)); GMOCK_METHOD1_(, noexcept, , isClientStream, bool(StreamId));
GMOCK_METHOD1_(, noexcept, , isServerStream, bool(StreamId)); GMOCK_METHOD1_(, noexcept, , isServerStream, bool(StreamId));
GMOCK_METHOD1_(, noexcept, , getStreamInitiator, StreamInitiator(StreamId));
GMOCK_METHOD1_(, noexcept, , isBidirectionalStream, bool(StreamId)); GMOCK_METHOD1_(, noexcept, , isBidirectionalStream, bool(StreamId));
GMOCK_METHOD1_(, noexcept, , isUnidirectionalStream, bool(StreamId)); GMOCK_METHOD1_(, noexcept, , isUnidirectionalStream, bool(StreamId));
GMOCK_METHOD1_(
,
noexcept,
,
getStreamDirectionality,
StreamDirectionality(StreamId));
MOCK_METHOD1( MOCK_METHOD1(
notifyPendingWriteOnConnection, notifyPendingWriteOnConnection,
folly::Expected<folly::Unit, LocalErrorCode>(WriteCallback*)); folly::Expected<folly::Unit, LocalErrorCode>(WriteCallback*));

View File

@@ -397,6 +397,18 @@ class MockObserver : public Observer {
, ,
knobFrameReceived, knobFrameReceived,
void(QuicSocket*, const KnobFrameEvent&)); void(QuicSocket*, const KnobFrameEvent&));
GMOCK_METHOD2_(
,
noexcept,
,
streamOpened,
void(QuicSocket*, const StreamOpenEvent&));
GMOCK_METHOD2_(
,
noexcept,
,
streamClosed,
void(QuicSocket*, const StreamCloseEvent&));
static auto getLossPacketNum(PacketNum packetNum) { static auto getLossPacketNum(PacketNum packetNum) {
return testing::Field( return testing::Field(
@@ -419,6 +431,17 @@ class MockObserver : public Observer {
testing::Field( testing::Field(
&Observer::LostPacket::packet, getLossPacketNum(packetNum))); &Observer::LostPacket::packet, getLossPacketNum(packetNum)));
} }
static auto getStreamEventMatcher(
const StreamId id,
StreamInitiator initiator,
StreamDirectionality directionality) {
return AllOf(
testing::Field(&StreamEvent::streamId, testing::Eq(id)),
testing::Field(&StreamEvent::streamInitiator, testing::Eq(initiator)),
testing::Field(
&StreamEvent::streamDirectionality, testing::Eq(directionality)));
}
}; };
inline std::ostream& operator<<(std::ostream& os, const MockQuicTransport&) { inline std::ostream& operator<<(std::ostream& os, const MockQuicTransport&) {

View File

@@ -3006,6 +3006,20 @@ TEST_F(QuicTransportImplTest, IsBidirectionalStream) {
EXPECT_TRUE(transport->isBidirectionalStream(stream)); EXPECT_TRUE(transport->isBidirectionalStream(stream));
} }
TEST_F(QuicTransportImplTest, GetStreamDirectionalityUnidirectional) {
auto stream = transport->createUnidirectionalStream().value();
EXPECT_EQ(
StreamDirectionality::Unidirectional,
transport->getStreamDirectionality(stream));
}
TEST_F(QuicTransportImplTest, GetStreamDirectionalityBidirectional) {
auto stream = transport->createBidirectionalStream().value();
EXPECT_EQ(
StreamDirectionality::Bidirectional,
transport->getStreamDirectionality(stream));
}
TEST_F(QuicTransportImplTest, PeekCallbackDataAvailable) { TEST_F(QuicTransportImplTest, PeekCallbackDataAvailable) {
auto stream1 = transport->createBidirectionalStream().value(); auto stream1 = transport->createBidirectionalStream().value();
auto stream2 = transport->createBidirectionalStream().value(); auto stream2 = transport->createBidirectionalStream().value();

View File

@@ -457,6 +457,211 @@ TEST_F(QuicTransportTest, AppLimitedWithObservers) {
Mock::VerifyAndClearExpectations(cb2.get()); Mock::VerifyAndClearExpectations(cb2.get());
} }
TEST_F(QuicTransportTest, ObserverStreamEventBidirectionalLocalOpenClose) {
Observer::Config configWithStreamEvents = {};
configWithStreamEvents.streamEvents = true;
auto cb1 = std::make_unique<StrictMock<MockObserver>>(configWithStreamEvents);
auto cb2 = std::make_unique<StrictMock<MockObserver>>(Observer::Config());
EXPECT_CALL(*cb1, observerAttach(transport_.get()));
transport_->addObserver(cb1.get());
EXPECT_CALL(*cb2, observerAttach(transport_.get()));
transport_->addObserver(cb2.get());
EXPECT_THAT(
transport_->getObservers(), UnorderedElementsAre(cb1.get(), cb2.get()));
const auto id = 0x01;
const auto streamEventMatcher = MockObserver::getStreamEventMatcher(
id, StreamInitiator::Local, StreamDirectionality::Bidirectional);
EXPECT_CALL(*cb1, streamOpened(transport_.get(), streamEventMatcher));
EXPECT_EQ(id, transport_->createBidirectionalStream().value());
EXPECT_EQ(
StreamDirectionality::Bidirectional,
transport_->getStreamDirectionality(id));
EXPECT_EQ(StreamInitiator::Local, transport_->getStreamInitiator(id));
EXPECT_CALL(*cb1, streamClosed(transport_.get(), streamEventMatcher));
auto stream = CHECK_NOTNULL(
transport_->getConnectionState().streamManager->getStream(id));
stream->sendState = StreamSendState::Closed;
stream->recvState = StreamRecvState::Closed;
transport_->getConnectionState().streamManager->addClosed(id);
transport_->onNetworkData(
SocketAddress("::1", 10000),
NetworkData(IOBuf::copyBuffer("fake data"), Clock::now()));
EXPECT_CALL(*cb1, close(transport_.get(), _));
EXPECT_CALL(*cb2, close(transport_.get(), _));
EXPECT_CALL(*cb1, destroy(transport_.get()));
EXPECT_CALL(*cb2, destroy(transport_.get()));
transport_ = nullptr;
}
TEST_F(QuicTransportTest, ObserverStreamEventBidirectionalRemoteOpenClose) {
Observer::Config configWithStreamEvents = {};
configWithStreamEvents.streamEvents = true;
auto cb1 = std::make_unique<StrictMock<MockObserver>>(configWithStreamEvents);
auto cb2 = std::make_unique<StrictMock<MockObserver>>(Observer::Config());
EXPECT_CALL(*cb1, observerAttach(transport_.get()));
transport_->addObserver(cb1.get());
EXPECT_CALL(*cb2, observerAttach(transport_.get()));
transport_->addObserver(cb2.get());
EXPECT_THAT(
transport_->getObservers(), UnorderedElementsAre(cb1.get(), cb2.get()));
const auto id = 0x00;
const auto streamEventMatcher = MockObserver::getStreamEventMatcher(
id, StreamInitiator::Remote, StreamDirectionality::Bidirectional);
EXPECT_CALL(*cb1, streamOpened(transport_.get(), streamEventMatcher));
auto stream = CHECK_NOTNULL(
transport_->getConnectionState().streamManager->getStream(id));
EXPECT_THAT(stream, NotNull());
EXPECT_EQ(
StreamDirectionality::Bidirectional,
transport_->getStreamDirectionality(id));
EXPECT_EQ(StreamInitiator::Remote, transport_->getStreamInitiator(id));
EXPECT_CALL(*cb1, streamClosed(transport_.get(), streamEventMatcher));
stream->sendState = StreamSendState::Closed;
stream->recvState = StreamRecvState::Closed;
transport_->getConnectionState().streamManager->addClosed(id);
transport_->onNetworkData(
SocketAddress("::1", 10000),
NetworkData(IOBuf::copyBuffer("fake data"), Clock::now()));
EXPECT_CALL(*cb1, close(transport_.get(), _));
EXPECT_CALL(*cb2, close(transport_.get(), _));
EXPECT_CALL(*cb1, destroy(transport_.get()));
EXPECT_CALL(*cb2, destroy(transport_.get()));
transport_ = nullptr;
}
TEST_F(QuicTransportTest, ObserverStreamEventUnidirectionalLocalOpenClose) {
Observer::Config configWithStreamEvents = {};
configWithStreamEvents.streamEvents = true;
auto cb1 = std::make_unique<StrictMock<MockObserver>>(configWithStreamEvents);
auto cb2 = std::make_unique<StrictMock<MockObserver>>(Observer::Config());
EXPECT_CALL(*cb1, observerAttach(transport_.get()));
transport_->addObserver(cb1.get());
EXPECT_CALL(*cb2, observerAttach(transport_.get()));
transport_->addObserver(cb2.get());
EXPECT_THAT(
transport_->getObservers(), UnorderedElementsAre(cb1.get(), cb2.get()));
const auto id = 0x03;
const auto streamEventMatcher = MockObserver::getStreamEventMatcher(
id, StreamInitiator::Local, StreamDirectionality::Unidirectional);
EXPECT_CALL(*cb1, streamOpened(transport_.get(), streamEventMatcher));
EXPECT_EQ(id, transport_->createUnidirectionalStream().value());
EXPECT_EQ(
StreamDirectionality::Unidirectional,
transport_->getStreamDirectionality(id));
EXPECT_EQ(StreamInitiator::Local, transport_->getStreamInitiator(id));
EXPECT_CALL(*cb1, streamClosed(transport_.get(), streamEventMatcher));
auto stream = CHECK_NOTNULL(
transport_->getConnectionState().streamManager->getStream(id));
stream->sendState = StreamSendState::Closed;
stream->recvState = StreamRecvState::Closed;
transport_->getConnectionState().streamManager->addClosed(id);
transport_->onNetworkData(
SocketAddress("::1", 10000),
NetworkData(IOBuf::copyBuffer("fake data"), Clock::now()));
EXPECT_CALL(*cb1, close(transport_.get(), _));
EXPECT_CALL(*cb2, close(transport_.get(), _));
EXPECT_CALL(*cb1, destroy(transport_.get()));
EXPECT_CALL(*cb2, destroy(transport_.get()));
transport_ = nullptr;
}
TEST_F(QuicTransportTest, ObserverStreamEventUnidirectionalRemoteOpenClose) {
Observer::Config configWithStreamEvents = {};
configWithStreamEvents.streamEvents = true;
auto cb1 = std::make_unique<StrictMock<MockObserver>>(configWithStreamEvents);
auto cb2 = std::make_unique<StrictMock<MockObserver>>(Observer::Config());
EXPECT_CALL(*cb1, observerAttach(transport_.get()));
transport_->addObserver(cb1.get());
EXPECT_CALL(*cb2, observerAttach(transport_.get()));
transport_->addObserver(cb2.get());
EXPECT_THAT(
transport_->getObservers(), UnorderedElementsAre(cb1.get(), cb2.get()));
const auto id = 0x02;
const auto streamEventMatcher = MockObserver::getStreamEventMatcher(
id, StreamInitiator::Remote, StreamDirectionality::Unidirectional);
EXPECT_CALL(*cb1, streamOpened(transport_.get(), streamEventMatcher));
auto stream = CHECK_NOTNULL(
transport_->getConnectionState().streamManager->getStream(id));
EXPECT_EQ(
StreamDirectionality::Unidirectional,
transport_->getStreamDirectionality(id));
EXPECT_EQ(StreamInitiator::Remote, transport_->getStreamInitiator(id));
EXPECT_CALL(*cb1, streamClosed(transport_.get(), streamEventMatcher));
stream->sendState = StreamSendState::Closed;
stream->recvState = StreamRecvState::Closed;
transport_->getConnectionState().streamManager->addClosed(id);
transport_->onNetworkData(
SocketAddress("::1", 10000),
NetworkData(IOBuf::copyBuffer("fake data"), Clock::now()));
EXPECT_CALL(*cb1, close(transport_.get(), _));
EXPECT_CALL(*cb2, close(transport_.get(), _));
EXPECT_CALL(*cb1, destroy(transport_.get()));
EXPECT_CALL(*cb2, destroy(transport_.get()));
transport_ = nullptr;
}
TEST_F(QuicTransportTest, StreamBidirectionalLocal) {
const auto id = 0x01;
EXPECT_EQ(id, transport_->createBidirectionalStream().value());
EXPECT_EQ(
StreamDirectionality::Bidirectional,
transport_->getStreamDirectionality(id));
EXPECT_EQ(StreamInitiator::Local, transport_->getStreamInitiator(id));
}
TEST_F(QuicTransportTest, StreamBidirectionalRemote) {
const auto id = 0x00;
// trigger tracking of new remote stream via getStream()
CHECK_NOTNULL(transport_->getConnectionState().streamManager->getStream(id));
EXPECT_EQ(
StreamDirectionality::Bidirectional,
transport_->getStreamDirectionality(id));
EXPECT_EQ(StreamInitiator::Remote, transport_->getStreamInitiator(id));
}
TEST_F(QuicTransportTest, StreamUnidirectionalLocal) {
const auto id = 0x03;
EXPECT_EQ(id, transport_->createUnidirectionalStream().value());
EXPECT_EQ(
StreamDirectionality::Unidirectional,
transport_->getStreamDirectionality(id));
EXPECT_EQ(StreamInitiator::Local, transport_->getStreamInitiator(id));
}
TEST_F(QuicTransportTest, StreamUnidirectionalRemote) {
const auto id = 0x02;
// trigger tracking of new remote stream via getStream()
CHECK_NOTNULL(transport_->getConnectionState().streamManager->getStream(id));
EXPECT_EQ(
StreamDirectionality::Unidirectional,
transport_->getStreamDirectionality(id));
EXPECT_EQ(StreamInitiator::Remote, transport_->getStreamInitiator(id));
}
TEST_F(QuicTransportTest, WriteSmall) { TEST_F(QuicTransportTest, WriteSmall) {
// Testing writing a small buffer that could be fit in a single packet // Testing writing a small buffer that could be fit in a single packet
auto stream = transport_->createBidirectionalStream().value(); auto stream = transport_->createBidirectionalStream().value();

View File

@@ -6,9 +6,9 @@
* *
*/ */
#include "quic/state/QuicStreamManager.h" #include <quic/state/QuicStreamManager.h>
#include <quic/state/QuicStreamUtilities.h> #include <quic/state/QuicStreamUtilities.h>
#include <quic/state/StateData.h>
namespace quic { namespace quic {

View File

@@ -7,7 +7,6 @@
*/ */
#include <quic/state/QuicStreamUtilities.h> #include <quic/state/QuicStreamUtilities.h>
#include <quic/state/StateData.h>
namespace quic { namespace quic {
@@ -27,6 +26,11 @@ bool isBidirectionalStream(StreamId stream) {
return !isUnidirectionalStream(stream); return !isUnidirectionalStream(stream);
} }
StreamDirectionality getStreamDirectionality(StreamId stream) {
return isUnidirectionalStream(stream) ? StreamDirectionality::Unidirectional
: StreamDirectionality::Bidirectional;
}
bool isSendingStream(QuicNodeType nodeType, StreamId stream) { bool isSendingStream(QuicNodeType nodeType, StreamId stream) {
return isUnidirectionalStream(stream) && return isUnidirectionalStream(stream) &&
((nodeType == QuicNodeType::Client && isClientStream(stream)) || ((nodeType == QuicNodeType::Client && isClientStream(stream)) ||
@@ -48,4 +52,10 @@ bool isRemoteStream(QuicNodeType nodeType, StreamId stream) {
return (nodeType == QuicNodeType::Client && isServerStream(stream)) || return (nodeType == QuicNodeType::Client && isServerStream(stream)) ||
(nodeType == QuicNodeType::Server && isClientStream(stream)); (nodeType == QuicNodeType::Server && isClientStream(stream));
} }
StreamInitiator getStreamInitiator(QuicNodeType nodeType, StreamId stream) {
return isLocalStream(nodeType, stream) ? StreamInitiator::Local
: StreamInitiator::Remote;
}
} // namespace quic } // namespace quic

View File

@@ -6,10 +6,17 @@
* *
*/ */
#include <quic/state/StateData.h> #pragma once
#include <quic/QuicConstants.h>
#include <quic/codec/Types.h>
namespace quic { namespace quic {
enum class StreamInitiator : uint8_t { Local, Remote };
enum class StreamDirectionality : uint8_t { Unidirectional, Bidirectional };
/** /**
* Returns whether the given StreamId identifies a client stream. * Returns whether the given StreamId identifies a client stream.
*/ */
@@ -30,6 +37,11 @@ bool isUnidirectionalStream(StreamId stream);
*/ */
bool isBidirectionalStream(StreamId stream); bool isBidirectionalStream(StreamId stream);
/**
* Returns directionality (unidirectional or bidirectional) of a stream by ID.
*/
StreamDirectionality getStreamDirectionality(StreamId stream);
/** /**
* Returns whether the given QuicNodeType and StreamId indicate a sending * Returns whether the given QuicNodeType and StreamId indicate a sending
* stream, i.e., a stream which only sends data. Note that a bidirectional * stream, i.e., a stream which only sends data. Note that a bidirectional
@@ -56,4 +68,9 @@ bool isLocalStream(QuicNodeType nodeType, StreamId stream);
*/ */
bool isRemoteStream(QuicNodeType nodeType, StreamId stream); bool isRemoteStream(QuicNodeType nodeType, StreamId stream);
/**
* Returns initiator (local or remote) of a stream by ID.
*/
StreamInitiator getStreamInitiator(QuicNodeType nodeType, StreamId stream);
} // namespace quic } // namespace quic

View File

@@ -8,6 +8,7 @@
#include <gmock/gmock.h> #include <gmock/gmock.h>
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include "quic/state/QuicStreamUtilities.h"
#include <quic/state/QuicStreamFunctions.h> #include <quic/state/QuicStreamFunctions.h>
#include <quic/state/QuicStreamUtilities.h> #include <quic/state/QuicStreamUtilities.h>
@@ -1103,6 +1104,18 @@ TEST_F(QuicStreamFunctionsTest, IsBidirectionalStream) {
EXPECT_FALSE(isBidirectionalStream(0xff)); EXPECT_FALSE(isBidirectionalStream(0xff));
} }
TEST_F(QuicStreamFunctionsTest, GetStreamDirectionality) {
EXPECT_EQ(StreamDirectionality::Bidirectional, getStreamDirectionality(0x01));
EXPECT_EQ(StreamDirectionality::Bidirectional, getStreamDirectionality(0xf0));
EXPECT_EQ(StreamDirectionality::Bidirectional, getStreamDirectionality(0xf1));
EXPECT_EQ(
StreamDirectionality::Unidirectional, getStreamDirectionality(0x02));
EXPECT_EQ(
StreamDirectionality::Unidirectional, getStreamDirectionality(0x03));
EXPECT_EQ(
StreamDirectionality::Unidirectional, getStreamDirectionality(0xff));
}
TEST_F(QuicStreamFunctionsTest, IsSendingStream) { TEST_F(QuicStreamFunctionsTest, IsSendingStream) {
QuicClientConnectionState clientState( QuicClientConnectionState clientState(
FizzClientQuicHandshakeContext::Builder().build()); FizzClientQuicHandshakeContext::Builder().build());
@@ -1161,6 +1174,174 @@ TEST_F(QuicStreamFunctionsTest, IsReceivingStream) {
EXPECT_TRUE(isReceivingStream(nodeType, id)); EXPECT_TRUE(isReceivingStream(nodeType, id));
} }
TEST_F(QuicStreamFunctionsTest, GetStreamInitiatorBidirectional) {
const auto clientStream1Id =
conn.streamManager->createNextBidirectionalStream().value()->id;
EXPECT_EQ(conn.streamManager->streamCount(), 1);
const auto clientStream2Id =
conn.streamManager->createNextBidirectionalStream().value()->id;
EXPECT_EQ(conn.streamManager->streamCount(), 2);
EXPECT_EQ(clientStream1Id, 0x00);
EXPECT_EQ(clientStream2Id, 0x04);
const auto serverStream1Id =
CHECK_NOTNULL(conn.streamManager->getStream(clientStream1Id + 1))->id;
const auto serverStream2Id =
CHECK_NOTNULL(conn.streamManager->getStream(clientStream2Id + 1))->id;
EXPECT_EQ(
StreamDirectionality::Bidirectional,
getStreamDirectionality(clientStream1Id));
EXPECT_EQ(
StreamDirectionality::Bidirectional,
getStreamDirectionality(clientStream2Id));
EXPECT_EQ(
StreamDirectionality::Bidirectional,
getStreamDirectionality(serverStream1Id));
EXPECT_EQ(
StreamDirectionality::Bidirectional,
getStreamDirectionality(serverStream2Id));
EXPECT_EQ(
StreamInitiator::Local,
getStreamInitiator(conn.nodeType, clientStream1Id));
EXPECT_EQ(
StreamInitiator::Local,
getStreamInitiator(conn.nodeType, clientStream2Id));
EXPECT_EQ(
StreamInitiator::Remote,
getStreamInitiator(conn.nodeType, serverStream1Id));
EXPECT_EQ(
StreamInitiator::Remote,
getStreamInitiator(conn.nodeType, serverStream2Id));
}
TEST_F(QuicServerStreamFunctionsTest, GetStreamInitiatorBidirectional) {
const auto serverStream1Id =
conn.streamManager->createNextBidirectionalStream().value()->id;
EXPECT_EQ(conn.streamManager->streamCount(), 1);
const auto serverStream2Id =
conn.streamManager->createNextBidirectionalStream().value()->id;
EXPECT_EQ(conn.streamManager->streamCount(), 2);
EXPECT_EQ(serverStream1Id, 0x01);
EXPECT_EQ(serverStream2Id, 0x05);
const auto clientStream1Id =
CHECK_NOTNULL(conn.streamManager->getStream(serverStream1Id - 1))->id;
const auto clientStream2Id =
CHECK_NOTNULL(conn.streamManager->getStream(serverStream2Id - 1))->id;
EXPECT_EQ(
StreamDirectionality::Bidirectional,
getStreamDirectionality(serverStream1Id));
EXPECT_EQ(
StreamDirectionality::Bidirectional,
getStreamDirectionality(serverStream2Id));
EXPECT_EQ(
StreamDirectionality::Bidirectional,
getStreamDirectionality(clientStream1Id));
EXPECT_EQ(
StreamDirectionality::Bidirectional,
getStreamDirectionality(clientStream2Id));
EXPECT_EQ(
StreamInitiator::Local,
getStreamInitiator(conn.nodeType, serverStream1Id));
EXPECT_EQ(
StreamInitiator::Local,
getStreamInitiator(conn.nodeType, serverStream2Id));
EXPECT_EQ(
StreamInitiator::Remote,
getStreamInitiator(conn.nodeType, clientStream1Id));
EXPECT_EQ(
StreamInitiator::Remote,
getStreamInitiator(conn.nodeType, clientStream2Id));
}
TEST_F(QuicStreamFunctionsTest, GetStreamInitiatorUnidirectional) {
const auto clientStream1Id =
conn.streamManager->createNextUnidirectionalStream().value()->id;
EXPECT_EQ(conn.streamManager->streamCount(), 1);
const auto clientStream2Id =
conn.streamManager->createNextUnidirectionalStream().value()->id;
EXPECT_EQ(conn.streamManager->streamCount(), 2);
EXPECT_EQ(clientStream1Id, 0x02);
EXPECT_EQ(clientStream2Id, 0x06);
const auto serverStream1Id =
CHECK_NOTNULL(conn.streamManager->getStream(clientStream1Id + 1))->id;
const auto serverStream2Id =
CHECK_NOTNULL(conn.streamManager->getStream(clientStream2Id + 1))->id;
EXPECT_EQ(
StreamDirectionality::Unidirectional,
getStreamDirectionality(clientStream1Id));
EXPECT_EQ(
StreamDirectionality::Unidirectional,
getStreamDirectionality(clientStream2Id));
EXPECT_EQ(
StreamDirectionality::Unidirectional,
getStreamDirectionality(serverStream1Id));
EXPECT_EQ(
StreamDirectionality::Unidirectional,
getStreamDirectionality(serverStream2Id));
EXPECT_EQ(
StreamInitiator::Local,
getStreamInitiator(conn.nodeType, clientStream1Id));
EXPECT_EQ(
StreamInitiator::Local,
getStreamInitiator(conn.nodeType, clientStream2Id));
EXPECT_EQ(
StreamInitiator::Remote,
getStreamInitiator(conn.nodeType, serverStream1Id));
EXPECT_EQ(
StreamInitiator::Remote,
getStreamInitiator(conn.nodeType, serverStream2Id));
}
TEST_F(QuicServerStreamFunctionsTest, GetStreamInitiatorUnidirectional) {
const auto serverStream1Id =
conn.streamManager->createNextUnidirectionalStream().value()->id;
EXPECT_EQ(conn.streamManager->streamCount(), 1);
const auto serverStream2Id =
conn.streamManager->createNextUnidirectionalStream().value()->id;
EXPECT_EQ(conn.streamManager->streamCount(), 2);
EXPECT_EQ(serverStream1Id, 0x03);
EXPECT_EQ(serverStream2Id, 0x07);
const auto clientStream1Id =
CHECK_NOTNULL(conn.streamManager->getStream(serverStream1Id - 1))->id;
const auto clientStream2Id =
CHECK_NOTNULL(conn.streamManager->getStream(serverStream2Id - 1))->id;
EXPECT_EQ(
StreamDirectionality::Unidirectional,
getStreamDirectionality(serverStream1Id));
EXPECT_EQ(
StreamDirectionality::Unidirectional,
getStreamDirectionality(serverStream2Id));
EXPECT_EQ(
StreamDirectionality::Unidirectional,
getStreamDirectionality(clientStream1Id));
EXPECT_EQ(
StreamDirectionality::Unidirectional,
getStreamDirectionality(clientStream2Id));
EXPECT_EQ(
StreamInitiator::Local,
getStreamInitiator(conn.nodeType, serverStream1Id));
EXPECT_EQ(
StreamInitiator::Local,
getStreamInitiator(conn.nodeType, serverStream2Id));
EXPECT_EQ(
StreamInitiator::Remote,
getStreamInitiator(conn.nodeType, clientStream1Id));
EXPECT_EQ(
StreamInitiator::Remote,
getStreamInitiator(conn.nodeType, clientStream2Id));
}
TEST_F(QuicStreamFunctionsTest, HasReadableDataNoData) { TEST_F(QuicStreamFunctionsTest, HasReadableDataNoData) {
auto stream = conn.streamManager->createNextBidirectionalStream().value(); auto stream = conn.streamManager->createNextBidirectionalStream().value();
auto buf1 = IOBuf::copyBuffer("just"); auto buf1 = IOBuf::copyBuffer("just");