1
0
mirror of https://github.com/facebookincubator/mvfst.git synced 2025-08-09 20:42:44 +03:00

Allow specifying error code in setReadCallback

Summary: We have an API behavior where setReadCalback will issue a StopSending on behalf of the app. This is useful but has confusing semantics as it always defaults to GenericApplicationError::NO_ERROR. Instead let the error be specified as part of the API.

Reviewed By: yangchi, lnicco

Differential Revision: D23055196

fbshipit-source-id: 755f4122bf445016c9b5adb23c3090fc23173eb9
This commit is contained in:
Matt Joras
2020-08-13 18:27:17 -07:00
committed by Facebook GitHub Bot
parent acf0d30c16
commit 81756e3d13
6 changed files with 111 additions and 13 deletions

View File

@@ -439,11 +439,19 @@ class QuicSocket {
/** /**
* Set the read callback for the given stream. Note that read callback is * Set the read callback for the given stream. Note that read callback is
* expected to be set all the time. Removing read callback indicates that * expected to be set all the time. Removing read callback indicates that
* stream is no longer intended to be read again. * stream is no longer intended to be read again. This will issue a
* StopSending if cb is being set to nullptr after previously being not
* nullptr. The err parameter is used to control the error sent in the
* StopSending. By default when cb is nullptr this function will cause the
* transport to send a StopSending frame with
* GenericApplicationErrorCode::NO_ERROR. If err is specified to be
* folly::none, no StopSending will be sent.
*/ */
virtual folly::Expected<folly::Unit, LocalErrorCode> setReadCallback( virtual folly::Expected<folly::Unit, LocalErrorCode> setReadCallback(
StreamId id, StreamId id,
ReadCallback* cb) = 0; ReadCallback* cb,
folly::Optional<ApplicationErrorCode> err =
GenericApplicationErrorCode::NO_ERROR) = 0;
/** /**
* Convenience function that sets the read callbacks of all streams to be * Convenience function that sets the read callbacks of all streams to be

View File

@@ -650,7 +650,8 @@ QuicTransportBase::setStreamFlowControlWindow(
folly::Expected<folly::Unit, LocalErrorCode> QuicTransportBase::setReadCallback( folly::Expected<folly::Unit, LocalErrorCode> QuicTransportBase::setReadCallback(
StreamId id, StreamId id,
ReadCallback* cb) { ReadCallback* cb,
folly::Optional<ApplicationErrorCode> err) {
if (isSendingStream(conn_->nodeType, id)) { if (isSendingStream(conn_->nodeType, id)) {
return folly::makeUnexpected(LocalErrorCode::INVALID_OPERATION); return folly::makeUnexpected(LocalErrorCode::INVALID_OPERATION);
} }
@@ -660,12 +661,15 @@ folly::Expected<folly::Unit, LocalErrorCode> QuicTransportBase::setReadCallback(
if (!conn_->streamManager->streamExists(id)) { if (!conn_->streamManager->streamExists(id)) {
return folly::makeUnexpected(LocalErrorCode::STREAM_NOT_EXISTS); return folly::makeUnexpected(LocalErrorCode::STREAM_NOT_EXISTS);
} }
return setReadCallbackInternal(id, cb); return setReadCallbackInternal(id, cb, err);
} }
void QuicTransportBase::unsetAllReadCallbacks() { void QuicTransportBase::unsetAllReadCallbacks() {
for (auto& streamCallbackPair : readCallbacks_) { for (auto& streamCallbackPair : readCallbacks_) {
setReadCallbackInternal(streamCallbackPair.first, nullptr); setReadCallbackInternal(
streamCallbackPair.first,
nullptr,
GenericApplicationErrorCode::NO_ERROR);
} }
} }
@@ -685,7 +689,8 @@ void QuicTransportBase::unsetAllDeliveryCallbacks() {
folly::Expected<folly::Unit, LocalErrorCode> folly::Expected<folly::Unit, LocalErrorCode>
QuicTransportBase::setReadCallbackInternal( QuicTransportBase::setReadCallbackInternal(
StreamId id, StreamId id,
ReadCallback* cb) noexcept { ReadCallback* cb,
folly::Optional<ApplicationErrorCode> err) noexcept {
VLOG(4) << "Setting setReadCallback for stream=" << id << " cb=" << cb << " " VLOG(4) << "Setting setReadCallback for stream=" << id << " cb=" << cb << " "
<< *this; << *this;
auto readCbIt = readCallbacks_.find(id); auto readCbIt = readCallbacks_.find(id);
@@ -702,8 +707,8 @@ QuicTransportBase::setReadCallbackInternal(
return folly::makeUnexpected(LocalErrorCode::INVALID_OPERATION); return folly::makeUnexpected(LocalErrorCode::INVALID_OPERATION);
} else { } else {
readCb = cb; readCb = cb;
if (readCb == nullptr) { if (readCb == nullptr && err) {
return stopSending(id, GenericApplicationErrorCode::NO_ERROR); return stopSending(id, err.value());
} }
} }
updateReadLooper(); updateReadLooper();

View File

@@ -114,7 +114,9 @@ class QuicTransportBase : public QuicSocket {
folly::Expected<folly::Unit, LocalErrorCode> setReadCallback( folly::Expected<folly::Unit, LocalErrorCode> setReadCallback(
StreamId id, StreamId id,
ReadCallback* cb) override; ReadCallback* cb,
folly::Optional<ApplicationErrorCode> err =
GenericApplicationErrorCode::NO_ERROR) override;
void unsetAllReadCallbacks() override; void unsetAllReadCallbacks() override;
void unsetAllPeekCallbacks() override; void unsetAllPeekCallbacks() override;
void unsetAllDeliveryCallbacks() override; void unsetAllDeliveryCallbacks() override;
@@ -610,7 +612,8 @@ class QuicTransportBase : public QuicSocket {
void checkForClosedStream(); void checkForClosedStream();
folly::Expected<folly::Unit, LocalErrorCode> setReadCallbackInternal( folly::Expected<folly::Unit, LocalErrorCode> setReadCallbackInternal(
StreamId id, StreamId id,
ReadCallback* cb) noexcept; ReadCallback* cb,
folly::Optional<ApplicationErrorCode> err) noexcept;
folly::Expected<folly::Unit, LocalErrorCode> setPeekCallbackInternal( folly::Expected<folly::Unit, LocalErrorCode> setPeekCallbackInternal(
StreamId id, StreamId id,
PeekCallback* cb) noexcept; PeekCallback* cb) noexcept;

View File

@@ -84,9 +84,12 @@ class MockQuicSocket : public QuicSocket {
folly::Expected<folly::Unit, LocalErrorCode>(StreamId, uint64_t)); folly::Expected<folly::Unit, LocalErrorCode>(StreamId, uint64_t));
MOCK_METHOD1(setTransportSettings, void(TransportSettings)); MOCK_METHOD1(setTransportSettings, void(TransportSettings));
MOCK_CONST_METHOD0(isPartiallyReliableTransport, bool()); MOCK_CONST_METHOD0(isPartiallyReliableTransport, bool());
MOCK_METHOD2( MOCK_METHOD3(
setReadCallback, setReadCallback,
folly::Expected<folly::Unit, LocalErrorCode>(StreamId, ReadCallback*)); folly::Expected<folly::Unit, LocalErrorCode>(
StreamId,
ReadCallback*,
folly::Optional<ApplicationErrorCode> err));
MOCK_METHOD1(setConnectionCallback, void(ConnectionCallback*)); MOCK_METHOD1(setConnectionCallback, void(ConnectionCallback*));
void setEarlyDataAppParamsFunctions( void setEarlyDataAppParamsFunctions(
folly::Function<bool(const folly::Optional<std::string>&, const Buf&) folly::Function<bool(const folly::Optional<std::string>&, const Buf&)

View File

@@ -888,6 +888,85 @@ TEST_F(QuicTransportTest, StopSending) {
EXPECT_TRUE(foundStopSending); EXPECT_TRUE(foundStopSending);
} }
TEST_F(QuicTransportTest, StopSendingReadCallbackDefault) {
auto streamId = transport_->createBidirectionalStream().value();
NiceMock<MockReadCallback> readCb;
EXPECT_CALL(*socket_, write(_, _)).WillOnce(Invoke(bufLength));
transport_->setReadCallback(streamId, &readCb);
transport_->setReadCallback(streamId, nullptr);
loopForWrites();
EXPECT_EQ(1, transport_->getConnectionState().outstandings.packets.size());
auto packet =
getLastOutstandingPacket(
transport_->getConnectionState(), PacketNumberSpace::AppData)
->packet;
EXPECT_EQ(1, packet.frames.size());
bool foundStopSending = false;
for (auto& frame : packet.frames) {
const QuicSimpleFrame* simpleFrame = frame.asQuicSimpleFrame();
if (!simpleFrame) {
continue;
}
const StopSendingFrame* stopSending = simpleFrame->asStopSendingFrame();
if (!stopSending) {
continue;
}
EXPECT_EQ(streamId, stopSending->streamId);
EXPECT_EQ(GenericApplicationErrorCode::NO_ERROR, stopSending->errorCode);
foundStopSending = true;
}
EXPECT_TRUE(foundStopSending);
}
TEST_F(QuicTransportTest, StopSendingReadCallback) {
auto streamId = transport_->createBidirectionalStream().value();
NiceMock<MockReadCallback> readCb;
EXPECT_CALL(*socket_, write(_, _)).WillOnce(Invoke(bufLength));
transport_->setReadCallback(streamId, &readCb);
transport_->setReadCallback(
streamId, nullptr, GenericApplicationErrorCode::UNKNOWN);
loopForWrites();
EXPECT_EQ(1, transport_->getConnectionState().outstandings.packets.size());
auto packet =
getLastOutstandingPacket(
transport_->getConnectionState(), PacketNumberSpace::AppData)
->packet;
EXPECT_EQ(1, packet.frames.size());
bool foundStopSending = false;
for (auto& frame : packet.frames) {
const QuicSimpleFrame* simpleFrame = frame.asQuicSimpleFrame();
if (!simpleFrame) {
continue;
}
const StopSendingFrame* stopSending = simpleFrame->asStopSendingFrame();
if (!stopSending) {
continue;
}
EXPECT_EQ(streamId, stopSending->streamId);
EXPECT_EQ(GenericApplicationErrorCode::UNKNOWN, stopSending->errorCode);
foundStopSending = true;
}
EXPECT_TRUE(foundStopSending);
}
TEST_F(QuicTransportTest, StopSendingReadCallbackNone) {
auto streamId = transport_->createBidirectionalStream().value();
NiceMock<MockReadCallback> readCb;
transport_->setReadCallback(streamId, &readCb);
transport_->setReadCallback(streamId, nullptr, folly::none);
loopForWrites();
EXPECT_EQ(0, transport_->getConnectionState().outstandings.packets.size());
}
TEST_F(QuicTransportTest, NoStopSendingReadCallback) {
auto streamId = transport_->createBidirectionalStream().value();
NiceMock<MockReadCallback> readCb;
transport_->setReadCallback(streamId, &readCb);
loopForWrites();
EXPECT_EQ(0, transport_->getConnectionState().outstandings.packets.size());
transport_->setReadCallback(streamId, nullptr, folly::none);
}
TEST_F(QuicTransportTest, SendPathChallenge) { TEST_F(QuicTransportTest, SendPathChallenge) {
EXPECT_CALL(*socket_, write(_, _)).WillOnce(Invoke(bufLength)); EXPECT_CALL(*socket_, write(_, _)).WillOnce(Invoke(bufLength));
auto& conn = transport_->getConnectionState(); auto& conn = transport_->getConnectionState();

View File

@@ -27,7 +27,7 @@ class QuicSocketTest : public Test {
} }
void openStream(StreamId) { void openStream(StreamId) {
EXPECT_CALL(*socket_, setReadCallback(3, &handler_)); EXPECT_CALL(*socket_, setReadCallback(3, &handler_, _));
socket_->cb_->onNewBidirectionalStream(3); socket_->cb_->onNewBidirectionalStream(3);
} }