package redis

import (
	"context"
	"fmt"
	"reflect"
	"sync"
	"sync/atomic"
	"testing"
	"time"

	"github.com/redis/go-redis/v9/internal/pool"
	"github.com/redis/go-redis/v9/internal/proto"

	. "github.com/bsm/ginkgo/v2"
	. "github.com/bsm/gomega"
)

var _ = Describe("newClusterState", func() {
	var state *clusterState

	createClusterState := func(slots []ClusterSlot) *clusterState {
		opt := &ClusterOptions{}
		opt.init()
		nodes := newClusterNodes(opt)
		state, err := newClusterState(nodes, slots, "10.10.10.10:1234")
		Expect(err).NotTo(HaveOccurred())
		return state
	}

	Describe("sorting", func() {
		BeforeEach(func() {
			state = createClusterState([]ClusterSlot{{
				Start: 1000,
				End:   1999,
			}, {
				Start: 0,
				End:   999,
			}, {
				Start: 2000,
				End:   2999,
			}})
		})

		It("sorts slots", func() {
			Expect(state.slots).To(Equal([]*clusterSlot{
				{start: 0, end: 999, nodes: nil},
				{start: 1000, end: 1999, nodes: nil},
				{start: 2000, end: 2999, nodes: nil},
			}))
		})
	})

	Describe("loopback", func() {
		BeforeEach(func() {
			state = createClusterState([]ClusterSlot{{
				Nodes: []ClusterNode{{Addr: "127.0.0.1:7001"}},
			}, {
				Nodes: []ClusterNode{{Addr: "127.0.0.1:7002"}},
			}, {
				Nodes: []ClusterNode{{Addr: "1.2.3.4:1234"}},
			}, {
				Nodes: []ClusterNode{{Addr: ":1234"}},
			}})
		})

		It("replaces loopback hosts in addresses", func() {
			slotAddr := func(slot *clusterSlot) string {
				return slot.nodes[0].Client.Options().Addr
			}

			Expect(slotAddr(state.slots[0])).To(Equal("10.10.10.10:7001"))
			Expect(slotAddr(state.slots[1])).To(Equal("10.10.10.10:7002"))
			Expect(slotAddr(state.slots[2])).To(Equal("1.2.3.4:1234"))
			Expect(slotAddr(state.slots[3])).To(Equal(":1234"))
		})
	})
})

type fixedHash string

func (h fixedHash) Get(string) string {
	return string(h)
}

func TestRingSetAddrsAndRebalanceRace(t *testing.T) {
	const (
		ringShard1Name = "ringShardOne"
		ringShard2Name = "ringShardTwo"

		ringShard1Port = "6390"
		ringShard2Port = "6391"
	)

	ring := NewRing(&RingOptions{
		Addrs: map[string]string{
			ringShard1Name: ":" + ringShard1Port,
		},
		// Disable heartbeat
		HeartbeatFrequency: 1 * time.Hour,
		NewConsistentHash: func(shards []string) ConsistentHash {
			switch len(shards) {
			case 1:
				return fixedHash(ringShard1Name)
			case 2:
				return fixedHash(ringShard2Name)
			default:
				t.Fatalf("Unexpected number of shards: %v", shards)
				return nil
			}
		},
	})
	defer ring.Close()

	// Continuously update addresses by adding and removing one address
	updatesDone := make(chan struct{})
	defer func() { close(updatesDone) }()
	go func() {
		for i := 0; ; i++ {
			select {
			case <-updatesDone:
				return
			default:
				if i%2 == 0 {
					ring.SetAddrs(map[string]string{
						ringShard1Name: ":" + ringShard1Port,
					})
				} else {
					ring.SetAddrs(map[string]string{
						ringShard1Name: ":" + ringShard1Port,
						ringShard2Name: ":" + ringShard2Port,
					})
				}
			}
		}
	}()

	timer := time.NewTimer(1 * time.Second)
	for running := true; running; {
		select {
		case <-timer.C:
			running = false
		default:
			shard, err := ring.sharding.GetByKey("whatever")
			if err == nil && shard == nil {
				t.Fatal("shard is nil")
			}
		}
	}
}

func BenchmarkRingShardingRebalanceLocked(b *testing.B) {
	opts := &RingOptions{
		Addrs: make(map[string]string),
		// Disable heartbeat
		HeartbeatFrequency: 1 * time.Hour,
	}
	for i := 0; i < 100; i++ {
		opts.Addrs[fmt.Sprintf("shard%d", i)] = fmt.Sprintf(":63%02d", i)
	}

	ring := NewRing(opts)
	defer ring.Close()

	b.ResetTimer()
	for i := 0; i < b.N; i++ {
		ring.sharding.rebalanceLocked()
	}
}

type testCounter struct {
	mu sync.Mutex
	t  *testing.T
	m  map[string]int
}

func newTestCounter(t *testing.T) *testCounter {
	return &testCounter{t: t, m: make(map[string]int)}
}

func (ct *testCounter) increment(key string) {
	ct.mu.Lock()
	defer ct.mu.Unlock()
	ct.m[key]++
}

func (ct *testCounter) expect(values map[string]int) {
	ct.mu.Lock()
	defer ct.mu.Unlock()
	ct.t.Helper()
	if !reflect.DeepEqual(values, ct.m) {
		ct.t.Errorf("expected %v != actual %v", values, ct.m)
	}
}

func TestRingShardsCleanup(t *testing.T) {
	const (
		ringShard1Name = "ringShardOne"
		ringShard2Name = "ringShardTwo"

		ringShard1Addr = "shard1.test"
		ringShard2Addr = "shard2.test"
	)

	t.Run("closes unused shards", func(t *testing.T) {
		closeCounter := newTestCounter(t)

		ring := NewRing(&RingOptions{
			Addrs: map[string]string{
				ringShard1Name: ringShard1Addr,
				ringShard2Name: ringShard2Addr,
			},
			NewClient: func(opt *Options) *Client {
				c := NewClient(opt)
				c.baseClient.onClose = func() error {
					closeCounter.increment(opt.Addr)
					return nil
				}
				return c
			},
		})
		closeCounter.expect(map[string]int{})

		// no change due to the same addresses
		ring.SetAddrs(map[string]string{
			ringShard1Name: ringShard1Addr,
			ringShard2Name: ringShard2Addr,
		})
		closeCounter.expect(map[string]int{})

		ring.SetAddrs(map[string]string{
			ringShard1Name: ringShard1Addr,
		})
		closeCounter.expect(map[string]int{ringShard2Addr: 1})

		ring.SetAddrs(map[string]string{
			ringShard2Name: ringShard2Addr,
		})
		closeCounter.expect(map[string]int{ringShard1Addr: 1, ringShard2Addr: 1})

		ring.Close()
		closeCounter.expect(map[string]int{ringShard1Addr: 1, ringShard2Addr: 2})
	})

	t.Run("closes created shards if ring was closed", func(t *testing.T) {
		createCounter := newTestCounter(t)
		closeCounter := newTestCounter(t)

		var (
			ring        *Ring
			shouldClose int32
		)

		ring = NewRing(&RingOptions{
			Addrs: map[string]string{
				ringShard1Name: ringShard1Addr,
			},
			NewClient: func(opt *Options) *Client {
				if atomic.LoadInt32(&shouldClose) != 0 {
					ring.Close()
				}
				createCounter.increment(opt.Addr)
				c := NewClient(opt)
				c.baseClient.onClose = func() error {
					closeCounter.increment(opt.Addr)
					return nil
				}
				return c
			},
		})
		createCounter.expect(map[string]int{ringShard1Addr: 1})
		closeCounter.expect(map[string]int{})

		atomic.StoreInt32(&shouldClose, 1)

		ring.SetAddrs(map[string]string{
			ringShard2Name: ringShard2Addr,
		})
		createCounter.expect(map[string]int{ringShard1Addr: 1, ringShard2Addr: 1})
		closeCounter.expect(map[string]int{ringShard1Addr: 1, ringShard2Addr: 1})
	})
}

//------------------------------------------------------------------------------

type timeoutErr struct {
	error
}

func (e timeoutErr) Timeout() bool {
	return true
}

func (e timeoutErr) Temporary() bool {
	return true
}

func (e timeoutErr) Error() string {
	return "i/o timeout"
}

var _ = Describe("withConn", func() {
	var client *Client

	BeforeEach(func() {
		client = NewClient(&Options{
			PoolSize: 1,
		})
	})

	AfterEach(func() {
		client.Close()
	})

	It("should replace the connection in the pool when there is no error", func() {
		var conn *pool.Conn

		client.withConn(ctx, func(ctx context.Context, c *pool.Conn) error {
			conn = c
			return nil
		})

		newConn, err := client.connPool.Get(ctx)
		Expect(err).To(BeNil())
		Expect(newConn).To(Equal(conn))
	})

	It("should replace the connection in the pool when there is an error not related to a bad connection", func() {
		var conn *pool.Conn

		client.withConn(ctx, func(ctx context.Context, c *pool.Conn) error {
			conn = c
			return proto.RedisError("LOADING")
		})

		newConn, err := client.connPool.Get(ctx)
		Expect(err).To(BeNil())
		Expect(newConn).To(Equal(conn))
	})

	It("should remove the connection from the pool when it times out", func() {
		var conn *pool.Conn

		client.withConn(ctx, func(ctx context.Context, c *pool.Conn) error {
			conn = c
			return timeoutErr{}
		})

		newConn, err := client.connPool.Get(ctx)
		Expect(err).To(BeNil())
		Expect(newConn).NotTo(Equal(conn))
		Expect(client.connPool.Len()).To(Equal(1))
	})
})