1
0
mirror of https://github.com/facebookincubator/mvfst.git synced 2025-11-24 04:01:07 +03:00
Files
mvfst/quic/server/handshake/test/ServerHandshakeTest.cpp
Aman Sharma 56c0231b9d Implement direct encap transport parameter negotiation + Fix build errors
Summary:
This diff implements the transport parameter negotiation logic for direct encapsulation support on top of D77604174, addresses reviewer feedback by changing the connection state pointer to a reference, and **fixes critical build errors** caused by the constructor signature changes.

**Changes Made:**

1. **Client-side logic**: The client sends the `client_direct_encap` transport parameter with no value if `supportDirectEncap` is true.

2. **Server-side logic**: The server sends the `server_direct_encap` transport parameter if `directEncapAddress` is not null AND the client sent the `client_direct_encap` parameter. The value is the IP address bytes in network byte order.

3. **Pointer to Reference Change**: Changed `const QuicConnectionStateBase* conn_` to `const QuicConnectionStateBase& conn_` in ServerTransportParametersExtension as requested by reviewer feedback, since nullability is not possible (non-null is an invariant).

4. **🔧 Build Error Fixes**: Fixed multiple test files that were broken by the constructor signature changes:

**Build Fixes Applied:**

- **Fixed 3 critical build failures** that prevented compilation:
  - `fbcode//quic/facebook/mbed/test:mbed_client_handshake`
  - `fbcode//quic/fizz/client/handshake/test:fizz_client_handshake_test`
  - `fbcode//quic/server/handshake/test:ServerHandshakeTest`

- **Updated constructor calls** in test files to include the new `const QuicConnectionStateBase& conn` parameter
- **Fixed helper functions** like `constructServerTp()` to accept and pass connection state
- **Updated test classes** like `MalformedServerTransportParamsExt` to handle the new parameter

**Files Fixed:**
- `fbcode/quic/facebook/mbed/test/MbedClientHandshake.cpp` - Fixed 4 constructor calls and helper functions
- `fbcode/quic/fizz/client/handshake/test/FizzClientHandshakeTest.cpp` - Fixed constructor call
- `fbcode/quic/server/handshake/test/ServerHandshakeTest.cpp` - Fixed constructor call

**Test Results:**
-  `buck test fbcode//quic/facebook/mbed/test:mbed_client_handshake` → Pass 7, Fail 0
-  `buck test fbcode//quic/fizz/client/handshake/test:fizz_client_handshake_test` → Pass 12, Fail 0
-  All previously failing builds now compile successfully

**Implementation Details:**

- Added `encodeIPAddressParameter()` function to handle IP address encoding (supports both IPv4 and IPv6)
- Modified `getSupportedExtTransportParams()` to include client-side direct encap logic
- Created new `getClientDependentExtTransportParams()` function that specifically handles server-side direct encap logic based on client parameters
- Updated `ServerTransportParametersExtension` to use the new function for adding client-dependent parameters
- Updated `ServerStateMachine` to pass connection state to the extension
- **Changed constructor parameter order**: `conn` parameter now comes before `customTransportParameters` to maintain C++ default parameter rules
- **Updated member initialization order**: Fixed to match class declaration order
- **Fixed all test constructors**: Updated test cases to provide connection state parameter

**Architecture:**

Instead of overloading `getSupportedExtTransportParams()` with two parameters, the solution now uses a dedicated `getClientDependentExtTransportParams()` function that:
- Only handles parameters that depend on client capabilities (currently `server_direct_encap`)
- Returns a clean list of parameters without duplicating base transport parameters
- Provides better separation of concerns and clearer function naming

**Unit Tests Added:**

- Comprehensive test suite in `fbcode/quic/handshake/test/TransportParametersTest.cpp`
- 8 test cases covering all client/server scenarios with IPv4/IPv6 support
- Tests verify parameter presence/absence and correct IP address byte encoding
- All tests pass successfully
- **Updated test infrastructure**: Fixed ServerTransportParametersTest.cpp to work with reference-based connection state

**Requirements Fulfilled:**
 Client sends `client_direct_encap` parameter with no value if `supportDirectEncap` is true
 Server sends `server_direct_encap` parameter with IP address bytes if conditions are met
 Changed connection state from pointer to reference as requested by reviewer
 **Fixed all build errors caused by constructor signature changes**
 ---
> Generated by [RACER](https://www.internalfb.com/wiki/RACER_(Risk-Aware_Code_Editing_and_Refactoring)/), powered by [Confucius](https://www.internalfb.com/wiki/Confucius/Analect/Shared_Analects/Confucius_Code_Assist_(CCA)/)
[Session](https://www.internalfb.com/confucius?entry_name=RACER&mode=Focused&namespace[0]=agentrix&session_id=8c84b14a-56a5-11f0-8e69-214e73924e50&tab=Chat), [Trace](https://www.internalfb.com/confucius?entry_name=RACER&mode=Focused&namespace[0]=agentrix&session_id=8c84b14a-56a5-11f0-8e69-214e73924e50&tab=Trace)
[Session](https://www.internalfb.com/confucius?entry_name=RACER&mode=Focused&namespace[0]=agentrix&session_id=439da8ee-5798-11f0-ace1b7dae9e7575d&tab=Chat), [Trace](https://www.internalfb.com/confucius?entry_name=RACER&mode=Focused&namespace[0]=agentrix&session_id=439da8ee-5798-11f0-ace1-b7dae9e7575d&tab=Trace)
[Session](https://www.internalfb.com/confucius?session_id=7ed2dc86-5847-11f0-8055-b73b775dc61a&tab=Chat), [Trace](https://www.internalfb.com/confucius?session_id=7ed2dc86-5847-11f0-8055-b73b775dc61a&tab=Trace)
[Session](https://www.internalfb.com/confucius?session_id=8bdc0a0c-584b-11f0-9977-35e1e0d6200a&tab=Chat), [Trace](https://www.internalfb.com/confucius?session_id=8bdc0a0c-584b-1f0-9977-35e1e0d6200a&tab=Trace)
**[Current Session](https://www.internalfb.com/confucius?session_id={{ session_id }}&tab=Chat), [Trace](https://www.internalfb.com/confucius?session_id={{ session_id }}&tab=Trace)**
[Session](https://www.internalfb.com/confucius?entry_name=RACER&mode=Focused&namespace[0]=agentrix&session_id=08290174-5b4d-11f0-ac9d-93447239bce3&tab=Chat), [Trace](https://www.internalfb.com/confucius?entry_name=RACER&mode=Focused&namespace[0]=agentrix&session_id=08290174-5b4d-11f0-ac9d-93447239bce3&tab=Trace)
[Session](https://www.internalfb.com/confucius?entry_name=RACER&mode=Focused&namespace[0]=agentrix&session_id=ded2f5f2-5b6d-11f0-b259-5db72d7f2f63&tab=Chat), [Trace](https://www.internalfb.com/confucius?entry_name=RACER&mode=Focused&namespace[0]=agentrix&session_id=ded2f5f2-5b6d-11f0-b259-5db72d7f2f63&tab=Trace)

Reviewed By: hanidamlaj

Differential Revision: D77605298

fbshipit-source-id: 22d3faffaa93f1aa57e05c984339ab3b2e817ac1
2025-07-07 20:04:24 -07:00

907 lines
30 KiB
C++

/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include <fizz/client/test/Mocks.h>
#include <fizz/crypto/test/TestUtil.h>
#include <fizz/protocol/test/Mocks.h>
#include <fizz/server/test/Mocks.h>
#include <folly/io/async/SSLContext.h>
#include <folly/io/async/ScopedEventBaseThread.h>
#include <quic/QuicConstants.h>
#include <quic/common/test/TestUtils.h>
#include <quic/fizz/client/handshake/FizzClientExtensions.h>
#include <quic/fizz/handshake/FizzBridge.h>
#include <quic/fizz/handshake/QuicFizzFactory.h>
#include <quic/fizz/server/handshake/AppToken.h>
#include <quic/fizz/server/handshake/FizzServerHandshake.h>
#include <quic/fizz/server/handshake/FizzServerQuicHandshakeContext.h>
#include <quic/server/handshake/AppToken.h>
#include <quic/server/handshake/ServerHandshake.h>
#include <quic/state/StateData.h>
using namespace std;
using namespace testing;
static constexpr folly::StringPiece kTestHostname = "www.facebook.com";
namespace quic::test {
class MockServerHandshakeCallback : public ServerHandshake::HandshakeCallback {
public:
~MockServerHandshakeCallback() override = default;
MOCK_METHOD(void, onCryptoEventAvailable, (), (noexcept));
};
struct TestingServerConnectionState : public QuicServerConnectionState {
explicit TestingServerConnectionState(
std::shared_ptr<FizzServerQuicHandshakeContext> context)
: QuicServerConnectionState(std::move(context)) {}
uint32_t getDestructorGuardCount() const {
return folly::DelayedDestruction::getDestructorGuardCount();
}
};
class ServerHandshakeTest : public Test {
public:
~ServerHandshakeTest() override = default;
virtual void setupClientAndServerContext() {}
QuicVersion getVersion() {
return QuicVersion::MVFST;
}
virtual void initialize() {
handshake->initialize(&evb, &serverCallback);
}
void SetUp() override {
// This client context is used outside the context of QUIC in this test, so
// we have to manually configure the QUIC record customizations.
clientCtx = quic::test::createClientCtx();
clientCtx->setOmitEarlyRecordLayer(true);
clientCtx->setFactory(std::make_shared<QuicFizzFactory>());
serverCtx = quic::test::createServerCtx();
setupClientAndServerContext();
auto fizzServerContext = FizzServerQuicHandshakeContext::Builder()
.setFizzServerContext(serverCtx)
.build();
conn.reset(new TestingServerConnectionState(fizzServerContext));
cryptoState = conn->cryptoState.get();
handshake = conn->serverHandshakeLayer;
hostname = kTestHostname.str();
verifier = std::make_shared<fizz::test::MockCertificateVerifier>();
uint64_t initialMaxData = kDefaultConnectionFlowControlWindow;
uint64_t initialMaxStreamDataBidiLocal = kDefaultStreamFlowControlWindow;
uint64_t initialMaxStreamDataBidiRemote = kDefaultStreamFlowControlWindow;
uint64_t initialMaxStreamDataUni = kDefaultStreamFlowControlWindow;
uint64_t initialMaxStreamsBidi = kDefaultMaxStreamsBidirectional;
uint64_t initialMaxStreamsUni = kDefaultMaxStreamsUnidirectional;
auto clientExtensions =
std::make_shared<ClientTransportParametersExtension>(
getVersion(),
initialMaxData,
initialMaxStreamDataBidiLocal,
initialMaxStreamDataBidiRemote,
initialMaxStreamDataUni,
initialMaxStreamsBidi,
initialMaxStreamsUni,
kDefaultIdleTimeout,
kDefaultAckDelayExponent,
kDefaultUDPSendPacketLen,
kDefaultActiveConnectionIdLimit,
ConnectionId::createZeroLength());
fizzClient.reset(new fizz::client::FizzClient<
ServerHandshakeTest,
fizz::client::ClientStateMachine>(
clientState, clientReadBuffer, readAeadOptions, *this, dg.get()));
std::vector<QuicVersion> supportedVersions = {getVersion()};
auto params = std::make_shared<ServerTransportParametersExtension>(
getVersion(),
initialMaxData,
initialMaxStreamDataBidiLocal,
initialMaxStreamDataBidiRemote,
initialMaxStreamDataUni,
initialMaxStreamsBidi,
initialMaxStreamsUni,
/*disableMigration=*/true,
kDefaultIdleTimeout,
kDefaultAckDelayExponent,
kDefaultUDPSendPacketLen,
generateStatelessResetToken(),
ConnectionId::createAndMaybeCrash(
std::vector<uint8_t>{0xff, 0xfe, 0xfd, 0xfc}),
ConnectionId::createZeroLength(),
*conn);
initialize();
handshake->accept(params);
ON_CALL(serverCallback, onCryptoEventAvailable())
.WillByDefault(Invoke([this]() {
VLOG(1) << "onCryptoEventAvailable";
processCryptoEvents();
}));
auto cachedPsk = clientCtx->getPsk(hostname);
fizzClient->connect(
clientCtx,
verifier,
hostname,
cachedPsk,
folly::Optional<std::vector<fizz::ech::ParsedECHConfig>>(folly::none),
std::make_shared<FizzClientExtensions>(clientExtensions, 0));
}
void processCryptoEvents() {
auto handshakeStateResult = setHandshakeState();
if (handshakeStateResult.hasError()) {
VLOG(1) << "server exception " << handshakeStateResult.error().message;
ex = folly::makeUnexpected(handshakeStateResult.error());
if (!inRoundScope_ && !handshakeCv.ready()) {
VLOG(1) << "Posting handshake cv";
handshakeCv.post();
}
return;
}
waitForData = false;
do {
auto writableBytes = getHandshakeWriteBytes();
if (writableBytes->empty()) {
break;
}
VLOG(1) << "server->client bytes="
<< writableBytes->computeChainDataLength();
clientReadBuffer.append(std::move(writableBytes));
fizzClient->newTransportData();
} while (!waitForData);
if (!inRoundScope_ && !handshakeCv.ready()) {
VLOG(1) << "Posting handshake cv";
handshakeCv.post();
}
}
void clientServerRound() {
SCOPE_EXIT {
inRoundScope_ = false;
};
inRoundScope_ = true;
evb.loop();
for (auto& clientWrite : clientWrites) {
for (auto& content : clientWrite.contents) {
auto encryptionLevel =
getEncryptionLevelFromFizz(content.encryptionLevel);
auto result =
handshake->doHandshake(std::move(content.data), encryptionLevel);
if (result.hasError()) {
ex = folly::makeUnexpected(result.error());
}
}
}
processCryptoEvents();
evb.loopIgnoreKeepAlive();
}
void serverClientRound() {
SCOPE_EXIT {
inRoundScope_ = false;
};
inRoundScope_ = true;
evb.loop();
waitForData = false;
do {
auto writableBytes = getHandshakeWriteBytes();
if (writableBytes->empty()) {
break;
}
VLOG(1) << "server->client bytes="
<< writableBytes->computeChainDataLength();
clientReadBuffer.append(std::move(writableBytes));
fizzClient->newTransportData();
} while (!waitForData);
evb.loop();
}
[[nodiscard]] folly::Expected<folly::Unit, QuicError> setHandshakeState() {
auto oneRttWriteCipherTmp = handshake->getFirstOneRttWriteCipher();
if (oneRttWriteCipherTmp.hasError()) {
return folly::makeUnexpected(oneRttWriteCipherTmp.error());
}
auto oneRttReadCipherTmp = handshake->getFirstOneRttReadCipher();
if (oneRttReadCipherTmp.hasError()) {
return folly::makeUnexpected(oneRttReadCipherTmp.error());
}
auto zeroRttReadCipherTmp = handshake->getZeroRttReadCipher();
if (zeroRttReadCipherTmp.hasError()) {
return folly::makeUnexpected(zeroRttReadCipherTmp.error());
}
auto handshakeWriteCipherTmp = std::move(conn->handshakeWriteCipher);
auto handshakeReadCipherTmp = handshake->getHandshakeReadCipher();
if (handshakeReadCipherTmp.hasError()) {
return folly::makeUnexpected(handshakeReadCipherTmp.error());
}
if (oneRttWriteCipherTmp.value()) {
oneRttWriteCipher = std::move(oneRttWriteCipherTmp.value());
}
if (oneRttReadCipherTmp.value()) {
oneRttReadCipher = std::move(oneRttReadCipherTmp.value());
}
if (zeroRttReadCipherTmp.value()) {
zeroRttReadCipher = std::move(zeroRttReadCipherTmp.value());
}
if (handshakeReadCipherTmp.value()) {
handshakeReadCipher = std::move(handshakeReadCipherTmp.value());
}
if (handshakeWriteCipherTmp) {
handshakeWriteCipher = std::move(handshakeWriteCipherTmp);
}
return folly::unit;
}
void expectOneRttReadCipher(bool expected) {
EXPECT_EQ(oneRttReadCipher.get() != nullptr, expected);
}
void expectOneRttWriteCipher(bool expected) {
EXPECT_EQ(oneRttWriteCipher.get() != nullptr, expected);
}
void expectOneRttCipher(bool expected) {
expectOneRttWriteCipher(expected);
expectOneRttReadCipher(expected);
if (expected) {
EXPECT_EQ(zeroRttReadCipher.get(), nullptr);
} else {
EXPECT_EQ(zeroRttReadCipher.get(), nullptr);
}
}
void expectZeroRttCipher(bool expected, bool oneRttRead) {
CHECK(expected || !oneRttRead) << "invalid condition supplied";
EXPECT_NE(oneRttWriteCipher.get(), nullptr);
if (expected) {
if (oneRttRead) {
EXPECT_NE(oneRttReadCipher.get(), nullptr);
} else {
EXPECT_EQ(oneRttReadCipher.get(), nullptr);
}
EXPECT_NE(zeroRttReadCipher.get(), nullptr);
} else {
EXPECT_EQ(oneRttReadCipher.get(), nullptr);
EXPECT_EQ(zeroRttReadCipher.get(), nullptr);
}
}
BufPtr getHandshakeWriteBytes() {
auto buf = folly::IOBuf::create(0);
switch (clientState.readRecordLayer()->getEncryptionLevel()) {
case fizz::EncryptionLevel::Plaintext:
if (!cryptoState->initialStream.writeBuffer.empty()) {
buf->appendToChain(cryptoState->initialStream.writeBuffer.move());
}
break;
case fizz::EncryptionLevel::Handshake:
case fizz::EncryptionLevel::EarlyData:
if (!cryptoState->handshakeStream.writeBuffer.empty()) {
buf->appendToChain(cryptoState->handshakeStream.writeBuffer.move());
}
break;
case fizz::EncryptionLevel::AppTraffic:
if (!cryptoState->oneRttStream.writeBuffer.empty()) {
buf->appendToChain(cryptoState->oneRttStream.writeBuffer.move());
}
}
return buf;
}
void operator()(fizz::DeliverAppData&) {}
void operator()(fizz::WriteToSocket& write) {
clientWrites.push_back(std::move(write));
}
void operator()(fizz::client::ReportEarlyHandshakeSuccess&) {
earlyHandshakeSuccess = true;
}
void operator()(fizz::client::ReportHandshakeSuccess&) {
handshakeSuccess = true;
}
void operator()(fizz::client::ReportEarlyWriteFailed&) {
earlyWriteFailed = true;
}
void operator()(fizz::ReportError&) {
error = true;
}
void operator()(fizz::WaitForData&) {
waitForData = true;
fizzClient->waitForData();
}
void operator()(fizz::client::MutateState& mutator) {
mutator(clientState);
}
void operator()(fizz::client::NewCachedPsk& newCachedPsk) {
clientCtx->putPsk(hostname, std::move(newCachedPsk.psk));
}
void operator()(fizz::SecretAvailable&) {}
void operator()(fizz::EndOfData&) {}
void operator()(fizz::client::ECHRetryAvailable&) {}
class DelayedHolder : public folly::DelayedDestruction {};
std::unique_ptr<DelayedHolder, folly::DelayedDestruction::Destructor> dg;
folly::EventBase evb;
std::unique_ptr<
TestingServerConnectionState,
folly::DelayedDestruction::Destructor>
conn{nullptr};
ServerHandshake* handshake;
QuicCryptoState* cryptoState;
fizz::client::State clientState;
std::unique_ptr<fizz::client::FizzClient<
ServerHandshakeTest,
fizz::client::ClientStateMachine>>
fizzClient;
folly::IOBufQueue clientReadBuffer{folly::IOBufQueue::cacheChainLength()};
bool earlyHandshakeSuccess{false};
bool handshakeSuccess{false};
bool earlyWriteFailed{false};
bool error{false};
fizz::Aead::AeadOptions readAeadOptions;
std::vector<fizz::WriteToSocket> clientWrites;
MockServerHandshakeCallback serverCallback;
std::unique_ptr<Aead> oneRttWriteCipher;
std::unique_ptr<Aead> oneRttReadCipher;
std::unique_ptr<Aead> zeroRttReadCipher;
std::unique_ptr<Aead> handshakeWriteCipher;
std::unique_ptr<Aead> handshakeReadCipher;
folly::Expected<folly::Unit, QuicError> ex{folly::unit};
std::string hostname;
std::shared_ptr<fizz::test::MockCertificateVerifier> verifier;
std::shared_ptr<fizz::client::FizzClientContext> clientCtx;
std::shared_ptr<fizz::server::FizzServerContext> serverCtx;
folly::Baton<> handshakeCv;
bool inRoundScope_{false};
bool waitForData{false};
};
TEST_F(ServerHandshakeTest, TestGetExportedKeyingMaterial) {
// Sanity check. getExportedKeyingMaterial() should return nullptr prior to
// an handshake.
auto ekm = handshake->getExportedKeyingMaterial(
"EXPORTER-Some-Label", std::nullopt, 32);
EXPECT_TRUE(!ekm.has_value());
clientServerRound();
serverClientRound();
ekm = handshake->getExportedKeyingMaterial(
"EXPORTER-Some-Label", std::nullopt, 32);
ASSERT_TRUE(ekm.has_value());
EXPECT_EQ(ekm->size(), 32);
ekm = handshake->getExportedKeyingMaterial(
"EXPORTER-Some-Label", ByteRange(), 32);
ASSERT_TRUE(ekm.has_value());
EXPECT_EQ(ekm->size(), 32);
}
TEST_F(ServerHandshakeTest, TestHandshakeSuccess) {
clientServerRound();
EXPECT_EQ(handshake->getPhase(), ServerHandshake::Phase::Handshake);
serverClientRound();
clientServerRound();
EXPECT_EQ(handshake->getPhase(), ServerHandshake::Phase::Established);
ASSERT_FALSE(ex.hasError());
expectOneRttCipher(true);
EXPECT_EQ(handshake->getApplicationProtocol(), "quic_test");
EXPECT_TRUE(handshakeSuccess);
}
TEST_F(ServerHandshakeTest, TestHandshakeSuccessIgnoreNonHandshake) {
fizz::WriteToSocket write;
fizz::TLSContent content;
content.contentType = fizz::ContentType::alert;
content.data = folly::IOBuf::copyBuffer(folly::unhexlify("01000000"));
content.encryptionLevel = fizz::EncryptionLevel::Plaintext;
write.contents.push_back(std::move(content));
clientWrites.push_back(std::move(write));
clientServerRound();
EXPECT_EQ(handshake->getPhase(), ServerHandshake::Phase::Handshake);
serverClientRound();
clientServerRound();
EXPECT_EQ(handshake->getPhase(), ServerHandshake::Phase::Established);
ASSERT_FALSE(ex.hasError());
expectOneRttCipher(true);
EXPECT_EQ(handshake->getApplicationProtocol(), "quic_test");
EXPECT_TRUE(handshakeSuccess);
}
TEST_F(ServerHandshakeTest, TestMalformedHandshakeMessage) {
fizz::WriteToSocket write;
fizz::TLSContent content;
content.contentType = fizz::ContentType::handshake;
content.data = folly::IOBuf::copyBuffer(folly::unhexlify("01000000"));
content.encryptionLevel = fizz::EncryptionLevel::Plaintext;
write.contents.push_back(std::move(content));
clientWrites.clear();
clientWrites.push_back(std::move(write));
clientServerRound();
EXPECT_TRUE(ex.hasError());
}
class AsyncRejectingTicketCipher : public fizz::server::TicketCipher {
public:
~AsyncRejectingTicketCipher() override = default;
folly::SemiFuture<folly::Optional<
std::pair<std::unique_ptr<folly::IOBuf>, std::chrono::seconds>>>
encrypt(fizz::server::ResumptionState) const override {
if (!encryptAsync_) {
return std::make_pair(folly::IOBuf::create(0), 2s);
} else {
encryptAsync_ = false;
return std::move(encryptFuture_).deferValue([](auto&&) {
VLOG(1) << "got ticket async";
return folly::makeSemiFuture<folly::Optional<
std::pair<std::unique_ptr<folly::IOBuf>, std::chrono::seconds>>>(
std::make_pair(folly::IOBuf::create(0), 2s));
});
}
}
void setDecryptAsync(bool async, folly::SemiFuture<folly::Unit> future) {
decryptAsync_ = async;
decryptFuture_ = std::move(future);
}
void setEncryptAsync(bool async, folly::SemiFuture<folly::Unit> future) {
encryptAsync_ = async;
encryptFuture_ = std::move(future);
}
void setDecryptError(bool error) {
error_ = error;
}
folly::SemiFuture<
std::pair<fizz::PskType, folly::Optional<fizz::server::ResumptionState>>>
decrypt(std::unique_ptr<folly::IOBuf>) const override {
if (!decryptAsync_) {
if (error_) {
throw std::runtime_error("test decrypt error");
}
return std::make_pair(fizz::PskType::Rejected, folly::none);
} else {
decryptAsync_ = false;
return std::move(decryptFuture_).deferValue([&](auto&&) {
VLOG(1) << "triggered reject";
if (error_) {
throw std::runtime_error("test decrypt error");
}
return folly::makeSemiFuture<std::pair<
fizz::PskType,
folly::Optional<fizz::server::ResumptionState>>>(
std::make_pair(fizz::PskType::Rejected, folly::none));
});
}
}
private:
mutable folly::SemiFuture<folly::Unit> decryptFuture_;
mutable folly::SemiFuture<folly::Unit> encryptFuture_;
mutable bool decryptAsync_{true};
mutable bool encryptAsync_{false};
bool error_{false};
};
class ServerHandshakeWriteNSTTest : public ServerHandshakeTest {
public:
void setupClientAndServerContext() override {
serverCtx->setSendNewSessionTicket(false);
ticketCipher_ = std::make_shared<fizz::server::test::MockTicketCipher>();
ticketCipher_->setDefaults();
serverCtx->setTicketCipher(ticketCipher_);
cache_ = std::make_shared<fizz::client::BasicPskCache>();
clientCtx->setPskCache(cache_);
clientCtx->setSupportedAlpns({"h3"});
serverCtx->setSupportedAlpns({"h3", "hq"});
}
protected:
std::shared_ptr<fizz::client::BasicPskCache> cache_;
std::shared_ptr<fizz::server::test::MockTicketCipher> ticketCipher_;
};
TEST_F(ServerHandshakeWriteNSTTest, TestWriteNST) {
clientServerRound();
EXPECT_EQ(handshake->getPhase(), ServerHandshake::Phase::Handshake);
serverClientRound();
clientServerRound();
EXPECT_EQ(handshake->getPhase(), ServerHandshake::Phase::Established);
expectOneRttCipher(true);
AppToken appToken;
EXPECT_FALSE(cache_->getPsk(kTestHostname.str()));
EXPECT_CALL(*ticketCipher_, _encrypt(_))
.WillOnce(Invoke([&appToken](fizz::server::ResumptionState& resState) {
EXPECT_TRUE(
folly::IOBufEqualTo()(resState.appToken, encodeAppToken(appToken)));
return std::make_pair(folly::IOBuf::copyBuffer("appToken"), 100s);
}));
ASSERT_FALSE(handshake->writeNewSessionTicket(appToken).hasError());
processCryptoEvents();
evb.loop();
EXPECT_TRUE(cache_->getPsk(kTestHostname.str()));
}
class ServerHandshakePskTest : public ServerHandshakeTest {
public:
~ServerHandshakePskTest() override = default;
void SetUp() override {
cache = std::make_shared<fizz::client::BasicPskCache>();
psk.psk = std::string("psk");
psk.secret = std::string("secret");
psk.type = fizz::PskType::Resumption;
psk.version = fizz::ProtocolVersion::tls_1_3;
psk.cipher = fizz::CipherSuite::TLS_AES_128_GCM_SHA256;
psk.group = fizz::NamedGroup::x25519;
psk.serverCert = std::make_shared<fizz::test::MockCert>();
psk.alpn = std::string("h3");
psk.ticketAgeAdd = 1;
psk.ticketIssueTime = std::chrono::system_clock::time_point();
psk.ticketExpirationTime =
std::chrono::system_clock::time_point(std::chrono::seconds(20));
psk.ticketHandshakeTime = std::chrono::system_clock::time_point();
psk.maxEarlyDataSize = 2;
ServerHandshakeTest::SetUp();
}
void setupClientAndServerContext() override {
cache->putPsk(kTestHostname.str(), psk);
ticketCipher = makeTicketCipher();
serverCtx->setTicketCipher(ticketCipher);
clientCtx->setPskCache(cache);
clientCtx->setSupportedAlpns({"h3"});
serverCtx->setSupportedAlpns({"h3", "hq"});
}
virtual std::shared_ptr<fizz::server::TicketCipher> makeTicketCipher() = 0;
std::shared_ptr<fizz::client::BasicPskCache> cache;
folly::Promise<folly::Unit> promise;
std::shared_ptr<fizz::server::TicketCipher> ticketCipher;
fizz::client::CachedPsk psk;
};
class ServerHandshakeHRRTest : public ServerHandshakePskTest {
public:
~ServerHandshakeHRRTest() override = default;
void setupClientAndServerContext() override {
// Make a group mismatch happen.
psk.group = fizz::NamedGroup::secp256r1;
clientCtx->setSupportedGroups(
{fizz::NamedGroup::secp256r1, fizz::NamedGroup::x25519});
clientCtx->setSupportedVersions({fizz::ProtocolVersion::tls_1_3});
clientCtx->setDefaultShares({fizz::NamedGroup::secp256r1});
serverCtx->setSupportedGroups({fizz::NamedGroup::x25519});
serverCtx->setSupportedVersions({fizz::ProtocolVersion::tls_1_3});
clientCtx->setSupportedAlpns({"h3"});
serverCtx->setSupportedAlpns({"h3", "hq"});
ServerHandshakePskTest::setupClientAndServerContext();
}
std::shared_ptr<fizz::server::TicketCipher> makeTicketCipher() override {
auto cipher = std::make_shared<AsyncRejectingTicketCipher>();
cipher->setDecryptAsync(true, promise.getFuture());
return cipher;
}
};
TEST_F(ServerHandshakeHRRTest, TestHRR) {
auto rejectingCipher =
dynamic_cast<AsyncRejectingTicketCipher*>(ticketCipher.get());
rejectingCipher->setDecryptAsync(false, folly::makeFuture());
clientServerRound();
EXPECT_EQ(handshake->getPhase(), ServerHandshake::Phase::Handshake);
serverClientRound();
clientServerRound();
EXPECT_EQ(handshake->getPhase(), ServerHandshake::Phase::Handshake);
expectOneRttReadCipher(false);
expectOneRttWriteCipher(true);
serverClientRound();
clientServerRound();
EXPECT_EQ(handshake->getPhase(), ServerHandshake::Phase::Established);
EXPECT_EQ(handshake->getApplicationProtocol(), "h3");
expectOneRttCipher(true);
}
TEST_F(ServerHandshakeHRRTest, TestAsyncHRR) {
// Make an async ticket decryption operation.
clientServerRound();
promise.setValue();
evb.loop();
expectOneRttCipher(false);
handshakeCv.wait();
handshakeCv.reset();
clientServerRound();
serverClientRound();
EXPECT_EQ(handshake->getPhase(), ServerHandshake::Phase::Handshake);
clientServerRound();
EXPECT_EQ(handshake->getPhase(), ServerHandshake::Phase::Established);
EXPECT_EQ(handshake->getApplicationProtocol(), "h3");
expectOneRttCipher(true);
}
TEST_F(ServerHandshakeHRRTest, TestAsyncCancel) {
// Make an async ticket decryption operation.
clientServerRound();
handshake->cancel();
// Let's destroy the crypto state to make sure it is not referenced.
conn->cryptoState.reset();
promise.setValue();
evb.loop();
EXPECT_EQ(handshake->getApplicationProtocol(), std::nullopt);
expectOneRttCipher(false);
}
class ServerHandshakeAsyncTest : public ServerHandshakePskTest {
public:
~ServerHandshakeAsyncTest() override = default;
std::shared_ptr<fizz::server::TicketCipher> makeTicketCipher() override {
auto cipher = std::make_shared<AsyncRejectingTicketCipher>();
cipher->setDecryptAsync(false, folly::makeFuture());
cipher->setEncryptAsync(true, promise.getFuture());
return cipher;
}
};
TEST_F(ServerHandshakeAsyncTest, TestAsyncCancel) {
// Make an async ticket decryption operation.
clientServerRound();
serverClientRound();
clientServerRound();
handshake->cancel();
// Let's destroy the crypto state to make sure it is not referenced.
conn->cryptoState.reset();
promise.setValue();
evb.loop();
EXPECT_EQ(conn->getDestructorGuardCount(), 0);
}
class ServerHandshakeAsyncErrorTest : public ServerHandshakePskTest {
public:
~ServerHandshakeAsyncErrorTest() override = default;
std::shared_ptr<fizz::server::TicketCipher> makeTicketCipher() override {
auto cipher = std::make_shared<AsyncRejectingTicketCipher>();
cipher->setDecryptAsync(true, promise.getFuture());
cipher->setDecryptError(true);
return cipher;
}
};
TEST_F(ServerHandshakeAsyncErrorTest, TestAsyncError) {
clientServerRound();
bool error = false;
EXPECT_CALL(serverCallback, onCryptoEventAvailable())
.WillRepeatedly(Invoke([&] {
if (handshake->getFirstOneRttReadCipher().hasError()) {
error = true;
}
}));
promise.setValue();
evb.loop();
EXPECT_TRUE(error);
}
TEST_F(ServerHandshakeAsyncErrorTest, TestCancelOnAsyncError) {
clientServerRound();
EXPECT_CALL(serverCallback, onCryptoEventAvailable())
.WillRepeatedly(Invoke([&] {
handshake->cancel();
// Let's destroy the crypto state to make sure it is not referenced.
conn->cryptoState.reset();
}));
promise.setValue();
evb.loop();
EXPECT_TRUE(handshake->getFirstOneRttReadCipher().hasError());
}
TEST_F(ServerHandshakeAsyncErrorTest, TestCancelWhileWaitingAsyncError) {
clientServerRound();
handshake->cancel();
// Let's destroy the crypto state to make sure it is not referenced.
conn->cryptoState.reset();
promise.setValue();
evb.loop();
EXPECT_TRUE(handshake->getFirstOneRttReadCipher().hasError());
}
class ServerHandshakeSyncErrorTest : public ServerHandshakePskTest {
public:
~ServerHandshakeSyncErrorTest() override = default;
std::shared_ptr<fizz::server::TicketCipher> makeTicketCipher() override {
auto cipher = std::make_shared<AsyncRejectingTicketCipher>();
cipher->setDecryptError(true);
cipher->setDecryptAsync(false, folly::makeFuture());
return cipher;
}
};
TEST_F(ServerHandshakeSyncErrorTest, TestError) {
// Make an async ticket decryption operation.
clientServerRound();
evb.loop();
EXPECT_TRUE(handshake->getFirstOneRttReadCipher().hasError());
}
class ServerHandshakeZeroRttDefaultAppTokenValidatorTest
: public ServerHandshakePskTest {
public:
~ServerHandshakeZeroRttDefaultAppTokenValidatorTest() override = default;
/**
* This cipher can currently resume only 1 connection.
*/
class AcceptingTicketCipher : public fizz::server::TicketCipher {
public:
~AcceptingTicketCipher() override = default;
folly::SemiFuture<folly::Optional<
std::pair<std::unique_ptr<folly::IOBuf>, std::chrono::seconds>>>
encrypt(fizz::server::ResumptionState) const override {
// Fake handshake, no need todo anything here.
return std::make_pair(folly::IOBuf::create(0), 2s);
}
void setPsk(fizz::client::CachedPsk psk) {
resState.version = psk.version;
resState.cipher = psk.cipher;
resState.resumptionSecret = folly::IOBuf::copyBuffer(psk.secret);
resState.alpn = psk.alpn;
resState.ticketIssueTime = std::chrono::system_clock::time_point();
resState.handshakeTime = std::chrono::system_clock::time_point();
resState.serverCert = psk.serverCert;
}
folly::SemiFuture<std::pair<
fizz::PskType,
folly::Optional<fizz::server::ResumptionState>>>
decrypt(std::unique_ptr<folly::IOBuf>) const override {
return std::make_pair(fizz::PskType::Resumption, std::move(resState));
}
private:
mutable fizz::server::ResumptionState resState;
};
void setupClientAndServerContext() override {
clientCtx->setSendEarlyData(true);
serverCtx->setEarlyDataSettings(
true,
fizz::server::ClockSkewTolerance{-1000ms, 1000ms},
std::make_shared<fizz::server::AllowAllReplayReplayCache>());
clientCtx->setSupportedVersions({fizz::ProtocolVersion::tls_1_3});
serverCtx->setSupportedVersions({fizz::ProtocolVersion::tls_1_3});
clientCtx->setSupportedAlpns({"h3"});
serverCtx->setSupportedAlpns({"h3", "hq"});
ServerHandshakePskTest::setupClientAndServerContext();
}
std::shared_ptr<fizz::server::TicketCipher> makeTicketCipher() override {
auto cipher = std::make_shared<AcceptingTicketCipher>();
cipher->setPsk(psk);
return cipher;
}
};
TEST_F(
ServerHandshakeZeroRttDefaultAppTokenValidatorTest,
TestDefaultAppTokenValidatorRejectZeroRtt) {
clientServerRound();
EXPECT_EQ(handshake->getPhase(), ServerHandshake::Phase::Handshake);
expectZeroRttCipher(false, false);
serverClientRound();
clientServerRound();
EXPECT_EQ(handshake->getPhase(), ServerHandshake::Phase::Established);
expectOneRttCipher(true);
}
class ServerHandshakeZeroRttTest
: public ServerHandshakeZeroRttDefaultAppTokenValidatorTest {
void initialize() override {
auto validator =
std::make_unique<fizz::server::test::MockAppTokenValidator>();
validator_ = validator.get();
handshake->initialize(&evb, &serverCallback, std::move(validator));
}
protected:
fizz::server::test::MockAppTokenValidator* validator_;
};
TEST_F(ServerHandshakeZeroRttTest, TestResumption) {
EXPECT_CALL(*validator_, validate(_)).WillOnce(Return(true));
clientServerRound();
EXPECT_EQ(handshake->getPhase(), ServerHandshake::Phase::KeysDerived);
expectZeroRttCipher(true, false);
serverClientRound();
clientServerRound();
EXPECT_EQ(handshake->getPhase(), ServerHandshake::Phase::Established);
expectZeroRttCipher(true, true);
}
TEST_F(ServerHandshakeZeroRttTest, TestRejectZeroRttNotEnabled) {
auto realServerCtx =
dynamic_cast<FizzServerHandshake*>(handshake)->getContext();
auto nonConstServerCtx =
const_cast<fizz::server::FizzServerContext*>(realServerCtx);
nonConstServerCtx->setEarlyDataSettings(
false, fizz::server::ClockSkewTolerance(), nullptr);
EXPECT_CALL(*validator_, validate(_)).Times(0);
clientServerRound();
EXPECT_EQ(handshake->getPhase(), ServerHandshake::Phase::Handshake);
expectZeroRttCipher(false, false);
serverClientRound();
clientServerRound();
EXPECT_EQ(handshake->getPhase(), ServerHandshake::Phase::Established);
expectOneRttCipher(true);
}
TEST_F(ServerHandshakeZeroRttTest, TestRejectZeroRttInvalidToken) {
EXPECT_CALL(*validator_, validate(_)).WillOnce(Return(false));
clientServerRound();
EXPECT_EQ(handshake->getPhase(), ServerHandshake::Phase::Handshake);
expectZeroRttCipher(false, false);
serverClientRound();
clientServerRound();
EXPECT_EQ(handshake->getPhase(), ServerHandshake::Phase::Established);
expectOneRttCipher(true);
}
} // namespace quic::test