diff --git a/quic/client/QuicClientTransportLite.cpp b/quic/client/QuicClientTransportLite.cpp index 1a8cbcb45..fc4b20b60 100644 --- a/quic/client/QuicClientTransportLite.cpp +++ b/quic/client/QuicClientTransportLite.cpp @@ -1195,6 +1195,10 @@ QuicClientTransportLite::startCryptoHandshake() { conn_->initialHeaderCipher = std::move(clientHeaderCipherResult.value()); customTransportParameters_ = getSupportedExtTransportParams(*conn_); + if (conn_->transportSettings.supportDirectEncap) { + customTransportParameters_.push_back( + encodeEmptyParameter(TransportParameterId::client_direct_encap)); + } auto paramsExtension = std::make_shared( conn_->originalVersion.value(), diff --git a/quic/fizz/client/handshake/test/FizzClientHandshakeTest.cpp b/quic/fizz/client/handshake/test/FizzClientHandshakeTest.cpp index 3f97bd64a..111c5690f 100644 --- a/quic/fizz/client/handshake/test/FizzClientHandshakeTest.cpp +++ b/quic/fizz/client/handshake/test/FizzClientHandshakeTest.cpp @@ -111,7 +111,8 @@ class ClientHandshakeTest : public Test, public boost::static_visitor<> { generateStatelessResetToken(), ConnectionId::createAndMaybeCrash( std::vector{0xff, 0xfe, 0xfd, 0xfc}), - ConnectionId::createZeroLength()); + ConnectionId::createZeroLength(), + *conn); fizzServer.reset( new fizz::server:: FizzServer( diff --git a/quic/handshake/TransportParameters.cpp b/quic/handshake/TransportParameters.cpp index 14ca80cc8..1639cdaa1 100644 --- a/quic/handshake/TransportParameters.cpp +++ b/quic/handshake/TransportParameters.cpp @@ -89,6 +89,12 @@ folly::Expected encodeIntegerParameter( return TransportParameter{id, std::move(data)}; } +TransportParameter encodeIPAddressParameter( + TransportParameterId id, + const folly::IPAddress& addr) { + return {id, BufHelpers::copyBuffer(addr.bytes(), addr.byteCount())}; +} + std::vector getSupportedExtTransportParams( const QuicConnectionStateBase& conn) { using TpId = TransportParameterId; diff --git a/quic/handshake/TransportParameters.h b/quic/handshake/TransportParameters.h index 75abb4c49..b42255113 100644 --- a/quic/handshake/TransportParameters.h +++ b/quic/handshake/TransportParameters.h @@ -148,6 +148,10 @@ folly::Expected encodeIntegerParameter( TransportParameterId id, uint64_t value); +TransportParameter encodeIPAddressParameter( + TransportParameterId id, + const folly::IPAddress& addr); + inline TransportParameter encodeEmptyParameter(TransportParameterId id) { TransportParameter param; param.parameter = id; diff --git a/quic/handshake/test/BUCK b/quic/handshake/test/BUCK index 88b23c9b0..4a0f2d49d 100644 --- a/quic/handshake/test/BUCK +++ b/quic/handshake/test/BUCK @@ -1,4 +1,4 @@ -load("@fbcode//quic:defs.bzl", "mvfst_cpp_library") +load("@fbcode//quic:defs.bzl", "mvfst_cpp_library", "mvfst_cpp_test") oncall("traffic_protocols") @@ -18,3 +18,21 @@ mvfst_cpp_library( "//quic/handshake:handshake", ], ) + +mvfst_cpp_test( + name = "TransportParametersTest", + srcs = [ + "TransportParametersTest.cpp", + ], + supports_static_listing = False, + deps = [ + "//folly:network_address", + "//folly/portability:gmock", + "//folly/portability:gtest", + "//quic/client:state_and_handshake", + "//quic/fizz/client/handshake:fizz_client_handshake", + "//quic/fizz/server/handshake:fizz_server_handshake", + "//quic/handshake:transport_parameters", + "//quic/server/state:server", + ], +) diff --git a/quic/handshake/test/TransportParametersTest.cpp b/quic/handshake/test/TransportParametersTest.cpp new file mode 100644 index 000000000..59cec61ad --- /dev/null +++ b/quic/handshake/test/TransportParametersTest.cpp @@ -0,0 +1,196 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace ::testing; + +namespace quic { +namespace test { + +class TransportParametersTest : public Test { + protected: + // Helper function that simulates client transport parameter generation + // This mirrors what QuicClientTransportLite does - gets base params and adds + // client-specific ones + std::vector getClientTransportParams( + const QuicClientConnectionState& conn) { + auto params = getSupportedExtTransportParams(conn); + if (conn.transportSettings.supportDirectEncap) { + params.push_back( + encodeEmptyParameter(TransportParameterId::client_direct_encap)); + } + return params; + } +}; + +// Test client-side direct encap parameter generation +TEST_F(TransportParametersTest, ClientDirectEncapEnabled) { + QuicClientConnectionState clientConn( + FizzClientQuicHandshakeContext::Builder().build()); + clientConn.transportSettings.supportDirectEncap = true; + + auto customTransportParams = getClientTransportParams(clientConn); + + auto it = findParameter( + customTransportParams, TransportParameterId::client_direct_encap); + EXPECT_TRUE(it != customTransportParams.end()); + EXPECT_TRUE(it->value->empty()); +} + +TEST_F(TransportParametersTest, ClientDirectEncapDisabled) { + QuicClientConnectionState clientConn( + FizzClientQuicHandshakeContext::Builder().build()); + clientConn.transportSettings.supportDirectEncap = false; + + auto customTransportParams = getClientTransportParams(clientConn); + + EXPECT_THAT( + customTransportParams, + Not(Contains(Field( + &TransportParameter::parameter, + Eq(TransportParameterId::client_direct_encap))))); +} + +// Test server-side direct encap parameter generation with IPv4 +TEST_F(TransportParametersTest, ServerDirectEncapIPv4) { + QuicServerConnectionState serverConn( + FizzServerQuicHandshakeContext::Builder().build()); + serverConn.transportSettings.directEncapAddress = + folly::IPAddress("192.168.1.1"); + + // Create client parameters containing client_direct_encap + std::vector clientParams; + clientParams.push_back( + encodeEmptyParameter(TransportParameterId::client_direct_encap)); + + auto customTransportParams = + getClientDependentExtTransportParams(serverConn, clientParams); + + auto it = findParameter( + customTransportParams, TransportParameterId::server_direct_encap); + EXPECT_TRUE(it != customTransportParams.end()); + EXPECT_EQ(it->value->length(), 4); // IPv4 is 4 bytes + + // Verify the IP address bytes + auto expectedAddr = folly::IPAddress("192.168.1.1"); + auto expectedBytes = expectedAddr.bytes(); + auto actualRange = it->value->coalesce(); + EXPECT_EQ(actualRange.size(), 4); + EXPECT_EQ(memcmp(actualRange.data(), expectedBytes, 4), 0); +} + +// Test server-side direct encap parameter generation with IPv6 +TEST_F(TransportParametersTest, ServerDirectEncapIPv6) { + QuicServerConnectionState serverConn( + FizzServerQuicHandshakeContext::Builder().build()); + serverConn.transportSettings.directEncapAddress = + folly::IPAddress("2001:db8::1"); + + // Create client parameters containing client_direct_encap + std::vector clientParams; + clientParams.push_back( + encodeEmptyParameter(TransportParameterId::client_direct_encap)); + + auto customTransportParams = + getClientDependentExtTransportParams(serverConn, clientParams); + + auto it = findParameter( + customTransportParams, TransportParameterId::server_direct_encap); + EXPECT_TRUE(it != customTransportParams.end()); + EXPECT_EQ(it->value->length(), 16); // IPv6 is 16 bytes + + // Verify the IP address bytes + auto expectedAddr = folly::IPAddress("2001:db8::1"); + auto expectedBytes = expectedAddr.bytes(); + auto actualRange = it->value->coalesce(); + EXPECT_EQ(actualRange.size(), 16); + EXPECT_EQ(memcmp(actualRange.data(), expectedBytes, 16), 0); +} + +// Test server doesn't send server_direct_encap when address not configured +TEST_F(TransportParametersTest, ServerDirectEncapNoAddress) { + QuicServerConnectionState serverConn( + FizzServerQuicHandshakeContext::Builder().build()); + // Don't set directEncapAddress + + // Create client parameters containing client_direct_encap + std::vector clientParams; + clientParams.push_back( + encodeEmptyParameter(TransportParameterId::client_direct_encap)); + + auto customTransportParams = + getClientDependentExtTransportParams(serverConn, clientParams); + + EXPECT_THAT( + customTransportParams, + Not(Contains(Field( + &TransportParameter::parameter, + Eq(TransportParameterId::server_direct_encap))))); +} + +// Test server doesn't send server_direct_encap when client doesn't support it +TEST_F(TransportParametersTest, ServerDirectEncapClientNotSupported) { + QuicServerConnectionState serverConn( + FizzServerQuicHandshakeContext::Builder().build()); + serverConn.transportSettings.directEncapAddress = + folly::IPAddress("192.168.1.1"); + + // Create client parameters WITHOUT client_direct_encap + std::vector clientParams; + // Add some other parameter to make sure we're not just testing empty list + auto paramResult = + encodeIntegerParameter(TransportParameterId::idle_timeout, 5000); + ASSERT_FALSE(paramResult.hasError()); + clientParams.push_back(paramResult.value()); + + auto customTransportParams = + getClientDependentExtTransportParams(serverConn, clientParams); + + EXPECT_THAT( + customTransportParams, + Not(Contains(Field( + &TransportParameter::parameter, + Eq(TransportParameterId::server_direct_encap))))); +} + +// Test IP address encoding helper function directly +TEST_F(TransportParametersTest, EncodeIPAddressParameterIPv4) { + folly::IPAddress addr("10.0.0.1"); + auto param = + encodeIPAddressParameter(TransportParameterId::server_direct_encap, addr); + + EXPECT_EQ(param.parameter, TransportParameterId::server_direct_encap); + EXPECT_EQ(param.value->length(), 4); + + auto expectedBytes = addr.bytes(); + auto actualRange = param.value->coalesce(); + EXPECT_EQ(memcmp(actualRange.data(), expectedBytes, 4), 0); +} + +TEST_F(TransportParametersTest, EncodeIPAddressParameterIPv6) { + folly::IPAddress addr("::1"); + auto param = + encodeIPAddressParameter(TransportParameterId::server_direct_encap, addr); + + EXPECT_EQ(param.parameter, TransportParameterId::server_direct_encap); + EXPECT_EQ(param.value->length(), 16); + + auto expectedBytes = addr.bytes(); + auto actualRange = param.value->coalesce(); + EXPECT_EQ(memcmp(actualRange.data(), expectedBytes, 16), 0); +} + +} // namespace test +} // namespace quic diff --git a/quic/server/handshake/BUCK b/quic/server/handshake/BUCK index e2be74e95..83577e262 100644 --- a/quic/server/handshake/BUCK +++ b/quic/server/handshake/BUCK @@ -39,6 +39,7 @@ mvfst_cpp_library( ":stateless_reset_generator", "//fizz/server:server_extensions", "//quic/fizz/handshake:fizz_handshake", + "//quic/state:quic_state_machine", ], ) @@ -51,7 +52,9 @@ mvfst_cpp_library( "AppToken.h", ], exported_deps = [ + "//folly:expected", "//quic:constants", + "//quic:exception", "//quic/handshake:transport_parameters", ], ) diff --git a/quic/server/handshake/ServerTransportParametersExtension.h b/quic/server/handshake/ServerTransportParametersExtension.h index 5d9c1237c..b7891d04a 100644 --- a/quic/server/handshake/ServerTransportParametersExtension.h +++ b/quic/server/handshake/ServerTransportParametersExtension.h @@ -10,6 +10,33 @@ #include #include #include +#include + +namespace { + +std::vector getClientDependentExtTransportParams( + const quic::QuicConnectionStateBase& conn, + const std::vector& clientParams) { + using TpId = quic::TransportParameterId; + std::vector params; + + // Server-side direct encap logic + if (conn.transportSettings.directEncapAddress.has_value()) { + // Check if client sent client_direct_encap parameter + auto clientDirectEncapIt = + findParameter(clientParams, TpId::client_direct_encap); + if (clientDirectEncapIt != clientParams.end()) { + // Client supports direct encap and server has address configured + params.push_back(encodeIPAddressParameter( + TpId::server_direct_encap, + conn.transportSettings.directEncapAddress.value())); + } + } + + return params; +} + +} // namespace namespace quic { @@ -30,6 +57,7 @@ class ServerTransportParametersExtension : public fizz::ServerExtensions { const StatelessResetToken& token, ConnectionId initialSourceCid, ConnectionId originalDestinationCid, + const QuicConnectionStateBase& conn, std::vector customTransportParameters = std::vector()) : encodingVersion_(encodingVersion), @@ -46,7 +74,8 @@ class ServerTransportParametersExtension : public fizz::ServerExtensions { token_(token), initialSourceCid_(initialSourceCid), originalDestinationCid_(originalDestinationCid), - customTransportParameters_(std::move(customTransportParameters)) {} + customTransportParameters_(std::move(customTransportParameters)), + conn_(conn) {} ~ServerTransportParametersExtension() override = default; @@ -181,6 +210,15 @@ class ServerTransportParametersExtension : public fizz::ServerExtensions { params.parameters.push_back(customParameter); } + // Add direct encap parameters if connection state is available + if (clientTransportParameters_.has_value()) { + auto additionalParams = getClientDependentExtTransportParams( + conn_, clientTransportParameters_->parameters); + for (const auto& param : additionalParams) { + params.parameters.push_back(param); + } + } + exts.push_back(encodeExtension(params, encodingVersion_)); return exts; } @@ -206,5 +244,6 @@ class ServerTransportParametersExtension : public fizz::ServerExtensions { ConnectionId initialSourceCid_; ConnectionId originalDestinationCid_; std::vector customTransportParameters_; + const QuicConnectionStateBase& conn_; }; } // namespace quic diff --git a/quic/server/handshake/test/ServerHandshakeTest.cpp b/quic/server/handshake/test/ServerHandshakeTest.cpp index fcaf38432..9ca14185d 100644 --- a/quic/server/handshake/test/ServerHandshakeTest.cpp +++ b/quic/server/handshake/test/ServerHandshakeTest.cpp @@ -122,7 +122,8 @@ class ServerHandshakeTest : public Test { generateStatelessResetToken(), ConnectionId::createAndMaybeCrash( std::vector{0xff, 0xfe, 0xfd, 0xfc}), - ConnectionId::createZeroLength()); + ConnectionId::createZeroLength(), + *conn); initialize(); handshake->accept(params); diff --git a/quic/server/handshake/test/ServerTransportParametersTest.cpp b/quic/server/handshake/test/ServerTransportParametersTest.cpp index a3f1f144b..3be2cef22 100644 --- a/quic/server/handshake/test/ServerTransportParametersTest.cpp +++ b/quic/server/handshake/test/ServerTransportParametersTest.cpp @@ -12,6 +12,7 @@ #include #include +#include #include #include @@ -36,6 +37,8 @@ static ClientHello getClientHello(QuicVersion version) { } TEST(ServerTransportParametersTest, TestGetExtensions) { + QuicServerConnectionState conn( + FizzServerQuicHandshakeContext::Builder().build()); ServerTransportParametersExtension ext( QuicVersion::MVFST, kDefaultConnectionFlowControlWindow, @@ -51,7 +54,8 @@ TEST(ServerTransportParametersTest, TestGetExtensions) { generateStatelessResetToken(), ConnectionId::createAndMaybeCrash( std::vector{0xff, 0xfe, 0xfd, 0xfc}), - ConnectionId::createZeroLength()); + ConnectionId::createZeroLength(), + conn); auto extensions = ext.getExtensions(getClientHello(QuicVersion::MVFST)); EXPECT_EQ(extensions.size(), 1); @@ -60,6 +64,8 @@ TEST(ServerTransportParametersTest, TestGetExtensions) { } TEST(ServerTransportParametersTest, TestGetExtensionsMissingClientParams) { + QuicServerConnectionState conn( + FizzServerQuicHandshakeContext::Builder().build()); ServerTransportParametersExtension ext( QuicVersion::MVFST, kDefaultConnectionFlowControlWindow, @@ -75,11 +81,14 @@ TEST(ServerTransportParametersTest, TestGetExtensionsMissingClientParams) { generateStatelessResetToken(), ConnectionId::createAndMaybeCrash( std::vector{0xff, 0xfe, 0xfd, 0xfc}), - ConnectionId::createZeroLength()); + ConnectionId::createZeroLength(), + conn); EXPECT_THROW(ext.getExtensions(TestMessages::clientHello()), FizzException); } TEST(ServerTransportParametersTest, TestQuicV1RejectDraftExtensionNumber) { + QuicServerConnectionState conn( + FizzServerQuicHandshakeContext::Builder().build()); ServerTransportParametersExtension ext( QuicVersion::QUIC_V1, kDefaultConnectionFlowControlWindow, @@ -95,13 +104,16 @@ TEST(ServerTransportParametersTest, TestQuicV1RejectDraftExtensionNumber) { generateStatelessResetToken(), ConnectionId::createAndMaybeCrash( std::vector{0xff, 0xfe, 0xfd, 0xfc}), - ConnectionId::createZeroLength()); + ConnectionId::createZeroLength(), + conn); EXPECT_THROW( ext.getExtensions(getClientHello(QuicVersion::MVFST)), FizzException); EXPECT_NO_THROW(ext.getExtensions(getClientHello(QuicVersion::QUIC_V1))); } TEST(ServerTransportParametersTest, TestQuicV1RejectDuplicateExtensions) { + QuicServerConnectionState conn( + FizzServerQuicHandshakeContext::Builder().build()); ServerTransportParametersExtension ext( QuicVersion::QUIC_V1, kDefaultConnectionFlowControlWindow, @@ -118,7 +130,8 @@ TEST(ServerTransportParametersTest, TestQuicV1RejectDuplicateExtensions) { ConnectionId::createAndMaybeCrash( std::vector{0xff, 0xfe, 0xfd, 0xfc}), ConnectionId::createAndMaybeCrash( - std::vector{0xfb, 0xfa, 0xf9, 0xf8})); + std::vector{0xfb, 0xfa, 0xf9, 0xf8}), + conn); auto chlo = getClientHello(QuicVersion::QUIC_V1); ClientTransportParameters duplicateClientParams; @@ -133,6 +146,8 @@ TEST(ServerTransportParametersTest, TestQuicV1RejectDuplicateExtensions) { } TEST(ServerTransportParametersTest, TestQuicV1Fields) { + QuicServerConnectionState conn( + FizzServerQuicHandshakeContext::Builder().build()); ServerTransportParametersExtension ext( QuicVersion::QUIC_V1, kDefaultConnectionFlowControlWindow, @@ -149,7 +164,8 @@ TEST(ServerTransportParametersTest, TestQuicV1Fields) { ConnectionId::createAndMaybeCrash( std::vector{0xff, 0xfe, 0xfd, 0xfc}), ConnectionId::createAndMaybeCrash( - std::vector{0xfb, 0xfa, 0xf9, 0xf8})); + std::vector{0xfb, 0xfa, 0xf9, 0xf8}), + conn); auto extensions = ext.getExtensions(getClientHello(QuicVersion::QUIC_V1)); EXPECT_EQ(extensions.size(), 1); @@ -175,6 +191,8 @@ TEST(ServerTransportParametersTest, TestQuicV1Fields) { } TEST(ServerTransportParametersTest, TestMvfstFields) { + QuicServerConnectionState conn( + FizzServerQuicHandshakeContext::Builder().build()); ServerTransportParametersExtension ext( QuicVersion::MVFST, kDefaultConnectionFlowControlWindow, @@ -191,7 +209,8 @@ TEST(ServerTransportParametersTest, TestMvfstFields) { ConnectionId::createAndMaybeCrash( std::vector{0xff, 0xfe, 0xfd, 0xfc}), ConnectionId::createAndMaybeCrash( - std::vector{0xfb, 0xfa, 0xf9, 0xf8})); + std::vector{0xfb, 0xfa, 0xf9, 0xf8}), + conn); auto extensions = ext.getExtensions(getClientHello(QuicVersion::MVFST)); EXPECT_EQ(extensions.size(), 1); diff --git a/quic/server/state/ServerStateMachine.cpp b/quic/server/state/ServerStateMachine.cpp index b7137f058..1082a9d2f 100644 --- a/quic/server/state/ServerStateMachine.cpp +++ b/quic/server/state/ServerStateMachine.cpp @@ -1030,6 +1030,7 @@ folly::Expected onServerReadDataFromOpen( *newServerConnIdData->token, conn.serverConnectionId.value(), initialDestinationConnectionId, + conn, customTransportParams)); conn.transportParametersEncoded = true; const CryptoFactory& cryptoFactory =