diff --git a/quic/api/QuicTransportFunctions.cpp b/quic/api/QuicTransportFunctions.cpp index 51c99bf8c..45f69ce91 100644 --- a/quic/api/QuicTransportFunctions.cpp +++ b/quic/api/QuicTransportFunctions.cpp @@ -92,6 +92,36 @@ bool toWriteAppDataAcks(const quic::QuicConnectionStateBase& conn) { using namespace quic; +/** + * This function returns the number of write bytes that are available until we + * reach the writableBytesLimit. It may or may not be the limiting factor on the + * number of bytes we can write on the wire. + * + * If the client's address has not been verified, this will return the number of + * write bytes available until writableBytesLimit is reached. + * + * Otherwise if the client's address is validated, it will return unlimited + * number of bytes to write. + */ +uint64_t maybeUnvalidatedClientWritableBytes( + quic::QuicConnectionStateBase& conn) { + if (!conn.writableBytesLimit) { + return unlimitedWritableBytes(conn); + } + + if (*conn.writableBytesLimit <= conn.lossState.totalBytesSent) { + QUIC_STATS(conn.statsCallback, onConnectionWritableBytesLimited); + return 0; + } + + uint64_t writableBytes = + *conn.writableBytesLimit - conn.lossState.totalBytesSent; + + // round the result up to the nearest multiple of udpSendPacketLen. + return (writableBytes + conn.udpSendPacketLen - 1) / conn.udpSendPacketLen * + conn.udpSendPacketLen; +} + WriteQuicDataResult writeQuicDataToSocketImpl( folly::AsyncUDPSocket& sock, QuicConnectionStateBase& connection, @@ -899,7 +929,15 @@ void updateConnection( } } -uint64_t congestionControlWritableBytes(const QuicConnectionStateBase& conn) { +uint64_t probePacketWritableBytes(QuicConnectionStateBase& conn) { + uint64_t probeWritableBytes = maybeUnvalidatedClientWritableBytes(conn); + if (!probeWritableBytes) { + conn.numProbesWritableBytesLimited++; + } + return probeWritableBytes; +} + +uint64_t congestionControlWritableBytes(QuicConnectionStateBase& conn) { uint64_t writableBytes = std::numeric_limits::max(); if (conn.pendingEvents.pathChallenge || conn.outstandingPathValidation) { @@ -914,11 +952,7 @@ uint64_t congestionControlWritableBytes(const QuicConnectionStateBase& conn) { std::chrono::steady_clock::now(), conn.lossState.srtt == 0us ? kDefaultInitialRtt : conn.lossState.srtt); } else if (conn.writableBytesLimit) { - if (*conn.writableBytesLimit <= conn.lossState.totalBytesSent) { - QUIC_STATS(conn.statsCallback, onConnectionWritableBytesLimited); - return 0; - } - writableBytes = *conn.writableBytesLimit - conn.lossState.totalBytesSent; + writableBytes = maybeUnvalidatedClientWritableBytes(conn); } if (conn.congestionController) { @@ -936,7 +970,7 @@ uint64_t congestionControlWritableBytes(const QuicConnectionStateBase& conn) { conn.udpSendPacketLen; } -uint64_t unlimitedWritableBytes(const QuicConnectionStateBase&) { +uint64_t unlimitedWritableBytes(QuicConnectionStateBase&) { return std::numeric_limits::max(); } @@ -1461,6 +1495,7 @@ uint64_t writeProbingDataToSocket( CloningScheduler cloningScheduler( scheduler, connection, "CloningScheduler", aead.getCipherOverhead()); auto writeLoopBeginTime = Clock::now(); + auto written = writeConnectionDataToSocket( sock, connection, @@ -1469,7 +1504,9 @@ uint64_t writeProbingDataToSocket( builder, pnSpace, cloningScheduler, - unlimitedWritableBytes, + connection.transportSettings.enableWritableBytesLimit + ? probePacketWritableBytes + : unlimitedWritableBytes, probesToSend, aead, headerCipher, @@ -1493,7 +1530,9 @@ uint64_t writeProbingDataToSocket( builder, pnSpace, pingScheduler, - unlimitedWritableBytes, + connection.transportSettings.enableWritableBytesLimit + ? probePacketWritableBytes + : unlimitedWritableBytes, probesToSend - written, aead, headerCipher, @@ -1552,7 +1591,7 @@ uint64_t writeD6DProbeToSocket( return written; } -WriteDataReason shouldWriteData(const QuicConnectionStateBase& conn) { +WriteDataReason shouldWriteData(/*const*/ QuicConnectionStateBase& conn) { auto& numProbePackets = conn.pendingEvents.numProbePackets; bool shouldWriteInitialProbes = numProbePackets[PacketNumberSpace::Initial] && conn.initialWriteCipher; diff --git a/quic/api/QuicTransportFunctions.h b/quic/api/QuicTransportFunctions.h index d14cdd008..8cb322b87 100644 --- a/quic/api/QuicTransportFunctions.h +++ b/quic/api/QuicTransportFunctions.h @@ -75,7 +75,7 @@ using HeaderBuilder = std::function; using WritableBytesFunc = - std::function; + std::function; // Encapsulating the return value for the write functions. // Useful because probes can go over the packet limit. @@ -150,7 +150,7 @@ uint64_t writeZeroRttDataToSocket( * Whether we should and can write data. * */ -WriteDataReason shouldWriteData(const QuicConnectionStateBase& conn); +WriteDataReason shouldWriteData(QuicConnectionStateBase& conn); bool hasAckDataToWrite(const QuicConnectionStateBase& conn); WriteDataReason hasNonAckDataToWrite(const QuicConnectionStateBase& conn); @@ -207,13 +207,21 @@ void updateConnection( uint32_t encodedBodySize, bool isDSRPacket); +/** + * Returns the number of writable bytes available for constructing a PTO packet. + * This will either return std::numeric_limits::max() or the number + * of bytes until the writableBytesLimit is reached – depending on whether the + * client's address has been validated. + */ +uint64_t probePacketWritableBytes(QuicConnectionStateBase& conn); + /** * Returns the minimum available bytes window out of path validation rate * limiting, 0-rtt total bytes sent limiting, and the congestion controller. */ -uint64_t congestionControlWritableBytes(const QuicConnectionStateBase& conn); +uint64_t congestionControlWritableBytes(QuicConnectionStateBase& conn); -uint64_t unlimitedWritableBytes(const QuicConnectionStateBase&); +uint64_t unlimitedWritableBytes(QuicConnectionStateBase&); void writeCloseCommon( folly::AsyncUDPSocket& sock, diff --git a/quic/api/test/QuicTransportFunctionsTest.cpp b/quic/api/test/QuicTransportFunctionsTest.cpp index f1d6d6f05..65dd54f29 100644 --- a/quic/api/test/QuicTransportFunctionsTest.cpp +++ b/quic/api/test/QuicTransportFunctionsTest.cpp @@ -2970,8 +2970,10 @@ TEST_F(QuicTransportFunctionsTest, WriteProbingNewData) { auto currentPacketSeqNum = conn->ackStates.appDataAckState.nextPacketNum; auto mockCongestionController = std::make_unique>(); + // Probing data is not limited by congestion control, this should not affect + // anything EXPECT_CALL(*mockCongestionController, getWritableBytes()) - .WillRepeatedly(Return(2000)); + .WillRepeatedly(Return(0)); auto rawCongestionController = mockCongestionController.get(); conn->congestionController = std::move(mockCongestionController); EventBase evb; @@ -3061,8 +3063,10 @@ TEST_F(QuicTransportFunctionsTest, WriteProbingCryptoData) { // Replace real congestionController with MockCongestionController: auto mockCongestionController = std::make_unique>(); + // Probing data is not limited by congestion control, this should not affect + // anything EXPECT_CALL(*mockCongestionController, getWritableBytes()) - .WillRepeatedly(Return(2000)); + .WillRepeatedly(Return(0)); auto rawCongestionController = mockCongestionController.get(); conn.congestionController = std::move(mockCongestionController); EventBase evb; @@ -3091,6 +3095,54 @@ TEST_F(QuicTransportFunctionsTest, WriteProbingCryptoData) { EXPECT_FALSE(cryptoStream->retransmissionBuffer.empty()); } +TEST_F(QuicTransportFunctionsTest, WriteableBytesLimitedProbingCryptoData) { + QuicServerConnectionState conn( + FizzServerQuicHandshakeContext::Builder().build()); + conn.statsCallback = quicStats_.get(); + conn.transportSettings.enableWritableBytesLimit = true; + conn.writableBytesLimit = 2 * conn.udpSendPacketLen; + + conn.serverConnectionId = getTestConnectionId(); + conn.clientConnectionId = getTestConnectionId(); + // writeCryptoDataProbesToSocketForTest writes Initial LongHeader, thus it + // writes at Initial level. + auto currentPacketSeqNum = conn.ackStates.initialAckState.nextPacketNum; + // Replace real congestionController with MockCongestionController: + auto mockCongestionController = + std::make_unique>(); + auto rawCongestionController = mockCongestionController.get(); + conn.congestionController = std::move(mockCongestionController); + EventBase evb; + auto socket = + std::make_unique>(&evb); + auto rawSocket = socket.get(); + auto cryptoStream = &conn.cryptoState->initialStream; + uint8_t probesToSend = 4; + auto buf = buildRandomInputData(conn.udpSendPacketLen * probesToSend); + EXPECT_CALL(*quicStats_, onConnectionWritableBytesLimited()) + .Times(AtLeast(1)); + writeDataToQuicStream(*cryptoStream, buf->clone()); + + auto currentStreamWriteOffset = cryptoStream->currentWriteOffset; + EXPECT_CALL(*rawCongestionController, onPacketSent(_)).Times(2); + EXPECT_CALL(*rawSocket, write(_, _)) + .WillRepeatedly(Invoke([&](const SocketAddress&, + const std::unique_ptr& iobuf) { + auto len = iobuf->computeChainDataLength(); + EXPECT_EQ(conn.udpSendPacketLen - aead->getCipherOverhead(), len); + return len; + })); + writeCryptoDataProbesToSocketForTest( + *rawSocket, conn, probesToSend, *aead, *headerCipher, getVersion(conn)); + + EXPECT_EQ(conn.numProbesWritableBytesLimited, 1); + EXPECT_LT(currentPacketSeqNum, conn.ackStates.initialAckState.nextPacketNum); + EXPECT_FALSE(conn.outstandings.packets.empty()); + EXPECT_TRUE(conn.pendingEvents.setLossDetectionAlarm); + EXPECT_GT(cryptoStream->currentWriteOffset, currentStreamWriteOffset); + EXPECT_FALSE(cryptoStream->retransmissionBuffer.empty()); +} + TEST_F(QuicTransportFunctionsTest, ProbingNotFallbackToPingWhenNoQuota) { auto conn = createConn(); auto mockCongestionController = diff --git a/quic/server/QuicServerTransport.cpp b/quic/server/QuicServerTransport.cpp index 037a14f86..036a6f9c0 100644 --- a/quic/server/QuicServerTransport.cpp +++ b/quic/server/QuicServerTransport.cpp @@ -146,6 +146,9 @@ void QuicServerTransport::onReadData( readData.peer = peer; readData.networkData = std::move(networkData); bool waitingForFirstPacket = !hasReceivedPackets(*conn_); + uint64_t prevWritableBytes = serverConn_->writableBytesLimit + ? *serverConn_->writableBytesLimit + : std::numeric_limits::max(); onServerReadData(*serverConn_, readData); processPendingData(true); @@ -163,6 +166,20 @@ void QuicServerTransport::onReadData( hasReceivedPackets(*conn_)) { connSetupCallback_->onFirstPeerPacketProcessed(); } + + uint64_t curWritableBytes = serverConn_->writableBytesLimit + ? *serverConn_->writableBytesLimit + : std::numeric_limits::max(); + + // If we've increased our writable bytes limit after processing incoming data + // and we were previously blocked from writing probes, fire the PTO alarm + if (serverConn_->transportSettings.enableWritableBytesLimit && + serverConn_->numProbesWritableBytesLimited && + prevWritableBytes < curWritableBytes) { + onPTOAlarm(*serverConn_); + serverConn_->numProbesWritableBytesLimited = 0; + } + maybeWriteNewSessionTicket(); maybeNotifyConnectionIdBound(); maybeNotifyHandshakeFinished(); diff --git a/quic/server/handshake/test/DefaultAppTokenValidatorTest.cpp b/quic/server/handshake/test/DefaultAppTokenValidatorTest.cpp index 2e6a828fd..1286b1662 100644 --- a/quic/server/handshake/test/DefaultAppTokenValidatorTest.cpp +++ b/quic/server/handshake/test/DefaultAppTokenValidatorTest.cpp @@ -390,6 +390,7 @@ TEST_F(LimitIfNoMatchPolicyTest, EmptySourceToken) { ASSERT_THAT( conn_.tokenSourceAddresses, ElementsAre(conn_.peerAddress.getIPAddress())); + EXPECT_FALSE(conn_.isClientAddrVerified); } TEST_F(LimitIfNoMatchPolicyTest, OneSourceTokenNoAddrMatch) { @@ -403,6 +404,7 @@ TEST_F(LimitIfNoMatchPolicyTest, OneSourceTokenNoAddrMatch) { conn_.tokenSourceAddresses, ElementsAre( folly::IPAddress("1.2.3.5"), conn_.peerAddress.getIPAddress())); + EXPECT_FALSE(conn_.isClientAddrVerified); } TEST_F(LimitIfNoMatchPolicyTest, OneSourceTokenAddrMatch) { @@ -413,6 +415,7 @@ TEST_F(LimitIfNoMatchPolicyTest, OneSourceTokenAddrMatch) { ASSERT_THAT( conn_.tokenSourceAddresses, ElementsAre(conn_.peerAddress.getIPAddress())); + EXPECT_TRUE(conn_.isClientAddrVerified); } TEST_F(LimitIfNoMatchPolicyTest, MaxNumSourceTokenNoAddrMatch) { @@ -431,6 +434,7 @@ TEST_F(LimitIfNoMatchPolicyTest, MaxNumSourceTokenNoAddrMatch) { folly::IPAddress("1.2.3.6"), folly::IPAddress("1.2.3.7"), conn_.peerAddress.getIPAddress())); + EXPECT_FALSE(conn_.isClientAddrVerified); } TEST_F(LimitIfNoMatchPolicyTest, MaxNumSourceTokenAddrMatch) { @@ -447,6 +451,7 @@ TEST_F(LimitIfNoMatchPolicyTest, MaxNumSourceTokenAddrMatch) { folly::IPAddress("1.2.3.5"), folly::IPAddress("1.2.3.7"), conn_.peerAddress.getIPAddress())); + EXPECT_TRUE(conn_.isClientAddrVerified); } class RejectIfNoMatchPolicyTest : public SourceAddressTokenTest { @@ -476,6 +481,7 @@ TEST_F(RejectIfNoMatchPolicyTest, OneSourceTokenNoAddrMatch) { conn_.tokenSourceAddresses, ElementsAre( folly::IPAddress("1.2.3.5"), conn_.peerAddress.getIPAddress())); + EXPECT_FALSE(conn_.isClientAddrVerified); } TEST_F(RejectIfNoMatchPolicyTest, OneSourceTokenAddrMatch) { @@ -486,6 +492,7 @@ TEST_F(RejectIfNoMatchPolicyTest, OneSourceTokenAddrMatch) { ASSERT_THAT( conn_.tokenSourceAddresses, ElementsAre(conn_.peerAddress.getIPAddress())); + EXPECT_TRUE(conn_.isClientAddrVerified); } TEST_F(RejectIfNoMatchPolicyTest, MaxNumSourceTokenNoAddrMatch) { @@ -502,6 +509,7 @@ TEST_F(RejectIfNoMatchPolicyTest, MaxNumSourceTokenNoAddrMatch) { folly::IPAddress("1.2.3.6"), folly::IPAddress("1.2.3.7"), conn_.peerAddress.getIPAddress())); + EXPECT_FALSE(conn_.isClientAddrVerified); } TEST_F(RejectIfNoMatchPolicyTest, MaxNumSourceTokenAddrMatch) { @@ -519,6 +527,7 @@ TEST_F(RejectIfNoMatchPolicyTest, MaxNumSourceTokenAddrMatch) { folly::IPAddress("1.2.3.5"), folly::IPAddress("1.2.3.7"), conn_.peerAddress.getIPAddress())); + EXPECT_TRUE(conn_.isClientAddrVerified); } class AlwaysRejectPolicyTest : public SourceAddressTokenTest { @@ -537,6 +546,7 @@ TEST_F(AlwaysRejectPolicyTest, EmptySourceToken) { ASSERT_THAT( conn_.tokenSourceAddresses, ElementsAre(conn_.peerAddress.getIPAddress())); + EXPECT_FALSE(conn_.isClientAddrVerified); } TEST_F(AlwaysRejectPolicyTest, OneSourceTokenNoAddrMatch) { @@ -548,6 +558,7 @@ TEST_F(AlwaysRejectPolicyTest, OneSourceTokenNoAddrMatch) { conn_.tokenSourceAddresses, ElementsAre( folly::IPAddress("1.2.3.5"), conn_.peerAddress.getIPAddress())); + EXPECT_FALSE(conn_.isClientAddrVerified); } TEST_F(AlwaysRejectPolicyTest, OneSourceTokenAddrMatch) { @@ -558,6 +569,7 @@ TEST_F(AlwaysRejectPolicyTest, OneSourceTokenAddrMatch) { ASSERT_THAT( conn_.tokenSourceAddresses, ElementsAre(conn_.peerAddress.getIPAddress())); + EXPECT_TRUE(conn_.isClientAddrVerified); } TEST_F(AlwaysRejectPolicyTest, MaxNumSourceTokenNoAddrMatch) { @@ -574,6 +586,7 @@ TEST_F(AlwaysRejectPolicyTest, MaxNumSourceTokenNoAddrMatch) { folly::IPAddress("1.2.3.6"), folly::IPAddress("1.2.3.7"), conn_.peerAddress.getIPAddress())); + EXPECT_FALSE(conn_.isClientAddrVerified); } TEST_F(AlwaysRejectPolicyTest, MaxNumSourceTokenAddrMatch) { @@ -591,6 +604,7 @@ TEST_F(AlwaysRejectPolicyTest, MaxNumSourceTokenAddrMatch) { folly::IPAddress("1.2.3.5"), folly::IPAddress("1.2.3.7"), conn_.peerAddress.getIPAddress())); + EXPECT_TRUE(conn_.isClientAddrVerified); } } // namespace test diff --git a/quic/server/state/ServerStateMachine.cpp b/quic/server/state/ServerStateMachine.cpp index 4f64c2a00..67a65a955 100644 --- a/quic/server/state/ServerStateMachine.cpp +++ b/quic/server/state/ServerStateMachine.cpp @@ -101,7 +101,7 @@ void maybeSetExperimentalSettings(QuicServerConnectionState& conn) { conn.pacer->setExperimental(true); } } else if (conn.version == QuicVersion::MVFST_EXPERIMENTAL3) { - conn.enableWritableBytesLimit = true; + conn.transportSettings.enableWritableBytesLimit = true; } } } // namespace @@ -444,6 +444,7 @@ bool validateAndUpdateSourceToken( // of vector to increase its favorability. sourceAddresses.erase(sourceAddresses.begin() + ii); sourceAddresses.push_back(conn.peerAddress.getIPAddress()); + conn.isClientAddrVerified = true; } } conn.sourceTokenMatching = foundMatch; @@ -483,7 +484,9 @@ void updateWritableByteLimitOnRecvPacket(QuicServerConnectionState& conn) { if (conn.writableBytesLimit) { conn.writableBytesLimit = *conn.writableBytesLimit + conn.transportSettings.limitedCwndInMss * conn.udpSendPacketLen; - } else if (!conn.isClientAddrVerified && conn.enableWritableBytesLimit) { + } else if ( + !conn.isClientAddrVerified && + conn.transportSettings.enableWritableBytesLimit) { conn.writableBytesLimit = conn.transportSettings.limitedCwndInMss * conn.udpSendPacketLen; } diff --git a/quic/server/state/ServerStateMachine.h b/quic/server/state/ServerStateMachine.h index ae75ac65d..c8156019a 100644 --- a/quic/server/state/ServerStateMachine.h +++ b/quic/server/state/ServerStateMachine.h @@ -138,9 +138,6 @@ struct QuicServerConnectionState : public QuicConnectionStateBase { // NewToken). bool isClientAddrVerified{false}; - // Whether or not to enable WritableBytes limit - bool enableWritableBytesLimit{false}; - #ifdef CCP_ENABLED // Pointer to struct that maintains state needed for interacting with libccp. // Once instance of this struct is created for each instance of diff --git a/quic/server/test/QuicServerTransportTest.cpp b/quic/server/test/QuicServerTransportTest.cpp index b77e8e3f6..657316c7c 100644 --- a/quic/server/test/QuicServerTransportTest.cpp +++ b/quic/server/test/QuicServerTransportTest.cpp @@ -3270,8 +3270,8 @@ TEST_F( */ auto transportSettings = server->getTransportSettings(); transportSettings.limitedCwndInMss = 5; + transportSettings.enableWritableBytesLimit = true; server->setTransportSettings(transportSettings); - server->getNonConstConn().enableWritableBytesLimit = true; EXPECT_CALL(*quicStats_, onConnectionWritableBytesLimited()).Times(0); recvClientHello(true, QuicVersion::MVFST, "CHLO_CERT"); @@ -3328,8 +3328,8 @@ TEST_F( */ auto transportSettings = server->getTransportSettings(); transportSettings.limitedCwndInMss = 3; + transportSettings.enableWritableBytesLimit = true; server->setTransportSettings(transportSettings); - server->getNonConstConn().enableWritableBytesLimit = true; recvClientHello(true, QuicVersion::MVFST, "CHLO_CERT"); @@ -3413,8 +3413,8 @@ TEST_F( */ auto transportSettings = server->getTransportSettings(); transportSettings.limitedCwndInMss = 3; + transportSettings.enableWritableBytesLimit = true; server->setTransportSettings(transportSettings); - server->getNonConstConn().enableWritableBytesLimit = true; EXPECT_CALL(*quicStats_, onConnectionWritableBytesLimited()) .Times(AtLeast(1)); diff --git a/quic/state/StateData.h b/quic/state/StateData.h index 95a5b21a6..b3248f211 100644 --- a/quic/state/StateData.h +++ b/quic/state/StateData.h @@ -708,6 +708,9 @@ struct QuicConnectionStateBase : public folly::DelayedDestruction { // Whether we successfully used 0-RTT keys in this connection. bool usedZeroRtt{false}; + // Number of probe packets that were writableBytesLimited + uint64_t numProbesWritableBytesLimited{0}; + struct DatagramState { uint16_t maxReadFrameSize{kDefaultMaxDatagramFrameSize}; uint16_t maxWriteFrameSize{kDefaultMaxDatagramFrameSize}; diff --git a/quic/state/TransportSettings.h b/quic/state/TransportSettings.h index 2d2e11d58..94785891c 100644 --- a/quic/state/TransportSettings.h +++ b/quic/state/TransportSettings.h @@ -291,6 +291,8 @@ struct TransportSettings { // has to be set to something >> the RTT of the connection. bool enableKeepalive{false}; std::string flowPriming = ""; + // Whether or not to enable WritableBytes limit (server only) + bool enableWritableBytesLimit{false}; }; } // namespace quic