diff --git a/quic/api/QuicStreamAsyncTransport.cpp b/quic/api/QuicStreamAsyncTransport.cpp index 00480fff1..20e51d0b0 100644 --- a/quic/api/QuicStreamAsyncTransport.cpp +++ b/quic/api/QuicStreamAsyncTransport.cpp @@ -7,7 +7,6 @@ */ #include - #include namespace quic { @@ -19,34 +18,75 @@ QuicStreamAsyncTransport::createWithNewStream( if (!streamId) { return nullptr; } - UniquePtr ptr( - new QuicStreamAsyncTransport(std::move(sock), streamId.value())); - return ptr; + return createWithExistingStream(std::move(sock), *streamId); } QuicStreamAsyncTransport::UniquePtr QuicStreamAsyncTransport::createWithExistingStream( std::shared_ptr sock, quic::StreamId streamId) { - UniquePtr ptr(new QuicStreamAsyncTransport(std::move(sock), streamId)); + UniquePtr ptr(new QuicStreamAsyncTransport()); + ptr->setSocket(std::move(sock)); + ptr->setStreamId(streamId); return ptr; } -QuicStreamAsyncTransport::QuicStreamAsyncTransport( - std::shared_ptr sock, - quic::StreamId id) - : sock_(std::move(sock)), id_(id) {} +void QuicStreamAsyncTransport::setSocket( + std::shared_ptr sock) { + sock_ = std::move(sock); +} -QuicStreamAsyncTransport::~QuicStreamAsyncTransport() { - sock_->setReadCallback(id_, nullptr); - closeWithReset(); +void QuicStreamAsyncTransport::setStreamId(quic::StreamId id) { + CHECK(!id_.hasValue()) << "stream id can only be set once"; + CHECK(state_ == CloseState::OPEN) << "Current state: " << (int)state_; + + id_ = id; + + // TODO: handle timeout for assigning stream id + + sock_->setReadCallback(*id_, this); + handleRead(); + + if (!writeCallbacks_.empty()) { + // adjust offsets of buffered writes + auto streamWriteOffset = sock_->getStreamWriteOffset(*id_); + if (streamWriteOffset.hasError()) { + folly::AsyncSocketException ex( + folly::AsyncSocketException::UNKNOWN, + folly::to( + "Quic write error: ", toString(streamWriteOffset.error()))); + closeNowImpl(std::move(ex)); + return; + } + for (auto& p : writeCallbacks_) { + p.first += *streamWriteOffset; + } + sock_->notifyPendingWriteOnStream(*id_, this); + } +} + +void QuicStreamAsyncTransport::destroy() { + if (state_ != CloseState::CLOSED) { + state_ = CloseState::CLOSED; + sock_->closeNow(folly::none); + } + // Then call DelayedDestruction::destroy() to take care of + // whether or not we need immediate or delayed destruction + DelayedDestruction::destroy(); } void QuicStreamAsyncTransport::setReadCB( AsyncTransport::ReadCallback* callback) { readCb_ = callback; - // It should be ok to do this immediately, rather than in the loop - handleRead(); + if (id_) { + if (readCb_) { + sock_->setReadCallback(*id_, this); + // It should be ok to do this immediately, rather than in the loop + handleRead(); + } else { + sock_->setReadCallback(*id_, nullptr); + } + } } folly::AsyncTransport::ReadCallback* QuicStreamAsyncTransport::getReadCallback() @@ -56,13 +96,15 @@ folly::AsyncTransport::ReadCallback* QuicStreamAsyncTransport::getReadCallback() void QuicStreamAsyncTransport::addWriteCallback( AsyncTransport::WriteCallback* callback, - size_t offset, - size_t size) { + size_t offset) { + size_t size = writeBuf_.chainLength(); writeCallbacks_.emplace_back(offset + size, callback); - sock_->notifyPendingWriteOnStream(id_, this); + if (id_) { + sock_->notifyPendingWriteOnStream(*id_, this); + } } -void QuicStreamAsyncTransport::handleOffsetError( +void QuicStreamAsyncTransport::handleWriteOffsetError( AsyncTransport::WriteCallback* callback, LocalErrorCode error) { folly::AsyncSocketException ex( @@ -71,18 +113,50 @@ void QuicStreamAsyncTransport::handleOffsetError( callback->writeErr(0, ex); } +bool QuicStreamAsyncTransport::handleWriteStateError( + AsyncTransport::WriteCallback* callback) { + if (writeEOF_ != EOFState::NOT_SEEN) { + folly::AsyncSocketException ex( + folly::AsyncSocketException::UNKNOWN, + "Quic write error: bad EOF state"); + callback->writeErr(0, ex); + return true; + } else if (state_ == CloseState::CLOSED) { + folly::AsyncSocketException ex( + folly::AsyncSocketException::UNKNOWN, "Quic write error: closed state"); + callback->writeErr(0, ex); + return true; + } else if (ex_) { + callback->writeErr(0, *ex_); + return true; + } else { + return false; + } +} + +folly::Expected +QuicStreamAsyncTransport::getStreamWriteOffset() const { + if (!id_) { + return 0; + } + return sock_->getStreamWriteOffset(*id_); +} + void QuicStreamAsyncTransport::write( AsyncTransport::WriteCallback* callback, const void* buf, size_t bytes, folly::WriteFlags /*flags*/) { - auto streamWriteOffset = sock_->getStreamWriteOffset(id_); + if (handleWriteStateError(callback)) { + return; + } + auto streamWriteOffset = getStreamWriteOffset(); if (streamWriteOffset.hasError()) { - handleOffsetError(callback, streamWriteOffset.error()); + handleWriteOffsetError(callback, streamWriteOffset.error()); return; } writeBuf_.append(folly::IOBuf::wrapBuffer(buf, bytes)); - addWriteCallback(callback, *streamWriteOffset, bytes); + addWriteCallback(callback, *streamWriteOffset); } void QuicStreamAsyncTransport::writev( @@ -90,35 +164,41 @@ void QuicStreamAsyncTransport::writev( const iovec* vec, size_t count, folly::WriteFlags /*flags*/) { - auto streamWriteOffset = sock_->getStreamWriteOffset(id_); - if (streamWriteOffset.hasError()) { - handleOffsetError(callback, streamWriteOffset.error()); + if (handleWriteStateError(callback)) { + return; + } + auto streamWriteOffset = getStreamWriteOffset(); + if (streamWriteOffset.hasError()) { + handleWriteOffsetError(callback, streamWriteOffset.error()); return; } - size_t totalBytes = 0; for (size_t i = 0; i < count; i++) { writeBuf_.append(folly::IOBuf::wrapBuffer(vec[i].iov_base, vec[i].iov_len)); - totalBytes += vec[i].iov_len; } - addWriteCallback(callback, *streamWriteOffset, totalBytes); + addWriteCallback(callback, *streamWriteOffset); } void QuicStreamAsyncTransport::writeChain( AsyncTransport::WriteCallback* callback, std::unique_ptr&& buf, folly::WriteFlags /*flags*/) { - auto streamWriteOffset = sock_->getStreamWriteOffset(id_); - if (streamWriteOffset.hasError()) { - handleOffsetError(callback, streamWriteOffset.error()); + if (handleWriteStateError(callback)) { + return; + } + auto streamWriteOffset = getStreamWriteOffset(); + if (streamWriteOffset.hasError()) { + handleWriteOffsetError(callback, streamWriteOffset.error()); return; } - size_t len = buf->computeChainDataLength(); writeBuf_.append(std::move(buf)); - addWriteCallback(callback, *streamWriteOffset, len); + addWriteCallback(callback, *streamWriteOffset); } void QuicStreamAsyncTransport::close() { - sock_->stopSending(id_, quic::GenericApplicationErrorCode::UNKNOWN); + state_ = CloseState::CLOSING; + if (id_) { + sock_->stopSending(*id_, quic::GenericApplicationErrorCode::UNKNOWN); + } shutdownWrite(); if (readCb_ && readEOF_ != EOFState::DELIVERED) { // This is such a bizarre operation. I almost think if we haven't seen @@ -127,41 +207,46 @@ void QuicStreamAsyncTransport::close() { readEOF_ = EOFState::QUEUED; handleRead(); } + sock_->closeGracefully(); } void QuicStreamAsyncTransport::closeNow() { - if (writeBuf_.empty()) { - close(); - } else { - sock_->stopSending(id_, quic::GenericApplicationErrorCode::UNKNOWN); - sock_->resetStream(id_, quic::GenericApplicationErrorCode::UNKNOWN); - VLOG(4) << "Reset stream from closeNow"; - } + folly::AsyncSocketException ex( + folly::AsyncSocketException::UNKNOWN, "Quic closeNow"); + closeNowImpl(std::move(ex)); } void QuicStreamAsyncTransport::closeWithReset() { - sock_->stopSending(id_, quic::GenericApplicationErrorCode::UNKNOWN); - sock_->resetStream(id_, quic::GenericApplicationErrorCode::UNKNOWN); - VLOG(4) << "Reset stream from closeWithReset"; + if (id_) { + sock_->stopSending(*id_, quic::GenericApplicationErrorCode::UNKNOWN); + sock_->resetStream(*id_, quic::GenericApplicationErrorCode::UNKNOWN); + } + folly::AsyncSocketException ex( + folly::AsyncSocketException::UNKNOWN, "Quic closeNow"); + closeNowImpl(std::move(ex)); } void QuicStreamAsyncTransport::shutdownWrite() { if (writeEOF_ == EOFState::NOT_SEEN) { writeEOF_ = EOFState::QUEUED; - sock_->notifyPendingWriteOnStream(id_, this); + if (id_) { + sock_->notifyPendingWriteOnStream(*id_, this); + } } } void QuicStreamAsyncTransport::shutdownWriteNow() { - if (readEOF_ == EOFState::DELIVERED) { + if (writeEOF_ == EOFState::DELIVERED) { // writes already shutdown return; } if (writeBuf_.empty()) { shutdownWrite(); } else { - sock_->resetStream(id_, quic::GenericApplicationErrorCode::UNKNOWN); - VLOG(4) << "Reset stream from shutdownWriteNow"; + if (id_) { + sock_->resetStream(*id_, quic::GenericApplicationErrorCode::UNKNOWN); + VLOG(4) << "Reset stream from shutdownWriteNow"; + } } } @@ -184,7 +269,7 @@ bool QuicStreamAsyncTransport::isPending() const { } bool QuicStreamAsyncTransport::connecting() const { - return false; + return !id_.hasValue() && (state_ == CloseState::OPEN); } bool QuicStreamAsyncTransport::error() const { @@ -205,7 +290,7 @@ void QuicStreamAsyncTransport::detachEventBase() { } bool QuicStreamAsyncTransport::isDetachable() const { - return false; // ? + return false; } void QuicStreamAsyncTransport::setSendTimeout(uint32_t /*milliseconds*/) { @@ -235,27 +320,22 @@ bool QuicStreamAsyncTransport::isEorTrackingEnabled() const { void QuicStreamAsyncTransport::setEorTracking(bool /*track*/) {} size_t QuicStreamAsyncTransport::getAppBytesWritten() const { - auto res = sock_->getStreamWriteOffset(id_); + auto res = getStreamWriteOffset(); // TODO: track written bytes to have it available after QUIC stream closure return res.hasError() ? 0 : res.value(); } size_t QuicStreamAsyncTransport::getRawBytesWritten() const { - auto res = sock_->getStreamWriteOffset(id_); - // TODO: track written bytes to have it available after QUIC stream closure - return res.hasError() ? 0 : res.value(); + return getAppBytesWritten(); } size_t QuicStreamAsyncTransport::getAppBytesReceived() const { - auto res = sock_->getStreamReadOffset(id_); // TODO: track read bytes to have it available after QUIC stream closure - return res.hasError() ? 0 : res.value(); + return 0; } size_t QuicStreamAsyncTransport::getRawBytesReceived() const { - auto res = sock_->getStreamReadOffset(id_); - // TODO: track read bytes to have it available after QUIC stream closure - return res.hasError() ? 0 : res.value(); + return getAppBytesReceived(); } std::string QuicStreamAsyncTransport::getApplicationProtocol() const noexcept { @@ -293,8 +373,8 @@ void QuicStreamAsyncTransport::handleRead() { folly::DelayedDestruction::DestructorGuard dg(this); bool emptyRead = false; size_t numReads = 0; - while (readCb_ && !ex_ && readEOF_ == EOFState::NOT_SEEN && !emptyRead && - ++numReads < 16 /* max reads per event */) { + while (readCb_ && id_ && !ex_ && readEOF_ == EOFState::NOT_SEEN && + !emptyRead && ++numReads < 16 /* max reads per event */) { void* buf = nullptr; size_t len = 0; if (readCb_->isBufferMovable()) { @@ -308,7 +388,7 @@ void QuicStreamAsyncTransport::handleRead() { break; } } - auto readData = sock_->read(id_, len); + auto readData = sock_->read(*id_, len); if (readData.hasError()) { ex_ = folly::AsyncSocketException( folly::AsyncSocketException::UNKNOWN, @@ -332,31 +412,37 @@ void QuicStreamAsyncTransport::handleRead() { } } } - if (readCb_) { - if (ex_) { - auto cb = readCb_; - readCb_ = nullptr; - cb->readErr(*ex_); - } else if (readEOF_ == EOFState::QUEUED) { - auto cb = readCb_; - readCb_ = nullptr; - cb->readEOF(); - readEOF_ = EOFState::DELIVERED; - } + + // in case readCb_ got reset from read callbacks + if (!readCb_) { + return; } - if (readCb_ && readEOF_ == EOFState::NOT_SEEN && !ex_) { - sock_->setReadCallback(id_, this); - } else { - sock_->setReadCallback(id_, nullptr); + + if (ex_) { + auto cb = readCb_; + readCb_ = nullptr; + cb->readErr(*ex_); + } else if (readEOF_ == EOFState::QUEUED) { + auto cb = readCb_; + readCb_ = nullptr; + cb->readEOF(); + readEOF_ = EOFState::DELIVERED; + } + + if (id_) { + if (!readCb_ || readEOF_ != EOFState::NOT_SEEN) { + sock_->setReadCallback(*id_, nullptr); + } } } void QuicStreamAsyncTransport::send(uint64_t maxToSend) { + CHECK(id_); // overkill until there are delivery cbs folly::DelayedDestruction::DestructorGuard dg(this); uint64_t toSend = std::min(maxToSend, folly::to(writeBuf_.chainLength())); - auto streamWriteOffset = sock_->getStreamWriteOffset(id_); + auto streamWriteOffset = sock_->getStreamWriteOffset(*id_); if (streamWriteOffset.hasError()) { // handle error folly::AsyncSocketException ex( @@ -370,7 +456,7 @@ void QuicStreamAsyncTransport::send(uint64_t maxToSend) { uint64_t sentOffset = *streamWriteOffset + toSend; bool writeEOF = (writeEOF_ == EOFState::QUEUED); auto res = sock_->writeChain( - id_, + *id_, writeBuf_.split(toSend), writeEOF, false, @@ -380,15 +466,16 @@ void QuicStreamAsyncTransport::send(uint64_t maxToSend) { folly::AsyncSocketException::UNKNOWN, folly::to("Quic write error: ", toString(res.error()))); failWrites(ex); - } else { - if (writeEOF) { - writeEOF_ = EOFState::DELIVERED; - VLOG(4) << "Closed stream id_=" << id_; - } - // not actually sent. Mirrors AsyncSocket and invokes when data is in - // transport buffers - invokeWriteCallbacks(sentOffset); + return; } + if (writeEOF) { + writeEOF_ = EOFState::DELIVERED; + } else if (writeBuf_.chainLength()) { + sock_->notifyPendingWriteOnStream(*id_, this); + } + // not actually sent. Mirrors AsyncSocket and invokes when data is in + // transport buffers + invokeWriteCallbacks(sentOffset); } void QuicStreamAsyncTransport::invokeWriteCallbacks(size_t sentOffset) { @@ -398,9 +485,13 @@ void QuicStreamAsyncTransport::invokeWriteCallbacks(size_t sentOffset) { writeCallbacks_.pop_front(); wcb->writeSuccess(); } + if (writeEOF_ == EOFState::DELIVERED) { + CHECK(writeCallbacks_.empty()); + } } -void QuicStreamAsyncTransport::failWrites(folly::AsyncSocketException& ex) { +void QuicStreamAsyncTransport::failWrites( + const folly::AsyncSocketException& ex) { while (!writeCallbacks_.empty()) { auto& front = writeCallbacks_.front(); auto wcb = front.second; @@ -411,8 +502,9 @@ void QuicStreamAsyncTransport::failWrites(folly::AsyncSocketException& ex) { } void QuicStreamAsyncTransport::onStreamWriteReady( - quic::StreamId /*id*/, + quic::StreamId id, uint64_t maxToSend) noexcept { + CHECK(id == *id_); if (writeEOF_ == EOFState::DELIVERED && writeBuf_.empty()) { // nothing left to write return; @@ -424,10 +516,26 @@ void QuicStreamAsyncTransport::onStreamWriteError( StreamId /*id*/, std::pair> error) noexcept { - folly::AsyncSocketException ex( + closeNowImpl(folly::AsyncSocketException( folly::AsyncSocketException::UNKNOWN, - folly::to("Quic write error: ", toString(error))); - failWrites(ex); + folly::to("Quic write error: ", toString(error)))); +} + +void QuicStreamAsyncTransport::closeNowImpl(folly::AsyncSocketException&& ex) { + folly::DelayedDestruction::DestructorGuard dg(this); + if (state_ == CloseState::CLOSED) { + return; + } + state_ = CloseState::CLOSED; + ex_ = ex; + readCb_ = nullptr; + if (id_) { + sock_->setReadCallback(*id_, nullptr); + sock_->unregisterStreamWriteCallback(*id_); + id_.reset(); + } + sock_->closeNow(folly::none); + failWrites(*ex_); } } // namespace quic diff --git a/quic/api/QuicStreamAsyncTransport.h b/quic/api/QuicStreamAsyncTransport.h index 4dd4377b8..c6c0e9ca9 100644 --- a/quic/api/QuicStreamAsyncTransport.h +++ b/quic/api/QuicStreamAsyncTransport.h @@ -8,7 +8,6 @@ #pragma once -// #include #include #include @@ -34,26 +33,27 @@ class QuicStreamAsyncTransport : public folly::AsyncTransport, quic::StreamId streamId); protected: - QuicStreamAsyncTransport( - std::shared_ptr sock, - quic::StreamId id); + QuicStreamAsyncTransport() = default; + ~QuicStreamAsyncTransport() override = default; + + void setSocket(std::shared_ptr sock); + + // While stream id is not set, all writes are buffered. + void setStreamId(StreamId id); public: - ~QuicStreamAsyncTransport() override; + // + // folly::DelayedDestruction + // + void destroy() override; + // + // folly::AsyncTransport overrides + // void setReadCB(AsyncTransport::ReadCallback* callback) override; AsyncTransport::ReadCallback* getReadCallback() const override; - void addWriteCallback( - AsyncTransport::WriteCallback* callback, - size_t offset, - size_t size); - - void handleOffsetError( - AsyncTransport::WriteCallback* callback, - LocalErrorCode error); - void write( AsyncTransport::WriteCallback* callback, const void* buf, @@ -123,34 +123,49 @@ class QuicStreamAsyncTransport : public folly::AsyncTransport, std::string getSecurityProtocol() const override; - private: + protected: + // + // QucSocket::ReadCallback overrides + // void readAvailable(quic::StreamId /*streamId*/) noexcept override; - void readError( quic::StreamId /*streamId*/, std::pair> error) noexcept override; - void runLoopCallback() noexcept override; - - void handleRead(); - void send(uint64_t maxToSend); - - void invokeWriteCallbacks(size_t sentOffset); - - void failWrites(folly::AsyncSocketException& ex); - + // + // QucSocket::WriteCallback overrides + // void onStreamWriteReady( quic::StreamId /*id*/, uint64_t maxToSend) noexcept override; - void onStreamWriteError( StreamId /*id*/, std::pair> error) noexcept override; + // + // folly::EventBase::LoopCallback overrides + // + void runLoopCallback() noexcept override; + + // Utils + void addWriteCallback(AsyncTransport::WriteCallback* callback, size_t offset); + void handleWriteOffsetError( + AsyncTransport::WriteCallback* callback, + LocalErrorCode error); + bool handleWriteStateError(AsyncTransport::WriteCallback* callback); + void handleRead(); + void send(uint64_t maxToSend); + folly::Expected getStreamWriteOffset() const; + void invokeWriteCallbacks(size_t sentOffset); + void failWrites(const folly::AsyncSocketException& ex); + void closeNowImpl(folly::AsyncSocketException&& ex); + + enum class CloseState { OPEN, CLOSING, CLOSED }; + CloseState state_{CloseState::OPEN}; std::shared_ptr sock_; - quic::StreamId id_; + folly::Optional id_; enum class EOFState { NOT_SEEN, QUEUED, DELIVERED }; EOFState readEOF_{EOFState::NOT_SEEN}; EOFState writeEOF_{EOFState::NOT_SEEN}; diff --git a/quic/api/test/QuicStreamAsyncTransportTest.cpp b/quic/api/test/QuicStreamAsyncTransportTest.cpp index 186de7a67..764e7f7c0 100644 --- a/quic/api/test/QuicStreamAsyncTransportTest.cpp +++ b/quic/api/test/QuicStreamAsyncTransportTest.cpp @@ -135,6 +135,10 @@ class QuicStreamAsyncTransportTest : public Test { } void TearDown() override { + if (serverAsyncWrapper_) { + serverAsyncWrapper_->getEventBase()->runInEventBaseThreadAndWait( + [&]() { serverAsyncWrapper_.reset(); }); + } server_->shutdown(); server_ = nullptr; clientEvb_.runInEventBaseThreadAndWait([&] { diff --git a/quic/client/CMakeLists.txt b/quic/client/CMakeLists.txt index a7c98ac97..1bbc058ba 100644 --- a/quic/client/CMakeLists.txt +++ b/quic/client/CMakeLists.txt @@ -6,6 +6,7 @@ add_library( mvfst_client STATIC QuicClientTransport.cpp + QuicClientAsyncTransport.cpp handshake/ClientHandshake.cpp state/ClientStateMachine.cpp ) diff --git a/quic/client/QuicClientAsyncTransport.cpp b/quic/client/QuicClientAsyncTransport.cpp new file mode 100644 index 000000000..369da4ecb --- /dev/null +++ b/quic/client/QuicClientAsyncTransport.cpp @@ -0,0 +1,60 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + * + */ + +#include +#include +#include + +namespace quic { + +QuicClientAsyncTransport::QuicClientAsyncTransport( + const std::shared_ptr& clientSock) { + setSocket(clientSock); + clientSock->start(this); +} + +void QuicClientAsyncTransport::onNewBidirectionalStream( + StreamId /*id*/) noexcept { + CHECK(false); +} +void QuicClientAsyncTransport::onNewUnidirectionalStream( + StreamId /*id*/) noexcept { + CHECK(false); +} + +void QuicClientAsyncTransport::onStopSending( + StreamId /*id*/, + ApplicationErrorCode /*error*/) noexcept {} + +void QuicClientAsyncTransport::onConnectionEnd() noexcept { + folly::AsyncSocketException ex( + folly::AsyncSocketException::UNKNOWN, "Quic connection ended"); + // TODO: closeNow inside this callback may actually trigger gracefull close + closeNowImpl(std::move(ex)); +} + +void QuicClientAsyncTransport::onConnectionError( + std::pair code) noexcept { + folly::AsyncSocketException ex( + folly::AsyncSocketException::UNKNOWN, + folly::to("Quic connection error", code.second)); + // TODO: closeNow inside this callback may actually trigger gracefull close + closeNowImpl(std::move(ex)); +} + +void QuicClientAsyncTransport::onTransportReady() noexcept { + auto streamId = sock_->createBidirectionalStream(); + if (!streamId) { + folly::AsyncSocketException ex( + folly::AsyncSocketException::UNKNOWN, "Quic failed to create stream"); + closeNowImpl(std::move(ex)); + } + setStreamId(streamId.value()); +} + +} // namespace quic diff --git a/quic/client/QuicClientAsyncTransport.h b/quic/client/QuicClientAsyncTransport.h new file mode 100644 index 000000000..62fc3d0d0 --- /dev/null +++ b/quic/client/QuicClientAsyncTransport.h @@ -0,0 +1,43 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + * + */ + +#pragma once + +#include +#include + +namespace quic { + +/** + * Adaptor from QuicClientTransport to folly::AsyncTransport, + * for experiments with QUIC in code using folly::AsyncSockets. + */ +class QuicClientAsyncTransport : public QuicStreamAsyncTransport, + public QuicSocket::ConnectionCallback { + public: + using UniquePtr = std::unique_ptr< + QuicClientAsyncTransport, + folly::DelayedDestruction::Destructor>; + explicit QuicClientAsyncTransport( + const std::shared_ptr& clientSock); + + protected: + ~QuicClientAsyncTransport() override = default; + + // + // QuicSocket::ConnectionCallback + // + void onNewBidirectionalStream(StreamId id) noexcept override; + void onNewUnidirectionalStream(StreamId id) noexcept override; + void onStopSending(StreamId id, ApplicationErrorCode error) noexcept override; + void onConnectionEnd() noexcept override; + void onConnectionError( + std::pair code) noexcept override; + void onTransportReady() noexcept override; +}; +} // namespace quic diff --git a/quic/server/CMakeLists.txt b/quic/server/CMakeLists.txt index 94bbd4d6a..9ca5c3c29 100644 --- a/quic/server/CMakeLists.txt +++ b/quic/server/CMakeLists.txt @@ -86,6 +86,7 @@ file( *.h ) list(FILTER QUIC_API_HEADERS_TOINSTALL EXCLUDE REGEX test/) +list(FILTER QUIC_API_HEADERS_TOINSTALL EXCLUDE REGEX async_tran/) foreach(header ${QUIC_API_HEADERS_TOINSTALL}) get_filename_component(header_dir ${header} DIRECTORY) install(FILES ${header} DESTINATION include/quic/server/${header_dir}) diff --git a/quic/server/async_tran/QuicAsyncTransportAcceptor.cpp b/quic/server/async_tran/QuicAsyncTransportAcceptor.cpp new file mode 100644 index 000000000..4295ee176 --- /dev/null +++ b/quic/server/async_tran/QuicAsyncTransportAcceptor.cpp @@ -0,0 +1,41 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + * + */ + +#include +#include +#include + +namespace quic { + +QuicAsyncTransportAcceptor::QuicAsyncTransportAcceptor( + folly::EventBase* evb, + ManagedConnectionFactory connectionFactory) + : wangle::Acceptor(wangle::ServerSocketConfig()), + connectionFactory_(std::move(connectionFactory)), + evb_(evb) { + Acceptor::initDownstreamConnectionManager(evb_); +} + +quic::QuicServerTransport::Ptr QuicAsyncTransportAcceptor::make( + folly::EventBase* evb, + std::unique_ptr sock, + const folly::SocketAddress&, + std::shared_ptr ctx) noexcept { + CHECK_EQ(evb, evb_); + quic::QuicServerAsyncTransport::UniquePtr asyncWrapper( + new quic::QuicServerAsyncTransport()); + auto transport = + quic::QuicServerTransport::make(evb, std::move(sock), *asyncWrapper, ctx); + asyncWrapper->setServerSocket(transport); + wangle::ManagedConnection* managedConnection = + connectionFactory_(std::move(asyncWrapper)); + Acceptor::addConnection(managedConnection); + return transport; +} + +} // namespace quic diff --git a/quic/server/async_tran/QuicAsyncTransportAcceptor.h b/quic/server/async_tran/QuicAsyncTransportAcceptor.h new file mode 100644 index 000000000..b446d98ba --- /dev/null +++ b/quic/server/async_tran/QuicAsyncTransportAcceptor.h @@ -0,0 +1,41 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + * + */ + +#pragma once + +#include +#include +#include + +namespace quic { + +class QuicAsyncTransportAcceptor : public wangle::Acceptor, + public quic::QuicServerTransportFactory { + public: + using ManagedConnectionFactory = folly::Function; + + QuicAsyncTransportAcceptor( + folly::EventBase* evb, + ManagedConnectionFactory connectionFactory); + ~QuicAsyncTransportAcceptor() override = default; + + // quic::QuicServerTransportFactory + quic::QuicServerTransport::Ptr make( + folly::EventBase* evb, + std::unique_ptr sock, + const folly::SocketAddress&, + std::shared_ptr + ctx) noexcept override; + + private: + ManagedConnectionFactory connectionFactory_; + folly::EventBase* evb_; +}; + +} // namespace quic diff --git a/quic/server/async_tran/QuicAsyncTransportServer.cpp b/quic/server/async_tran/QuicAsyncTransportServer.cpp new file mode 100644 index 000000000..ba2c737be --- /dev/null +++ b/quic/server/async_tran/QuicAsyncTransportServer.cpp @@ -0,0 +1,76 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + * + */ + +#include +#include + +namespace quic { + +QuicAsyncTransportServer::QuicAsyncTransportServer( + QuicAsyncTransportAcceptor::ManagedConnectionFactory connectionFactory) + : connectionFactory_(std::move(connectionFactory)), + quicServer_(quic::QuicServer::createQuicServer()) {} + +void QuicAsyncTransportServer::setFizzContext( + std::shared_ptr ctx) { + fizzCtx_ = std::move(ctx); +} + +void QuicAsyncTransportServer::setTransportSettings() { + quic::TransportSettings transportSettings; + uint64_t flowControl = 2024 * 1024 * 1024; + transportSettings.advertisedInitialConnectionWindowSize = flowControl; + transportSettings.advertisedInitialBidiLocalStreamWindowSize = flowControl; + transportSettings.advertisedInitialBidiRemoteStreamWindowSize = flowControl; + transportSettings.advertisedInitialUniStreamWindowSize = flowControl; + quicServer_->setTransportSettings(transportSettings); +} + +void QuicAsyncTransportServer::start( + const folly::SocketAddress& address, + size_t numThreads) { + if (numThreads == 0) { + numThreads = std::thread::hardware_concurrency(); + } + std::vector evbs; + for (size_t i = 0; i < numThreads; ++i) { + auto scopedEvb = std::make_unique(); + workerEvbs_.push_back(std::move(scopedEvb)); + auto workerEvb = workerEvbs_.back()->getEventBase(); + evbs.push_back(workerEvb); + } + setTransportSettings(); + quicServer_->initialize(address, evbs, false /* useDefaultTransport */); + quicServer_->waitUntilInitialized(); + createAcceptors(); + quicServer_->start(); +} + +void QuicAsyncTransportServer::createAcceptors() { + for (auto& worker : workerEvbs_) { + auto evb = worker->getEventBase(); + quicServer_->setFizzContext(evb, fizzCtx_); + auto acceptor = std::make_unique( + evb, [this](folly::AsyncTransport::UniquePtr tran) { + return connectionFactory_(std::move(tran)); + }); + quicServer_->addTransportFactory(evb, acceptor.get()); + acceptors_.push_back(std::move(acceptor)); + } +} + +void QuicAsyncTransportServer::shutdown() { + quicServer_->rejectNewConnections(true); + for (size_t i = 0; i < workerEvbs_.size(); i++) { + workerEvbs_[i]->getEventBase()->runInEventBaseThreadAndWait( + [&] { acceptors_[i]->dropAllConnections(); }); + } + quicServer_->shutdown(); + quicServer_.reset(); +} +} // namespace quic diff --git a/quic/server/async_tran/QuicAsyncTransportServer.h b/quic/server/async_tran/QuicAsyncTransportServer.h new file mode 100644 index 000000000..3b9730fd6 --- /dev/null +++ b/quic/server/async_tran/QuicAsyncTransportServer.h @@ -0,0 +1,50 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + * + */ + +#pragma once + +#include +#include +#include + +namespace quic { + +/** + * QUIC server with single stream connections wrapped into folly:AsyncTransport + * adaptor. For experiments with QUIC in existing code using + * folly::AsyncServerSocket and wangle::Acceptor. + */ +class QuicAsyncTransportServer { + public: + explicit QuicAsyncTransportServer( + QuicAsyncTransportAcceptor::ManagedConnectionFactory connectionFactory); + virtual ~QuicAsyncTransportServer() = default; + + void setFizzContext( + std::shared_ptr ctx); + + void start(const folly::SocketAddress& address, size_t numThreads = 0); + + quic::QuicServer& quicServer() { + return *quicServer_; + } + + void shutdown(); + + protected: + void setTransportSettings(); + void createAcceptors(); + + QuicAsyncTransportAcceptor::ManagedConnectionFactory connectionFactory_; + std::shared_ptr quicServer_; + std::vector> acceptors_; + std::vector> workerEvbs_; + std::shared_ptr fizzCtx_; +}; + +} // namespace quic diff --git a/quic/server/async_tran/QuicServerAsyncTransport.cpp b/quic/server/async_tran/QuicServerAsyncTransport.cpp new file mode 100644 index 000000000..4317cc836 --- /dev/null +++ b/quic/server/async_tran/QuicServerAsyncTransport.cpp @@ -0,0 +1,50 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + * + */ + +#include +#include + +namespace quic { + +void QuicServerAsyncTransport::setServerSocket( + std::shared_ptr sock) { + setSocket(std::move(sock)); +} + +void QuicServerAsyncTransport::onNewBidirectionalStream(StreamId id) noexcept { + if (id != 0) { + CHECK(false) << "Only single stream 0 is supported"; + } + setStreamId(id); +} +void QuicServerAsyncTransport::onNewUnidirectionalStream( + StreamId /*id*/) noexcept { + CHECK(false) << "Unidirectional stream not supported"; +} + +void QuicServerAsyncTransport::onStopSending( + StreamId /*id*/, + ApplicationErrorCode /*error*/) noexcept {} + +void QuicServerAsyncTransport::onConnectionEnd() noexcept { + folly::AsyncSocketException ex( + folly::AsyncSocketException::UNKNOWN, "Quic connection ended"); + closeNowImpl(std::move(ex)); +} + +void QuicServerAsyncTransport::onConnectionError( + std::pair code) noexcept { + folly::AsyncSocketException ex( + folly::AsyncSocketException::UNKNOWN, + folly::to("Quic connection error", code.second)); + closeNowImpl(std::move(ex)); +} + +void QuicServerAsyncTransport::onTransportReady() noexcept {} + +} // namespace quic diff --git a/quic/server/async_tran/QuicServerAsyncTransport.h b/quic/server/async_tran/QuicServerAsyncTransport.h new file mode 100644 index 000000000..b99e7941a --- /dev/null +++ b/quic/server/async_tran/QuicServerAsyncTransport.h @@ -0,0 +1,38 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + * + */ + +#pragma once + +#include + +namespace quic { + +class QuicServerAsyncTransport : public QuicStreamAsyncTransport, + public QuicSocket::ConnectionCallback { + public: + using UniquePtr = std::unique_ptr< + QuicServerAsyncTransport, + folly::DelayedDestruction::Destructor>; + QuicServerAsyncTransport() = default; + void setServerSocket(std::shared_ptr sock); + + protected: + ~QuicServerAsyncTransport() override = default; + + // + // QuicSocket::ConnectionCallback + // + void onNewBidirectionalStream(StreamId id) noexcept override; + void onNewUnidirectionalStream(StreamId id) noexcept override; + void onStopSending(StreamId id, ApplicationErrorCode error) noexcept override; + void onConnectionEnd() noexcept override; + void onConnectionError( + std::pair code) noexcept override; + void onTransportReady() noexcept override; +}; +} // namespace quic diff --git a/quic/server/async_tran/test/QuicAsyncTransportServerTest.cpp b/quic/server/async_tran/test/QuicAsyncTransportServerTest.cpp new file mode 100644 index 000000000..2b3708281 --- /dev/null +++ b/quic/server/async_tran/test/QuicAsyncTransportServerTest.cpp @@ -0,0 +1,164 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + * + */ + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "folly/io/async/AsyncTransport.h" + +using namespace testing; + +namespace quic::test { + +class MockConnection : public wangle::ManagedConnection { + public: + explicit MockConnection(folly::AsyncTransport::UniquePtr sock) + : sock_(std::move(sock)) {} + void timeoutExpired() noexcept final {} + void describe(std::ostream&) const final {} + bool isBusy() const final { + return true; + } + void notifyPendingShutdown() final {} + void closeWhenIdle() final {} + void dropConnection(const std::string& /*errorMsg*/ = "") final { + destroy(); + } + void dumpConnectionState(uint8_t) final {} + + private: + folly::AsyncTransport::UniquePtr sock_; +}; + +class QuicAsyncTransportServerTest : public Test { + public: + void SetUp() override { + folly::ssl::init(); + createServer(); + createClient(); + } + + void createServer() { + EXPECT_CALL(serverReadCB_, isBufferMovable_()) + .WillRepeatedly(Return(false)); + EXPECT_CALL(serverReadCB_, getReadBuffer(_, _)) + .WillRepeatedly(Invoke([&](void** buf, size_t* len) { + *buf = serverBuf_.data(); + *len = serverBuf_.size(); + })); + EXPECT_CALL(serverReadCB_, readDataAvailable_(_)) + .WillOnce(Invoke([&](auto len) { + auto echoData = folly::IOBuf::wrapBuffer(serverBuf_.data(), len); + echoData->appendChain(folly::IOBuf::copyBuffer(" ding dong")); + serverAsyncWrapper_->writeChain(&serverWriteCB_, std::move(echoData)); + serverAsyncWrapper_->shutdownWrite(); + })); + EXPECT_CALL(serverReadCB_, readEOF_()).WillOnce(Return()); + EXPECT_CALL(serverWriteCB_, writeSuccess_()).WillOnce(Return()); + + server_ = std::make_shared([this](auto sock) { + sock->setReadCB(&serverReadCB_); + serverAsyncWrapper_ = std::move(sock); + return new MockConnection(nullptr); + }); + server_->setFizzContext(test::createServerCtx()); + folly::SocketAddress addr("::1", 0); + server_->start(addr, 1); + serverAddr_ = server_->quicServer().getAddress(); + } + + void createClient() { + clientEvbThread_ = std::thread([&]() { clientEvb_.loopForever(); }); + + EXPECT_CALL(clientReadCB_, isBufferMovable_()) + .WillRepeatedly(Return(false)); + EXPECT_CALL(clientReadCB_, getReadBuffer(_, _)) + .WillRepeatedly(Invoke([&](void** buf, size_t* len) { + *buf = clientBuf_.data(); + *len = clientBuf_.size(); + })); + EXPECT_CALL(clientReadCB_, readDataAvailable_(_)) + .WillOnce(Invoke([&](auto len) { + clientReadPromise_.setValue( + std::string(reinterpret_cast(clientBuf_.data()), len)); + })); + EXPECT_CALL(clientReadCB_, readEOF_()).WillOnce(Return()); + EXPECT_CALL(clientWriteCB_, writeSuccess_()).WillOnce(Return()); + + clientEvb_.runInEventBaseThreadAndWait([&]() { + auto sock = std::make_unique(&clientEvb_); + auto fizzClientContext = + FizzClientQuicHandshakeContext::Builder() + .setCertificateVerifier(test::createTestCertificateVerifier()) + .build(); + client_ = std::make_shared( + &clientEvb_, std::move(sock), std::move(fizzClientContext)); + client_->setHostname("echo.com"); + client_->addNewPeerAddress(serverAddr_); + clientAsyncWrapper_.reset(new QuicClientAsyncTransport(client_)); + clientAsyncWrapper_->setReadCB(&clientReadCB_); + }); + } + + void TearDown() override { + server_->shutdown(); + server_ = nullptr; + clientEvb_.runInEventBaseThreadAndWait([&] { + clientAsyncWrapper_ = nullptr; + client_ = nullptr; + }); + clientEvb_.terminateLoopSoon(); + clientEvbThread_.join(); + } + + protected: + std::shared_ptr server_; + folly::SocketAddress serverAddr_; + folly::AsyncTransport::UniquePtr serverAsyncWrapper_; + folly::test::MockWriteCallback serverWriteCB_; + folly::test::MockReadCallback serverReadCB_; + std::array serverBuf_; + + std::shared_ptr client_; + folly::EventBase clientEvb_; + std::thread clientEvbThread_; + QuicClientAsyncTransport::UniquePtr clientAsyncWrapper_; + folly::test::MockWriteCallback clientWriteCB_; + folly::test::MockReadCallback clientReadCB_; + std::array clientBuf_; + folly::Promise clientReadPromise_; +}; + +TEST_F(QuicAsyncTransportServerTest, ReadWrite) { + auto [promise, future] = folly::makePromiseContract(); + clientReadPromise_ = std::move(promise); + + std::string msg = "jaja"; + clientEvb_.runInEventBaseThreadAndWait([&] { + clientAsyncWrapper_->write(&clientWriteCB_, msg.data(), msg.size()); + clientAsyncWrapper_->shutdownWrite(); + }); + + std::string clientReadString = std::move(future).get(1s); + EXPECT_EQ(clientReadString, "jaja ding dong"); +} + +} // namespace quic::test