mirror of
https://github.com/facebookincubator/mvfst.git
synced 2025-08-09 20:42:44 +03:00
Summary: See title Reviewed By: kvtsoy Differential Revision: D73444489 fbshipit-source-id: f83566ce023e8237335d3bb43d89fc471f053afa
902 lines
30 KiB
C++
902 lines
30 KiB
C++
/*
|
|
* Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
*
|
|
* This source code is licensed under the MIT license found in the
|
|
* LICENSE file in the root directory of this source tree.
|
|
*/
|
|
|
|
#include <gmock/gmock.h>
|
|
#include <gtest/gtest.h>
|
|
|
|
#include <fizz/client/test/Mocks.h>
|
|
#include <fizz/crypto/test/TestUtil.h>
|
|
#include <fizz/protocol/test/Mocks.h>
|
|
#include <fizz/server/test/Mocks.h>
|
|
|
|
#include <folly/io/async/SSLContext.h>
|
|
#include <folly/io/async/ScopedEventBaseThread.h>
|
|
|
|
#include <quic/QuicConstants.h>
|
|
#include <quic/common/test/TestUtils.h>
|
|
#include <quic/fizz/client/handshake/FizzClientExtensions.h>
|
|
#include <quic/fizz/handshake/FizzBridge.h>
|
|
#include <quic/fizz/handshake/QuicFizzFactory.h>
|
|
#include <quic/fizz/server/handshake/AppToken.h>
|
|
#include <quic/fizz/server/handshake/FizzServerHandshake.h>
|
|
#include <quic/fizz/server/handshake/FizzServerQuicHandshakeContext.h>
|
|
#include <quic/server/handshake/AppToken.h>
|
|
#include <quic/server/handshake/ServerHandshake.h>
|
|
#include <quic/state/StateData.h>
|
|
|
|
using namespace std;
|
|
using namespace testing;
|
|
|
|
static constexpr folly::StringPiece kTestHostname = "www.facebook.com";
|
|
|
|
namespace quic::test {
|
|
class MockServerHandshakeCallback : public ServerHandshake::HandshakeCallback {
|
|
public:
|
|
~MockServerHandshakeCallback() override = default;
|
|
|
|
MOCK_METHOD(void, onCryptoEventAvailable, (), (noexcept));
|
|
};
|
|
|
|
struct TestingServerConnectionState : public QuicServerConnectionState {
|
|
explicit TestingServerConnectionState(
|
|
std::shared_ptr<FizzServerQuicHandshakeContext> context)
|
|
: QuicServerConnectionState(std::move(context)) {}
|
|
|
|
uint32_t getDestructorGuardCount() const {
|
|
return folly::DelayedDestruction::getDestructorGuardCount();
|
|
}
|
|
};
|
|
|
|
class ServerHandshakeTest : public Test {
|
|
public:
|
|
~ServerHandshakeTest() override = default;
|
|
|
|
virtual void setupClientAndServerContext() {}
|
|
|
|
QuicVersion getVersion() {
|
|
return QuicVersion::MVFST;
|
|
}
|
|
|
|
virtual void initialize() {
|
|
handshake->initialize(&evb, &serverCallback);
|
|
}
|
|
|
|
void SetUp() override {
|
|
// This client context is used outside the context of QUIC in this test, so
|
|
// we have to manually configure the QUIC record customizations.
|
|
clientCtx = quic::test::createClientCtx();
|
|
clientCtx->setOmitEarlyRecordLayer(true);
|
|
clientCtx->setFactory(std::make_shared<QuicFizzFactory>());
|
|
serverCtx = quic::test::createServerCtx();
|
|
setupClientAndServerContext();
|
|
auto fizzServerContext = FizzServerQuicHandshakeContext::Builder()
|
|
.setFizzServerContext(serverCtx)
|
|
.build();
|
|
conn.reset(new TestingServerConnectionState(fizzServerContext));
|
|
cryptoState = conn->cryptoState.get();
|
|
handshake = conn->serverHandshakeLayer;
|
|
hostname = kTestHostname.str();
|
|
verifier = std::make_shared<fizz::test::MockCertificateVerifier>();
|
|
|
|
uint64_t initialMaxData = kDefaultConnectionFlowControlWindow;
|
|
uint64_t initialMaxStreamDataBidiLocal = kDefaultStreamFlowControlWindow;
|
|
uint64_t initialMaxStreamDataBidiRemote = kDefaultStreamFlowControlWindow;
|
|
uint64_t initialMaxStreamDataUni = kDefaultStreamFlowControlWindow;
|
|
uint64_t initialMaxStreamsBidi = kDefaultMaxStreamsBidirectional;
|
|
uint64_t initialMaxStreamsUni = kDefaultMaxStreamsUnidirectional;
|
|
auto clientExtensions =
|
|
std::make_shared<ClientTransportParametersExtension>(
|
|
getVersion(),
|
|
initialMaxData,
|
|
initialMaxStreamDataBidiLocal,
|
|
initialMaxStreamDataBidiRemote,
|
|
initialMaxStreamDataUni,
|
|
initialMaxStreamsBidi,
|
|
initialMaxStreamsUni,
|
|
kDefaultIdleTimeout,
|
|
kDefaultAckDelayExponent,
|
|
kDefaultUDPSendPacketLen,
|
|
kDefaultActiveConnectionIdLimit,
|
|
ConnectionId(std::vector<uint8_t>()));
|
|
fizzClient.reset(new fizz::client::FizzClient<
|
|
ServerHandshakeTest,
|
|
fizz::client::ClientStateMachine>(
|
|
clientState, clientReadBuffer, readAeadOptions, *this, dg.get()));
|
|
std::vector<QuicVersion> supportedVersions = {getVersion()};
|
|
auto params = std::make_shared<ServerTransportParametersExtension>(
|
|
getVersion(),
|
|
initialMaxData,
|
|
initialMaxStreamDataBidiLocal,
|
|
initialMaxStreamDataBidiRemote,
|
|
initialMaxStreamDataUni,
|
|
initialMaxStreamsBidi,
|
|
initialMaxStreamsUni,
|
|
/*disableMigration=*/true,
|
|
kDefaultIdleTimeout,
|
|
kDefaultAckDelayExponent,
|
|
kDefaultUDPSendPacketLen,
|
|
generateStatelessResetToken(),
|
|
ConnectionId(std::vector<uint8_t>{0xff, 0xfe, 0xfd, 0xfc}),
|
|
ConnectionId(std::vector<uint8_t>()));
|
|
initialize();
|
|
handshake->accept(params);
|
|
|
|
ON_CALL(serverCallback, onCryptoEventAvailable())
|
|
.WillByDefault(Invoke([this]() {
|
|
VLOG(1) << "onCryptoEventAvailable";
|
|
processCryptoEvents();
|
|
}));
|
|
auto cachedPsk = clientCtx->getPsk(hostname);
|
|
fizzClient->connect(
|
|
clientCtx,
|
|
verifier,
|
|
hostname,
|
|
cachedPsk,
|
|
Optional<std::vector<fizz::ech::ECHConfig>>(none),
|
|
std::make_shared<FizzClientExtensions>(clientExtensions, 0));
|
|
}
|
|
|
|
void processCryptoEvents() {
|
|
auto handshakeStateResult = setHandshakeState();
|
|
if (handshakeStateResult.hasError()) {
|
|
VLOG(1) << "server exception " << handshakeStateResult.error().message;
|
|
ex = folly::makeUnexpected(handshakeStateResult.error());
|
|
if (!inRoundScope_ && !handshakeCv.ready()) {
|
|
VLOG(1) << "Posting handshake cv";
|
|
handshakeCv.post();
|
|
}
|
|
return;
|
|
}
|
|
|
|
waitForData = false;
|
|
do {
|
|
auto writableBytes = getHandshakeWriteBytes();
|
|
if (writableBytes->empty()) {
|
|
break;
|
|
}
|
|
VLOG(1) << "server->client bytes="
|
|
<< writableBytes->computeChainDataLength();
|
|
clientReadBuffer.append(std::move(writableBytes));
|
|
fizzClient->newTransportData();
|
|
} while (!waitForData);
|
|
|
|
if (!inRoundScope_ && !handshakeCv.ready()) {
|
|
VLOG(1) << "Posting handshake cv";
|
|
handshakeCv.post();
|
|
}
|
|
}
|
|
|
|
void clientServerRound() {
|
|
SCOPE_EXIT {
|
|
inRoundScope_ = false;
|
|
};
|
|
inRoundScope_ = true;
|
|
evb.loop();
|
|
for (auto& clientWrite : clientWrites) {
|
|
for (auto& content : clientWrite.contents) {
|
|
auto encryptionLevel =
|
|
getEncryptionLevelFromFizz(content.encryptionLevel);
|
|
auto result =
|
|
handshake->doHandshake(std::move(content.data), encryptionLevel);
|
|
if (result.hasError()) {
|
|
ex = folly::makeUnexpected(result.error());
|
|
}
|
|
}
|
|
}
|
|
processCryptoEvents();
|
|
evb.loopIgnoreKeepAlive();
|
|
}
|
|
|
|
void serverClientRound() {
|
|
SCOPE_EXIT {
|
|
inRoundScope_ = false;
|
|
};
|
|
inRoundScope_ = true;
|
|
evb.loop();
|
|
waitForData = false;
|
|
do {
|
|
auto writableBytes = getHandshakeWriteBytes();
|
|
if (writableBytes->empty()) {
|
|
break;
|
|
}
|
|
VLOG(1) << "server->client bytes="
|
|
<< writableBytes->computeChainDataLength();
|
|
clientReadBuffer.append(std::move(writableBytes));
|
|
fizzClient->newTransportData();
|
|
} while (!waitForData);
|
|
evb.loop();
|
|
}
|
|
|
|
[[nodiscard]] folly::Expected<folly::Unit, QuicError> setHandshakeState() {
|
|
auto oneRttWriteCipherTmp = handshake->getFirstOneRttWriteCipher();
|
|
if (oneRttWriteCipherTmp.hasError()) {
|
|
return folly::makeUnexpected(oneRttWriteCipherTmp.error());
|
|
}
|
|
auto oneRttReadCipherTmp = handshake->getFirstOneRttReadCipher();
|
|
if (oneRttReadCipherTmp.hasError()) {
|
|
return folly::makeUnexpected(oneRttReadCipherTmp.error());
|
|
}
|
|
auto zeroRttReadCipherTmp = handshake->getZeroRttReadCipher();
|
|
if (zeroRttReadCipherTmp.hasError()) {
|
|
return folly::makeUnexpected(zeroRttReadCipherTmp.error());
|
|
}
|
|
auto handshakeWriteCipherTmp = std::move(conn->handshakeWriteCipher);
|
|
auto handshakeReadCipherTmp = handshake->getHandshakeReadCipher();
|
|
if (handshakeReadCipherTmp.hasError()) {
|
|
return folly::makeUnexpected(handshakeReadCipherTmp.error());
|
|
}
|
|
if (oneRttWriteCipherTmp.value()) {
|
|
oneRttWriteCipher = std::move(oneRttWriteCipherTmp.value());
|
|
}
|
|
if (oneRttReadCipherTmp.value()) {
|
|
oneRttReadCipher = std::move(oneRttReadCipherTmp.value());
|
|
}
|
|
if (zeroRttReadCipherTmp.value()) {
|
|
zeroRttReadCipher = std::move(zeroRttReadCipherTmp.value());
|
|
}
|
|
if (handshakeReadCipherTmp.value()) {
|
|
handshakeReadCipher = std::move(handshakeReadCipherTmp.value());
|
|
}
|
|
if (handshakeWriteCipherTmp) {
|
|
handshakeWriteCipher = std::move(handshakeWriteCipherTmp);
|
|
}
|
|
return folly::unit;
|
|
}
|
|
|
|
void expectOneRttReadCipher(bool expected) {
|
|
EXPECT_EQ(oneRttReadCipher.get() != nullptr, expected);
|
|
}
|
|
|
|
void expectOneRttWriteCipher(bool expected) {
|
|
EXPECT_EQ(oneRttWriteCipher.get() != nullptr, expected);
|
|
}
|
|
|
|
void expectOneRttCipher(bool expected) {
|
|
expectOneRttWriteCipher(expected);
|
|
expectOneRttReadCipher(expected);
|
|
|
|
if (expected) {
|
|
EXPECT_EQ(zeroRttReadCipher.get(), nullptr);
|
|
} else {
|
|
EXPECT_EQ(zeroRttReadCipher.get(), nullptr);
|
|
}
|
|
}
|
|
|
|
void expectZeroRttCipher(bool expected, bool oneRttRead) {
|
|
CHECK(expected || !oneRttRead) << "invalid condition supplied";
|
|
EXPECT_NE(oneRttWriteCipher.get(), nullptr);
|
|
if (expected) {
|
|
if (oneRttRead) {
|
|
EXPECT_NE(oneRttReadCipher.get(), nullptr);
|
|
} else {
|
|
EXPECT_EQ(oneRttReadCipher.get(), nullptr);
|
|
}
|
|
EXPECT_NE(zeroRttReadCipher.get(), nullptr);
|
|
} else {
|
|
EXPECT_EQ(oneRttReadCipher.get(), nullptr);
|
|
EXPECT_EQ(zeroRttReadCipher.get(), nullptr);
|
|
}
|
|
}
|
|
|
|
BufPtr getHandshakeWriteBytes() {
|
|
auto buf = folly::IOBuf::create(0);
|
|
switch (clientState.readRecordLayer()->getEncryptionLevel()) {
|
|
case fizz::EncryptionLevel::Plaintext:
|
|
if (!cryptoState->initialStream.writeBuffer.empty()) {
|
|
buf->prependChain(cryptoState->initialStream.writeBuffer.move());
|
|
}
|
|
break;
|
|
case fizz::EncryptionLevel::Handshake:
|
|
case fizz::EncryptionLevel::EarlyData:
|
|
if (!cryptoState->handshakeStream.writeBuffer.empty()) {
|
|
buf->prependChain(cryptoState->handshakeStream.writeBuffer.move());
|
|
}
|
|
break;
|
|
case fizz::EncryptionLevel::AppTraffic:
|
|
if (!cryptoState->oneRttStream.writeBuffer.empty()) {
|
|
buf->prependChain(cryptoState->oneRttStream.writeBuffer.move());
|
|
}
|
|
}
|
|
return buf;
|
|
}
|
|
|
|
void operator()(fizz::DeliverAppData&) {}
|
|
|
|
void operator()(fizz::WriteToSocket& write) {
|
|
clientWrites.push_back(std::move(write));
|
|
}
|
|
|
|
void operator()(fizz::client::ReportEarlyHandshakeSuccess&) {
|
|
earlyHandshakeSuccess = true;
|
|
}
|
|
|
|
void operator()(fizz::client::ReportHandshakeSuccess&) {
|
|
handshakeSuccess = true;
|
|
}
|
|
|
|
void operator()(fizz::client::ReportEarlyWriteFailed&) {
|
|
earlyWriteFailed = true;
|
|
}
|
|
|
|
void operator()(fizz::ReportError&) {
|
|
error = true;
|
|
}
|
|
|
|
void operator()(fizz::WaitForData&) {
|
|
waitForData = true;
|
|
fizzClient->waitForData();
|
|
}
|
|
|
|
void operator()(fizz::client::MutateState& mutator) {
|
|
mutator(clientState);
|
|
}
|
|
|
|
void operator()(fizz::client::NewCachedPsk& newCachedPsk) {
|
|
clientCtx->putPsk(hostname, std::move(newCachedPsk.psk));
|
|
}
|
|
|
|
void operator()(fizz::SecretAvailable&) {}
|
|
|
|
void operator()(fizz::EndOfData&) {}
|
|
|
|
void operator()(fizz::client::ECHRetryAvailable&) {}
|
|
|
|
class DelayedHolder : public folly::DelayedDestruction {};
|
|
|
|
std::unique_ptr<DelayedHolder, folly::DelayedDestruction::Destructor> dg;
|
|
|
|
folly::EventBase evb;
|
|
std::unique_ptr<
|
|
TestingServerConnectionState,
|
|
folly::DelayedDestruction::Destructor>
|
|
conn{nullptr};
|
|
ServerHandshake* handshake;
|
|
QuicCryptoState* cryptoState;
|
|
|
|
fizz::client::State clientState;
|
|
std::unique_ptr<fizz::client::FizzClient<
|
|
ServerHandshakeTest,
|
|
fizz::client::ClientStateMachine>>
|
|
fizzClient;
|
|
folly::IOBufQueue clientReadBuffer{folly::IOBufQueue::cacheChainLength()};
|
|
bool earlyHandshakeSuccess{false};
|
|
bool handshakeSuccess{false};
|
|
bool earlyWriteFailed{false};
|
|
bool error{false};
|
|
fizz::Aead::AeadOptions readAeadOptions;
|
|
|
|
std::vector<fizz::WriteToSocket> clientWrites;
|
|
MockServerHandshakeCallback serverCallback;
|
|
|
|
std::unique_ptr<Aead> oneRttWriteCipher;
|
|
std::unique_ptr<Aead> oneRttReadCipher;
|
|
std::unique_ptr<Aead> zeroRttReadCipher;
|
|
std::unique_ptr<Aead> handshakeWriteCipher;
|
|
std::unique_ptr<Aead> handshakeReadCipher;
|
|
|
|
folly::Expected<folly::Unit, QuicError> ex{folly::unit};
|
|
std::string hostname;
|
|
std::shared_ptr<fizz::test::MockCertificateVerifier> verifier;
|
|
std::shared_ptr<fizz::client::FizzClientContext> clientCtx;
|
|
std::shared_ptr<fizz::server::FizzServerContext> serverCtx;
|
|
folly::Baton<> handshakeCv;
|
|
bool inRoundScope_{false};
|
|
bool waitForData{false};
|
|
};
|
|
|
|
TEST_F(ServerHandshakeTest, TestGetExportedKeyingMaterial) {
|
|
// Sanity check. getExportedKeyingMaterial() should return nullptr prior to
|
|
// an handshake.
|
|
auto ekm =
|
|
handshake->getExportedKeyingMaterial("EXPORTER-Some-Label", none, 32);
|
|
EXPECT_TRUE(!ekm.has_value());
|
|
|
|
clientServerRound();
|
|
serverClientRound();
|
|
ekm = handshake->getExportedKeyingMaterial("EXPORTER-Some-Label", none, 32);
|
|
ASSERT_TRUE(ekm.has_value());
|
|
EXPECT_EQ(ekm->size(), 32);
|
|
|
|
ekm = handshake->getExportedKeyingMaterial(
|
|
"EXPORTER-Some-Label", ByteRange(), 32);
|
|
ASSERT_TRUE(ekm.has_value());
|
|
EXPECT_EQ(ekm->size(), 32);
|
|
}
|
|
|
|
TEST_F(ServerHandshakeTest, TestHandshakeSuccess) {
|
|
clientServerRound();
|
|
EXPECT_EQ(handshake->getPhase(), ServerHandshake::Phase::Handshake);
|
|
serverClientRound();
|
|
clientServerRound();
|
|
EXPECT_EQ(handshake->getPhase(), ServerHandshake::Phase::Established);
|
|
ASSERT_FALSE(ex.hasError());
|
|
expectOneRttCipher(true);
|
|
EXPECT_EQ(handshake->getApplicationProtocol(), "quic_test");
|
|
EXPECT_TRUE(handshakeSuccess);
|
|
}
|
|
|
|
TEST_F(ServerHandshakeTest, TestHandshakeSuccessIgnoreNonHandshake) {
|
|
fizz::WriteToSocket write;
|
|
fizz::TLSContent content;
|
|
content.contentType = fizz::ContentType::alert;
|
|
content.data = folly::IOBuf::copyBuffer(folly::unhexlify("01000000"));
|
|
content.encryptionLevel = fizz::EncryptionLevel::Plaintext;
|
|
write.contents.push_back(std::move(content));
|
|
clientWrites.push_back(std::move(write));
|
|
clientServerRound();
|
|
EXPECT_EQ(handshake->getPhase(), ServerHandshake::Phase::Handshake);
|
|
serverClientRound();
|
|
clientServerRound();
|
|
EXPECT_EQ(handshake->getPhase(), ServerHandshake::Phase::Established);
|
|
ASSERT_FALSE(ex.hasError());
|
|
expectOneRttCipher(true);
|
|
EXPECT_EQ(handshake->getApplicationProtocol(), "quic_test");
|
|
EXPECT_TRUE(handshakeSuccess);
|
|
}
|
|
|
|
TEST_F(ServerHandshakeTest, TestMalformedHandshakeMessage) {
|
|
fizz::WriteToSocket write;
|
|
fizz::TLSContent content;
|
|
content.contentType = fizz::ContentType::handshake;
|
|
content.data = folly::IOBuf::copyBuffer(folly::unhexlify("01000000"));
|
|
content.encryptionLevel = fizz::EncryptionLevel::Plaintext;
|
|
write.contents.push_back(std::move(content));
|
|
clientWrites.clear();
|
|
clientWrites.push_back(std::move(write));
|
|
clientServerRound();
|
|
|
|
EXPECT_TRUE(ex.hasError());
|
|
}
|
|
|
|
class AsyncRejectingTicketCipher : public fizz::server::TicketCipher {
|
|
public:
|
|
~AsyncRejectingTicketCipher() override = default;
|
|
|
|
folly::SemiFuture<
|
|
Optional<std::pair<std::unique_ptr<folly::IOBuf>, std::chrono::seconds>>>
|
|
encrypt(fizz::server::ResumptionState) const override {
|
|
if (!encryptAsync_) {
|
|
return std::make_pair(folly::IOBuf::create(0), 2s);
|
|
} else {
|
|
encryptAsync_ = false;
|
|
return std::move(encryptFuture_).deferValue([](auto&&) {
|
|
VLOG(1) << "got ticket async";
|
|
return folly::makeSemiFuture<Optional<
|
|
std::pair<std::unique_ptr<folly::IOBuf>, std::chrono::seconds>>>(
|
|
std::make_pair(folly::IOBuf::create(0), 2s));
|
|
});
|
|
}
|
|
}
|
|
|
|
void setDecryptAsync(bool async, folly::SemiFuture<folly::Unit> future) {
|
|
decryptAsync_ = async;
|
|
decryptFuture_ = std::move(future);
|
|
}
|
|
|
|
void setEncryptAsync(bool async, folly::SemiFuture<folly::Unit> future) {
|
|
encryptAsync_ = async;
|
|
encryptFuture_ = std::move(future);
|
|
}
|
|
|
|
void setDecryptError(bool error) {
|
|
error_ = error;
|
|
}
|
|
|
|
folly::SemiFuture<
|
|
std::pair<fizz::PskType, Optional<fizz::server::ResumptionState>>>
|
|
decrypt(std::unique_ptr<folly::IOBuf>) const override {
|
|
if (!decryptAsync_) {
|
|
if (error_) {
|
|
throw std::runtime_error("test decrypt error");
|
|
}
|
|
return std::make_pair(fizz::PskType::Rejected, none);
|
|
} else {
|
|
decryptAsync_ = false;
|
|
return std::move(decryptFuture_).deferValue([&](auto&&) {
|
|
VLOG(1) << "triggered reject";
|
|
if (error_) {
|
|
throw std::runtime_error("test decrypt error");
|
|
}
|
|
return folly::makeSemiFuture<
|
|
std::pair<fizz::PskType, Optional<fizz::server::ResumptionState>>>(
|
|
std::make_pair(fizz::PskType::Rejected, none));
|
|
});
|
|
}
|
|
}
|
|
|
|
private:
|
|
mutable folly::SemiFuture<folly::Unit> decryptFuture_;
|
|
mutable folly::SemiFuture<folly::Unit> encryptFuture_;
|
|
mutable bool decryptAsync_{true};
|
|
mutable bool encryptAsync_{false};
|
|
bool error_{false};
|
|
};
|
|
|
|
class ServerHandshakeWriteNSTTest : public ServerHandshakeTest {
|
|
public:
|
|
void setupClientAndServerContext() override {
|
|
serverCtx->setSendNewSessionTicket(false);
|
|
ticketCipher_ = std::make_shared<fizz::server::test::MockTicketCipher>();
|
|
ticketCipher_->setDefaults();
|
|
serverCtx->setTicketCipher(ticketCipher_);
|
|
cache_ = std::make_shared<fizz::client::BasicPskCache>();
|
|
clientCtx->setPskCache(cache_);
|
|
clientCtx->setSupportedAlpns({"h3"});
|
|
serverCtx->setSupportedAlpns({"h3", "hq"});
|
|
}
|
|
|
|
protected:
|
|
std::shared_ptr<fizz::client::BasicPskCache> cache_;
|
|
std::shared_ptr<fizz::server::test::MockTicketCipher> ticketCipher_;
|
|
};
|
|
|
|
TEST_F(ServerHandshakeWriteNSTTest, TestWriteNST) {
|
|
clientServerRound();
|
|
EXPECT_EQ(handshake->getPhase(), ServerHandshake::Phase::Handshake);
|
|
serverClientRound();
|
|
clientServerRound();
|
|
EXPECT_EQ(handshake->getPhase(), ServerHandshake::Phase::Established);
|
|
expectOneRttCipher(true);
|
|
|
|
AppToken appToken;
|
|
|
|
EXPECT_FALSE(cache_->getPsk(kTestHostname.str()));
|
|
EXPECT_CALL(*ticketCipher_, _encrypt(_))
|
|
.WillOnce(Invoke([&appToken](fizz::server::ResumptionState& resState) {
|
|
EXPECT_TRUE(
|
|
folly::IOBufEqualTo()(resState.appToken, encodeAppToken(appToken)));
|
|
return std::make_pair(folly::IOBuf::copyBuffer("appToken"), 100s);
|
|
}));
|
|
ASSERT_FALSE(handshake->writeNewSessionTicket(appToken).hasError());
|
|
processCryptoEvents();
|
|
evb.loop();
|
|
EXPECT_TRUE(cache_->getPsk(kTestHostname.str()));
|
|
}
|
|
|
|
class ServerHandshakePskTest : public ServerHandshakeTest {
|
|
public:
|
|
~ServerHandshakePskTest() override = default;
|
|
|
|
void SetUp() override {
|
|
cache = std::make_shared<fizz::client::BasicPskCache>();
|
|
psk.psk = std::string("psk");
|
|
psk.secret = std::string("secret");
|
|
psk.type = fizz::PskType::Resumption;
|
|
psk.version = fizz::ProtocolVersion::tls_1_3;
|
|
psk.cipher = fizz::CipherSuite::TLS_AES_128_GCM_SHA256;
|
|
psk.group = fizz::NamedGroup::x25519;
|
|
psk.serverCert = std::make_shared<fizz::test::MockCert>();
|
|
psk.alpn = std::string("h3");
|
|
psk.ticketAgeAdd = 1;
|
|
psk.ticketIssueTime = std::chrono::system_clock::time_point();
|
|
psk.ticketExpirationTime =
|
|
std::chrono::system_clock::time_point(std::chrono::seconds(20));
|
|
psk.ticketHandshakeTime = std::chrono::system_clock::time_point();
|
|
psk.maxEarlyDataSize = 2;
|
|
ServerHandshakeTest::SetUp();
|
|
}
|
|
|
|
void setupClientAndServerContext() override {
|
|
cache->putPsk(kTestHostname.str(), psk);
|
|
ticketCipher = makeTicketCipher();
|
|
serverCtx->setTicketCipher(ticketCipher);
|
|
clientCtx->setPskCache(cache);
|
|
clientCtx->setSupportedAlpns({"h3"});
|
|
serverCtx->setSupportedAlpns({"h3", "hq"});
|
|
}
|
|
|
|
virtual std::shared_ptr<fizz::server::TicketCipher> makeTicketCipher() = 0;
|
|
|
|
std::shared_ptr<fizz::client::BasicPskCache> cache;
|
|
folly::Promise<folly::Unit> promise;
|
|
std::shared_ptr<fizz::server::TicketCipher> ticketCipher;
|
|
fizz::client::CachedPsk psk;
|
|
};
|
|
|
|
class ServerHandshakeHRRTest : public ServerHandshakePskTest {
|
|
public:
|
|
~ServerHandshakeHRRTest() override = default;
|
|
|
|
void setupClientAndServerContext() override {
|
|
// Make a group mismatch happen.
|
|
psk.group = fizz::NamedGroup::secp256r1;
|
|
clientCtx->setSupportedGroups(
|
|
{fizz::NamedGroup::secp256r1, fizz::NamedGroup::x25519});
|
|
clientCtx->setSupportedVersions({fizz::ProtocolVersion::tls_1_3});
|
|
clientCtx->setDefaultShares({fizz::NamedGroup::secp256r1});
|
|
serverCtx->setSupportedGroups({fizz::NamedGroup::x25519});
|
|
serverCtx->setSupportedVersions({fizz::ProtocolVersion::tls_1_3});
|
|
clientCtx->setSupportedAlpns({"h3"});
|
|
serverCtx->setSupportedAlpns({"h3", "hq"});
|
|
ServerHandshakePskTest::setupClientAndServerContext();
|
|
}
|
|
|
|
std::shared_ptr<fizz::server::TicketCipher> makeTicketCipher() override {
|
|
auto cipher = std::make_shared<AsyncRejectingTicketCipher>();
|
|
cipher->setDecryptAsync(true, promise.getFuture());
|
|
return cipher;
|
|
}
|
|
};
|
|
|
|
TEST_F(ServerHandshakeHRRTest, TestHRR) {
|
|
auto rejectingCipher =
|
|
dynamic_cast<AsyncRejectingTicketCipher*>(ticketCipher.get());
|
|
rejectingCipher->setDecryptAsync(false, folly::makeFuture());
|
|
clientServerRound();
|
|
EXPECT_EQ(handshake->getPhase(), ServerHandshake::Phase::Handshake);
|
|
serverClientRound();
|
|
clientServerRound();
|
|
EXPECT_EQ(handshake->getPhase(), ServerHandshake::Phase::Handshake);
|
|
|
|
expectOneRttReadCipher(false);
|
|
expectOneRttWriteCipher(true);
|
|
serverClientRound();
|
|
clientServerRound();
|
|
EXPECT_EQ(handshake->getPhase(), ServerHandshake::Phase::Established);
|
|
EXPECT_EQ(handshake->getApplicationProtocol(), "h3");
|
|
expectOneRttCipher(true);
|
|
}
|
|
|
|
TEST_F(ServerHandshakeHRRTest, TestAsyncHRR) {
|
|
// Make an async ticket decryption operation.
|
|
clientServerRound();
|
|
|
|
promise.setValue();
|
|
evb.loop();
|
|
|
|
expectOneRttCipher(false);
|
|
|
|
handshakeCv.wait();
|
|
handshakeCv.reset();
|
|
clientServerRound();
|
|
serverClientRound();
|
|
EXPECT_EQ(handshake->getPhase(), ServerHandshake::Phase::Handshake);
|
|
clientServerRound();
|
|
EXPECT_EQ(handshake->getPhase(), ServerHandshake::Phase::Established);
|
|
EXPECT_EQ(handshake->getApplicationProtocol(), "h3");
|
|
expectOneRttCipher(true);
|
|
}
|
|
|
|
TEST_F(ServerHandshakeHRRTest, TestAsyncCancel) {
|
|
// Make an async ticket decryption operation.
|
|
clientServerRound();
|
|
|
|
handshake->cancel();
|
|
// Let's destroy the crypto state to make sure it is not referenced.
|
|
conn->cryptoState.reset();
|
|
|
|
promise.setValue();
|
|
evb.loop();
|
|
|
|
EXPECT_EQ(handshake->getApplicationProtocol(), none);
|
|
expectOneRttCipher(false);
|
|
}
|
|
|
|
class ServerHandshakeAsyncTest : public ServerHandshakePskTest {
|
|
public:
|
|
~ServerHandshakeAsyncTest() override = default;
|
|
|
|
std::shared_ptr<fizz::server::TicketCipher> makeTicketCipher() override {
|
|
auto cipher = std::make_shared<AsyncRejectingTicketCipher>();
|
|
cipher->setDecryptAsync(false, folly::makeFuture());
|
|
cipher->setEncryptAsync(true, promise.getFuture());
|
|
return cipher;
|
|
}
|
|
};
|
|
|
|
TEST_F(ServerHandshakeAsyncTest, TestAsyncCancel) {
|
|
// Make an async ticket decryption operation.
|
|
clientServerRound();
|
|
serverClientRound();
|
|
clientServerRound();
|
|
|
|
handshake->cancel();
|
|
// Let's destroy the crypto state to make sure it is not referenced.
|
|
conn->cryptoState.reset();
|
|
|
|
promise.setValue();
|
|
evb.loop();
|
|
|
|
EXPECT_EQ(conn->getDestructorGuardCount(), 0);
|
|
}
|
|
|
|
class ServerHandshakeAsyncErrorTest : public ServerHandshakePskTest {
|
|
public:
|
|
~ServerHandshakeAsyncErrorTest() override = default;
|
|
|
|
std::shared_ptr<fizz::server::TicketCipher> makeTicketCipher() override {
|
|
auto cipher = std::make_shared<AsyncRejectingTicketCipher>();
|
|
cipher->setDecryptAsync(true, promise.getFuture());
|
|
cipher->setDecryptError(true);
|
|
return cipher;
|
|
}
|
|
};
|
|
|
|
TEST_F(ServerHandshakeAsyncErrorTest, TestAsyncError) {
|
|
clientServerRound();
|
|
|
|
bool error = false;
|
|
EXPECT_CALL(serverCallback, onCryptoEventAvailable())
|
|
.WillRepeatedly(Invoke([&] {
|
|
if (handshake->getFirstOneRttReadCipher().hasError()) {
|
|
error = true;
|
|
}
|
|
}));
|
|
promise.setValue();
|
|
evb.loop();
|
|
EXPECT_TRUE(error);
|
|
}
|
|
|
|
TEST_F(ServerHandshakeAsyncErrorTest, TestCancelOnAsyncError) {
|
|
clientServerRound();
|
|
|
|
EXPECT_CALL(serverCallback, onCryptoEventAvailable())
|
|
.WillRepeatedly(Invoke([&] {
|
|
handshake->cancel();
|
|
// Let's destroy the crypto state to make sure it is not referenced.
|
|
conn->cryptoState.reset();
|
|
}));
|
|
promise.setValue();
|
|
evb.loop();
|
|
EXPECT_TRUE(handshake->getFirstOneRttReadCipher().hasError());
|
|
}
|
|
|
|
TEST_F(ServerHandshakeAsyncErrorTest, TestCancelWhileWaitingAsyncError) {
|
|
clientServerRound();
|
|
handshake->cancel();
|
|
// Let's destroy the crypto state to make sure it is not referenced.
|
|
conn->cryptoState.reset();
|
|
|
|
promise.setValue();
|
|
evb.loop();
|
|
EXPECT_TRUE(handshake->getFirstOneRttReadCipher().hasError());
|
|
}
|
|
|
|
class ServerHandshakeSyncErrorTest : public ServerHandshakePskTest {
|
|
public:
|
|
~ServerHandshakeSyncErrorTest() override = default;
|
|
|
|
std::shared_ptr<fizz::server::TicketCipher> makeTicketCipher() override {
|
|
auto cipher = std::make_shared<AsyncRejectingTicketCipher>();
|
|
cipher->setDecryptError(true);
|
|
cipher->setDecryptAsync(false, folly::makeFuture());
|
|
return cipher;
|
|
}
|
|
};
|
|
|
|
TEST_F(ServerHandshakeSyncErrorTest, TestError) {
|
|
// Make an async ticket decryption operation.
|
|
clientServerRound();
|
|
evb.loop();
|
|
EXPECT_TRUE(handshake->getFirstOneRttReadCipher().hasError());
|
|
}
|
|
|
|
class ServerHandshakeZeroRttDefaultAppTokenValidatorTest
|
|
: public ServerHandshakePskTest {
|
|
public:
|
|
~ServerHandshakeZeroRttDefaultAppTokenValidatorTest() override = default;
|
|
|
|
/**
|
|
* This cipher can currently resume only 1 connection.
|
|
*/
|
|
class AcceptingTicketCipher : public fizz::server::TicketCipher {
|
|
public:
|
|
~AcceptingTicketCipher() override = default;
|
|
|
|
folly::SemiFuture<Optional<
|
|
std::pair<std::unique_ptr<folly::IOBuf>, std::chrono::seconds>>>
|
|
encrypt(fizz::server::ResumptionState) const override {
|
|
// Fake handshake, no need todo anything here.
|
|
return std::make_pair(folly::IOBuf::create(0), 2s);
|
|
}
|
|
|
|
void setPsk(fizz::client::CachedPsk psk) {
|
|
resState.version = psk.version;
|
|
resState.cipher = psk.cipher;
|
|
resState.resumptionSecret = folly::IOBuf::copyBuffer(psk.secret);
|
|
resState.alpn = psk.alpn;
|
|
resState.ticketIssueTime = std::chrono::system_clock::time_point();
|
|
resState.handshakeTime = std::chrono::system_clock::time_point();
|
|
resState.serverCert = psk.serverCert;
|
|
}
|
|
|
|
folly::SemiFuture<
|
|
std::pair<fizz::PskType, Optional<fizz::server::ResumptionState>>>
|
|
decrypt(std::unique_ptr<folly::IOBuf>) const override {
|
|
return std::make_pair(fizz::PskType::Resumption, std::move(resState));
|
|
}
|
|
|
|
private:
|
|
mutable fizz::server::ResumptionState resState;
|
|
};
|
|
|
|
void setupClientAndServerContext() override {
|
|
clientCtx->setSendEarlyData(true);
|
|
serverCtx->setEarlyDataSettings(
|
|
true,
|
|
fizz::server::ClockSkewTolerance{-1000ms, 1000ms},
|
|
std::make_shared<fizz::server::AllowAllReplayReplayCache>());
|
|
|
|
clientCtx->setSupportedVersions({fizz::ProtocolVersion::tls_1_3});
|
|
serverCtx->setSupportedVersions({fizz::ProtocolVersion::tls_1_3});
|
|
clientCtx->setSupportedAlpns({"h3"});
|
|
serverCtx->setSupportedAlpns({"h3", "hq"});
|
|
ServerHandshakePskTest::setupClientAndServerContext();
|
|
}
|
|
|
|
std::shared_ptr<fizz::server::TicketCipher> makeTicketCipher() override {
|
|
auto cipher = std::make_shared<AcceptingTicketCipher>();
|
|
cipher->setPsk(psk);
|
|
return cipher;
|
|
}
|
|
};
|
|
|
|
TEST_F(
|
|
ServerHandshakeZeroRttDefaultAppTokenValidatorTest,
|
|
TestDefaultAppTokenValidatorRejectZeroRtt) {
|
|
clientServerRound();
|
|
EXPECT_EQ(handshake->getPhase(), ServerHandshake::Phase::Handshake);
|
|
expectZeroRttCipher(false, false);
|
|
serverClientRound();
|
|
clientServerRound();
|
|
EXPECT_EQ(handshake->getPhase(), ServerHandshake::Phase::Established);
|
|
expectOneRttCipher(true);
|
|
}
|
|
|
|
class ServerHandshakeZeroRttTest
|
|
: public ServerHandshakeZeroRttDefaultAppTokenValidatorTest {
|
|
void initialize() override {
|
|
auto validator =
|
|
std::make_unique<fizz::server::test::MockAppTokenValidator>();
|
|
validator_ = validator.get();
|
|
handshake->initialize(&evb, &serverCallback, std::move(validator));
|
|
}
|
|
|
|
protected:
|
|
fizz::server::test::MockAppTokenValidator* validator_;
|
|
};
|
|
|
|
TEST_F(ServerHandshakeZeroRttTest, TestResumption) {
|
|
EXPECT_CALL(*validator_, validate(_)).WillOnce(Return(true));
|
|
clientServerRound();
|
|
EXPECT_EQ(handshake->getPhase(), ServerHandshake::Phase::KeysDerived);
|
|
expectZeroRttCipher(true, false);
|
|
serverClientRound();
|
|
clientServerRound();
|
|
EXPECT_EQ(handshake->getPhase(), ServerHandshake::Phase::Established);
|
|
expectZeroRttCipher(true, true);
|
|
}
|
|
|
|
TEST_F(ServerHandshakeZeroRttTest, TestRejectZeroRttNotEnabled) {
|
|
auto realServerCtx =
|
|
dynamic_cast<FizzServerHandshake*>(handshake)->getContext();
|
|
auto nonConstServerCtx =
|
|
const_cast<fizz::server::FizzServerContext*>(realServerCtx);
|
|
nonConstServerCtx->setEarlyDataSettings(
|
|
false, fizz::server::ClockSkewTolerance(), nullptr);
|
|
EXPECT_CALL(*validator_, validate(_)).Times(0);
|
|
clientServerRound();
|
|
EXPECT_EQ(handshake->getPhase(), ServerHandshake::Phase::Handshake);
|
|
expectZeroRttCipher(false, false);
|
|
serverClientRound();
|
|
clientServerRound();
|
|
EXPECT_EQ(handshake->getPhase(), ServerHandshake::Phase::Established);
|
|
expectOneRttCipher(true);
|
|
}
|
|
|
|
TEST_F(ServerHandshakeZeroRttTest, TestRejectZeroRttInvalidToken) {
|
|
EXPECT_CALL(*validator_, validate(_)).WillOnce(Return(false));
|
|
clientServerRound();
|
|
EXPECT_EQ(handshake->getPhase(), ServerHandshake::Phase::Handshake);
|
|
expectZeroRttCipher(false, false);
|
|
serverClientRound();
|
|
clientServerRound();
|
|
EXPECT_EQ(handshake->getPhase(), ServerHandshake::Phase::Established);
|
|
expectOneRttCipher(true);
|
|
}
|
|
} // namespace quic::test
|