diff --git a/quic/api/QuicStreamAsyncTransport.cpp b/quic/api/QuicStreamAsyncTransport.cpp index 52d950393..5b984f283 100644 --- a/quic/api/QuicStreamAsyncTransport.cpp +++ b/quic/api/QuicStreamAsyncTransport.cpp @@ -43,9 +43,10 @@ void QuicStreamAsyncTransport::setStreamId(quic::StreamId id) { id_ = id; // TODO: handle timeout for assigning stream id - - sock_->setReadCallback(*id_, this); - handleRead(); + if (readCb_) { + sock_->setReadCallback(*id_, this); + handleRead(); + } if (!writeCallbacks_.empty()) { // adjust offsets of buffered writes @@ -67,8 +68,7 @@ void QuicStreamAsyncTransport::setStreamId(quic::StreamId id) { void QuicStreamAsyncTransport::destroy() { if (state_ != CloseState::CLOSED) { - state_ = CloseState::CLOSED; - sock_->closeNow(folly::none); + closeNow(); } // Then call DelayedDestruction::destroy() to take care of // whether or not we need immediate or delayed destruction @@ -207,12 +207,15 @@ void QuicStreamAsyncTransport::close() { readEOF_ = EOFState::QUEUED; handleRead(); } - sock_->closeGracefully(); } void QuicStreamAsyncTransport::closeNow() { folly::AsyncSocketException ex( folly::AsyncSocketException::UNKNOWN, "Quic closeNow"); + if (id_) { + sock_->stopSending(*id_, quic::GenericApplicationErrorCode::UNKNOWN); + shutdownWriteNow(); + } closeNowImpl(std::move(ex)); } @@ -240,13 +243,11 @@ void QuicStreamAsyncTransport::shutdownWriteNow() { // writes already shutdown return; } - if (writeBuf_.empty()) { - shutdownWrite(); - } else { - if (id_) { - sock_->resetStream(*id_, quic::GenericApplicationErrorCode::UNKNOWN); - VLOG(4) << "Reset stream from shutdownWriteNow"; - } + shutdownWrite(); + send(0); + if (id_ && writeEOF_ != EOFState::DELIVERED) { + sock_->resetStream(*id_, quic::GenericApplicationErrorCode::UNKNOWN); + VLOG(4) << "Reset stream from shutdownWriteNow"; } } @@ -453,7 +454,8 @@ void QuicStreamAsyncTransport::send(uint64_t maxToSend) { } uint64_t sentOffset = *streamWriteOffset + toSend; - bool writeEOF = (writeEOF_ == EOFState::QUEUED); + bool writeEOF = + (writeEOF_ == EOFState::QUEUED && writeBuf_.chainLength() == toSend); auto res = sock_->writeChain( *id_, writeBuf_.split(toSend), @@ -531,7 +533,6 @@ void QuicStreamAsyncTransport::closeNowImpl(folly::AsyncSocketException&& ex) { sock_->unregisterStreamWriteCallback(*id_); id_.reset(); } - sock_->closeNow(folly::none); failWrites(*ex_); } diff --git a/quic/api/test/QuicStreamAsyncTransportTest.cpp b/quic/api/test/QuicStreamAsyncTransportTest.cpp index 8495d45bf..dffbc2737 100644 --- a/quic/api/test/QuicStreamAsyncTransportTest.cpp +++ b/quic/api/test/QuicStreamAsyncTransportTest.cpp @@ -5,6 +5,8 @@ * LICENSE file in the root directory of this source tree. */ +#include +#include #include #include #include @@ -26,11 +28,25 @@ using namespace testing; namespace quic::test { class QuicStreamAsyncTransportTest : public Test { + protected: + struct Stream { + Stream() = default; + Stream(const Stream&) = delete; + Stream& operator=(const Stream&) = delete; + Stream(Stream&&) = delete; + Stream& operator=(Stream&&) = delete; + folly::test::MockWriteCallback writeCb; + folly::test::MockReadCallback readCb; + QuicStreamAsyncTransport::UniquePtr transport; + std::array buf; + uint8_t serverDone{2}; // need to finish reads & writes + }; + public: void SetUp() override { folly::ssl::init(); createServer(); - createClient(); + connect(); } void createServer() { @@ -53,33 +69,6 @@ class QuicStreamAsyncTransportTest : public Test { return transport; })); - EXPECT_CALL(serverConnectionCB_, onNewBidirectionalStream(_)) - .WillOnce(Invoke([&](StreamId id) { - serverAsyncWrapper_ = - QuicStreamAsyncTransport::createWithExistingStream( - serverSocket_, id); - serverAsyncWrapper_->setReadCB(&serverReadCB_); - })); - - EXPECT_CALL(serverReadCB_, isBufferMovable_()) - .WillRepeatedly(Return(false)); - EXPECT_CALL(serverReadCB_, getReadBuffer(_, _)) - .WillRepeatedly(Invoke([&](void** buf, size_t* len) { - *buf = serverBuf_.data(); - *len = serverBuf_.size(); - })); - EXPECT_CALL(serverReadCB_, readDataAvailable_(_)) - .WillOnce(Invoke([&](auto len) { - auto echoData = folly::IOBuf::copyBuffer("echo "); - echoData->appendChain( - folly::IOBuf::wrapBuffer(serverBuf_.data(), len)); - serverAsyncWrapper_->writeChain(&serverWriteCB_, std::move(echoData)); - serverAsyncWrapper_->shutdownWrite(); - })); - EXPECT_CALL(serverReadCB_, readEOF_()).WillOnce(Return()); - - EXPECT_CALL(serverWriteCB_, writeSuccess_()).WillOnce(Return()); - server_ = QuicServer::createQuicServer(); auto serverCtx = test::createServerCtx(); server_->setFizzContext(serverCtx); @@ -92,38 +81,83 @@ class QuicStreamAsyncTransportTest : public Test { serverAddr_ = server_->getAddress(); } - void createClient() { - clientEvbThread_ = std::thread([&]() { clientEvb_.loopForever(); }); + void expectNewServerStream() { + EXPECT_CALL(serverConnectionCB_, onNewBidirectionalStream(_)) + .WillOnce(Invoke([&](StreamId id) { + auto res = streams_.emplace( + std::piecewise_construct, + std::forward_as_tuple(id), + std::forward_as_tuple(std::make_unique())); + auto& newStream = *res.first->second; + newStream.transport = + QuicStreamAsyncTransport::createWithExistingStream( + serverSocket_, id); + EXPECT_CALL(newStream.readCb, readEOF_()).WillOnce(Invoke([this, id] { + auto& stream = *streams_[id]; + if (--stream.serverDone == 0) { + stream.transport->close(); + } + })); + EXPECT_CALL(newStream.readCb, isBufferMovable_()) + .WillRepeatedly(Return(false)); + EXPECT_CALL(newStream.readCb, getReadBuffer(_, _)) + .WillRepeatedly(Invoke([this, id](void** buf, size_t* len) { + auto& stream = *streams_[id]; + *buf = stream.buf.data(); + *len = stream.buf.size(); + })); + EXPECT_CALL(newStream.readCb, readDataAvailable_(_)) + .WillRepeatedly(Invoke([this, id](auto len) { + auto& stream = *streams_[id]; + auto echoData = folly::IOBuf::copyBuffer("echo "); + echoData->appendChain( + folly::IOBuf::wrapBuffer(stream.buf.data(), len)); + EXPECT_CALL(stream.writeCb, writeSuccess_()) + .WillOnce(Return()) + .RetiresOnSaturation(); + if (stream.transport->good()) { + // Echo the first readDataAvailable_ only + stream.transport->writeChain( + &stream.writeCb, std::move(echoData)); + stream.transport->shutdownWrite(); + if (--stream.serverDone == 0) { + stream.transport->close(); + } + } + })); + newStream.transport->setReadCB(&newStream.readCb); + })) + .RetiresOnSaturation(); + } - EXPECT_CALL(clientConnectionSetupCB_, onTransportReady()) - .WillOnce(Invoke([&]() { - clientAsyncWrapper_ = - QuicStreamAsyncTransport::createWithNewStream(client_); - ASSERT_TRUE(clientAsyncWrapper_); - clientAsyncWrapper_->setReadCB(&clientReadCB_); - startPromise_.setValue(); - })); + std::unique_ptr createClient(bool setReadCB = true) { + auto clientStream = std::make_unique(); + clientStream->transport = + QuicStreamAsyncTransport::createWithNewStream(client_); + CHECK(clientStream->transport); - EXPECT_CALL(clientReadCB_, isBufferMovable_()) + EXPECT_CALL(clientStream->readCb, isBufferMovable_()) .WillRepeatedly(Return(false)); - EXPECT_CALL(clientReadCB_, getReadBuffer(_, _)) - .WillRepeatedly(Invoke([&](void** buf, size_t* len) { - *buf = clientBuf_.data(); - *len = clientBuf_.size(); - })); - EXPECT_CALL(clientReadCB_, readDataAvailable_(_)) - .WillOnce(Invoke([&](auto len) { - clientReadPromise_.setValue( - std::string(reinterpret_cast(clientBuf_.data()), len)); - })); - EXPECT_CALL(clientReadCB_, readEOF_()).WillOnce(Return()); + EXPECT_CALL(clientStream->readCb, getReadBuffer(_, _)) + .WillRepeatedly(Invoke( + [clientStream = clientStream.get()](void** buf, size_t* len) { + *buf = clientStream->buf.data(); + *len = clientStream->buf.size(); + })); - EXPECT_CALL(clientWriteCB_, writeSuccess_()).WillOnce(Return()); + if (setReadCB) { + clientStream->transport->setReadCB(&clientStream->readCb); + } + return clientStream; + } - auto [promise, future] = folly::makePromiseContract(); - startPromise_ = std::move(promise); + void connect() { + auto [promiseX, future] = folly::makePromiseContract(); + auto promise = std::move(promiseX); + EXPECT_CALL(clientConnectionSetupCB_, onTransportReady()) + .WillOnce(Invoke([&promise]() mutable { promise.setValue(); })); - clientEvb_.runInEventBaseThreadAndWait([&]() { + clientEvb_.runInLoop([&]() { auto sock = std::make_unique(&clientEvb_); auto fizzClientContext = FizzClientQuicHandshakeContext::Builder() @@ -136,22 +170,17 @@ class QuicStreamAsyncTransportTest : public Test { client_->start(&clientConnectionSetupCB_, &clientConnectionCB_); }); - std::move(future).get(1s); + std::move(future).via(&clientEvb_).waitVia(&clientEvb_); } void TearDown() override { - if (serverAsyncWrapper_) { - serverAsyncWrapper_->getEventBase()->runInEventBaseThreadAndWait( - [&]() { serverAsyncWrapper_.reset(); }); + if (client_) { + client_->close(folly::none); } + clientEvb_.loop(); server_->shutdown(); server_ = nullptr; - clientEvb_.runInEventBaseThreadAndWait([&] { - clientAsyncWrapper_ = nullptr; - client_ = nullptr; - }); - clientEvb_.terminateLoopSoon(); - clientEvbThread_.join(); + client_ = nullptr; } protected: @@ -160,36 +189,110 @@ class QuicStreamAsyncTransportTest : public Test { NiceMock serverConnectionSetupCB_; NiceMock serverConnectionCB_; std::shared_ptr serverSocket_; - QuicStreamAsyncTransport::UniquePtr serverAsyncWrapper_; - folly::test::MockWriteCallback serverWriteCB_; - folly::test::MockReadCallback serverReadCB_; - std::array serverBuf_; + folly::F14FastMap> streams_; std::shared_ptr client_; folly::EventBase clientEvb_; - std::thread clientEvbThread_; NiceMock clientConnectionSetupCB_; NiceMock clientConnectionCB_; - QuicStreamAsyncTransport::UniquePtr clientAsyncWrapper_; - folly::Promise startPromise_; - folly::test::MockWriteCallback clientWriteCB_; - folly::test::MockReadCallback clientReadCB_; - std::array clientBuf_; - folly::Promise clientReadPromise_; }; TEST_F(QuicStreamAsyncTransportTest, ReadWrite) { - auto [promise, future] = folly::makePromiseContract(); - clientReadPromise_ = std::move(promise); + expectNewServerStream(); + auto clientStream = createClient(); + EXPECT_CALL(clientStream->readCb, readEOF_()).WillOnce(Return()); + auto [promiseX, future] = folly::makePromiseContract(); + auto promise = std::move(promiseX); + EXPECT_CALL(clientStream->readCb, readDataAvailable_(_)) + .WillOnce(Invoke([&clientStream, &promise](auto len) mutable { + promise.setValue(std::string( + reinterpret_cast(clientStream->buf.data()), len)); + })); std::string msg = "yo yo!"; - clientEvb_.runInEventBaseThreadAndWait([&] { - clientAsyncWrapper_->write(&clientWriteCB_, msg.data(), msg.size()); - clientAsyncWrapper_->shutdownWrite(); - }); + EXPECT_CALL(clientStream->writeCb, writeSuccess_()).WillOnce(Return()); + clientStream->transport->write( + &clientStream->writeCb, msg.data(), msg.size()); + clientStream->transport->shutdownWrite(); - std::string clientReadString = std::move(future).get(1s); - EXPECT_EQ(clientReadString, "echo yo yo!"); + EXPECT_EQ( + std::move(future).via(&clientEvb_).getVia(&clientEvb_), "echo yo yo!"); +} + +TEST_F(QuicStreamAsyncTransportTest, TwoClients) { + std::list> clientStreams; + std::list> futures; + std::string msg = "yo yo!"; + for (auto i = 0; i < 2; i++) { + expectNewServerStream(); + clientStreams.emplace_back(createClient()); + auto& clientStream = clientStreams.back(); + EXPECT_CALL(clientStream->readCb, readEOF_()).WillOnce(Return()); + auto [promiseX, future] = folly::makePromiseContract(); + auto promise = std::move(promiseX); + futures.emplace_back(std::move(future)); + EXPECT_CALL(clientStream->readCb, readDataAvailable_(_)) + .WillOnce(Invoke( + [clientStream = clientStream.get(), + p = folly::MoveWrapper(std::move(promise))](auto len) mutable { + p->setValue(std::string( + reinterpret_cast(clientStream->buf.data()), len)); + })); + + EXPECT_CALL(clientStream->writeCb, writeSuccess_()).WillOnce(Return()); + clientStream->transport->write( + &clientStream->writeCb, msg.data(), msg.size()); + clientStream->transport->shutdownWrite(); + } + for (auto& future : futures) { + EXPECT_EQ( + std::move(future).via(&clientEvb_).getVia(&clientEvb_), "echo yo yo!"); + } +} + +TEST_F(QuicStreamAsyncTransportTest, DelayedSetReadCB) { + expectNewServerStream(); + auto clientStream = createClient(/*setReadCB=*/false); + auto [promiseX, future] = folly::makePromiseContract(); + auto promise = std::move(promiseX); + EXPECT_CALL(clientStream->readCb, readDataAvailable_(_)) + .WillOnce(Invoke([&clientStream, &promise](auto len) mutable { + promise.setValue(std::string( + reinterpret_cast(clientStream->buf.data()), len)); + })); + + std::string msg = "yo yo!"; + EXPECT_CALL(clientStream->writeCb, writeSuccess_()).WillOnce(Return()); + clientStream->transport->write( + &clientStream->writeCb, msg.data(), msg.size()); + clientEvb_.runAfterDelay( + [&clientStream] { + EXPECT_CALL(clientStream->readCb, readEOF_()).WillOnce(Return()); + clientStream->transport->setReadCB(&clientStream->readCb); + clientStream->transport->shutdownWrite(); + }, + 750); + EXPECT_EQ( + std::move(future).via(&clientEvb_).getVia(&clientEvb_), "echo yo yo!"); +} + +TEST_F(QuicStreamAsyncTransportTest, close) { + auto clientStream = createClient(/*setReadCB=*/false); + EXPECT_TRUE(client_->good()); + clientStream->transport->close(); + clientStream->transport.reset(); + EXPECT_TRUE(client_->good()); + clientEvb_.loopOnce(); +} + +TEST_F(QuicStreamAsyncTransportTest, closeNow) { + auto clientStream = createClient(/*setReadCB=*/false); + EXPECT_TRUE(client_->good()); + clientStream->transport->closeNow(); + clientStream->transport.reset(); + // The quic socket is still good + EXPECT_TRUE(client_->good()); + clientEvb_.loopOnce(); } } // namespace quic::test