mirror of
https://github.com/redis/go-redis.git
synced 2025-12-03 18:31:14 +03:00
* Add search module builders and tests (#1) * Add search module builders and tests * Add tests * Use builders and Actions in more clean way * Update search_builders.go Co-authored-by: Nedyalko Dyakov <1547186+ndyakov@users.noreply.github.com> * Update search_builders.go Co-authored-by: Nedyalko Dyakov <1547186+ndyakov@users.noreply.github.com> * Apply suggestions from code review Co-authored-by: Nedyalko Dyakov <1547186+ndyakov@users.noreply.github.com> * feat(routing): add internal request/response policy enums * feat: load the policy table in cluster client (#4) * feat: load the policy table in cluster client * Remove comments * modify Tips and command pplicy in commandInfo (#5) * centralize cluster command routing in osscluster_router.go and refactor osscluster.go (#6) * centralize cluster command routing in osscluster_router.go and refactor osscluster.go * enalbe ci on all branches * Add debug prints * Add debug prints * FIX: deal with nil policy * FIX: fixing clusterClient process * chore(osscluster): simplify switch case * wip(command): ai generated clone method for commands * feat: implement response aggregator for Redis cluster commands * feat: implement response aggregator for Redis cluster commands * fix: solve concurrency errors * fix: solve concurrency errors * return MaxRedirects settings * remove locks from getCommandPolicy * Handle MOVED errors more robustly, remove cluster reloading at exectutions, ennsure better routing * Fix: supports Process hook test * Fix: remove response aggregation for single shard commands * Add more preformant type conversion for Cmd type * Add router logic into processPipeline --------- Co-authored-by: Nedyalko Dyakov <nedyalko.dyakov@gmail.com> * remove thread debugging code * remove thread debugging code && reject commands with policy that cannot be used in pipeline * refactor processPipline and cmdType enum * remove FDescribe from cluster tests * Add tests * fix aggregation test * fix mget test * fix mget test * remove aggregateKeyedResponses * added scaffolding for the req-resp manager * added default policies for the search commands * split command map into module->command * cleanup, added logic to refresh the cache * added reactive cache refresh * revert cluster refresh * fixed lint * addresed first batch of comments * rewrote aggregator implementations with atomic for native or nearnative primitives * addressed more comments, fixed lint * added batch aggregator operations * fixed lint * updated batch aggregator, fixed extractcommandvalue * fixed lint * added batching to aggregateResponses * fixed deadlocks * changed aggregator logic, added error params * added preemptive return to the aggregators * more work on the aggregators * updated and and or aggregators * fixed lint * added configurable policy resolvers * slight refactor * removed the interface, slight refactor * change func signature from cmdName to cmder * added nil safety assertions * few small refactors * added read only policies * removed leftover prints * Rebased to master, resolved comnflicts * fixed lint * updated gha * fixed tests, minor consistency refactor * preallocated simple errors * changed numeric aggregators to use float64 * speculative test fix * Update command.go Co-authored-by: Nedyalko Dyakov <1547186+ndyakov@users.noreply.github.com> * Update main_test.go Co-authored-by: Nedyalko Dyakov <1547186+ndyakov@users.noreply.github.com> * Add static shard picker * Fix nil value handling in command aggregation * Modify the Clone method to return a shallow copy * Add clone method to digest command * Optimize keyless command routing to respect ShardPicker policy * Remove MGET references * Fix MGET aggregation to map individual values to keys across shards * Add clone method to hybrid search commands * Undo changes in route keyless test * remove comments * Add test for DisableRoutingPolicies option * Add Routing Policies Comprehensive Test Suite and Fix multi keyed aggregation for different step --------- Co-authored-by: Nedyalko Dyakov <1547186+ndyakov@users.noreply.github.com> Co-authored-by: Nedyalko Dyakov <nedyalko.dyakov@gmail.com> Co-authored-by: Hristo Temelski <hristo.temelski@redis.com>
1003 lines
20 KiB
Go
1003 lines
20 KiB
Go
package routing
|
|
|
|
import (
|
|
"errors"
|
|
"fmt"
|
|
"math"
|
|
"sync"
|
|
|
|
"sync/atomic"
|
|
|
|
"github.com/redis/go-redis/v9/internal/util"
|
|
uberAtomic "go.uber.org/atomic"
|
|
)
|
|
|
|
var (
|
|
ErrMaxAggregation = errors.New("redis: no valid results to aggregate for max operation")
|
|
ErrMinAggregation = errors.New("redis: no valid results to aggregate for min operation")
|
|
ErrAndAggregation = errors.New("redis: no valid results to aggregate for logical AND operation")
|
|
ErrOrAggregation = errors.New("redis: no valid results to aggregate for logical OR operation")
|
|
)
|
|
|
|
// ResponseAggregator defines the interface for aggregating responses from multiple shards.
|
|
type ResponseAggregator interface {
|
|
// Add processes a single shard response.
|
|
Add(result interface{}, err error) error
|
|
|
|
// AddWithKey processes a single shard response for a specific key (used by keyed aggregators).
|
|
AddWithKey(key string, result interface{}, err error) error
|
|
|
|
BatchAdd(map[string]AggregatorResErr) error
|
|
|
|
BatchSlice([]AggregatorResErr) error
|
|
|
|
// Result returns the final aggregated result and any error.
|
|
Result() (interface{}, error)
|
|
}
|
|
|
|
type AggregatorResErr struct {
|
|
Result interface{}
|
|
Err error
|
|
}
|
|
|
|
// NewResponseAggregator creates an aggregator based on the response policy.
|
|
func NewResponseAggregator(policy ResponsePolicy, cmdName string) ResponseAggregator {
|
|
switch policy {
|
|
case RespDefaultKeyless:
|
|
return &DefaultKeylessAggregator{results: make([]interface{}, 0)}
|
|
case RespDefaultHashSlot:
|
|
return &DefaultKeyedAggregator{results: make(map[string]interface{})}
|
|
case RespAllSucceeded:
|
|
return &AllSucceededAggregator{}
|
|
case RespOneSucceeded:
|
|
return &OneSucceededAggregator{}
|
|
case RespAggSum:
|
|
return &AggSumAggregator{
|
|
// res:
|
|
}
|
|
case RespAggMin:
|
|
return &AggMinAggregator{
|
|
res: util.NewAtomicMin(),
|
|
}
|
|
case RespAggMax:
|
|
return &AggMaxAggregator{
|
|
res: util.NewAtomicMax(),
|
|
}
|
|
case RespAggLogicalAnd:
|
|
andAgg := &AggLogicalAndAggregator{}
|
|
andAgg.res.Add(1)
|
|
|
|
return andAgg
|
|
case RespAggLogicalOr:
|
|
return &AggLogicalOrAggregator{}
|
|
case RespSpecial:
|
|
return NewSpecialAggregator(cmdName)
|
|
default:
|
|
return &AllSucceededAggregator{}
|
|
}
|
|
}
|
|
|
|
func NewDefaultAggregator(isKeyed bool) ResponseAggregator {
|
|
if isKeyed {
|
|
return &DefaultKeyedAggregator{
|
|
results: make(map[string]interface{}),
|
|
}
|
|
}
|
|
return &DefaultKeylessAggregator{}
|
|
}
|
|
|
|
// AllSucceededAggregator returns one non-error reply if every shard succeeded,
|
|
// propagates the first error otherwise.
|
|
type AllSucceededAggregator struct {
|
|
err atomic.Value
|
|
res atomic.Value
|
|
}
|
|
|
|
func (a *AllSucceededAggregator) Add(result interface{}, err error) error {
|
|
if err != nil {
|
|
a.err.CompareAndSwap(nil, err)
|
|
return nil
|
|
}
|
|
|
|
if result != nil {
|
|
a.res.CompareAndSwap(nil, result)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (a *AllSucceededAggregator) BatchAdd(results map[string]AggregatorResErr) error {
|
|
for _, res := range results {
|
|
err := a.Add(res.Result, res.Err)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if res.Err != nil {
|
|
return nil
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (a *AllSucceededAggregator) BatchSlice(results []AggregatorResErr) error {
|
|
for _, res := range results {
|
|
err := a.Add(res.Result, res.Err)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if res.Err != nil {
|
|
return nil
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (a *AllSucceededAggregator) Result() (interface{}, error) {
|
|
var err error
|
|
res, e := a.res.Load(), a.err.Load()
|
|
if e != nil {
|
|
err = e.(error)
|
|
}
|
|
|
|
return res, err
|
|
}
|
|
|
|
func (a *AllSucceededAggregator) AddWithKey(key string, result interface{}, err error) error {
|
|
return a.Add(result, err)
|
|
}
|
|
|
|
// OneSucceededAggregator returns the first non-error reply,
|
|
// if all shards errored, returns any one of those errors.
|
|
type OneSucceededAggregator struct {
|
|
err atomic.Value
|
|
res atomic.Value
|
|
}
|
|
|
|
func (a *OneSucceededAggregator) Add(result interface{}, err error) error {
|
|
if err != nil {
|
|
a.err.CompareAndSwap(nil, err)
|
|
return nil
|
|
}
|
|
|
|
if result != nil {
|
|
a.res.CompareAndSwap(nil, result)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (a *OneSucceededAggregator) BatchAdd(results map[string]AggregatorResErr) error {
|
|
for _, res := range results {
|
|
err := a.Add(res.Result, res.Err)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if res.Err == nil {
|
|
return nil
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (a *OneSucceededAggregator) AddWithKey(key string, result interface{}, err error) error {
|
|
return a.Add(result, err)
|
|
}
|
|
|
|
func (a *OneSucceededAggregator) BatchSlice(results []AggregatorResErr) error {
|
|
for _, res := range results {
|
|
err := a.Add(res.Result, res.Err)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if res.Err == nil {
|
|
return nil
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (a *OneSucceededAggregator) Result() (interface{}, error) {
|
|
res, e := a.res.Load(), a.err.Load()
|
|
if res == nil {
|
|
return nil, e.(error)
|
|
}
|
|
|
|
return res, nil
|
|
}
|
|
|
|
// AggSumAggregator sums numeric replies from all shards.
|
|
type AggSumAggregator struct {
|
|
err atomic.Value
|
|
res uberAtomic.Float64
|
|
}
|
|
|
|
func (a *AggSumAggregator) Add(result interface{}, err error) error {
|
|
if err != nil {
|
|
a.err.CompareAndSwap(nil, err)
|
|
}
|
|
|
|
if result != nil {
|
|
val, err := toFloat64(result)
|
|
if err != nil {
|
|
a.err.CompareAndSwap(nil, err)
|
|
return err
|
|
}
|
|
a.res.Add(val)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (a *AggSumAggregator) BatchAdd(results map[string]AggregatorResErr) error {
|
|
var sum int64
|
|
|
|
for _, res := range results {
|
|
if res.Err != nil {
|
|
return a.Add(res.Result, res.Err)
|
|
}
|
|
|
|
intRes, err := toInt64(res.Result)
|
|
if err != nil {
|
|
return a.Add(nil, err)
|
|
}
|
|
|
|
sum += intRes
|
|
}
|
|
|
|
return a.Add(sum, nil)
|
|
}
|
|
|
|
func (a *AggSumAggregator) AddWithKey(key string, result interface{}, err error) error {
|
|
return a.Add(result, err)
|
|
}
|
|
|
|
func (a *AggSumAggregator) BatchSlice(results []AggregatorResErr) error {
|
|
var sum int64
|
|
|
|
for _, res := range results {
|
|
if res.Err != nil {
|
|
return a.Add(res.Result, res.Err)
|
|
}
|
|
|
|
intRes, err := toInt64(res.Result)
|
|
if err != nil {
|
|
return a.Add(nil, err)
|
|
}
|
|
|
|
sum += intRes
|
|
}
|
|
|
|
return a.Add(sum, nil)
|
|
}
|
|
|
|
func (a *AggSumAggregator) Result() (interface{}, error) {
|
|
res, err := a.res.Load(), a.err.Load()
|
|
if err != nil {
|
|
return nil, err.(error)
|
|
}
|
|
|
|
return res, nil
|
|
}
|
|
|
|
// AggMinAggregator returns the minimum numeric value from all shards.
|
|
type AggMinAggregator struct {
|
|
err atomic.Value
|
|
res *util.AtomicMin
|
|
}
|
|
|
|
func (a *AggMinAggregator) Add(result interface{}, err error) error {
|
|
if err != nil {
|
|
a.err.CompareAndSwap(nil, err)
|
|
return nil
|
|
}
|
|
|
|
floatVal, e := toFloat64(result)
|
|
if e != nil {
|
|
a.err.CompareAndSwap(nil, err)
|
|
return nil
|
|
}
|
|
|
|
a.res.Value(floatVal)
|
|
|
|
return nil
|
|
}
|
|
|
|
func (a *AggMinAggregator) BatchAdd(results map[string]AggregatorResErr) error {
|
|
min := int64(math.MaxInt64)
|
|
|
|
for _, res := range results {
|
|
if res.Err != nil {
|
|
_ = a.Add(nil, res.Err)
|
|
return nil
|
|
}
|
|
|
|
resInt, err := toInt64(res.Result)
|
|
if err != nil {
|
|
_ = a.Add(nil, res.Err)
|
|
return nil
|
|
}
|
|
|
|
if resInt < min {
|
|
min = resInt
|
|
}
|
|
|
|
}
|
|
|
|
return a.Add(min, nil)
|
|
}
|
|
|
|
func (a *AggMinAggregator) AddWithKey(key string, result interface{}, err error) error {
|
|
return a.Add(result, err)
|
|
}
|
|
|
|
func (a *AggMinAggregator) BatchSlice(results []AggregatorResErr) error {
|
|
min := float64(math.MaxFloat64)
|
|
|
|
for _, res := range results {
|
|
if res.Err != nil {
|
|
_ = a.Add(nil, res.Err)
|
|
return nil
|
|
}
|
|
|
|
floatVal, err := toFloat64(res.Result)
|
|
if err != nil {
|
|
_ = a.Add(nil, res.Err)
|
|
return nil
|
|
}
|
|
|
|
if floatVal < min {
|
|
min = floatVal
|
|
}
|
|
|
|
}
|
|
|
|
return a.Add(min, nil)
|
|
}
|
|
|
|
func (a *AggMinAggregator) Result() (interface{}, error) {
|
|
err := a.err.Load()
|
|
if err != nil {
|
|
return nil, err.(error)
|
|
}
|
|
|
|
val, hasVal := a.res.Min()
|
|
if !hasVal {
|
|
return nil, ErrMinAggregation
|
|
}
|
|
return val, nil
|
|
}
|
|
|
|
// AggMaxAggregator returns the maximum numeric value from all shards.
|
|
type AggMaxAggregator struct {
|
|
err atomic.Value
|
|
res *util.AtomicMax
|
|
}
|
|
|
|
func (a *AggMaxAggregator) Add(result interface{}, err error) error {
|
|
if err != nil {
|
|
a.err.CompareAndSwap(nil, err)
|
|
return nil
|
|
}
|
|
|
|
floatVal, e := toFloat64(result)
|
|
if e != nil {
|
|
a.err.CompareAndSwap(nil, err)
|
|
return nil
|
|
}
|
|
|
|
a.res.Value(floatVal)
|
|
|
|
return nil
|
|
}
|
|
|
|
func (a *AggMaxAggregator) BatchAdd(results map[string]AggregatorResErr) error {
|
|
max := int64(math.MinInt64)
|
|
|
|
for _, res := range results {
|
|
if res.Err != nil {
|
|
_ = a.Add(nil, res.Err)
|
|
return nil
|
|
}
|
|
|
|
resInt, err := toInt64(res.Result)
|
|
if err != nil {
|
|
_ = a.Add(nil, res.Err)
|
|
return nil
|
|
}
|
|
|
|
if resInt > max {
|
|
max = resInt
|
|
}
|
|
|
|
}
|
|
|
|
return a.Add(max, nil)
|
|
}
|
|
|
|
func (a *AggMaxAggregator) AddWithKey(key string, result interface{}, err error) error {
|
|
return a.Add(result, err)
|
|
}
|
|
|
|
func (a *AggMaxAggregator) BatchSlice(results []AggregatorResErr) error {
|
|
max := int64(math.MinInt64)
|
|
|
|
for _, res := range results {
|
|
if res.Err != nil {
|
|
_ = a.Add(nil, res.Err)
|
|
return nil
|
|
}
|
|
|
|
resInt, err := toInt64(res.Result)
|
|
if err != nil {
|
|
_ = a.Add(nil, res.Err)
|
|
return nil
|
|
}
|
|
|
|
if resInt > max {
|
|
max = resInt
|
|
}
|
|
|
|
}
|
|
|
|
return a.Add(max, nil)
|
|
}
|
|
|
|
func (a *AggMaxAggregator) Result() (interface{}, error) {
|
|
err := a.err.Load()
|
|
if err != nil {
|
|
return nil, err.(error)
|
|
}
|
|
|
|
val, hasVal := a.res.Max()
|
|
if !hasVal {
|
|
return nil, ErrMaxAggregation
|
|
}
|
|
return val, nil
|
|
}
|
|
|
|
// AggLogicalAndAggregator performs logical AND on boolean values.
|
|
type AggLogicalAndAggregator struct {
|
|
err atomic.Value
|
|
res atomic.Int64
|
|
hasResult atomic.Bool
|
|
}
|
|
|
|
func (a *AggLogicalAndAggregator) Add(result interface{}, err error) error {
|
|
if err != nil {
|
|
a.err.CompareAndSwap(nil, err)
|
|
return nil
|
|
}
|
|
|
|
val, e := toBool(result)
|
|
if e != nil {
|
|
a.err.CompareAndSwap(nil, e)
|
|
return e
|
|
}
|
|
|
|
if val {
|
|
a.res.And(1)
|
|
} else {
|
|
a.res.And(0)
|
|
}
|
|
|
|
a.hasResult.Store(true)
|
|
|
|
return nil
|
|
}
|
|
|
|
func (a *AggLogicalAndAggregator) BatchAdd(results map[string]AggregatorResErr) error {
|
|
result := true
|
|
|
|
for _, res := range results {
|
|
if res.Err != nil {
|
|
return a.Add(nil, res.Err)
|
|
}
|
|
|
|
boolRes, err := toBool(res.Result)
|
|
if err != nil {
|
|
return a.Add(nil, err)
|
|
}
|
|
|
|
result = result && boolRes
|
|
}
|
|
|
|
return a.Add(result, nil)
|
|
}
|
|
|
|
func (a *AggLogicalAndAggregator) AddWithKey(key string, result interface{}, err error) error {
|
|
return a.Add(result, err)
|
|
}
|
|
|
|
func (a *AggLogicalAndAggregator) BatchSlice(results []AggregatorResErr) error {
|
|
result := true
|
|
|
|
for _, res := range results {
|
|
if res.Err != nil {
|
|
return a.Add(nil, res.Err)
|
|
}
|
|
|
|
boolRes, err := toBool(res.Result)
|
|
if err != nil {
|
|
return a.Add(nil, err)
|
|
}
|
|
|
|
result = result && boolRes
|
|
}
|
|
|
|
return a.Add(result, nil)
|
|
}
|
|
|
|
func (a *AggLogicalAndAggregator) Result() (interface{}, error) {
|
|
err := a.err.Load()
|
|
if err != nil {
|
|
return nil, err.(error)
|
|
}
|
|
|
|
if !a.hasResult.Load() {
|
|
return nil, ErrAndAggregation
|
|
}
|
|
return a.res.Load() != 0, nil
|
|
}
|
|
|
|
// AggLogicalOrAggregator performs logical OR on boolean values.
|
|
type AggLogicalOrAggregator struct {
|
|
err atomic.Value
|
|
res atomic.Int64
|
|
hasResult atomic.Bool
|
|
}
|
|
|
|
func (a *AggLogicalOrAggregator) Add(result interface{}, err error) error {
|
|
if err != nil {
|
|
a.err.CompareAndSwap(nil, err)
|
|
return nil
|
|
}
|
|
|
|
val, e := toBool(result)
|
|
if e != nil {
|
|
a.err.CompareAndSwap(nil, e)
|
|
return e
|
|
}
|
|
|
|
if val {
|
|
a.res.Or(1)
|
|
} else {
|
|
a.res.Or(0)
|
|
}
|
|
|
|
a.hasResult.Store(true)
|
|
|
|
return nil
|
|
}
|
|
|
|
func (a *AggLogicalOrAggregator) BatchAdd(results map[string]AggregatorResErr) error {
|
|
result := false
|
|
|
|
for _, res := range results {
|
|
if res.Err != nil {
|
|
return a.Add(nil, res.Err)
|
|
}
|
|
|
|
boolRes, err := toBool(res.Result)
|
|
if err != nil {
|
|
return a.Add(nil, err)
|
|
}
|
|
|
|
result = result || boolRes
|
|
}
|
|
|
|
return a.Add(result, nil)
|
|
}
|
|
|
|
func (a *AggLogicalOrAggregator) AddWithKey(key string, result interface{}, err error) error {
|
|
return a.Add(result, err)
|
|
}
|
|
|
|
func (a *AggLogicalOrAggregator) BatchSlice(results []AggregatorResErr) error {
|
|
result := false
|
|
|
|
for _, res := range results {
|
|
if res.Err != nil {
|
|
return a.Add(nil, res.Err)
|
|
}
|
|
|
|
boolRes, err := toBool(res.Result)
|
|
if err != nil {
|
|
return a.Add(nil, err)
|
|
}
|
|
|
|
result = result || boolRes
|
|
}
|
|
|
|
return a.Add(result, nil)
|
|
}
|
|
|
|
func (a *AggLogicalOrAggregator) Result() (interface{}, error) {
|
|
err := a.err.Load()
|
|
if err != nil {
|
|
return nil, err.(error)
|
|
}
|
|
|
|
if !a.hasResult.Load() {
|
|
return nil, ErrOrAggregation
|
|
}
|
|
return a.res.Load() != 0, nil
|
|
}
|
|
|
|
func toInt64(val interface{}) (int64, error) {
|
|
if val == nil {
|
|
return 0, nil
|
|
}
|
|
switch v := val.(type) {
|
|
case int64:
|
|
return v, nil
|
|
case int:
|
|
return int64(v), nil
|
|
case int32:
|
|
return int64(v), nil
|
|
case float64:
|
|
if v != math.Trunc(v) {
|
|
return 0, fmt.Errorf("cannot convert float %f to int64", v)
|
|
}
|
|
return int64(v), nil
|
|
default:
|
|
return 0, fmt.Errorf("cannot convert %T to int64", val)
|
|
}
|
|
}
|
|
|
|
func toFloat64(val interface{}) (float64, error) {
|
|
if val == nil {
|
|
return 0, nil
|
|
}
|
|
|
|
switch v := val.(type) {
|
|
case float64:
|
|
return v, nil
|
|
case int:
|
|
return float64(v), nil
|
|
case int32:
|
|
return float64(v), nil
|
|
case int64:
|
|
return float64(v), nil
|
|
case float32:
|
|
return float64(v), nil
|
|
default:
|
|
return 0, fmt.Errorf("cannot convert %T to float64", val)
|
|
}
|
|
}
|
|
|
|
func toBool(val interface{}) (bool, error) {
|
|
if val == nil {
|
|
return false, nil
|
|
}
|
|
switch v := val.(type) {
|
|
case bool:
|
|
return v, nil
|
|
case int64:
|
|
return v != 0, nil
|
|
case int:
|
|
return v != 0, nil
|
|
default:
|
|
return false, fmt.Errorf("cannot convert %T to bool", val)
|
|
}
|
|
}
|
|
|
|
// DefaultKeylessAggregator collects all results in an array, order doesn't matter.
|
|
type DefaultKeylessAggregator struct {
|
|
mu sync.Mutex
|
|
results []interface{}
|
|
firstErr error
|
|
}
|
|
|
|
func (a *DefaultKeylessAggregator) add(result interface{}, err error) error {
|
|
if err != nil && a.firstErr == nil {
|
|
a.firstErr = err
|
|
return nil
|
|
}
|
|
if err == nil {
|
|
a.results = append(a.results, result)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (a *DefaultKeylessAggregator) Add(result interface{}, err error) error {
|
|
a.mu.Lock()
|
|
defer a.mu.Unlock()
|
|
|
|
return a.add(result, err)
|
|
}
|
|
|
|
func (a *DefaultKeylessAggregator) BatchAdd(results map[string]AggregatorResErr) error {
|
|
a.mu.Lock()
|
|
defer a.mu.Unlock()
|
|
|
|
for _, res := range results {
|
|
err := a.add(res.Result, res.Err)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if res.Err != nil {
|
|
return nil
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (a *DefaultKeylessAggregator) AddWithKey(key string, result interface{}, err error) error {
|
|
return a.Add(result, err)
|
|
}
|
|
|
|
func (a *DefaultKeylessAggregator) BatchSlice(results []AggregatorResErr) error {
|
|
a.mu.Lock()
|
|
defer a.mu.Unlock()
|
|
|
|
for _, res := range results {
|
|
err := a.add(res.Result, res.Err)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if res.Err != nil {
|
|
return nil
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (a *DefaultKeylessAggregator) Result() (interface{}, error) {
|
|
a.mu.Lock()
|
|
defer a.mu.Unlock()
|
|
|
|
if a.firstErr != nil {
|
|
return nil, a.firstErr
|
|
}
|
|
return a.results, nil
|
|
}
|
|
|
|
// DefaultKeyedAggregator reassembles replies in the exact key order of the original request.
|
|
type DefaultKeyedAggregator struct {
|
|
mu sync.Mutex
|
|
results map[string]interface{}
|
|
keyOrder []string
|
|
firstErr error
|
|
}
|
|
|
|
func NewDefaultKeyedAggregator(keyOrder []string) *DefaultKeyedAggregator {
|
|
return &DefaultKeyedAggregator{
|
|
results: make(map[string]interface{}),
|
|
keyOrder: keyOrder,
|
|
}
|
|
}
|
|
|
|
func (a *DefaultKeyedAggregator) add(result interface{}, err error) error {
|
|
if err != nil && a.firstErr == nil {
|
|
a.firstErr = err
|
|
return nil
|
|
}
|
|
// For non-keyed Add, just collect the result without ordering
|
|
if err == nil {
|
|
a.results["__default__"] = result
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (a *DefaultKeyedAggregator) Add(result interface{}, err error) error {
|
|
a.mu.Lock()
|
|
defer a.mu.Unlock()
|
|
|
|
return a.add(result, err)
|
|
}
|
|
|
|
func (a *DefaultKeyedAggregator) BatchAdd(results map[string]AggregatorResErr) error {
|
|
a.mu.Lock()
|
|
defer a.mu.Unlock()
|
|
|
|
for _, res := range results {
|
|
err := a.add(res.Result, res.Err)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if res.Err != nil {
|
|
return nil
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (a *DefaultKeyedAggregator) addWithKey(key string, result interface{}, err error) error {
|
|
if err != nil && a.firstErr == nil {
|
|
a.firstErr = err
|
|
return nil
|
|
}
|
|
if err == nil {
|
|
a.results[key] = result
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (a *DefaultKeyedAggregator) AddWithKey(key string, result interface{}, err error) error {
|
|
a.mu.Lock()
|
|
defer a.mu.Unlock()
|
|
|
|
return a.addWithKey(key, result, err)
|
|
}
|
|
|
|
func (a *DefaultKeyedAggregator) BatchAddWithKeyOrder(results map[string]AggregatorResErr, keyOrder []string) error {
|
|
a.mu.Lock()
|
|
defer a.mu.Unlock()
|
|
|
|
a.keyOrder = keyOrder
|
|
for key, res := range results {
|
|
err := a.addWithKey(key, res.Result, res.Err)
|
|
if err != nil {
|
|
return nil
|
|
}
|
|
|
|
if res.Err != nil {
|
|
return nil
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (a *DefaultKeyedAggregator) SetKeyOrder(keyOrder []string) {
|
|
a.mu.Lock()
|
|
defer a.mu.Unlock()
|
|
a.keyOrder = keyOrder
|
|
}
|
|
|
|
func (a *DefaultKeyedAggregator) BatchSlice(results []AggregatorResErr) error {
|
|
a.mu.Lock()
|
|
defer a.mu.Unlock()
|
|
|
|
for _, res := range results {
|
|
err := a.add(res.Result, res.Err)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if res.Err != nil {
|
|
return nil
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (a *DefaultKeyedAggregator) Result() (interface{}, error) {
|
|
a.mu.Lock()
|
|
defer a.mu.Unlock()
|
|
|
|
if a.firstErr != nil {
|
|
return nil, a.firstErr
|
|
}
|
|
|
|
// If no explicit key order is set, return results in any order
|
|
if len(a.keyOrder) == 0 {
|
|
orderedResults := make([]interface{}, 0, len(a.results))
|
|
for _, result := range a.results {
|
|
orderedResults = append(orderedResults, result)
|
|
}
|
|
return orderedResults, nil
|
|
}
|
|
|
|
// Return results in the exact key order
|
|
orderedResults := make([]interface{}, len(a.keyOrder))
|
|
for i, key := range a.keyOrder {
|
|
if result, exists := a.results[key]; exists {
|
|
orderedResults[i] = result
|
|
}
|
|
}
|
|
return orderedResults, nil
|
|
}
|
|
|
|
// SpecialAggregator provides a registry for command-specific aggregation logic.
|
|
type SpecialAggregator struct {
|
|
mu sync.Mutex
|
|
aggregatorFunc func([]interface{}, []error) (interface{}, error)
|
|
results []interface{}
|
|
errors []error
|
|
}
|
|
|
|
func (a *SpecialAggregator) add(result interface{}, err error) error {
|
|
a.results = append(a.results, result)
|
|
a.errors = append(a.errors, err)
|
|
return nil
|
|
}
|
|
|
|
func (a *SpecialAggregator) Add(result interface{}, err error) error {
|
|
a.mu.Lock()
|
|
defer a.mu.Unlock()
|
|
|
|
return a.add(result, err)
|
|
}
|
|
|
|
func (a *SpecialAggregator) BatchAdd(results map[string]AggregatorResErr) error {
|
|
a.mu.Lock()
|
|
defer a.mu.Unlock()
|
|
|
|
for _, res := range results {
|
|
err := a.add(res.Result, res.Err)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if res.Err != nil {
|
|
return nil
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (a *SpecialAggregator) AddWithKey(key string, result interface{}, err error) error {
|
|
return a.Add(result, err)
|
|
}
|
|
|
|
func (a *SpecialAggregator) BatchSlice(results []AggregatorResErr) error {
|
|
a.mu.Lock()
|
|
defer a.mu.Unlock()
|
|
|
|
for _, res := range results {
|
|
err := a.add(res.Result, res.Err)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if res.Err != nil {
|
|
return nil
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (a *SpecialAggregator) Result() (interface{}, error) {
|
|
a.mu.Lock()
|
|
defer a.mu.Unlock()
|
|
|
|
if a.aggregatorFunc != nil {
|
|
return a.aggregatorFunc(a.results, a.errors)
|
|
}
|
|
// Default behavior: return first non-error result or first error
|
|
for i, err := range a.errors {
|
|
if err == nil {
|
|
return a.results[i], nil
|
|
}
|
|
}
|
|
if len(a.errors) > 0 {
|
|
return nil, a.errors[0]
|
|
}
|
|
return nil, nil
|
|
}
|
|
|
|
// SpecialAggregatorRegistry holds custom aggregation functions for specific commands.
|
|
var SpecialAggregatorRegistry = make(map[string]func([]interface{}, []error) (interface{}, error))
|
|
|
|
// RegisterSpecialAggregator registers a custom aggregation function for a command.
|
|
func RegisterSpecialAggregator(cmdName string, fn func([]interface{}, []error) (interface{}, error)) {
|
|
SpecialAggregatorRegistry[cmdName] = fn
|
|
}
|
|
|
|
// NewSpecialAggregator creates a special aggregator with command-specific logic if available.
|
|
func NewSpecialAggregator(cmdName string) *SpecialAggregator {
|
|
agg := &SpecialAggregator{}
|
|
if fn, exists := SpecialAggregatorRegistry[cmdName]; exists {
|
|
agg.aggregatorFunc = fn
|
|
}
|
|
return agg
|
|
}
|