diff --git a/quic/QuicConstants.h b/quic/QuicConstants.h index cec2a93ce..5fec49a02 100644 --- a/quic/QuicConstants.h +++ b/quic/QuicConstants.h @@ -378,6 +378,8 @@ constexpr uint64_t kMinNumAvailableConnIds = 8; // default capability of QUIC partial reliability constexpr TransportPartialReliabilitySetting kDefaultPartialReliability = false; +constexpr uint64_t kMaxPacketNumber = (1ull << 62) - 1; + enum class ZeroRttSourceTokenMatchingPolicy : uint8_t { REJECT_IF_NO_EXACT_MATCH, LIMIT_IF_NO_EXACT_MATCH, diff --git a/quic/api/QuicTransportBase.cpp b/quic/api/QuicTransportBase.cpp index a51c788df..d3e53c85c 100644 --- a/quic/api/QuicTransportBase.cpp +++ b/quic/api/QuicTransportBase.cpp @@ -2277,6 +2277,11 @@ void QuicTransportBase::writeSocketData() { auto packetsBefore = conn_->outstandingPackets.size(); writeData(); if (closeState_ != CloseState::CLOSED) { + if (conn_->pendingEvents.closeTransport == true) { + throw QuicTransportException( + "Max packet number reached", + TransportErrorCode::PROTOCOL_VIOLATION); + } setLossDetectionAlarm(*conn_, *this); auto packetsAfter = conn_->outstandingPackets.size(); bool packetWritten = (packetsAfter > packetsBefore); diff --git a/quic/api/test/QuicTransportBaseTest.cpp b/quic/api/test/QuicTransportBaseTest.cpp index 0ab1c6a0e..26f6c68df 100644 --- a/quic/api/test/QuicTransportBaseTest.cpp +++ b/quic/api/test/QuicTransportBaseTest.cpp @@ -356,6 +356,10 @@ class TestQuicTransport closeImpl(folly::none, false, false); } + void invokeWriteSocketData() { + writeSocketData(); + } + QuicServerConnectionState* transportConn; std::unique_ptr aead; std::unique_ptr headerCipher; @@ -1285,7 +1289,14 @@ TEST_P(QuicTransportImplTestClose, TestNotifyPendingWriteOnCloseWithError) { } evb->loopOnce(); } +TEST_F(QuicTransportImplTest, TestTransportCloseWithMaxPacketNumber) { + transport->setServerConnectionId(); + transport->transportConn->pendingEvents.closeTransport = false; + EXPECT_NO_THROW(transport->invokeWriteSocketData()); + transport->transportConn->pendingEvents.closeTransport = true; + EXPECT_THROW(transport->invokeWriteSocketData(), QuicTransportException); +} TEST_F(QuicTransportImplTest, TestGracefulCloseWithActiveStream) { EXPECT_CALL(connCallback, onConnectionEnd()).Times(0); EXPECT_CALL(connCallback, onConnectionError(_)).Times(0); diff --git a/quic/state/QuicStateFunctions.cpp b/quic/state/QuicStateFunctions.cpp index 1c8cbde63..b516f20c8 100644 --- a/quic/state/QuicStateFunctions.cpp +++ b/quic/state/QuicStateFunctions.cpp @@ -213,6 +213,9 @@ void increaseNextPacketNum( QuicConnectionStateBase& conn, PacketNumberSpace pnSpace) noexcept { getAckState(conn, pnSpace).nextPacketNum++; + if (getAckState(conn, pnSpace).nextPacketNum == kMaxPacketNumber - 1) { + conn.pendingEvents.closeTransport = true; + } } std::deque::iterator getFirstOutstandingPacket( diff --git a/quic/state/StateData.h b/quic/state/StateData.h index b2b8d633b..9d3ca6bc6 100644 --- a/quic/state/StateData.h +++ b/quic/state/StateData.h @@ -609,6 +609,9 @@ struct QuicConnectionStateBase { uint8_t numProbePackets{0}; bool cancelPingTimeout{false}; + + // close transport when the next packet number reaches kMaxPacketNum + bool closeTransport{false}; }; PendingEvents pendingEvents; diff --git a/quic/state/test/QuicStateFunctionsTest.cpp b/quic/state/test/QuicStateFunctionsTest.cpp index 6f06de404..4888f2974 100644 --- a/quic/state/test/QuicStateFunctionsTest.cpp +++ b/quic/state/test/QuicStateFunctionsTest.cpp @@ -584,6 +584,14 @@ TEST_F(QuicStateFunctionsTest, EarliestLossTimer) { EXPECT_EQ(currentTime, earliestLossTimer(conn).first.value()); } +TEST_P(QuicStateFunctionsTest, CloseTranportStateChange) { + QuicConnectionStateBase conn(QuicNodeType::Server); + getAckState(conn, GetParam()).nextPacketNum = kMaxPacketNumber - 2; + EXPECT_FALSE(conn.pendingEvents.closeTransport); + increaseNextPacketNum(conn, GetParam()); + EXPECT_TRUE(conn.pendingEvents.closeTransport); +} + INSTANTIATE_TEST_CASE_P( QuicStateFunctionsTests, QuicStateFunctionsTest,