diff --git a/quic/api/QuicTransportBase.h b/quic/api/QuicTransportBase.h index 97c3784f7..c259577f5 100644 --- a/quic/api/QuicTransportBase.h +++ b/quic/api/QuicTransportBase.h @@ -8,9 +8,6 @@ #pragma once -#include -#include -#include #include #include #include @@ -22,6 +19,10 @@ #include #include +#include +#include +#include + namespace quic { enum class CloseState { OPEN, GRACEFUL_CLOSING, CLOSED }; @@ -561,7 +562,9 @@ class QuicTransportBase : public QuicSocket { std::unique_ptr socket_; ConnectionCallback* connCallback_{nullptr}; - std::unique_ptr conn_; + std:: + unique_ptr + conn_; struct ReadCallbackData { ReadCallback* readCb; diff --git a/quic/api/test/QuicTransportBaseTest.cpp b/quic/api/test/QuicTransportBaseTest.cpp index 14e4a8315..89f631af3 100644 --- a/quic/api/test/QuicTransportBaseTest.cpp +++ b/quic/api/test/QuicTransportBaseTest.cpp @@ -136,7 +136,7 @@ class TestQuicTransport conn->clientConnectionId = ConnectionId({10, 9, 8, 7}); conn->version = QuicVersion::MVFST; transportConn = conn.get(); - conn_ = std::move(conn); + conn_.reset(conn.release()); aead = test::createNoOpAead(); headerCipher = test::createNoOpHeaderCipher(); connIdAlgo_ = std::make_unique(); diff --git a/quic/api/test/QuicTransportTest.cpp b/quic/api/test/QuicTransportTest.cpp index 63a6b337c..9b2a213c7 100644 --- a/quic/api/test/QuicTransportTest.cpp +++ b/quic/api/test/QuicTransportTest.cpp @@ -42,7 +42,7 @@ class TestQuicTransport ConnectionCallback& cb) : QuicTransportBase(evb, std::move(socket)) { setConnectionCallback(&cb); - conn_ = std::make_unique(); + conn_.reset(new QuicServerConnectionState()); conn_->clientConnectionId = ConnectionId({9, 8, 7, 6}); conn_->serverConnectionId = ConnectionId({1, 2, 3, 4}); conn_->version = QuicVersion::MVFST; diff --git a/quic/client/QuicClientTransport.cpp b/quic/client/QuicClientTransport.cpp index 6fcacc526..8cfa699f8 100644 --- a/quic/client/QuicClientTransport.cpp +++ b/quic/client/QuicClientTransport.cpp @@ -54,7 +54,7 @@ QuicClientTransport::QuicClientTransport( auto tempConn = std::make_unique(std::move(handshakeFactory)); clientConn_ = tempConn.get(); - conn_ = std::move(tempConn); + conn_.reset(tempConn.release()); std::vector connIdData( std::max(kMinInitialDestinationConnIdLength, connectionIdSize)); folly::Random::secureRandom(connIdData.data(), connIdData.size()); @@ -208,7 +208,7 @@ void QuicClientTransport::processPacketData( auto tempConn = undoAllClientStateForRetry(std::move(uniqueClient)); clientConn_ = tempConn.get(); - conn_ = std::move(tempConn); + conn_.reset(tempConn.release()); clientConn_->retryToken = longHeader->getToken(); diff --git a/quic/client/state/ClientStateMachine.h b/quic/client/state/ClientStateMachine.h index eb03a6552..e66f9e370 100644 --- a/quic/client/state/ClientStateMachine.h +++ b/quic/client/state/ClientStateMachine.h @@ -60,7 +60,7 @@ struct QuicClientConnectionState : public QuicConnectionStateBase { auto tmpClientHandshake = handshakeFactory->makeClientHandshake(*cryptoState); clientHandshakeLayer = tmpClientHandshake.get(); - handshakeLayer.reset(tmpClientHandshake.release()); + handshakeLayer = std::move(tmpClientHandshake); // We shouldn't normally need to set this until we're starting the // transport, however writing unit tests is much easier if we set this here. updateFlowControlStateWithSettings(flowControlState, transportSettings); diff --git a/quic/client/test/QuicClientTransportTest.cpp b/quic/client/test/QuicClientTransportTest.cpp index 654be4ea1..2c4dea106 100644 --- a/quic/client/test/QuicClientTransportTest.cpp +++ b/quic/client/test/QuicClientTransportTest.cpp @@ -1354,8 +1354,6 @@ class QuicClientTransportTest : public Test { new FakeOneRttHandshakeLayer(*client->getNonConstConn().cryptoState); client->getNonConstConn().clientHandshakeLayer = mockClientHandshake; client->getNonConstConn().handshakeLayer.reset(mockClientHandshake); - handshakeDG = std::make_unique( - mockClientHandshake); setFakeHandshakeCiphers(); // Allow ignoring path mtu for testing negotiation. client->getNonConstConn().transportSettings.canIgnorePathMTU = true; @@ -1690,7 +1688,6 @@ class QuicClientTransportTest : public Test { std::unique_ptr eventbase_; SocketAddress serverAddr{"127.0.0.1", 443}; AsyncUDPSocket::ReadCallback* networkReadCallback{nullptr}; - std::unique_ptr handshakeDG; FakeOneRttHandshakeLayer* mockClientHandshake; std::shared_ptr client; PacketNum initialPacketNum{0}, handshakePacketNum{0}, appDataPacketNum{0}; diff --git a/quic/handshake/HandshakeLayer.h b/quic/handshake/HandshakeLayer.h index 25fbf8272..c992325cd 100644 --- a/quic/handshake/HandshakeLayer.h +++ b/quic/handshake/HandshakeLayer.h @@ -8,8 +8,6 @@ #pragma once -#include - #include #include #include @@ -20,13 +18,12 @@ constexpr folly::StringPiece kQuicKeyLabel = "quic key"; constexpr folly::StringPiece kQuicIVLabel = "quic iv"; constexpr folly::StringPiece kQuicPNLabel = "quic hp"; -class Handshake : public folly::DelayedDestruction { +class Handshake { public: + virtual ~Handshake() = default; + virtual const folly::Optional& getApplicationProtocol() const = 0; - - protected: - virtual ~Handshake() = default; }; constexpr folly::StringPiece kQuicDraft17Salt = diff --git a/quic/server/QuicServerTransport.cpp b/quic/server/QuicServerTransport.cpp index ff030ec05..4cae6cb96 100644 --- a/quic/server/QuicServerTransport.cpp +++ b/quic/server/QuicServerTransport.cpp @@ -24,7 +24,7 @@ QuicServerTransport::QuicServerTransport( auto tempConn = std::make_unique(); tempConn->serverAddr = socket_->address(); serverConn_ = tempConn.get(); - conn_ = std::move(tempConn); + conn_.reset(tempConn.release()); // TODO: generate this when we can encode the packet sequence number // correctly. // conn_->nextSequenceNum = folly::Random::secureRandom(); diff --git a/quic/server/handshake/ServerHandshake.cpp b/quic/server/handshake/ServerHandshake.cpp index bb7db4128..e38be3d33 100644 --- a/quic/server/handshake/ServerHandshake.cpp +++ b/quic/server/handshake/ServerHandshake.cpp @@ -8,14 +8,20 @@ #include -#include #include #include #include +#include + namespace quic { -ServerHandshake::ServerHandshake(QuicCryptoState& cryptoState) - : actionGuard_(nullptr), cryptoState_(cryptoState), visitor_(*this) {} +ServerHandshake::ServerHandshake( + QuicConnectionStateBase* conn, + QuicCryptoState& cryptoState) + : conn_(conn), + actionGuard_(nullptr), + cryptoState_(cryptoState), + visitor_(*this) {} void ServerHandshake::accept( std::shared_ptr transportParams) { @@ -238,7 +244,7 @@ void ServerHandshake::addProcessingActions(fizz::server::AsyncActions actions) { "Processing action while pending", TransportErrorCode::INTERNAL_ERROR)); return; } - actionGuard_ = folly::DelayedDestruction::DestructorGuard(this); + actionGuard_ = folly::DelayedDestruction::DestructorGuard(conn_); startActions(std::move(actions)); } @@ -257,7 +263,7 @@ void ServerHandshake::processActions( fizz::server::ServerStateMachine::CompletedActions actions) { // This extra DestructorGuard is needed due to the gap between clearing // actionGuard_ and potentially processing another action. - folly::DelayedDestruction::DestructorGuard dg(this); + folly::DelayedDestruction::DestructorGuard dg(conn_); for (auto& action : actions) { switch (action.type()) { @@ -307,7 +313,7 @@ void ServerHandshake::processPendingEvents() { return; } - folly::DelayedDestruction::DestructorGuard dg(this); + folly::DelayedDestruction::DestructorGuard dg(conn_); inProcessPendingEvents_ = true; SCOPE_EXIT { inProcessPendingEvents_ = false; @@ -316,7 +322,7 @@ void ServerHandshake::processPendingEvents() { while (!actionGuard_ && !error_) { folly::Optional actions; - actionGuard_ = folly::DelayedDestruction::DestructorGuard(this); + actionGuard_ = folly::DelayedDestruction::DestructorGuard(conn_); if (!waitForData_) { switch (state_.readRecordLayer()->getEncryptionLevel()) { case fizz::EncryptionLevel::Plaintext: diff --git a/quic/server/handshake/ServerHandshake.h b/quic/server/handshake/ServerHandshake.h index 201ee1046..2fb78bdbb 100644 --- a/quic/server/handshake/ServerHandshake.h +++ b/quic/server/handshake/ServerHandshake.h @@ -26,6 +26,8 @@ namespace quic { +// struct QuicConnectionStateBase; + /** * ServerHandshake abstracts details of the TLS 1.3 fizz crypto handshake. The * TLS handshake can be async, so ServerHandshake provides an API to deal with @@ -69,7 +71,9 @@ class ServerHandshake : public Handshake { */ enum class Phase { Handshake, KeysDerived, Established }; - explicit ServerHandshake(QuicCryptoState& cryptoState); + explicit ServerHandshake( + QuicConnectionStateBase* conn, + QuicCryptoState& cryptoState); /** * Starts accepting the TLS connection. @@ -247,6 +251,7 @@ class ServerHandshake : public Handshake { fizz::server::State state_; fizz::server::ServerStateMachine machine_; + QuicConnectionStateBase* conn_; folly::DelayedDestruction::DestructorGuard actionGuard_; folly::Executor* executor_; std::shared_ptr context_; diff --git a/quic/server/handshake/test/ServerHandshakeTest.cpp b/quic/server/handshake/test/ServerHandshakeTest.cpp index 26e24c95b..305c865bc 100644 --- a/quic/server/handshake/test/ServerHandshakeTest.cpp +++ b/quic/server/handshake/test/ServerHandshakeTest.cpp @@ -48,10 +48,8 @@ class MockServerHandshakeCallback : public ServerHandshake::HandshakeCallback { GMOCK_METHOD0_(, noexcept, , onCryptoEventAvailable, void()); }; -class TestingServerHandshake : public ServerHandshake { - public: - explicit TestingServerHandshake(QuicCryptoState& cryptoState) - : ServerHandshake(cryptoState) {} +struct TestingServerConnectionState : public QuicServerConnectionState { + explicit TestingServerConnectionState() : QuicServerConnectionState() {} uint32_t getDestructorGuardCount() const { return folly::DelayedDestruction::getDestructorGuardCount(); @@ -74,14 +72,15 @@ class ServerHandshakeTest : public Test { void SetUp() override { folly::ssl::init(); - cryptoState = std::make_unique(); + conn.reset(new TestingServerConnectionState()); + cryptoState = conn->cryptoState.get(); clientCtx = std::make_shared(); clientCtx->setOmitEarlyRecordLayer(true); clientCtx->setFactory(std::make_shared()); clientCtx->setClock(std::make_shared()); serverCtx = quic::test::createServerCtx(); setupClientAndServerContext(); - handshake.reset(new TestingServerHandshake(*cryptoState)); + handshake = conn->serverHandshakeLayer; hostname = kTestHostname.str(); verifier = std::make_shared(); @@ -320,8 +319,12 @@ class ServerHandshakeTest : public Test { std::unique_ptr dg; folly::EventBase evb; - std::unique_ptr handshake; - std::unique_ptr cryptoState; + std::unique_ptr< + TestingServerConnectionState, + folly::DelayedDestruction::Destructor> + conn{nullptr}; + ServerHandshake* handshake; + QuicCryptoState* cryptoState; fizz::client::State clientState; std::unique_ptrcancel(); // Let's destroy the crypto state to make sure it is not referenced. - cryptoState.reset(); + conn->cryptoState.reset(); promise.setValue(); evb.loop(); @@ -641,12 +644,12 @@ TEST_F(ServerHandshakeAsyncTest, TestAsyncCancel) { handshake->cancel(); // Let's destroy the crypto state to make sure it is not referenced. - cryptoState.reset(); + conn->cryptoState.reset(); promise.setValue(); evb.loop(); - EXPECT_EQ(handshake->getDestructorGuardCount(), 0); + EXPECT_EQ(conn->getDestructorGuardCount(), 0); } class ServerHandshakeAsyncErrorTest : public ServerHandshakePskTest { @@ -685,7 +688,7 @@ TEST_F(ServerHandshakeAsyncErrorTest, TestCancelOnAsyncError) { .WillRepeatedly(Invoke([&] { handshake->cancel(); // Let's destroy the crypto state to make sure it is not referenced. - cryptoState.reset(); + conn->cryptoState.reset(); })); promise.setValue(); evb.loop(); @@ -696,7 +699,7 @@ TEST_F(ServerHandshakeAsyncErrorTest, TestCancelWhileWaitingAsyncError) { clientServerRound(); handshake->cancel(); // Let's destroy the crypto state to make sure it is not referenced. - cryptoState.reset(); + conn->cryptoState.reset(); promise.setValue(); evb.loop(); diff --git a/quic/server/state/ServerStateMachine.h b/quic/server/state/ServerStateMachine.h index 05ed48621..13c33c23d 100644 --- a/quic/server/state/ServerStateMachine.h +++ b/quic/server/state/ServerStateMachine.h @@ -133,7 +133,7 @@ struct QuicServerConnectionState : public QuicConnectionStateBase { QuicVersion::QUIC_DRAFT, QuicVersion::QUIC_DRAFT_23}}; originalVersion = QuicVersion::MVFST; - serverHandshakeLayer = new ServerHandshake(*cryptoState); + serverHandshakeLayer = new ServerHandshake(this, *cryptoState); handshakeLayer.reset(serverHandshakeLayer); // We shouldn't normally need to set this until we're starting the // transport, however writing unit tests is much easier if we set this here. diff --git a/quic/server/test/QuicServerTransportTest.cpp b/quic/server/test/QuicServerTransportTest.cpp index fb256069f..53acc6213 100644 --- a/quic/server/test/QuicServerTransportTest.cpp +++ b/quic/server/test/QuicServerTransportTest.cpp @@ -44,7 +44,7 @@ class FakeServerHandshake : public ServerHandshake { bool chloSync = false, bool cfinSync = false, folly::Optional clientActiveConnectionIdLimit = folly::none) - : ServerHandshake(*conn.cryptoState), + : ServerHandshake(&conn, *conn.cryptoState), conn_(conn), chloSync_(chloSync), cfinSync_(cfinSync), diff --git a/quic/state/StateData.h b/quic/state/StateData.h index 7eb435e04..001a998fa 100644 --- a/quic/state/StateData.h +++ b/quic/state/StateData.h @@ -8,10 +8,6 @@ #pragma once -#include -#include -#include -#include #include #include #include @@ -25,6 +21,13 @@ #include #include #include + +#include +#include +#include +#include +#include + #include #include #include @@ -489,7 +492,7 @@ class CongestionControllerFactory; class LoopDetectorCallback; class PendingPathRateLimiter; -struct QuicConnectionStateBase { +struct QuicConnectionStateBase : public folly::DelayedDestruction { virtual ~QuicConnectionStateBase() = default; explicit QuicConnectionStateBase(QuicNodeType type) : nodeType(type) {} @@ -497,8 +500,7 @@ struct QuicConnectionStateBase { // Type of node owning this connection (client or server). QuicNodeType nodeType; - std::unique_ptr - handshakeLayer; + std::unique_ptr handshakeLayer; // Crypto stream std::unique_ptr cryptoState;