diff --git a/quic/QuicConstants.h b/quic/QuicConstants.h index 0d65a73c4..af1f386f7 100644 --- a/quic/QuicConstants.h +++ b/quic/QuicConstants.h @@ -293,6 +293,7 @@ enum class LocalErrorCode : uint32_t { CALLBACK_ALREADY_INSTALLED = 0x4000001A, KNOB_FRAME_UNSUPPORTED = 0x4000001B, PACER_NOT_AVAILABLE = 0x4000001C, + RTX_POLICIES_LIMIT_EXCEEDED = 0x4000001D, }; enum class QuicNodeType : bool { diff --git a/quic/QuicException.cpp b/quic/QuicException.cpp index b681d164f..6ab4c8b45 100644 --- a/quic/QuicException.cpp +++ b/quic/QuicException.cpp @@ -123,6 +123,8 @@ folly::StringPiece toString(LocalErrorCode code) { return "Knob Frame Not Supported"; case LocalErrorCode::PACER_NOT_AVAILABLE: return "Pacer not available"; + case LocalErrorCode::RTX_POLICIES_LIMIT_EXCEEDED: + return "Retransmission policies limit exceeded"; default: break; } diff --git a/quic/api/QuicSocket.h b/quic/api/QuicSocket.h index 349e53d51..9c6d4c4bc 100644 --- a/quic/api/QuicSocket.h +++ b/quic/api/QuicSocket.h @@ -1424,7 +1424,7 @@ class QuicSocket { virtual folly::Expected setStreamGroupRetransmissionPolicy( StreamGroupId groupId, - QuicStreamGroupRetransmissionPolicy policy) noexcept = 0; + std::optional policy) noexcept = 0; protected: /** diff --git a/quic/api/QuicTransportBase.cpp b/quic/api/QuicTransportBase.cpp index d23662866..05e362fad 100644 --- a/quic/api/QuicTransportBase.cpp +++ b/quic/api/QuicTransportBase.cpp @@ -3811,11 +3811,24 @@ bool QuicTransportBase::checkCustomRetransmissionProfilesEnabled() const { folly::Expected QuicTransportBase::setStreamGroupRetransmissionPolicy( - StreamGroupId /* groupId */, - QuicStreamGroupRetransmissionPolicy /* policy */) noexcept { + StreamGroupId groupId, + std::optional policy) noexcept { + // Reset the policy to default one. + if (policy == std::nullopt) { + conn_->retransmissionPolicies.erase(groupId); + return folly::unit; + } + if (!checkCustomRetransmissionProfilesEnabled()) { return folly::makeUnexpected(LocalErrorCode::INVALID_OPERATION); } + + if (conn_->retransmissionPolicies.size() >= + conn_->transportSettings.advertisedMaxStreamGroups) { + return folly::makeUnexpected(LocalErrorCode::RTX_POLICIES_LIMIT_EXCEEDED); + } + + conn_->retransmissionPolicies.emplace(groupId, *policy); return folly::unit; } diff --git a/quic/api/QuicTransportBase.h b/quic/api/QuicTransportBase.h index b13a9e849..8f348d87d 100644 --- a/quic/api/QuicTransportBase.h +++ b/quic/api/QuicTransportBase.h @@ -655,12 +655,21 @@ class QuicTransportBase : public QuicSocket, QuicStreamPrioritiesObserver { void appendCmsgs(const folly::SocketOptionMap& options); /** - * Sets stream group retransmission policy. + * Sets the policy per stream group id. + * If policy == std::nullopt, the policy is removed for corresponding stream + * group id (reset to the default rtx policy). */ folly::Expected setStreamGroupRetransmissionPolicy( StreamGroupId groupId, - QuicStreamGroupRetransmissionPolicy policy) noexcept override; + std::optional policy) noexcept + override; + + [[nodiscard]] const folly:: + F14FastMap& + getStreamGroupRetransmissionPolicies() const { + return conn_->retransmissionPolicies; + } protected: void updateCongestionControlSettings( diff --git a/quic/api/test/MockQuicSocket.h b/quic/api/test/MockQuicSocket.h index dd9cdbf76..ee2b948ec 100644 --- a/quic/api/test/MockQuicSocket.h +++ b/quic/api/test/MockQuicSocket.h @@ -366,7 +366,7 @@ class MockQuicSocket : public QuicSocket { MOCK_METHOD( (folly::Expected), setStreamGroupRetransmissionPolicy, - (StreamGroupId, QuicStreamGroupRetransmissionPolicy), + (StreamGroupId, std::optional), (noexcept)); }; } // namespace quic diff --git a/quic/api/test/QuicTransportBaseTest.cpp b/quic/api/test/QuicTransportBaseTest.cpp index 98a0c2922..3e9d3b619 100644 --- a/quic/api/test/QuicTransportBaseTest.cpp +++ b/quic/api/test/QuicTransportBaseTest.cpp @@ -4454,6 +4454,105 @@ TEST_P( res = transport->setStreamGroupRetransmissionPolicy(groupId, policy); EXPECT_TRUE(res.hasError()); EXPECT_EQ(res.error(), LocalErrorCode::INVALID_OPERATION); + EXPECT_EQ(1, transport->getStreamGroupRetransmissionPolicies().size()); + + transport.reset(); +} + +TEST_P( + QuicTransportImplTestWithGroups, + TestStreamGroupRetransmissionPolicyReset) { + auto transportSettings = transport->getTransportSettings(); + transportSettings.advertisedMaxStreamGroups = 16; + transport->setTransportSettings(transportSettings); + transport->getConnectionState().streamManager->refreshTransportSettings( + transportSettings); + + const StreamGroupId groupId = 0x00; + QuicStreamGroupRetransmissionPolicy policy; + + // Add the policy. + auto res = transport->setStreamGroupRetransmissionPolicy(groupId, policy); + EXPECT_TRUE(res.hasValue()); + EXPECT_EQ(transport->getStreamGroupRetransmissionPolicies().size(), 1); + + // Reset allowed. + res = transport->setStreamGroupRetransmissionPolicy(groupId, std::nullopt); + EXPECT_TRUE(res.hasValue()); + EXPECT_EQ(transport->getStreamGroupRetransmissionPolicies().size(), 0); + + // Add the policy back. + res = transport->setStreamGroupRetransmissionPolicy(groupId, policy); + EXPECT_TRUE(res.hasValue()); + EXPECT_EQ(transport->getStreamGroupRetransmissionPolicies().size(), 1); + + // Reset allowed even if custom policies are disabled. + transportSettings.advertisedMaxStreamGroups = 0; + res = transport->setStreamGroupRetransmissionPolicy(groupId, std::nullopt); + EXPECT_TRUE(res.hasValue()); + EXPECT_EQ(transport->getStreamGroupRetransmissionPolicies().size(), 0); + + transport.reset(); +} + +TEST_P( + QuicTransportImplTestWithGroups, + TestStreamGroupRetransmissionPolicyAddRemove) { + auto transportSettings = transport->getTransportSettings(); + transportSettings.advertisedMaxStreamGroups = 16; + transport->setTransportSettings(transportSettings); + transport->getConnectionState().streamManager->refreshTransportSettings( + transportSettings); + + // Add a policy. + const StreamGroupId groupId = 0x00; + const QuicStreamGroupRetransmissionPolicy policy; + auto res = transport->setStreamGroupRetransmissionPolicy(groupId, policy); + EXPECT_TRUE(res.hasValue()); + EXPECT_EQ(transport->getStreamGroupRetransmissionPolicies().size(), 1); + + // Add another one. + const StreamGroupId groupId2 = 0x04; + const QuicStreamGroupRetransmissionPolicy policy2; + res = transport->setStreamGroupRetransmissionPolicy(groupId2, policy2); + EXPECT_TRUE(res.hasValue()); + EXPECT_EQ(transport->getStreamGroupRetransmissionPolicies().size(), 2); + + // Remove second policy. + res = transport->setStreamGroupRetransmissionPolicy(groupId2, std::nullopt); + EXPECT_TRUE(res.hasValue()); + EXPECT_EQ(transport->getStreamGroupRetransmissionPolicies().size(), 1); + + // Remove first policy. + res = transport->setStreamGroupRetransmissionPolicy(groupId, std::nullopt); + EXPECT_TRUE(res.hasValue()); + EXPECT_EQ(transport->getStreamGroupRetransmissionPolicies().size(), 0); + + transport.reset(); +} + +TEST_P( + QuicTransportImplTestWithGroups, + TestStreamGroupRetransmissionPolicyMaxLimit) { + auto transportSettings = transport->getTransportSettings(); + transportSettings.advertisedMaxStreamGroups = 1; + transport->setTransportSettings(transportSettings); + transport->getConnectionState().streamManager->refreshTransportSettings( + transportSettings); + + // Add a policy. + const StreamGroupId groupId = 0x00; + const QuicStreamGroupRetransmissionPolicy policy; + auto res = transport->setStreamGroupRetransmissionPolicy(groupId, policy); + EXPECT_TRUE(res.hasValue()); + EXPECT_EQ(transport->getStreamGroupRetransmissionPolicies().size(), 1); + + // Try adding another one; should be over the limit. + const StreamGroupId groupId2 = 0x04; + res = transport->setStreamGroupRetransmissionPolicy(groupId2, policy); + EXPECT_TRUE(res.hasError()); + EXPECT_EQ(res.error(), LocalErrorCode::RTX_POLICIES_LIMIT_EXCEEDED); + EXPECT_EQ(transport->getStreamGroupRetransmissionPolicies().size(), 1); transport.reset(); } diff --git a/quic/state/StateData.h b/quic/state/StateData.h index a63f1b113..746f37acd 100644 --- a/quic/state/StateData.h +++ b/quic/state/StateData.h @@ -26,6 +26,7 @@ #include #include #include +#include #include #include #include @@ -723,6 +724,9 @@ struct QuicConnectionStateBase : public folly::DelayedDestruction { maybePeerAckReceiveTimestampsConfig; bool peerAdvertisedKnobFrameSupport{false}; + // Retransmission policies map. + folly::F14FastMap + retransmissionPolicies; }; std::ostream& operator<<(std::ostream& os, const QuicConnectionStateBase& st);