diff --git a/quic/QuicConstants.h b/quic/QuicConstants.h index 9aa53f095..59570396b 100644 --- a/quic/QuicConstants.h +++ b/quic/QuicConstants.h @@ -159,6 +159,7 @@ enum class LocalErrorCode : uint32_t { INVALID_OPERATION = 0x40000017, STREAM_LIMIT_EXCEEDED = 0x40000018, CONNECTION_ABANDONED = 0x40000019, + CALLBACK_ALREADY_INSTALLED = 0x4000001A, }; using QuicErrorCode = diff --git a/quic/QuicException.cpp b/quic/QuicException.cpp index de53e40c1..c8979de31 100644 --- a/quic/QuicException.cpp +++ b/quic/QuicException.cpp @@ -84,6 +84,8 @@ std::string toString(LocalErrorCode code) { return "New version negotiatied"; case LocalErrorCode::INVALID_WRITE_CALLBACK: return "Invalid write callback"; + case LocalErrorCode::CALLBACK_ALREADY_INSTALLED: + return "Callback already installed"; case LocalErrorCode::TLS_HANDSHAKE_FAILED: return "TLS handshake failed"; case LocalErrorCode::APP_ERROR: diff --git a/quic/api/QuicTransportBase.cpp b/quic/api/QuicTransportBase.cpp index e66a3e676..6b10de62d 100644 --- a/quic/api/QuicTransportBase.cpp +++ b/quic/api/QuicTransportBase.cpp @@ -1636,10 +1636,19 @@ QuicTransportBase::notifyPendingWriteOnStream(StreamId id, WriteCallback* wcb) { if (!qStream->writable()) { return folly::makeUnexpected(LocalErrorCode::STREAM_CLOSED); } + + if (wcb == nullptr) { + return folly::makeUnexpected(LocalErrorCode::INVALID_WRITE_CALLBACK); + } // Add the callback to the pending write callbacks so that if we are closed // while we are scheduled in the loop, the close will error out the callbacks. - if (!pendingWriteCallbacks_.emplace(id, wcb).second) { - return folly::makeUnexpected(LocalErrorCode::INVALID_WRITE_CALLBACK); + auto wcbEmplaceResult = pendingWriteCallbacks_.emplace(id, wcb); + if (!wcbEmplaceResult.second) { + if ((wcbEmplaceResult.first)->second != wcb) { + return folly::makeUnexpected(LocalErrorCode::INVALID_WRITE_CALLBACK); + } else { + return folly::makeUnexpected(LocalErrorCode::CALLBACK_ALREADY_INSTALLED); + } } runOnEvbAsync([id](auto self) { auto wcbIt = self->pendingWriteCallbacks_.find(id); diff --git a/quic/api/test/QuicTransportBaseTest.cpp b/quic/api/test/QuicTransportBaseTest.cpp index 3768fd667..8213efd97 100644 --- a/quic/api/test/QuicTransportBaseTest.cpp +++ b/quic/api/test/QuicTransportBaseTest.cpp @@ -1214,6 +1214,17 @@ TEST_P(QuicTransportImplTestClose, TestNotifyPendingConnWriteOnCloseWithError) { evb->loopOnce(); } +TEST_F(QuicTransportImplTest, TestNotifyPendingWriteWithActiveCallback) { + auto stream = transport->createBidirectionalStream().value(); + MockWriteCallback wcb; + EXPECT_CALL(wcb, onStreamWriteReady(stream, _)); + auto ok1 = transport->notifyPendingWriteOnStream(stream, &wcb); + EXPECT_TRUE(ok1.hasValue()); + auto ok2 = transport->notifyPendingWriteOnStream(stream, &wcb); + EXPECT_EQ(ok2.error(), quic::LocalErrorCode::CALLBACK_ALREADY_INSTALLED); + evb->loopOnce(); +} + TEST_F(QuicTransportImplTest, TestNotifyPendingWriteOnCloseWithoutError) { auto stream = transport->createBidirectionalStream().value(); MockWriteCallback wcb;