diff --git a/quic/api/QuicTransportBase.cpp b/quic/api/QuicTransportBase.cpp index 7195ef570..b5e9bab25 100644 --- a/quic/api/QuicTransportBase.cpp +++ b/quic/api/QuicTransportBase.cpp @@ -2265,8 +2265,8 @@ void QuicTransportBase::setEarlyDataAppParamsFunctions( folly::Function&, const Buf&) const> validator, folly::Function getter) { - earlyDataAppParamsValidator_ = std::move(validator); - earlyDataAppParamsGetter_ = std::move(getter); + conn_->earlyDataAppParamsValidator = std::move(validator); + conn_->earlyDataAppParamsGetter = std::move(getter); } void QuicTransportBase::cancelAllAppCallbacks( diff --git a/quic/api/QuicTransportBase.h b/quic/api/QuicTransportBase.h index 8f7d13a7d..1b355b542 100644 --- a/quic/api/QuicTransportBase.h +++ b/quic/api/QuicTransportBase.h @@ -627,11 +627,6 @@ class QuicTransportBase : public QuicSocket { folly::SocketAddress localFallbackAddress; // CongestionController factory std::shared_ptr ccFactory_{nullptr}; - - folly::Function&, const Buf&) const> - earlyDataAppParamsValidator_; - - folly::Function earlyDataAppParamsGetter_; }; std::ostream& operator<<(std::ostream& os, const QuicTransportBase& qt); diff --git a/quic/client/QuicClientTransport.cpp b/quic/client/QuicClientTransport.cpp index 061ab8c1f..c8767cb59 100644 --- a/quic/client/QuicClientTransport.cpp +++ b/quic/client/QuicClientTransport.cpp @@ -788,8 +788,8 @@ folly::Optional QuicClientTransport::getPsk() { if (!conn_->transportSettings.attemptEarlyData) { quicCachedPsk->cachedPsk.maxEarlyDataSize = 0; } else if ( - earlyDataAppParamsValidator_ && - !earlyDataAppParamsValidator_( + conn_->earlyDataAppParamsValidator && + !conn_->earlyDataAppParamsValidator( quicCachedPsk->cachedPsk.alpn, folly::IOBuf::copyBuffer(quicCachedPsk->appParams))) { quicCachedPsk->cachedPsk.maxEarlyDataSize = 0; @@ -895,8 +895,8 @@ void QuicClientTransport::onNewCachedPsk( quicCachedPsk.transportParams = getServerCachedTransportParameters(*clientConn_); - if (earlyDataAppParamsGetter_) { - auto appParams = earlyDataAppParamsGetter_(); + if (conn_->earlyDataAppParamsGetter) { + auto appParams = conn_->earlyDataAppParamsGetter(); if (appParams) { quicCachedPsk.appParams = appParams->moveToFbString().toStdString(); } diff --git a/quic/client/state/ClientStateMachine.cpp b/quic/client/state/ClientStateMachine.cpp index 07a40dbdc..cc1368814 100644 --- a/quic/client/state/ClientStateMachine.cpp +++ b/quic/client/state/ClientStateMachine.cpp @@ -49,6 +49,9 @@ std::unique_ptr undoAllClientStateForRetry( newConn->readCodec->setClientConnectionId(*conn->clientConnectionId); newConn->readCodec->setCodecParameters(CodecParameters( conn->peerAckDelayExponent, conn->originalVersion.value())); + newConn->earlyDataAppParamsValidator = + std::move(conn->earlyDataAppParamsValidator); + newConn->earlyDataAppParamsGetter = std::move(conn->earlyDataAppParamsGetter); return newConn; } diff --git a/quic/server/QuicServerTransport.cpp b/quic/server/QuicServerTransport.cpp index d01b29d7c..34705cd47 100644 --- a/quic/server/QuicServerTransport.cpp +++ b/quic/server/QuicServerTransport.cpp @@ -140,8 +140,7 @@ void QuicServerTransport::accept() { evb_, ctx_, this, - std::make_unique( - serverConn_, std::move(earlyDataAppParamsValidator_))); + std::make_unique(serverConn_)); } void QuicServerTransport::writeData() { @@ -414,8 +413,8 @@ void QuicServerTransport::maybeWriteNewSessionTicket() { if (appToken.sourceAddresses.empty()) { appToken.sourceAddresses.push_back(conn_->peerAddress.getIPAddress()); } - if (earlyDataAppParamsGetter_) { - appToken.appParams = earlyDataAppParamsGetter_(); + if (conn_->earlyDataAppParamsGetter) { + appToken.appParams = conn_->earlyDataAppParamsGetter(); } serverConn_->serverHandshakeLayer->writeNewSessionTicket(appToken); } diff --git a/quic/server/handshake/DefaultAppTokenValidator.cpp b/quic/server/handshake/DefaultAppTokenValidator.cpp index da31d453e..5ad460711 100644 --- a/quic/server/handshake/DefaultAppTokenValidator.cpp +++ b/quic/server/handshake/DefaultAppTokenValidator.cpp @@ -28,13 +28,8 @@ namespace quic { DefaultAppTokenValidator::DefaultAppTokenValidator( - QuicServerConnectionState* conn, - folly::Function& alpn, - const std::unique_ptr& appParams) const> - earlyDataAppParamsValidator) - : conn_(conn), - earlyDataAppParamsValidator_(std::move(earlyDataAppParamsValidator)) {} + QuicServerConnectionState* conn) + : conn_(conn) {} bool DefaultAppTokenValidator::validate( const fizz::server::ResumptionState& resumptionState) const { @@ -139,8 +134,8 @@ bool DefaultAppTokenValidator::validate( // If application has set validator and the token is invalid, reject 0-RTT. // If application did not set validator, it's valid. - if (earlyDataAppParamsValidator_ && - !earlyDataAppParamsValidator_( + if (conn_->earlyDataAppParamsValidator && + !conn_->earlyDataAppParamsValidator( resumptionState.alpn, appToken->appParams)) { VLOG(10) << "Invalid app params"; return false; diff --git a/quic/server/handshake/DefaultAppTokenValidator.h b/quic/server/handshake/DefaultAppTokenValidator.h index 6b92231cd..86eb39c64 100644 --- a/quic/server/handshake/DefaultAppTokenValidator.h +++ b/quic/server/handshake/DefaultAppTokenValidator.h @@ -28,21 +28,12 @@ struct QuicServerConnectionState; class DefaultAppTokenValidator : public fizz::server::AppTokenValidator { public: - explicit DefaultAppTokenValidator( - QuicServerConnectionState* conn, - folly::Function& alpn, - const std::unique_ptr& appParams) const> - earlyDataAppParamsValidator); + explicit DefaultAppTokenValidator(QuicServerConnectionState* conn); bool validate(const fizz::server::ResumptionState&) const override; private: QuicServerConnectionState* conn_; - folly::Function& alpn, - const std::unique_ptr& appParams) const> - earlyDataAppParamsValidator_; }; } // namespace quic diff --git a/quic/server/handshake/test/DefaultAppTokenValidatorTest.cpp b/quic/server/handshake/test/DefaultAppTokenValidatorTest.cpp index 13266e72f..016dc6bbe 100644 --- a/quic/server/handshake/test/DefaultAppTokenValidatorTest.cpp +++ b/quic/server/handshake/test/DefaultAppTokenValidatorTest.cpp @@ -42,9 +42,9 @@ TEST(DefaultAppTokenValidatorTest, TestValidParams) { ResumptionState resState; resState.appToken = encodeAppToken(appToken); - auto appParamsValidator = [](const folly::Optional&, - const Buf&) { return true; }; - DefaultAppTokenValidator validator(&conn, std::move(appParamsValidator)); + conn.earlyDataAppParamsValidator = [](const folly::Optional&, + const Buf&) { return true; }; + DefaultAppTokenValidator validator(&conn); EXPECT_TRUE(validator.validate(resState)); } @@ -70,9 +70,9 @@ TEST( ResumptionState resState; resState.appToken = encodeAppToken(appToken); - auto appParamsValidator = [](const folly::Optional&, - const Buf&) { return true; }; - DefaultAppTokenValidator validator(&conn, std::move(appParamsValidator)); + conn.earlyDataAppParamsValidator = [](const folly::Optional&, + const Buf&) { return true; }; + DefaultAppTokenValidator validator(&conn); EXPECT_TRUE(validator.validate(resState)); EXPECT_EQ( @@ -88,12 +88,12 @@ TEST(DefaultAppTokenValidatorTest, TestInvalidNullAppToken) { conn.version = QuicVersion::MVFST; ResumptionState resState; - auto appParamsValidator = [](const folly::Optional&, - const Buf&) { + conn.earlyDataAppParamsValidator = [](const folly::Optional&, + const Buf&) { EXPECT_TRUE(false); return true; }; - DefaultAppTokenValidator validator(&conn, std::move(appParamsValidator)); + DefaultAppTokenValidator validator(&conn); EXPECT_FALSE(validator.validate(resState)); } @@ -106,12 +106,12 @@ TEST(DefaultAppTokenValidatorTest, TestInvalidEmptyTransportParams) { ResumptionState resState; resState.appToken = encodeAppToken(appToken); - auto appParamsValidator = [](const folly::Optional&, - const Buf&) { + conn.earlyDataAppParamsValidator = [](const folly::Optional&, + const Buf&) { EXPECT_TRUE(false); return true; }; - DefaultAppTokenValidator validator(&conn, std::move(appParamsValidator)); + DefaultAppTokenValidator validator(&conn); EXPECT_FALSE(validator.validate(resState)); } @@ -141,12 +141,12 @@ TEST(DefaultAppTokenValidatorTest, TestInvalidMissingParams) { ResumptionState resState; resState.appToken = encodeAppToken(appToken); - auto appParamsValidator = [](const folly::Optional&, - const Buf&) { + conn.earlyDataAppParamsValidator = [](const folly::Optional&, + const Buf&) { EXPECT_TRUE(false); return true; }; - DefaultAppTokenValidator validator(&conn, std::move(appParamsValidator)); + DefaultAppTokenValidator validator(&conn); EXPECT_FALSE(validator.validate(resState)); } @@ -170,12 +170,12 @@ TEST(DefaultAppTokenValidatorTest, TestInvalidRedundantParameter) { ResumptionState resState; resState.appToken = encodeAppToken(appToken); - auto appParamsValidator = [](const folly::Optional&, - const Buf&) { + conn.earlyDataAppParamsValidator = [](const folly::Optional&, + const Buf&) { EXPECT_TRUE(false); return true; }; - DefaultAppTokenValidator validator(&conn, std::move(appParamsValidator)); + DefaultAppTokenValidator validator(&conn); EXPECT_FALSE(validator.validate(resState)); } @@ -197,12 +197,12 @@ TEST(DefaultAppTokenValidatorTest, TestInvalidDecreasedInitialMaxStreamData) { ResumptionState resState; resState.appToken = encodeAppToken(appToken); - auto appParamsValidator = [](const folly::Optional&, - const Buf&) { + conn.earlyDataAppParamsValidator = [](const folly::Optional&, + const Buf&) { EXPECT_TRUE(false); return true; }; - DefaultAppTokenValidator validator(&conn, std::move(appParamsValidator)); + DefaultAppTokenValidator validator(&conn); EXPECT_FALSE(validator.validate(resState)); } @@ -224,12 +224,12 @@ TEST(DefaultAppTokenValidatorTest, TestChangedIdleTimeout) { ResumptionState resState; resState.appToken = encodeAppToken(appToken); - auto appParamsValidator = [](const folly::Optional&, - const Buf&) { + conn.earlyDataAppParamsValidator = [](const folly::Optional&, + const Buf&) { EXPECT_TRUE(false); return true; }; - DefaultAppTokenValidator validator(&conn, std::move(appParamsValidator)); + DefaultAppTokenValidator validator(&conn); EXPECT_FALSE(validator.validate(resState)); } @@ -251,12 +251,12 @@ TEST(DefaultAppTokenValidatorTest, TestDecreasedInitialMaxStreams) { ResumptionState resState; resState.appToken = encodeAppToken(appToken); - auto appParamsValidator = [](const folly::Optional&, - const Buf&) { + conn.earlyDataAppParamsValidator = [](const folly::Optional&, + const Buf&) { EXPECT_TRUE(false); return true; }; - DefaultAppTokenValidator validator(&conn, std::move(appParamsValidator)); + DefaultAppTokenValidator validator(&conn); EXPECT_FALSE(validator.validate(resState)); } @@ -280,9 +280,9 @@ TEST(DefaultAppTokenValidatorTest, TestInvalidAppParams) { ResumptionState resState; resState.appToken = encodeAppToken(appToken); - auto appParamsValidator = [](const folly::Optional&, - const Buf&) { return false; }; - DefaultAppTokenValidator validator(&conn, std::move(appParamsValidator)); + conn.earlyDataAppParamsValidator = [](const folly::Optional&, + const Buf&) { return false; }; + DefaultAppTokenValidator validator(&conn); EXPECT_FALSE(validator.validate(resState)); } @@ -307,9 +307,11 @@ class SourceAddressTokenTest : public Test { ResumptionState resState; resState.appToken = encodeAppToken(appToken_); - auto appParamsValidator = [=](const folly::Optional&, - const Buf&) { return acceptZeroRtt; }; - DefaultAppTokenValidator validator(&conn_, std::move(appParamsValidator)); + conn_.earlyDataAppParamsValidator = [=](const folly::Optional&, + const Buf&) { + return acceptZeroRtt; + }; + DefaultAppTokenValidator validator(&conn_); EXPECT_EQ(validator.validate(resState), acceptZeroRtt); } diff --git a/quic/state/StateData.h b/quic/state/StateData.h index 982708941..f49014373 100644 --- a/quic/state/StateData.h +++ b/quic/state/StateData.h @@ -810,6 +810,13 @@ struct QuicConnectionStateBase : public folly::DelayedDestruction { // Use this measured rtt as init rtt (from Transport Settings) TimePoint pathChallengeStartTime; + /** + * Eary data app params functions. + */ + folly::Function&, const Buf&) const> + earlyDataAppParamsValidator; + folly::Function earlyDataAppParamsGetter; + /** * Selects a previously unused peer-issued connection id to use. * If there are no available ids return false and don't change anything.