/* * Copyright (c) Meta Platforms, Inc. and affiliates. * * This source code is licensed under the MIT license found in the * LICENSE file in the root directory of this source tree. */ #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include using namespace std; using namespace testing; static constexpr folly::StringPiece kTestHostname = "www.facebook.com"; namespace quic { namespace test { class MockServerHandshakeCallback : public ServerHandshake::HandshakeCallback { public: ~MockServerHandshakeCallback() override = default; MOCK_METHOD(void, onCryptoEventAvailable, (), (noexcept)); }; struct TestingServerConnectionState : public QuicServerConnectionState { explicit TestingServerConnectionState( std::shared_ptr context) : QuicServerConnectionState(std::move(context)) {} uint32_t getDestructorGuardCount() const { return folly::DelayedDestruction::getDestructorGuardCount(); } }; class ServerHandshakeTest : public Test { public: ~ServerHandshakeTest() override = default; virtual void setupClientAndServerContext() {} QuicVersion getVersion() { return QuicVersion::MVFST; } virtual void initialize() { handshake->initialize(&evb, &serverCallback); } void SetUp() override { // This client context is used outside the context of QUIC in this test, so // we have to manually configure the QUIC record customizations. clientCtx = quic::test::createClientCtx(); clientCtx->setOmitEarlyRecordLayer(true); clientCtx->setFactory(std::make_shared()); serverCtx = quic::test::createServerCtx(); setupClientAndServerContext(); auto fizzServerContext = FizzServerQuicHandshakeContext::Builder() .setFizzServerContext(serverCtx) .build(); conn.reset(new TestingServerConnectionState(fizzServerContext)); cryptoState = conn->cryptoState.get(); handshake = conn->serverHandshakeLayer; hostname = kTestHostname.str(); verifier = std::make_shared(); uint64_t initialMaxData = kDefaultConnectionFlowControlWindow; uint64_t initialMaxStreamDataBidiLocal = kDefaultStreamFlowControlWindow; uint64_t initialMaxStreamDataBidiRemote = kDefaultStreamFlowControlWindow; uint64_t initialMaxStreamDataUni = kDefaultStreamFlowControlWindow; uint64_t initialMaxStreamsBidi = kDefaultMaxStreamsBidirectional; uint64_t initialMaxStreamsUni = kDefaultMaxStreamsUnidirectional; auto clientExtensions = std::make_shared( getVersion(), initialMaxData, initialMaxStreamDataBidiLocal, initialMaxStreamDataBidiRemote, initialMaxStreamDataUni, initialMaxStreamsBidi, initialMaxStreamsUni, kDefaultIdleTimeout, kDefaultAckDelayExponent, kDefaultUDPSendPacketLen, kDefaultActiveConnectionIdLimit, ConnectionId(std::vector())); fizzClient.reset(new fizz::client::FizzClient< ServerHandshakeTest, fizz::client::ClientStateMachine>( clientState, clientReadBuffer, readAeadOptions, *this, dg.get())); std::vector supportedVersions = {getVersion()}; auto params = std::make_shared( getVersion(), initialMaxData, initialMaxStreamDataBidiLocal, initialMaxStreamDataBidiRemote, initialMaxStreamDataUni, initialMaxStreamsBidi, initialMaxStreamsUni, /*disableMigration=*/true, kDefaultIdleTimeout, kDefaultAckDelayExponent, kDefaultUDPSendPacketLen, generateStatelessResetToken(), ConnectionId(std::vector{0xff, 0xfe, 0xfd, 0xfc}), ConnectionId(std::vector())); initialize(); handshake->accept(params); ON_CALL(serverCallback, onCryptoEventAvailable()) .WillByDefault(Invoke([this]() { VLOG(1) << "onCryptoEventAvailable"; processCryptoEvents(); })); auto cachedPsk = clientCtx->getPsk(hostname); fizzClient->connect( clientCtx, verifier, hostname, cachedPsk, Optional>(none), std::make_shared(clientExtensions)); } void processCryptoEvents() { try { setHandshakeState(); waitForData = false; do { auto writableBytes = getHandshakeWriteBytes(); if (writableBytes->empty()) { break; } VLOG(1) << "server->client bytes=" << writableBytes->computeChainDataLength(); clientReadBuffer.append(std::move(writableBytes)); fizzClient->newTransportData(); } while (!waitForData); } catch (const QuicTransportException& e) { VLOG(1) << "server exception " << e.what(); ex = std::make_exception_ptr(e); } if (!inRoundScope_ && !handshakeCv.ready()) { VLOG(1) << "Posting handshake cv"; handshakeCv.post(); } } void clientServerRound() { SCOPE_EXIT { inRoundScope_ = false; }; inRoundScope_ = true; evb.loop(); try { for (auto& clientWrite : clientWrites) { for (auto& content : clientWrite.contents) { auto encryptionLevel = getEncryptionLevelFromFizz(content.encryptionLevel); handshake->doHandshake(std::move(content.data), encryptionLevel); } } processCryptoEvents(); } catch (const QuicTransportException&) { ex = std::current_exception(); } evb.loopIgnoreKeepAlive(); } void serverClientRound() { SCOPE_EXIT { inRoundScope_ = false; }; inRoundScope_ = true; evb.loop(); waitForData = false; do { auto writableBytes = getHandshakeWriteBytes(); if (writableBytes->empty()) { break; } VLOG(1) << "server->client bytes=" << writableBytes->computeChainDataLength(); clientReadBuffer.append(std::move(writableBytes)); fizzClient->newTransportData(); } while (!waitForData); evb.loop(); } void setHandshakeState() { auto oneRttWriteCipherTmp = handshake->getFirstOneRttWriteCipher(); auto oneRttReadCipherTmp = handshake->getFirstOneRttReadCipher(); auto zeroRttReadCipherTmp = handshake->getZeroRttReadCipher(); auto handshakeWriteCipherTmp = std::move(conn->handshakeWriteCipher); auto handshakeReadCipherTmp = handshake->getHandshakeReadCipher(); if (oneRttWriteCipherTmp) { oneRttWriteCipher = std::move(oneRttWriteCipherTmp); } if (oneRttReadCipherTmp) { oneRttReadCipher = std::move(oneRttReadCipherTmp); } if (zeroRttReadCipherTmp) { zeroRttReadCipher = std::move(zeroRttReadCipherTmp); } if (handshakeReadCipherTmp) { handshakeReadCipher = std::move(handshakeReadCipherTmp); } if (handshakeWriteCipherTmp) { handshakeWriteCipher = std::move(handshakeWriteCipherTmp); } } void expectOneRttReadCipher(bool expected) { EXPECT_EQ(oneRttReadCipher.get() != nullptr, expected); } void expectOneRttWriteCipher(bool expected) { EXPECT_EQ(oneRttWriteCipher.get() != nullptr, expected); } void expectOneRttCipher(bool expected) { expectOneRttWriteCipher(expected); expectOneRttReadCipher(expected); if (expected) { EXPECT_EQ(zeroRttReadCipher.get(), nullptr); } else { EXPECT_EQ(zeroRttReadCipher.get(), nullptr); } } void expectZeroRttCipher(bool expected, bool oneRttRead) { CHECK(expected || !oneRttRead) << "invalid condition supplied"; EXPECT_NE(oneRttWriteCipher.get(), nullptr); if (expected) { if (oneRttRead) { EXPECT_NE(oneRttReadCipher.get(), nullptr); } else { EXPECT_EQ(oneRttReadCipher.get(), nullptr); } EXPECT_NE(zeroRttReadCipher.get(), nullptr); } else { EXPECT_EQ(oneRttReadCipher.get(), nullptr); EXPECT_EQ(zeroRttReadCipher.get(), nullptr); } } Buf getHandshakeWriteBytes() { auto buf = folly::IOBuf::create(0); switch (clientState.readRecordLayer()->getEncryptionLevel()) { case fizz::EncryptionLevel::Plaintext: if (!cryptoState->initialStream.writeBuffer.empty()) { buf->prependChain(cryptoState->initialStream.writeBuffer.move()); } break; case fizz::EncryptionLevel::Handshake: case fizz::EncryptionLevel::EarlyData: if (!cryptoState->handshakeStream.writeBuffer.empty()) { buf->prependChain(cryptoState->handshakeStream.writeBuffer.move()); } break; case fizz::EncryptionLevel::AppTraffic: if (!cryptoState->oneRttStream.writeBuffer.empty()) { buf->prependChain(cryptoState->oneRttStream.writeBuffer.move()); } } return buf; } void operator()(fizz::DeliverAppData&) {} void operator()(fizz::WriteToSocket& write) { clientWrites.push_back(std::move(write)); } void operator()(fizz::client::ReportEarlyHandshakeSuccess&) { earlyHandshakeSuccess = true; } void operator()(fizz::client::ReportHandshakeSuccess&) { handshakeSuccess = true; } void operator()(fizz::client::ReportEarlyWriteFailed&) { earlyWriteFailed = true; } void operator()(fizz::ReportError&) { error = true; } void operator()(fizz::WaitForData&) { waitForData = true; fizzClient->waitForData(); } void operator()(fizz::client::MutateState& mutator) { mutator(clientState); } void operator()(fizz::client::NewCachedPsk& newCachedPsk) { clientCtx->putPsk(hostname, std::move(newCachedPsk.psk)); } void operator()(fizz::SecretAvailable&) {} void operator()(fizz::EndOfData&) {} void operator()(fizz::client::ECHRetryAvailable&) {} class DelayedHolder : public folly::DelayedDestruction {}; std::unique_ptr dg; folly::EventBase evb; std::unique_ptr< TestingServerConnectionState, folly::DelayedDestruction::Destructor> conn{nullptr}; ServerHandshake* handshake; QuicCryptoState* cryptoState; fizz::client::State clientState; std::unique_ptr> fizzClient; folly::IOBufQueue clientReadBuffer{folly::IOBufQueue::cacheChainLength()}; bool earlyHandshakeSuccess{false}; bool handshakeSuccess{false}; bool earlyWriteFailed{false}; bool error{false}; fizz::Aead::AeadOptions readAeadOptions; std::vector clientWrites; MockServerHandshakeCallback serverCallback; std::unique_ptr oneRttWriteCipher; std::unique_ptr oneRttReadCipher; std::unique_ptr zeroRttReadCipher; std::unique_ptr handshakeWriteCipher; std::unique_ptr handshakeReadCipher; std::exception_ptr ex; std::string hostname; std::shared_ptr verifier; std::shared_ptr clientCtx; std::shared_ptr serverCtx; folly::Baton<> handshakeCv; bool inRoundScope_{false}; bool waitForData{false}; }; TEST_F(ServerHandshakeTest, TestGetExportedKeyingMaterial) { // Sanity check. getExportedKeyingMaterial() should return nullptr prior to // an handshake. auto ekm = handshake->getExportedKeyingMaterial("EXPORTER-Some-Label", none, 32); EXPECT_TRUE(!ekm.has_value()); clientServerRound(); serverClientRound(); ekm = handshake->getExportedKeyingMaterial("EXPORTER-Some-Label", none, 32); ASSERT_TRUE(ekm.has_value()); EXPECT_EQ(ekm->size(), 32); ekm = handshake->getExportedKeyingMaterial( "EXPORTER-Some-Label", folly::ByteRange(), 32); ASSERT_TRUE(ekm.has_value()); EXPECT_EQ(ekm->size(), 32); } TEST_F(ServerHandshakeTest, TestHandshakeSuccess) { clientServerRound(); EXPECT_EQ(handshake->getPhase(), ServerHandshake::Phase::Handshake); serverClientRound(); clientServerRound(); EXPECT_EQ(handshake->getPhase(), ServerHandshake::Phase::Established); if (ex) { std::rethrow_exception(ex); } expectOneRttCipher(true); EXPECT_EQ(handshake->getApplicationProtocol(), "quic_test"); EXPECT_TRUE(handshakeSuccess); } TEST_F(ServerHandshakeTest, TestHandshakeSuccessIgnoreNonHandshake) { fizz::WriteToSocket write; fizz::TLSContent content; content.contentType = fizz::ContentType::alert; content.data = folly::IOBuf::copyBuffer(folly::unhexlify("01000000")); content.encryptionLevel = fizz::EncryptionLevel::Plaintext; write.contents.push_back(std::move(content)); clientWrites.push_back(std::move(write)); clientServerRound(); EXPECT_EQ(handshake->getPhase(), ServerHandshake::Phase::Handshake); serverClientRound(); clientServerRound(); EXPECT_EQ(handshake->getPhase(), ServerHandshake::Phase::Established); if (ex) { std::rethrow_exception(ex); } expectOneRttCipher(true); EXPECT_EQ(handshake->getApplicationProtocol(), "quic_test"); EXPECT_TRUE(handshakeSuccess); } TEST_F(ServerHandshakeTest, TestMalformedHandshakeMessage) { fizz::WriteToSocket write; fizz::TLSContent content; content.contentType = fizz::ContentType::handshake; content.data = folly::IOBuf::copyBuffer(folly::unhexlify("01000000")); content.encryptionLevel = fizz::EncryptionLevel::Plaintext; write.contents.push_back(std::move(content)); clientWrites.clear(); clientWrites.push_back(std::move(write)); clientServerRound(); EXPECT_TRUE(ex); } class AsyncRejectingTicketCipher : public fizz::server::TicketCipher { public: ~AsyncRejectingTicketCipher() override = default; folly::SemiFuture< Optional, std::chrono::seconds>>> encrypt(fizz::server::ResumptionState) const override { if (!encryptAsync_) { return std::make_pair(folly::IOBuf::create(0), 2s); } else { encryptAsync_ = false; return std::move(encryptFuture_).deferValue([](auto&&) { VLOG(1) << "got ticket async"; return folly::makeSemiFuture, std::chrono::seconds>>>( std::make_pair(folly::IOBuf::create(0), 2s)); }); } } void setDecryptAsync(bool async, folly::SemiFuture future) { decryptAsync_ = async; decryptFuture_ = std::move(future); } void setEncryptAsync(bool async, folly::SemiFuture future) { encryptAsync_ = async; encryptFuture_ = std::move(future); } void setDecryptError(bool error) { error_ = error; } folly::SemiFuture< std::pair>> decrypt(std::unique_ptr) const override { if (!decryptAsync_) { if (error_) { throw std::runtime_error("test decrypt error"); } return std::make_pair(fizz::PskType::Rejected, none); } else { decryptAsync_ = false; return std::move(decryptFuture_).deferValue([&](auto&&) { VLOG(1) << "triggered reject"; if (error_) { throw std::runtime_error("test decrypt error"); } return folly::makeSemiFuture< std::pair>>( std::make_pair(fizz::PskType::Rejected, none)); }); } } private: mutable folly::SemiFuture decryptFuture_; mutable folly::SemiFuture encryptFuture_; mutable bool decryptAsync_{true}; mutable bool encryptAsync_{false}; bool error_{false}; }; class ServerHandshakeWriteNSTTest : public ServerHandshakeTest { public: void setupClientAndServerContext() override { serverCtx->setSendNewSessionTicket(false); ticketCipher_ = std::make_shared(); ticketCipher_->setDefaults(); serverCtx->setTicketCipher(ticketCipher_); cache_ = std::make_shared(); clientCtx->setPskCache(cache_); clientCtx->setSupportedAlpns({"h3"}); serverCtx->setSupportedAlpns({"h3", "hq"}); } protected: std::shared_ptr cache_; std::shared_ptr ticketCipher_; }; TEST_F(ServerHandshakeWriteNSTTest, TestWriteNST) { clientServerRound(); EXPECT_EQ(handshake->getPhase(), ServerHandshake::Phase::Handshake); serverClientRound(); clientServerRound(); EXPECT_EQ(handshake->getPhase(), ServerHandshake::Phase::Established); expectOneRttCipher(true); AppToken appToken; EXPECT_FALSE(cache_->getPsk(kTestHostname.str())); EXPECT_CALL(*ticketCipher_, _encrypt(_)) .WillOnce(Invoke([&appToken](fizz::server::ResumptionState& resState) { EXPECT_TRUE( folly::IOBufEqualTo()(resState.appToken, encodeAppToken(appToken))); return std::make_pair(folly::IOBuf::copyBuffer("appToken"), 100s); })); handshake->writeNewSessionTicket(appToken); processCryptoEvents(); evb.loop(); EXPECT_TRUE(cache_->getPsk(kTestHostname.str())); } class ServerHandshakePskTest : public ServerHandshakeTest { public: ~ServerHandshakePskTest() override = default; void SetUp() override { cache = std::make_shared(); psk.psk = std::string("psk"); psk.secret = std::string("secret"); psk.type = fizz::PskType::Resumption; psk.version = fizz::ProtocolVersion::tls_1_3; psk.cipher = fizz::CipherSuite::TLS_AES_128_GCM_SHA256; psk.group = fizz::NamedGroup::x25519; psk.serverCert = std::make_shared(); psk.alpn = std::string("h3"); psk.ticketAgeAdd = 1; psk.ticketIssueTime = std::chrono::system_clock::time_point(); psk.ticketExpirationTime = std::chrono::system_clock::time_point(std::chrono::seconds(20)); psk.ticketHandshakeTime = std::chrono::system_clock::time_point(); psk.maxEarlyDataSize = 2; ServerHandshakeTest::SetUp(); } void setupClientAndServerContext() override { cache->putPsk(kTestHostname.str(), psk); ticketCipher = makeTicketCipher(); serverCtx->setTicketCipher(ticketCipher); clientCtx->setPskCache(cache); clientCtx->setSupportedAlpns({"h3"}); serverCtx->setSupportedAlpns({"h3", "hq"}); } virtual std::shared_ptr makeTicketCipher() = 0; std::shared_ptr cache; folly::Promise promise; std::shared_ptr ticketCipher; fizz::client::CachedPsk psk; }; class ServerHandshakeHRRTest : public ServerHandshakePskTest { public: ~ServerHandshakeHRRTest() override = default; void setupClientAndServerContext() override { // Make a group mismatch happen. psk.group = fizz::NamedGroup::secp256r1; clientCtx->setSupportedGroups( {fizz::NamedGroup::secp256r1, fizz::NamedGroup::x25519}); clientCtx->setSupportedVersions({fizz::ProtocolVersion::tls_1_3}); clientCtx->setDefaultShares({fizz::NamedGroup::secp256r1}); serverCtx->setSupportedGroups({fizz::NamedGroup::x25519}); serverCtx->setSupportedVersions({fizz::ProtocolVersion::tls_1_3}); clientCtx->setSupportedAlpns({"h3"}); serverCtx->setSupportedAlpns({"h3", "hq"}); ServerHandshakePskTest::setupClientAndServerContext(); } std::shared_ptr makeTicketCipher() override { auto cipher = std::make_shared(); cipher->setDecryptAsync(true, promise.getFuture()); return cipher; } }; TEST_F(ServerHandshakeHRRTest, TestHRR) { auto rejectingCipher = dynamic_cast(ticketCipher.get()); rejectingCipher->setDecryptAsync(false, folly::makeFuture()); clientServerRound(); EXPECT_EQ(handshake->getPhase(), ServerHandshake::Phase::Handshake); serverClientRound(); clientServerRound(); EXPECT_EQ(handshake->getPhase(), ServerHandshake::Phase::Handshake); expectOneRttReadCipher(false); expectOneRttWriteCipher(true); serverClientRound(); clientServerRound(); EXPECT_EQ(handshake->getPhase(), ServerHandshake::Phase::Established); EXPECT_EQ(handshake->getApplicationProtocol(), "h3"); expectOneRttCipher(true); } TEST_F(ServerHandshakeHRRTest, TestAsyncHRR) { // Make an async ticket decryption operation. clientServerRound(); promise.setValue(); evb.loop(); expectOneRttCipher(false); handshakeCv.wait(); handshakeCv.reset(); clientServerRound(); serverClientRound(); EXPECT_EQ(handshake->getPhase(), ServerHandshake::Phase::Handshake); clientServerRound(); EXPECT_EQ(handshake->getPhase(), ServerHandshake::Phase::Established); EXPECT_EQ(handshake->getApplicationProtocol(), "h3"); expectOneRttCipher(true); } TEST_F(ServerHandshakeHRRTest, TestAsyncCancel) { // Make an async ticket decryption operation. clientServerRound(); handshake->cancel(); // Let's destroy the crypto state to make sure it is not referenced. conn->cryptoState.reset(); promise.setValue(); evb.loop(); EXPECT_EQ(handshake->getApplicationProtocol(), none); expectOneRttCipher(false); } class ServerHandshakeAsyncTest : public ServerHandshakePskTest { public: ~ServerHandshakeAsyncTest() override = default; std::shared_ptr makeTicketCipher() override { auto cipher = std::make_shared(); cipher->setDecryptAsync(false, folly::makeFuture()); cipher->setEncryptAsync(true, promise.getFuture()); return cipher; } }; TEST_F(ServerHandshakeAsyncTest, TestAsyncCancel) { // Make an async ticket decryption operation. clientServerRound(); serverClientRound(); clientServerRound(); handshake->cancel(); // Let's destroy the crypto state to make sure it is not referenced. conn->cryptoState.reset(); promise.setValue(); evb.loop(); EXPECT_EQ(conn->getDestructorGuardCount(), 0); } class ServerHandshakeAsyncErrorTest : public ServerHandshakePskTest { public: ~ServerHandshakeAsyncErrorTest() override = default; std::shared_ptr makeTicketCipher() override { auto cipher = std::make_shared(); cipher->setDecryptAsync(true, promise.getFuture()); cipher->setDecryptError(true); return cipher; } }; TEST_F(ServerHandshakeAsyncErrorTest, TestAsyncError) { clientServerRound(); bool error = false; EXPECT_CALL(serverCallback, onCryptoEventAvailable()) .WillRepeatedly(Invoke([&] { try { handshake->getFirstOneRttReadCipher(); } catch (std::exception&) { error = true; } })); promise.setValue(); evb.loop(); EXPECT_TRUE(error); } TEST_F(ServerHandshakeAsyncErrorTest, TestCancelOnAsyncError) { clientServerRound(); EXPECT_CALL(serverCallback, onCryptoEventAvailable()) .WillRepeatedly(Invoke([&] { handshake->cancel(); // Let's destroy the crypto state to make sure it is not referenced. conn->cryptoState.reset(); })); promise.setValue(); evb.loop(); EXPECT_THROW(handshake->getFirstOneRttReadCipher(), std::runtime_error); } TEST_F(ServerHandshakeAsyncErrorTest, TestCancelWhileWaitingAsyncError) { clientServerRound(); handshake->cancel(); // Let's destroy the crypto state to make sure it is not referenced. conn->cryptoState.reset(); promise.setValue(); evb.loop(); EXPECT_THROW(handshake->getFirstOneRttReadCipher(), std::runtime_error); } class ServerHandshakeSyncErrorTest : public ServerHandshakePskTest { public: ~ServerHandshakeSyncErrorTest() override = default; std::shared_ptr makeTicketCipher() override { auto cipher = std::make_shared(); cipher->setDecryptError(true); cipher->setDecryptAsync(false, folly::makeFuture()); return cipher; } }; TEST_F(ServerHandshakeSyncErrorTest, TestError) { // Make an async ticket decryption operation. clientServerRound(); evb.loop(); EXPECT_THROW(handshake->getFirstOneRttReadCipher(), std::runtime_error); } class ServerHandshakeZeroRttDefaultAppTokenValidatorTest : public ServerHandshakePskTest { public: ~ServerHandshakeZeroRttDefaultAppTokenValidatorTest() override = default; /** * This cipher can currently resume only 1 connection. */ class AcceptingTicketCipher : public fizz::server::TicketCipher { public: ~AcceptingTicketCipher() override = default; folly::SemiFuture, std::chrono::seconds>>> encrypt(fizz::server::ResumptionState) const override { // Fake handshake, no need todo anything here. return std::make_pair(folly::IOBuf::create(0), 2s); } void setPsk(fizz::client::CachedPsk psk) { resState.version = psk.version; resState.cipher = psk.cipher; resState.resumptionSecret = folly::IOBuf::copyBuffer(psk.secret); resState.alpn = psk.alpn; resState.ticketIssueTime = std::chrono::system_clock::time_point(); resState.handshakeTime = std::chrono::system_clock::time_point(); resState.serverCert = psk.serverCert; } folly::SemiFuture< std::pair>> decrypt(std::unique_ptr) const override { return std::make_pair(fizz::PskType::Resumption, std::move(resState)); } private: mutable fizz::server::ResumptionState resState; }; void setupClientAndServerContext() override { clientCtx->setSendEarlyData(true); serverCtx->setEarlyDataSettings( true, fizz::server::ClockSkewTolerance{-1000ms, 1000ms}, std::make_shared()); clientCtx->setSupportedVersions({fizz::ProtocolVersion::tls_1_3}); serverCtx->setSupportedVersions({fizz::ProtocolVersion::tls_1_3}); clientCtx->setSupportedAlpns({"h3"}); serverCtx->setSupportedAlpns({"h3", "hq"}); ServerHandshakePskTest::setupClientAndServerContext(); } std::shared_ptr makeTicketCipher() override { auto cipher = std::make_shared(); cipher->setPsk(psk); return cipher; } }; TEST_F( ServerHandshakeZeroRttDefaultAppTokenValidatorTest, TestDefaultAppTokenValidatorRejectZeroRtt) { clientServerRound(); EXPECT_EQ(handshake->getPhase(), ServerHandshake::Phase::Handshake); expectZeroRttCipher(false, false); serverClientRound(); clientServerRound(); EXPECT_EQ(handshake->getPhase(), ServerHandshake::Phase::Established); expectOneRttCipher(true); } class ServerHandshakeZeroRttTest : public ServerHandshakeZeroRttDefaultAppTokenValidatorTest { void initialize() override { auto validator = std::make_unique(); validator_ = validator.get(); handshake->initialize(&evb, &serverCallback, std::move(validator)); } protected: fizz::server::test::MockAppTokenValidator* validator_; }; TEST_F(ServerHandshakeZeroRttTest, TestResumption) { EXPECT_CALL(*validator_, validate(_)).WillOnce(Return(true)); clientServerRound(); EXPECT_EQ(handshake->getPhase(), ServerHandshake::Phase::KeysDerived); expectZeroRttCipher(true, false); serverClientRound(); clientServerRound(); EXPECT_EQ(handshake->getPhase(), ServerHandshake::Phase::Established); expectZeroRttCipher(true, true); } TEST_F(ServerHandshakeZeroRttTest, TestRejectZeroRttNotEnabled) { auto realServerCtx = dynamic_cast(handshake)->getContext(); auto nonConstServerCtx = const_cast(realServerCtx); nonConstServerCtx->setEarlyDataSettings( false, fizz::server::ClockSkewTolerance(), nullptr); EXPECT_CALL(*validator_, validate(_)).Times(0); clientServerRound(); EXPECT_EQ(handshake->getPhase(), ServerHandshake::Phase::Handshake); expectZeroRttCipher(false, false); serverClientRound(); clientServerRound(); EXPECT_EQ(handshake->getPhase(), ServerHandshake::Phase::Established); expectOneRttCipher(true); } TEST_F(ServerHandshakeZeroRttTest, TestRejectZeroRttInvalidToken) { EXPECT_CALL(*validator_, validate(_)).WillOnce(Return(false)); clientServerRound(); EXPECT_EQ(handshake->getPhase(), ServerHandshake::Phase::Handshake); expectZeroRttCipher(false, false); serverClientRound(); clientServerRound(); EXPECT_EQ(handshake->getPhase(), ServerHandshake::Phase::Established); expectOneRttCipher(true); } } // namespace test } // namespace quic