1
0
mirror of https://github.com/redis/go-redis.git synced 2025-07-29 17:41:15 +03:00
This commit is contained in:
ofekshenawa
2024-06-20 02:30:37 +03:00
commit 0b95fd7fa5
188 changed files with 45644 additions and 0 deletions

58
internal/arg.go Normal file
View File

@ -0,0 +1,58 @@
package internal
import (
"fmt"
"strconv"
"time"
"github.com/redis/go-redis/v9/internal/util"
)
func AppendArg(b []byte, v interface{}) []byte {
switch v := v.(type) {
case nil:
return append(b, "<nil>"...)
case string:
return appendUTF8String(b, util.StringToBytes(v))
case []byte:
return appendUTF8String(b, v)
case int:
return strconv.AppendInt(b, int64(v), 10)
case int8:
return strconv.AppendInt(b, int64(v), 10)
case int16:
return strconv.AppendInt(b, int64(v), 10)
case int32:
return strconv.AppendInt(b, int64(v), 10)
case int64:
return strconv.AppendInt(b, v, 10)
case uint:
return strconv.AppendUint(b, uint64(v), 10)
case uint8:
return strconv.AppendUint(b, uint64(v), 10)
case uint16:
return strconv.AppendUint(b, uint64(v), 10)
case uint32:
return strconv.AppendUint(b, uint64(v), 10)
case uint64:
return strconv.AppendUint(b, v, 10)
case float32:
return strconv.AppendFloat(b, float64(v), 'f', -1, 64)
case float64:
return strconv.AppendFloat(b, v, 'f', -1, 64)
case bool:
if v {
return append(b, "true"...)
}
return append(b, "false"...)
case time.Time:
return v.AppendFormat(b, time.RFC3339Nano)
default:
return append(b, fmt.Sprint(v)...)
}
}
func appendUTF8String(dst []byte, src []byte) []byte {
dst = append(dst, src...)
return dst
}

1
internal/customvet/.gitignore vendored Normal file
View File

@ -0,0 +1 @@
/customvet

View File

@ -0,0 +1,62 @@
package setval
import (
"go/ast"
"go/token"
"go/types"
"golang.org/x/tools/go/analysis"
)
var Analyzer = &analysis.Analyzer{
Name: "setval",
Doc: "find Cmder types that are missing a SetVal method",
Run: func(pass *analysis.Pass) (interface{}, error) {
cmderTypes := make(map[string]token.Pos)
typesWithSetValMethod := make(map[string]bool)
for _, file := range pass.Files {
for _, decl := range file.Decls {
funcName, receiverType := parseFuncDecl(decl, pass.TypesInfo)
switch funcName {
case "Result":
cmderTypes[receiverType] = decl.Pos()
case "SetVal":
typesWithSetValMethod[receiverType] = true
}
}
}
for cmder, pos := range cmderTypes {
if !typesWithSetValMethod[cmder] {
pass.Reportf(pos, "%s is missing a SetVal method", cmder)
}
}
return nil, nil
},
}
func parseFuncDecl(decl ast.Decl, typesInfo *types.Info) (funcName, receiverType string) {
funcDecl, ok := decl.(*ast.FuncDecl)
if !ok {
return "", "" // Not a function declaration.
}
if funcDecl.Recv == nil {
return "", "" // Not a method.
}
if len(funcDecl.Recv.List) != 1 {
return "", "" // Unexpected number of receiver arguments. (Can this happen?)
}
receiverTypeObj := typesInfo.TypeOf(funcDecl.Recv.List[0].Type)
if receiverTypeObj == nil {
return "", "" // Unable to determine the receiver type.
}
return funcDecl.Name.Name, receiverTypeObj.String()
}

View File

@ -0,0 +1,14 @@
package setval_test
import (
"testing"
"golang.org/x/tools/go/analysis/analysistest"
"github.com/redis/go-redis/internal/customvet/checks/setval"
)
func Test(t *testing.T) {
testdata := analysistest.TestData()
analysistest.Run(t, testdata, setval.Analyzer, "a")
}

View File

@ -0,0 +1,29 @@
package a
type GoodCmd struct {
val int
}
func (c *GoodCmd) SetVal(val int) {
c.val = val
}
func (c *GoodCmd) Result() (int, error) {
return c.val, nil
}
type BadCmd struct {
val int
}
func (c *BadCmd) Result() (int, error) { // want "\\*a.BadCmd is missing a SetVal method"
return c.val, nil
}
type NotACmd struct {
val int
}
func (c *NotACmd) Val() int {
return c.val
}

10
internal/customvet/go.mod Normal file
View File

@ -0,0 +1,10 @@
module github.com/redis/go-redis/internal/customvet
go 1.17
require golang.org/x/tools v0.5.0
require (
golang.org/x/mod v0.7.0 // indirect
golang.org/x/sys v0.4.0 // indirect
)

View File

@ -0,0 +1,7 @@
golang.org/x/mod v0.7.0 h1:LapD9S96VoQRhi/GrNTqeBJFrUjs5UHCAtTlgwA5oZA=
golang.org/x/mod v0.7.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o=
golang.org/x/sys v0.4.0 h1:Zr2JFtRQNX3BCZ8YtxRE9hNJYC8J6I1MVbMg6owUp18=
golang.org/x/sys v0.4.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/tools v0.5.0 h1:+bSpV5HIeWkuvgaMfI3UmKRThoTA5ODJTUd8T17NO+4=
golang.org/x/tools v0.5.0/go.mod h1:N+Kgy78s5I24c24dU8OfWNEotWjutIs8SnJvn5IDq+k=

View File

@ -0,0 +1,13 @@
package main
import (
"golang.org/x/tools/go/analysis/multichecker"
"github.com/redis/go-redis/internal/customvet/checks/setval"
)
func main() {
multichecker.Main(
setval.Analyzer,
)
}

View File

@ -0,0 +1,78 @@
package hashtag
import (
"strings"
"github.com/redis/go-redis/v9/internal/rand"
)
const slotNumber = 16384
// CRC16 implementation according to CCITT standards.
// Copyright 2001-2010 Georges Menie (www.menie.org)
// Copyright 2013 The Go Authors. All rights reserved.
// http://redis.io/topics/cluster-spec#appendix-a-crc16-reference-implementation-in-ansi-c
var crc16tab = [256]uint16{
0x0000, 0x1021, 0x2042, 0x3063, 0x4084, 0x50a5, 0x60c6, 0x70e7,
0x8108, 0x9129, 0xa14a, 0xb16b, 0xc18c, 0xd1ad, 0xe1ce, 0xf1ef,
0x1231, 0x0210, 0x3273, 0x2252, 0x52b5, 0x4294, 0x72f7, 0x62d6,
0x9339, 0x8318, 0xb37b, 0xa35a, 0xd3bd, 0xc39c, 0xf3ff, 0xe3de,
0x2462, 0x3443, 0x0420, 0x1401, 0x64e6, 0x74c7, 0x44a4, 0x5485,
0xa56a, 0xb54b, 0x8528, 0x9509, 0xe5ee, 0xf5cf, 0xc5ac, 0xd58d,
0x3653, 0x2672, 0x1611, 0x0630, 0x76d7, 0x66f6, 0x5695, 0x46b4,
0xb75b, 0xa77a, 0x9719, 0x8738, 0xf7df, 0xe7fe, 0xd79d, 0xc7bc,
0x48c4, 0x58e5, 0x6886, 0x78a7, 0x0840, 0x1861, 0x2802, 0x3823,
0xc9cc, 0xd9ed, 0xe98e, 0xf9af, 0x8948, 0x9969, 0xa90a, 0xb92b,
0x5af5, 0x4ad4, 0x7ab7, 0x6a96, 0x1a71, 0x0a50, 0x3a33, 0x2a12,
0xdbfd, 0xcbdc, 0xfbbf, 0xeb9e, 0x9b79, 0x8b58, 0xbb3b, 0xab1a,
0x6ca6, 0x7c87, 0x4ce4, 0x5cc5, 0x2c22, 0x3c03, 0x0c60, 0x1c41,
0xedae, 0xfd8f, 0xcdec, 0xddcd, 0xad2a, 0xbd0b, 0x8d68, 0x9d49,
0x7e97, 0x6eb6, 0x5ed5, 0x4ef4, 0x3e13, 0x2e32, 0x1e51, 0x0e70,
0xff9f, 0xefbe, 0xdfdd, 0xcffc, 0xbf1b, 0xaf3a, 0x9f59, 0x8f78,
0x9188, 0x81a9, 0xb1ca, 0xa1eb, 0xd10c, 0xc12d, 0xf14e, 0xe16f,
0x1080, 0x00a1, 0x30c2, 0x20e3, 0x5004, 0x4025, 0x7046, 0x6067,
0x83b9, 0x9398, 0xa3fb, 0xb3da, 0xc33d, 0xd31c, 0xe37f, 0xf35e,
0x02b1, 0x1290, 0x22f3, 0x32d2, 0x4235, 0x5214, 0x6277, 0x7256,
0xb5ea, 0xa5cb, 0x95a8, 0x8589, 0xf56e, 0xe54f, 0xd52c, 0xc50d,
0x34e2, 0x24c3, 0x14a0, 0x0481, 0x7466, 0x6447, 0x5424, 0x4405,
0xa7db, 0xb7fa, 0x8799, 0x97b8, 0xe75f, 0xf77e, 0xc71d, 0xd73c,
0x26d3, 0x36f2, 0x0691, 0x16b0, 0x6657, 0x7676, 0x4615, 0x5634,
0xd94c, 0xc96d, 0xf90e, 0xe92f, 0x99c8, 0x89e9, 0xb98a, 0xa9ab,
0x5844, 0x4865, 0x7806, 0x6827, 0x18c0, 0x08e1, 0x3882, 0x28a3,
0xcb7d, 0xdb5c, 0xeb3f, 0xfb1e, 0x8bf9, 0x9bd8, 0xabbb, 0xbb9a,
0x4a75, 0x5a54, 0x6a37, 0x7a16, 0x0af1, 0x1ad0, 0x2ab3, 0x3a92,
0xfd2e, 0xed0f, 0xdd6c, 0xcd4d, 0xbdaa, 0xad8b, 0x9de8, 0x8dc9,
0x7c26, 0x6c07, 0x5c64, 0x4c45, 0x3ca2, 0x2c83, 0x1ce0, 0x0cc1,
0xef1f, 0xff3e, 0xcf5d, 0xdf7c, 0xaf9b, 0xbfba, 0x8fd9, 0x9ff8,
0x6e17, 0x7e36, 0x4e55, 0x5e74, 0x2e93, 0x3eb2, 0x0ed1, 0x1ef0,
}
func Key(key string) string {
if s := strings.IndexByte(key, '{'); s > -1 {
if e := strings.IndexByte(key[s+1:], '}'); e > 0 {
return key[s+1 : s+e+1]
}
}
return key
}
func RandomSlot() int {
return rand.Intn(slotNumber)
}
// Slot returns a consistent slot number between 0 and 16383
// for any given string key.
func Slot(key string) int {
if key == "" {
return RandomSlot()
}
key = Key(key)
return int(crc16sum(key)) % slotNumber
}
func crc16sum(key string) (crc uint16) {
for i := 0; i < len(key); i++ {
crc = (crc << 8) ^ crc16tab[(byte(crc>>8)^key[i])&0x00ff]
}
return
}

View File

@ -0,0 +1,71 @@
package hashtag
import (
"testing"
. "github.com/bsm/ginkgo/v2"
. "github.com/bsm/gomega"
"github.com/redis/go-redis/v9/internal/rand"
)
func TestGinkgoSuite(t *testing.T) {
RegisterFailHandler(Fail)
RunSpecs(t, "hashtag")
}
var _ = Describe("CRC16", func() {
// http://redis.io/topics/cluster-spec#keys-distribution-model
It("should calculate CRC16", func() {
tests := []struct {
s string
n uint16
}{
{"123456789", 0x31C3},
{string([]byte{83, 153, 134, 118, 229, 214, 244, 75, 140, 37, 215, 215}), 21847},
}
for _, test := range tests {
Expect(crc16sum(test.s)).To(Equal(test.n), "for %s", test.s)
}
})
})
var _ = Describe("HashSlot", func() {
It("should calculate hash slots", func() {
tests := []struct {
key string
slot int
}{
{"123456789", 12739},
{"{}foo", 9500},
{"foo{}", 5542},
{"foo{}{bar}", 8363},
{"", 10503},
{"", 5176},
{string([]byte{83, 153, 134, 118, 229, 214, 244, 75, 140, 37, 215, 215}), 5463},
}
// Empty keys receive random slot.
rand.Seed(100)
for _, test := range tests {
Expect(Slot(test.key)).To(Equal(test.slot), "for %s", test.key)
}
})
It("should extract keys from tags", func() {
tests := []struct {
one, two string
}{
{"foo{bar}", "bar"},
{"{foo}bar", "foo"},
{"{user1000}.following", "{user1000}.followers"},
{"foo{{bar}}zap", "{bar"},
{"foo{bar}{zap}", "bar"},
}
for _, test := range tests {
Expect(Slot(test.one)).To(Equal(Slot(test.two)), "for %s <-> %s", test.one, test.two)
}
})
})

207
internal/hscan/hscan.go Normal file
View File

@ -0,0 +1,207 @@
package hscan
import (
"errors"
"fmt"
"reflect"
"strconv"
)
// decoderFunc represents decoding functions for default built-in types.
type decoderFunc func(reflect.Value, string) error
// Scanner is the interface implemented by themselves,
// which will override the decoding behavior of decoderFunc.
type Scanner interface {
ScanRedis(s string) error
}
var (
// List of built-in decoders indexed by their numeric constant values (eg: reflect.Bool = 1).
decoders = []decoderFunc{
reflect.Bool: decodeBool,
reflect.Int: decodeInt,
reflect.Int8: decodeInt8,
reflect.Int16: decodeInt16,
reflect.Int32: decodeInt32,
reflect.Int64: decodeInt64,
reflect.Uint: decodeUint,
reflect.Uint8: decodeUint8,
reflect.Uint16: decodeUint16,
reflect.Uint32: decodeUint32,
reflect.Uint64: decodeUint64,
reflect.Float32: decodeFloat32,
reflect.Float64: decodeFloat64,
reflect.Complex64: decodeUnsupported,
reflect.Complex128: decodeUnsupported,
reflect.Array: decodeUnsupported,
reflect.Chan: decodeUnsupported,
reflect.Func: decodeUnsupported,
reflect.Interface: decodeUnsupported,
reflect.Map: decodeUnsupported,
reflect.Ptr: decodeUnsupported,
reflect.Slice: decodeSlice,
reflect.String: decodeString,
reflect.Struct: decodeUnsupported,
reflect.UnsafePointer: decodeUnsupported,
}
// Global map of struct field specs that is populated once for every new
// struct type that is scanned. This caches the field types and the corresponding
// decoder functions to avoid iterating through struct fields on subsequent scans.
globalStructMap = newStructMap()
)
func Struct(dst interface{}) (StructValue, error) {
v := reflect.ValueOf(dst)
// The destination to scan into should be a struct pointer.
if v.Kind() != reflect.Ptr || v.IsNil() {
return StructValue{}, fmt.Errorf("redis.Scan(non-pointer %T)", dst)
}
v = v.Elem()
if v.Kind() != reflect.Struct {
return StructValue{}, fmt.Errorf("redis.Scan(non-struct %T)", dst)
}
return StructValue{
spec: globalStructMap.get(v.Type()),
value: v,
}, nil
}
// Scan scans the results from a key-value Redis map result set to a destination struct.
// The Redis keys are matched to the struct's field with the `redis` tag.
func Scan(dst interface{}, keys []interface{}, vals []interface{}) error {
if len(keys) != len(vals) {
return errors.New("args should have the same number of keys and vals")
}
strct, err := Struct(dst)
if err != nil {
return err
}
// Iterate through the (key, value) sequence.
for i := 0; i < len(vals); i++ {
key, ok := keys[i].(string)
if !ok {
continue
}
val, ok := vals[i].(string)
if !ok {
continue
}
if err := strct.Scan(key, val); err != nil {
return err
}
}
return nil
}
func decodeBool(f reflect.Value, s string) error {
b, err := strconv.ParseBool(s)
if err != nil {
return err
}
f.SetBool(b)
return nil
}
func decodeInt8(f reflect.Value, s string) error {
return decodeNumber(f, s, 8)
}
func decodeInt16(f reflect.Value, s string) error {
return decodeNumber(f, s, 16)
}
func decodeInt32(f reflect.Value, s string) error {
return decodeNumber(f, s, 32)
}
func decodeInt64(f reflect.Value, s string) error {
return decodeNumber(f, s, 64)
}
func decodeInt(f reflect.Value, s string) error {
return decodeNumber(f, s, 0)
}
func decodeNumber(f reflect.Value, s string, bitSize int) error {
v, err := strconv.ParseInt(s, 10, bitSize)
if err != nil {
return err
}
f.SetInt(v)
return nil
}
func decodeUint8(f reflect.Value, s string) error {
return decodeUnsignedNumber(f, s, 8)
}
func decodeUint16(f reflect.Value, s string) error {
return decodeUnsignedNumber(f, s, 16)
}
func decodeUint32(f reflect.Value, s string) error {
return decodeUnsignedNumber(f, s, 32)
}
func decodeUint64(f reflect.Value, s string) error {
return decodeUnsignedNumber(f, s, 64)
}
func decodeUint(f reflect.Value, s string) error {
return decodeUnsignedNumber(f, s, 0)
}
func decodeUnsignedNumber(f reflect.Value, s string, bitSize int) error {
v, err := strconv.ParseUint(s, 10, bitSize)
if err != nil {
return err
}
f.SetUint(v)
return nil
}
func decodeFloat32(f reflect.Value, s string) error {
v, err := strconv.ParseFloat(s, 32)
if err != nil {
return err
}
f.SetFloat(v)
return nil
}
// although the default is float64, but we better define it.
func decodeFloat64(f reflect.Value, s string) error {
v, err := strconv.ParseFloat(s, 64)
if err != nil {
return err
}
f.SetFloat(v)
return nil
}
func decodeString(f reflect.Value, s string) error {
f.SetString(s)
return nil
}
func decodeSlice(f reflect.Value, s string) error {
// []byte slice ([]uint8).
if f.Type().Elem().Kind() == reflect.Uint8 {
f.SetBytes([]byte(s))
}
return nil
}
func decodeUnsupported(v reflect.Value, s string) error {
return fmt.Errorf("redis.Scan(unsupported %s)", v.Type())
}

View File

@ -0,0 +1,220 @@
package hscan
import (
"math"
"strconv"
"testing"
"time"
. "github.com/bsm/ginkgo/v2"
. "github.com/bsm/gomega"
"github.com/redis/go-redis/v9/internal/util"
)
type data struct {
Omit string `redis:"-"`
Empty string
String string `redis:"string"`
Bytes []byte `redis:"byte"`
Int int `redis:"int"`
Int8 int8 `redis:"int8"`
Int16 int16 `redis:"int16"`
Int32 int32 `redis:"int32"`
Int64 int64 `redis:"int64"`
Uint uint `redis:"uint"`
Uint8 uint8 `redis:"uint8"`
Uint16 uint16 `redis:"uint16"`
Uint32 uint32 `redis:"uint32"`
Uint64 uint64 `redis:"uint64"`
Float float32 `redis:"float"`
Float64 float64 `redis:"float64"`
Bool bool `redis:"bool"`
BoolRef *bool `redis:"boolRef"`
}
type TimeRFC3339Nano struct {
time.Time
}
func (t *TimeRFC3339Nano) ScanRedis(s string) (err error) {
t.Time, err = time.Parse(time.RFC3339Nano, s)
return
}
type TimeData struct {
Name string `redis:"name"`
Time *TimeRFC3339Nano `redis:"login"`
}
type i []interface{}
func TestGinkgoSuite(t *testing.T) {
RegisterFailHandler(Fail)
RunSpecs(t, "hscan")
}
var _ = Describe("Scan", func() {
It("catches bad args", func() {
var d data
Expect(Scan(&d, i{}, i{})).NotTo(HaveOccurred())
Expect(d).To(Equal(data{}))
Expect(Scan(&d, i{"key"}, i{})).To(HaveOccurred())
Expect(Scan(&d, i{"key"}, i{"1", "2"})).To(HaveOccurred())
Expect(Scan(nil, i{"key", "1"}, i{})).To(HaveOccurred())
var m map[string]interface{}
Expect(Scan(&m, i{"key"}, i{"1"})).To(HaveOccurred())
Expect(Scan(data{}, i{"key"}, i{"1"})).To(HaveOccurred())
Expect(Scan(data{}, i{"key", "string"}, i{nil, nil})).To(HaveOccurred())
})
It("number out of range", func() {
f := func(v uint64) string {
return strconv.FormatUint(v, 10) + "1"
}
keys := i{"int8", "int16", "int32", "int64", "uint8", "uint16", "uint32", "uint64", "float", "float64"}
vals := i{
f(math.MaxInt8), f(math.MaxInt16), f(math.MaxInt32), f(math.MaxInt64),
f(math.MaxUint8), f(math.MaxUint16), f(math.MaxUint32), strconv.FormatUint(math.MaxUint64, 10) + "1",
"13.4028234663852886e+38", "11.79769313486231570e+308",
}
for k, v := range keys {
var d data
Expect(Scan(&d, i{v}, i{vals[k]})).To(HaveOccurred())
}
// success
f = func(v uint64) string {
return strconv.FormatUint(v, 10)
}
keys = i{"int8", "int16", "int32", "int64", "uint8", "uint16", "uint32", "uint64", "float", "float64"}
vals = i{
f(math.MaxInt8), f(math.MaxInt16), f(math.MaxInt32), f(math.MaxInt64),
f(math.MaxUint8), f(math.MaxUint16), f(math.MaxUint32), strconv.FormatUint(math.MaxUint64, 10),
"3.40282346638528859811704183484516925440e+38", "1.797693134862315708145274237317043567981e+308",
}
var d data
Expect(Scan(&d, keys, vals)).NotTo(HaveOccurred())
Expect(d).To(Equal(data{
Int8: math.MaxInt8,
Int16: math.MaxInt16,
Int32: math.MaxInt32,
Int64: math.MaxInt64,
Uint8: math.MaxUint8,
Uint16: math.MaxUint16,
Uint32: math.MaxUint32,
Uint64: math.MaxUint64,
Float: math.MaxFloat32,
Float64: math.MaxFloat64,
}))
})
It("scans good values", func() {
var d data
// non-tagged fields.
Expect(Scan(&d, i{"key"}, i{"value"})).NotTo(HaveOccurred())
Expect(d).To(Equal(data{}))
keys := i{"string", "byte", "int", "int64", "uint", "uint64", "float", "float64", "bool", "boolRef"}
vals := i{
"str!", "bytes!", "123", "123456789123456789", "456", "987654321987654321",
"123.456", "123456789123456789.987654321987654321", "1", "1",
}
Expect(Scan(&d, keys, vals)).NotTo(HaveOccurred())
Expect(d).To(Equal(data{
String: "str!",
Bytes: []byte("bytes!"),
Int: 123,
Int64: 123456789123456789,
Uint: 456,
Uint64: 987654321987654321,
Float: 123.456,
Float64: 1.2345678912345678e+17,
Bool: true,
BoolRef: util.ToPtr(true),
}))
// Scan a different type with the same values to test that
// the struct spec maps don't conflict.
type data2 struct {
String string `redis:"string"`
Bytes []byte `redis:"byte"`
Int int `redis:"int"`
Uint uint `redis:"uint"`
Float float32 `redis:"float"`
Bool bool `redis:"bool"`
}
var d2 data2
Expect(Scan(&d2, keys, vals)).NotTo(HaveOccurred())
Expect(d2).To(Equal(data2{
String: "str!",
Bytes: []byte("bytes!"),
Int: 123,
Uint: 456,
Float: 123.456,
Bool: true,
}))
Expect(Scan(&d, i{"string", "float", "bool"}, i{"", "1", "t"})).NotTo(HaveOccurred())
Expect(d).To(Equal(data{
String: "",
Bytes: []byte("bytes!"),
Int: 123,
Int64: 123456789123456789,
Uint: 456,
Uint64: 987654321987654321,
Float: 1.0,
Float64: 1.2345678912345678e+17,
Bool: true,
BoolRef: util.ToPtr(true),
}))
})
It("omits untagged fields", func() {
var d data
Expect(Scan(&d, i{"empty", "omit", "string"}, i{"value", "value", "str!"})).NotTo(HaveOccurred())
Expect(d).To(Equal(data{
String: "str!",
}))
})
It("catches bad values", func() {
var d data
Expect(Scan(&d, i{"int"}, i{"a"})).To(HaveOccurred())
Expect(Scan(&d, i{"uint"}, i{"a"})).To(HaveOccurred())
Expect(Scan(&d, i{"uint"}, i{""})).To(HaveOccurred())
Expect(Scan(&d, i{"float"}, i{"b"})).To(HaveOccurred())
Expect(Scan(&d, i{"bool"}, i{"-1"})).To(HaveOccurred())
Expect(Scan(&d, i{"bool"}, i{""})).To(HaveOccurred())
Expect(Scan(&d, i{"bool"}, i{"123"})).To(HaveOccurred())
})
It("Implements Scanner", func() {
var td TimeData
now := time.Now()
Expect(Scan(&td, i{"name", "login"}, i{"hello", now.Format(time.RFC3339Nano)})).NotTo(HaveOccurred())
Expect(td.Name).To(Equal("hello"))
Expect(td.Time.UnixNano()).To(Equal(now.UnixNano()))
Expect(td.Time.Format(time.RFC3339Nano)).To(Equal(now.Format(time.RFC3339Nano)))
})
It("should time.Time RFC3339Nano", func() {
type TimeTime struct {
Time time.Time `redis:"time"`
}
now := time.Now()
var tt TimeTime
Expect(Scan(&tt, i{"time"}, i{now.Format(time.RFC3339Nano)})).NotTo(HaveOccurred())
Expect(now.Unix()).To(Equal(tt.Time.Unix()))
})
})

125
internal/hscan/structmap.go Normal file
View File

@ -0,0 +1,125 @@
package hscan
import (
"encoding"
"fmt"
"reflect"
"strings"
"sync"
"github.com/redis/go-redis/v9/internal/util"
)
// structMap contains the map of struct fields for target structs
// indexed by the struct type.
type structMap struct {
m sync.Map
}
func newStructMap() *structMap {
return new(structMap)
}
func (s *structMap) get(t reflect.Type) *structSpec {
if v, ok := s.m.Load(t); ok {
return v.(*structSpec)
}
spec := newStructSpec(t, "redis")
s.m.Store(t, spec)
return spec
}
//------------------------------------------------------------------------------
// structSpec contains the list of all fields in a target struct.
type structSpec struct {
m map[string]*structField
}
func (s *structSpec) set(tag string, sf *structField) {
s.m[tag] = sf
}
func newStructSpec(t reflect.Type, fieldTag string) *structSpec {
numField := t.NumField()
out := &structSpec{
m: make(map[string]*structField, numField),
}
for i := 0; i < numField; i++ {
f := t.Field(i)
tag := f.Tag.Get(fieldTag)
if tag == "" || tag == "-" {
continue
}
tag = strings.Split(tag, ",")[0]
if tag == "" {
continue
}
// Use the built-in decoder.
kind := f.Type.Kind()
if kind == reflect.Pointer {
kind = f.Type.Elem().Kind()
}
out.set(tag, &structField{index: i, fn: decoders[kind]})
}
return out
}
//------------------------------------------------------------------------------
// structField represents a single field in a target struct.
type structField struct {
index int
fn decoderFunc
}
//------------------------------------------------------------------------------
type StructValue struct {
spec *structSpec
value reflect.Value
}
func (s StructValue) Scan(key string, value string) error {
field, ok := s.spec.m[key]
if !ok {
return nil
}
v := s.value.Field(field.index)
isPtr := v.Kind() == reflect.Ptr
if isPtr && v.IsNil() {
v.Set(reflect.New(v.Type().Elem()))
}
if !isPtr && v.Type().Name() != "" && v.CanAddr() {
v = v.Addr()
isPtr = true
}
if isPtr && v.Type().NumMethod() > 0 && v.CanInterface() {
switch scan := v.Interface().(type) {
case Scanner:
return scan.ScanRedis(value)
case encoding.TextUnmarshaler:
return scan.UnmarshalText(util.StringToBytes(value))
}
}
if isPtr {
v = v.Elem()
}
if err := field.fn(v, value); err != nil {
t := s.value.Type()
return fmt.Errorf("cannot scan redis.result %s into struct field %s.%s of type %s, error-%s",
value, t.Name(), t.Field(field.index).Name, t.Field(field.index).Type, err.Error())
}
return nil
}

29
internal/internal.go Normal file
View File

@ -0,0 +1,29 @@
package internal
import (
"time"
"github.com/redis/go-redis/v9/internal/rand"
)
func RetryBackoff(retry int, minBackoff, maxBackoff time.Duration) time.Duration {
if retry < 0 {
panic("not reached")
}
if minBackoff == 0 {
return 0
}
d := minBackoff << uint(retry)
if d < minBackoff {
return maxBackoff
}
d = minBackoff + time.Duration(rand.Int63n(int64(d)))
if d > maxBackoff || d < minBackoff {
d = maxBackoff
}
return d
}

18
internal/internal_test.go Normal file
View File

@ -0,0 +1,18 @@
package internal
import (
"testing"
"time"
. "github.com/bsm/gomega"
)
func TestRetryBackoff(t *testing.T) {
RegisterTestingT(t)
for i := 0; i <= 16; i++ {
backoff := RetryBackoff(i, time.Millisecond, 512*time.Millisecond)
Expect(backoff >= 0).To(BeTrue())
Expect(backoff <= 512*time.Millisecond).To(BeTrue())
}
}

26
internal/log.go Normal file
View File

@ -0,0 +1,26 @@
package internal
import (
"context"
"fmt"
"log"
"os"
)
type Logging interface {
Printf(ctx context.Context, format string, v ...interface{})
}
type logger struct {
log *log.Logger
}
func (l *logger) Printf(ctx context.Context, format string, v ...interface{}) {
_ = l.log.Output(2, fmt.Sprintf(format, v...))
}
// Logger calls Output to print to the stderr.
// Arguments are handled in the manner of fmt.Print.
var Logger Logging = &logger{
log: log.New(os.Stderr, "redis: ", log.LstdFlags|log.Lshortfile),
}

63
internal/once.go Normal file
View File

@ -0,0 +1,63 @@
/*
Copyright 2014 The Camlistore Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package internal
import (
"sync"
"sync/atomic"
)
// A Once will perform a successful action exactly once.
//
// Unlike a sync.Once, this Once's func returns an error
// and is re-armed on failure.
type Once struct {
m sync.Mutex
done uint32
}
// Do calls the function f if and only if Do has not been invoked
// without error for this instance of Once. In other words, given
//
// var once Once
//
// if once.Do(f) is called multiple times, only the first call will
// invoke f, even if f has a different value in each invocation unless
// f returns an error. A new instance of Once is required for each
// function to execute.
//
// Do is intended for initialization that must be run exactly once. Since f
// is niladic, it may be necessary to use a function literal to capture the
// arguments to a function to be invoked by Do:
//
// err := config.once.Do(func() error { return config.init(filename) })
func (o *Once) Do(f func() error) error {
if atomic.LoadUint32(&o.done) == 1 {
return nil
}
// Slow-path.
o.m.Lock()
defer o.m.Unlock()
var err error
if o.done == 0 {
err = f()
if err == nil {
atomic.StoreUint32(&o.done, 1)
}
}
return err
}

View File

@ -0,0 +1,95 @@
package pool_test
import (
"context"
"fmt"
"testing"
"time"
"github.com/redis/go-redis/v9/internal/pool"
)
type poolGetPutBenchmark struct {
poolSize int
}
func (bm poolGetPutBenchmark) String() string {
return fmt.Sprintf("pool=%d", bm.poolSize)
}
func BenchmarkPoolGetPut(b *testing.B) {
ctx := context.Background()
benchmarks := []poolGetPutBenchmark{
{1},
{2},
{8},
{32},
{64},
{128},
}
for _, bm := range benchmarks {
b.Run(bm.String(), func(b *testing.B) {
connPool := pool.NewConnPool(&pool.Options{
Dialer: dummyDialer,
PoolSize: bm.poolSize,
PoolTimeout: time.Second,
ConnMaxIdleTime: time.Hour,
})
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
cn, err := connPool.Get(ctx)
if err != nil {
b.Fatal(err)
}
connPool.Put(ctx, cn)
}
})
})
}
}
type poolGetRemoveBenchmark struct {
poolSize int
}
func (bm poolGetRemoveBenchmark) String() string {
return fmt.Sprintf("pool=%d", bm.poolSize)
}
func BenchmarkPoolGetRemove(b *testing.B) {
ctx := context.Background()
benchmarks := []poolGetRemoveBenchmark{
{1},
{2},
{8},
{32},
{64},
{128},
}
for _, bm := range benchmarks {
b.Run(bm.String(), func(b *testing.B) {
connPool := pool.NewConnPool(&pool.Options{
Dialer: dummyDialer,
PoolSize: bm.poolSize,
PoolTimeout: time.Second,
ConnMaxIdleTime: time.Hour,
})
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
cn, err := connPool.Get(ctx)
if err != nil {
b.Fatal(err)
}
connPool.Remove(ctx, cn, nil)
}
})
})
}
}

127
internal/pool/conn.go Normal file
View File

@ -0,0 +1,127 @@
package pool
import (
"bufio"
"context"
"net"
"sync/atomic"
"time"
"github.com/redis/go-redis/v9/internal/proto"
)
var noDeadline = time.Time{}
type Conn struct {
usedAt int64 // atomic
netConn net.Conn
rd *proto.Reader
bw *bufio.Writer
wr *proto.Writer
Inited bool
pooled bool
createdAt time.Time
}
func NewConn(netConn net.Conn) *Conn {
cn := &Conn{
netConn: netConn,
createdAt: time.Now(),
}
cn.rd = proto.NewReader(netConn)
cn.bw = bufio.NewWriter(netConn)
cn.wr = proto.NewWriter(cn.bw)
cn.SetUsedAt(time.Now())
return cn
}
func (cn *Conn) UsedAt() time.Time {
unix := atomic.LoadInt64(&cn.usedAt)
return time.Unix(unix, 0)
}
func (cn *Conn) SetUsedAt(tm time.Time) {
atomic.StoreInt64(&cn.usedAt, tm.Unix())
}
func (cn *Conn) SetNetConn(netConn net.Conn) {
cn.netConn = netConn
cn.rd.Reset(netConn)
cn.bw.Reset(netConn)
}
func (cn *Conn) Write(b []byte) (int, error) {
return cn.netConn.Write(b)
}
func (cn *Conn) RemoteAddr() net.Addr {
if cn.netConn != nil {
return cn.netConn.RemoteAddr()
}
return nil
}
func (cn *Conn) WithReader(
ctx context.Context, timeout time.Duration, fn func(rd *proto.Reader) error,
) error {
if timeout >= 0 {
if err := cn.netConn.SetReadDeadline(cn.deadline(ctx, timeout)); err != nil {
return err
}
}
return fn(cn.rd)
}
func (cn *Conn) WithWriter(
ctx context.Context, timeout time.Duration, fn func(wr *proto.Writer) error,
) error {
if timeout >= 0 {
if err := cn.netConn.SetWriteDeadline(cn.deadline(ctx, timeout)); err != nil {
return err
}
}
if cn.bw.Buffered() > 0 {
cn.bw.Reset(cn.netConn)
}
if err := fn(cn.wr); err != nil {
return err
}
return cn.bw.Flush()
}
func (cn *Conn) Close() error {
return cn.netConn.Close()
}
func (cn *Conn) deadline(ctx context.Context, timeout time.Duration) time.Time {
tm := time.Now()
cn.SetUsedAt(tm)
if timeout > 0 {
tm = tm.Add(timeout)
}
if ctx != nil {
deadline, ok := ctx.Deadline()
if ok {
if timeout == 0 {
return deadline
}
if deadline.Before(tm) {
return deadline
}
return tm
}
}
if timeout > 0 {
return tm
}
return noDeadline
}

View File

@ -0,0 +1,49 @@
//go:build linux || darwin || dragonfly || freebsd || netbsd || openbsd || solaris || illumos
package pool
import (
"errors"
"io"
"net"
"syscall"
"time"
)
var errUnexpectedRead = errors.New("unexpected read from socket")
func connCheck(conn net.Conn) error {
// Reset previous timeout.
_ = conn.SetDeadline(time.Time{})
sysConn, ok := conn.(syscall.Conn)
if !ok {
return nil
}
rawConn, err := sysConn.SyscallConn()
if err != nil {
return err
}
var sysErr error
if err := rawConn.Read(func(fd uintptr) bool {
var buf [1]byte
n, err := syscall.Read(int(fd), buf[:])
switch {
case n == 0 && err == nil:
sysErr = io.EOF
case n > 0:
sysErr = errUnexpectedRead
case err == syscall.EAGAIN || err == syscall.EWOULDBLOCK:
sysErr = nil
default:
sysErr = err
}
return true
}); err != nil {
return err
}
return sysErr
}

View File

@ -0,0 +1,9 @@
//go:build !linux && !darwin && !dragonfly && !freebsd && !netbsd && !openbsd && !solaris && !illumos
package pool
import "net"
func connCheck(conn net.Conn) error {
return nil
}

View File

@ -0,0 +1,47 @@
//go:build linux || darwin || dragonfly || freebsd || netbsd || openbsd || solaris || illumos
package pool
import (
"net"
"net/http/httptest"
"time"
. "github.com/bsm/ginkgo/v2"
. "github.com/bsm/gomega"
)
var _ = Describe("tests conn_check with real conns", func() {
var ts *httptest.Server
var conn net.Conn
var err error
BeforeEach(func() {
ts = httptest.NewServer(nil)
conn, err = net.DialTimeout(ts.Listener.Addr().Network(), ts.Listener.Addr().String(), time.Second)
Expect(err).NotTo(HaveOccurred())
})
AfterEach(func() {
ts.Close()
})
It("good conn check", func() {
Expect(connCheck(conn)).NotTo(HaveOccurred())
Expect(conn.Close()).NotTo(HaveOccurred())
Expect(connCheck(conn)).To(HaveOccurred())
})
It("bad conn check", func() {
Expect(conn.Close()).NotTo(HaveOccurred())
Expect(connCheck(conn)).To(HaveOccurred())
})
It("check conn deadline", func() {
Expect(conn.SetDeadline(time.Now())).NotTo(HaveOccurred())
time.Sleep(time.Millisecond * 10)
Expect(connCheck(conn)).NotTo(HaveOccurred())
Expect(conn.Close()).NotTo(HaveOccurred())
})
})

View File

@ -0,0 +1,14 @@
package pool
import (
"net"
"time"
)
func (cn *Conn) SetCreatedAt(tm time.Time) {
cn.createdAt = tm
}
func (cn *Conn) NetConn() net.Conn {
return cn.netConn
}

123
internal/pool/main_test.go Normal file
View File

@ -0,0 +1,123 @@
package pool_test
import (
"context"
"fmt"
"net"
"sync"
"syscall"
"testing"
"time"
. "github.com/bsm/ginkgo/v2"
. "github.com/bsm/gomega"
)
func TestGinkgoSuite(t *testing.T) {
RegisterFailHandler(Fail)
RunSpecs(t, "pool")
}
func perform(n int, cbs ...func(int)) {
var wg sync.WaitGroup
for _, cb := range cbs {
for i := 0; i < n; i++ {
wg.Add(1)
go func(cb func(int), i int) {
defer GinkgoRecover()
defer wg.Done()
cb(i)
}(cb, i)
}
}
wg.Wait()
}
func dummyDialer(context.Context) (net.Conn, error) {
return newDummyConn(), nil
}
func newDummyConn() net.Conn {
return &dummyConn{
rawConn: new(dummyRawConn),
}
}
var (
_ net.Conn = (*dummyConn)(nil)
_ syscall.Conn = (*dummyConn)(nil)
)
type dummyConn struct {
rawConn *dummyRawConn
}
func (d *dummyConn) SyscallConn() (syscall.RawConn, error) {
return d.rawConn, nil
}
var errDummy = fmt.Errorf("dummyConn err")
func (d *dummyConn) Read(b []byte) (n int, err error) {
return 0, errDummy
}
func (d *dummyConn) Write(b []byte) (n int, err error) {
return 0, errDummy
}
func (d *dummyConn) Close() error {
d.rawConn.Close()
return nil
}
func (d *dummyConn) LocalAddr() net.Addr {
return &net.TCPAddr{}
}
func (d *dummyConn) RemoteAddr() net.Addr {
return &net.TCPAddr{}
}
func (d *dummyConn) SetDeadline(t time.Time) error {
return nil
}
func (d *dummyConn) SetReadDeadline(t time.Time) error {
return nil
}
func (d *dummyConn) SetWriteDeadline(t time.Time) error {
return nil
}
var _ syscall.RawConn = (*dummyRawConn)(nil)
type dummyRawConn struct {
mu sync.Mutex
closed bool
}
func (d *dummyRawConn) Control(f func(fd uintptr)) error {
return nil
}
func (d *dummyRawConn) Read(f func(fd uintptr) (done bool)) error {
d.mu.Lock()
defer d.mu.Unlock()
if d.closed {
return fmt.Errorf("dummyRawConn closed")
}
return nil
}
func (d *dummyRawConn) Write(f func(fd uintptr) (done bool)) error {
return nil
}
func (d *dummyRawConn) Close() {
d.mu.Lock()
d.closed = true
d.mu.Unlock()
}

518
internal/pool/pool.go Normal file
View File

@ -0,0 +1,518 @@
package pool
import (
"context"
"errors"
"net"
"sync"
"sync/atomic"
"time"
"github.com/redis/go-redis/v9/internal"
)
var (
// ErrClosed performs any operation on the closed client will return this error.
ErrClosed = errors.New("redis: client is closed")
// ErrPoolExhausted is returned from a pool connection method
// when the maximum number of database connections in the pool has been reached.
ErrPoolExhausted = errors.New("redis: connection pool exhausted")
// ErrPoolTimeout timed out waiting to get a connection from the connection pool.
ErrPoolTimeout = errors.New("redis: connection pool timeout")
)
var timers = sync.Pool{
New: func() interface{} {
t := time.NewTimer(time.Hour)
t.Stop()
return t
},
}
// Stats contains pool state information and accumulated stats.
type Stats struct {
Hits uint32 // number of times free connection was found in the pool
Misses uint32 // number of times free connection was NOT found in the pool
Timeouts uint32 // number of times a wait timeout occurred
TotalConns uint32 // number of total connections in the pool
IdleConns uint32 // number of idle connections in the pool
StaleConns uint32 // number of stale connections removed from the pool
}
type Pooler interface {
NewConn(context.Context) (*Conn, error)
CloseConn(*Conn) error
Get(context.Context) (*Conn, error)
Put(context.Context, *Conn)
Remove(context.Context, *Conn, error)
Len() int
IdleLen() int
Stats() *Stats
Close() error
}
type Options struct {
Dialer func(context.Context) (net.Conn, error)
PoolFIFO bool
PoolSize int
PoolTimeout time.Duration
MinIdleConns int
MaxIdleConns int
MaxActiveConns int
ConnMaxIdleTime time.Duration
ConnMaxLifetime time.Duration
}
type lastDialErrorWrap struct {
err error
}
type ConnPool struct {
cfg *Options
dialErrorsNum uint32 // atomic
lastDialError atomic.Value
queue chan struct{}
connsMu sync.Mutex
conns []*Conn
idleConns []*Conn
poolSize int
idleConnsLen int
stats Stats
_closed uint32 // atomic
}
var _ Pooler = (*ConnPool)(nil)
func NewConnPool(opt *Options) *ConnPool {
p := &ConnPool{
cfg: opt,
queue: make(chan struct{}, opt.PoolSize),
conns: make([]*Conn, 0, opt.PoolSize),
idleConns: make([]*Conn, 0, opt.PoolSize),
}
p.connsMu.Lock()
p.checkMinIdleConns()
p.connsMu.Unlock()
return p
}
func (p *ConnPool) checkMinIdleConns() {
if p.cfg.MinIdleConns == 0 {
return
}
for p.poolSize < p.cfg.PoolSize && p.idleConnsLen < p.cfg.MinIdleConns {
select {
case p.queue <- struct{}{}:
p.poolSize++
p.idleConnsLen++
go func() {
err := p.addIdleConn()
if err != nil && err != ErrClosed {
p.connsMu.Lock()
p.poolSize--
p.idleConnsLen--
p.connsMu.Unlock()
}
p.freeTurn()
}()
default:
return
}
}
}
func (p *ConnPool) addIdleConn() error {
cn, err := p.dialConn(context.TODO(), true)
if err != nil {
return err
}
p.connsMu.Lock()
defer p.connsMu.Unlock()
// It is not allowed to add new connections to the closed connection pool.
if p.closed() {
_ = cn.Close()
return ErrClosed
}
p.conns = append(p.conns, cn)
p.idleConns = append(p.idleConns, cn)
return nil
}
func (p *ConnPool) NewConn(ctx context.Context) (*Conn, error) {
return p.newConn(ctx, false)
}
func (p *ConnPool) newConn(ctx context.Context, pooled bool) (*Conn, error) {
if p.closed() {
return nil, ErrClosed
}
p.connsMu.Lock()
if p.cfg.MaxActiveConns > 0 && p.poolSize >= p.cfg.MaxActiveConns {
p.connsMu.Unlock()
return nil, ErrPoolExhausted
}
p.connsMu.Unlock()
cn, err := p.dialConn(ctx, pooled)
if err != nil {
return nil, err
}
p.connsMu.Lock()
defer p.connsMu.Unlock()
if p.cfg.MaxActiveConns > 0 && p.poolSize >= p.cfg.MaxActiveConns {
_ = cn.Close()
return nil, ErrPoolExhausted
}
p.conns = append(p.conns, cn)
if pooled {
// If pool is full remove the cn on next Put.
if p.poolSize >= p.cfg.PoolSize {
cn.pooled = false
} else {
p.poolSize++
}
}
return cn, nil
}
func (p *ConnPool) dialConn(ctx context.Context, pooled bool) (*Conn, error) {
if p.closed() {
return nil, ErrClosed
}
if atomic.LoadUint32(&p.dialErrorsNum) >= uint32(p.cfg.PoolSize) {
return nil, p.getLastDialError()
}
netConn, err := p.cfg.Dialer(ctx)
if err != nil {
p.setLastDialError(err)
if atomic.AddUint32(&p.dialErrorsNum, 1) == uint32(p.cfg.PoolSize) {
go p.tryDial()
}
return nil, err
}
cn := NewConn(netConn)
cn.pooled = pooled
return cn, nil
}
func (p *ConnPool) tryDial() {
for {
if p.closed() {
return
}
conn, err := p.cfg.Dialer(context.Background())
if err != nil {
p.setLastDialError(err)
time.Sleep(time.Second)
continue
}
atomic.StoreUint32(&p.dialErrorsNum, 0)
_ = conn.Close()
return
}
}
func (p *ConnPool) setLastDialError(err error) {
p.lastDialError.Store(&lastDialErrorWrap{err: err})
}
func (p *ConnPool) getLastDialError() error {
err, _ := p.lastDialError.Load().(*lastDialErrorWrap)
if err != nil {
return err.err
}
return nil
}
// Get returns existed connection from the pool or creates a new one.
func (p *ConnPool) Get(ctx context.Context) (*Conn, error) {
if p.closed() {
return nil, ErrClosed
}
if err := p.waitTurn(ctx); err != nil {
return nil, err
}
for {
p.connsMu.Lock()
cn, err := p.popIdle()
p.connsMu.Unlock()
if err != nil {
p.freeTurn()
return nil, err
}
if cn == nil {
break
}
if !p.isHealthyConn(cn) {
_ = p.CloseConn(cn)
continue
}
atomic.AddUint32(&p.stats.Hits, 1)
return cn, nil
}
atomic.AddUint32(&p.stats.Misses, 1)
newcn, err := p.newConn(ctx, true)
if err != nil {
p.freeTurn()
return nil, err
}
return newcn, nil
}
func (p *ConnPool) waitTurn(ctx context.Context) error {
select {
case <-ctx.Done():
return ctx.Err()
default:
}
select {
case p.queue <- struct{}{}:
return nil
default:
}
timer := timers.Get().(*time.Timer)
timer.Reset(p.cfg.PoolTimeout)
select {
case <-ctx.Done():
if !timer.Stop() {
<-timer.C
}
timers.Put(timer)
return ctx.Err()
case p.queue <- struct{}{}:
if !timer.Stop() {
<-timer.C
}
timers.Put(timer)
return nil
case <-timer.C:
timers.Put(timer)
atomic.AddUint32(&p.stats.Timeouts, 1)
return ErrPoolTimeout
}
}
func (p *ConnPool) freeTurn() {
<-p.queue
}
func (p *ConnPool) popIdle() (*Conn, error) {
if p.closed() {
return nil, ErrClosed
}
n := len(p.idleConns)
if n == 0 {
return nil, nil
}
var cn *Conn
if p.cfg.PoolFIFO {
cn = p.idleConns[0]
copy(p.idleConns, p.idleConns[1:])
p.idleConns = p.idleConns[:n-1]
} else {
idx := n - 1
cn = p.idleConns[idx]
p.idleConns = p.idleConns[:idx]
}
p.idleConnsLen--
p.checkMinIdleConns()
return cn, nil
}
func (p *ConnPool) Put(ctx context.Context, cn *Conn) {
if cn.rd.Buffered() > 0 {
internal.Logger.Printf(ctx, "Conn has unread data")
p.Remove(ctx, cn, BadConnError{})
return
}
if !cn.pooled {
p.Remove(ctx, cn, nil)
return
}
var shouldCloseConn bool
p.connsMu.Lock()
if p.cfg.MaxIdleConns == 0 || p.idleConnsLen < p.cfg.MaxIdleConns {
p.idleConns = append(p.idleConns, cn)
p.idleConnsLen++
} else {
p.removeConn(cn)
shouldCloseConn = true
}
p.connsMu.Unlock()
p.freeTurn()
if shouldCloseConn {
_ = p.closeConn(cn)
}
}
func (p *ConnPool) Remove(_ context.Context, cn *Conn, reason error) {
p.removeConnWithLock(cn)
p.freeTurn()
_ = p.closeConn(cn)
}
func (p *ConnPool) CloseConn(cn *Conn) error {
p.removeConnWithLock(cn)
return p.closeConn(cn)
}
func (p *ConnPool) removeConnWithLock(cn *Conn) {
p.connsMu.Lock()
defer p.connsMu.Unlock()
p.removeConn(cn)
}
func (p *ConnPool) removeConn(cn *Conn) {
for i, c := range p.conns {
if c == cn {
p.conns = append(p.conns[:i], p.conns[i+1:]...)
if cn.pooled {
p.poolSize--
p.checkMinIdleConns()
}
break
}
}
atomic.AddUint32(&p.stats.StaleConns, 1)
}
func (p *ConnPool) closeConn(cn *Conn) error {
return cn.Close()
}
// Len returns total number of connections.
func (p *ConnPool) Len() int {
p.connsMu.Lock()
n := len(p.conns)
p.connsMu.Unlock()
return n
}
// IdleLen returns number of idle connections.
func (p *ConnPool) IdleLen() int {
p.connsMu.Lock()
n := p.idleConnsLen
p.connsMu.Unlock()
return n
}
func (p *ConnPool) Stats() *Stats {
return &Stats{
Hits: atomic.LoadUint32(&p.stats.Hits),
Misses: atomic.LoadUint32(&p.stats.Misses),
Timeouts: atomic.LoadUint32(&p.stats.Timeouts),
TotalConns: uint32(p.Len()),
IdleConns: uint32(p.IdleLen()),
StaleConns: atomic.LoadUint32(&p.stats.StaleConns),
}
}
func (p *ConnPool) closed() bool {
return atomic.LoadUint32(&p._closed) == 1
}
func (p *ConnPool) Filter(fn func(*Conn) bool) error {
p.connsMu.Lock()
defer p.connsMu.Unlock()
var firstErr error
for _, cn := range p.conns {
if fn(cn) {
if err := p.closeConn(cn); err != nil && firstErr == nil {
firstErr = err
}
}
}
return firstErr
}
func (p *ConnPool) Close() error {
if !atomic.CompareAndSwapUint32(&p._closed, 0, 1) {
return ErrClosed
}
var firstErr error
p.connsMu.Lock()
for _, cn := range p.conns {
if err := p.closeConn(cn); err != nil && firstErr == nil {
firstErr = err
}
}
p.conns = nil
p.poolSize = 0
p.idleConns = nil
p.idleConnsLen = 0
p.connsMu.Unlock()
return firstErr
}
func (p *ConnPool) isHealthyConn(cn *Conn) bool {
now := time.Now()
if p.cfg.ConnMaxLifetime > 0 && now.Sub(cn.createdAt) >= p.cfg.ConnMaxLifetime {
return false
}
if p.cfg.ConnMaxIdleTime > 0 && now.Sub(cn.UsedAt()) >= p.cfg.ConnMaxIdleTime {
return false
}
if connCheck(cn.netConn) != nil {
return false
}
cn.SetUsedAt(now)
return true
}

View File

@ -0,0 +1,58 @@
package pool
import "context"
type SingleConnPool struct {
pool Pooler
cn *Conn
stickyErr error
}
var _ Pooler = (*SingleConnPool)(nil)
func NewSingleConnPool(pool Pooler, cn *Conn) *SingleConnPool {
return &SingleConnPool{
pool: pool,
cn: cn,
}
}
func (p *SingleConnPool) NewConn(ctx context.Context) (*Conn, error) {
return p.pool.NewConn(ctx)
}
func (p *SingleConnPool) CloseConn(cn *Conn) error {
return p.pool.CloseConn(cn)
}
func (p *SingleConnPool) Get(ctx context.Context) (*Conn, error) {
if p.stickyErr != nil {
return nil, p.stickyErr
}
return p.cn, nil
}
func (p *SingleConnPool) Put(ctx context.Context, cn *Conn) {}
func (p *SingleConnPool) Remove(ctx context.Context, cn *Conn, reason error) {
p.cn = nil
p.stickyErr = reason
}
func (p *SingleConnPool) Close() error {
p.cn = nil
p.stickyErr = ErrClosed
return nil
}
func (p *SingleConnPool) Len() int {
return 0
}
func (p *SingleConnPool) IdleLen() int {
return 0
}
func (p *SingleConnPool) Stats() *Stats {
return &Stats{}
}

View File

@ -0,0 +1,201 @@
package pool
import (
"context"
"errors"
"fmt"
"sync/atomic"
)
const (
stateDefault = 0
stateInited = 1
stateClosed = 2
)
type BadConnError struct {
wrapped error
}
var _ error = (*BadConnError)(nil)
func (e BadConnError) Error() string {
s := "redis: Conn is in a bad state"
if e.wrapped != nil {
s += ": " + e.wrapped.Error()
}
return s
}
func (e BadConnError) Unwrap() error {
return e.wrapped
}
//------------------------------------------------------------------------------
type StickyConnPool struct {
pool Pooler
shared int32 // atomic
state uint32 // atomic
ch chan *Conn
_badConnError atomic.Value
}
var _ Pooler = (*StickyConnPool)(nil)
func NewStickyConnPool(pool Pooler) *StickyConnPool {
p, ok := pool.(*StickyConnPool)
if !ok {
p = &StickyConnPool{
pool: pool,
ch: make(chan *Conn, 1),
}
}
atomic.AddInt32(&p.shared, 1)
return p
}
func (p *StickyConnPool) NewConn(ctx context.Context) (*Conn, error) {
return p.pool.NewConn(ctx)
}
func (p *StickyConnPool) CloseConn(cn *Conn) error {
return p.pool.CloseConn(cn)
}
func (p *StickyConnPool) Get(ctx context.Context) (*Conn, error) {
// In worst case this races with Close which is not a very common operation.
for i := 0; i < 1000; i++ {
switch atomic.LoadUint32(&p.state) {
case stateDefault:
cn, err := p.pool.Get(ctx)
if err != nil {
return nil, err
}
if atomic.CompareAndSwapUint32(&p.state, stateDefault, stateInited) {
return cn, nil
}
p.pool.Remove(ctx, cn, ErrClosed)
case stateInited:
if err := p.badConnError(); err != nil {
return nil, err
}
cn, ok := <-p.ch
if !ok {
return nil, ErrClosed
}
return cn, nil
case stateClosed:
return nil, ErrClosed
default:
panic("not reached")
}
}
return nil, fmt.Errorf("redis: StickyConnPool.Get: infinite loop")
}
func (p *StickyConnPool) Put(ctx context.Context, cn *Conn) {
defer func() {
if recover() != nil {
p.freeConn(ctx, cn)
}
}()
p.ch <- cn
}
func (p *StickyConnPool) freeConn(ctx context.Context, cn *Conn) {
if err := p.badConnError(); err != nil {
p.pool.Remove(ctx, cn, err)
} else {
p.pool.Put(ctx, cn)
}
}
func (p *StickyConnPool) Remove(ctx context.Context, cn *Conn, reason error) {
defer func() {
if recover() != nil {
p.pool.Remove(ctx, cn, ErrClosed)
}
}()
p._badConnError.Store(BadConnError{wrapped: reason})
p.ch <- cn
}
func (p *StickyConnPool) Close() error {
if shared := atomic.AddInt32(&p.shared, -1); shared > 0 {
return nil
}
for i := 0; i < 1000; i++ {
state := atomic.LoadUint32(&p.state)
if state == stateClosed {
return ErrClosed
}
if atomic.CompareAndSwapUint32(&p.state, state, stateClosed) {
close(p.ch)
cn, ok := <-p.ch
if ok {
p.freeConn(context.TODO(), cn)
}
return nil
}
}
return errors.New("redis: StickyConnPool.Close: infinite loop")
}
func (p *StickyConnPool) Reset(ctx context.Context) error {
if p.badConnError() == nil {
return nil
}
select {
case cn, ok := <-p.ch:
if !ok {
return ErrClosed
}
p.pool.Remove(ctx, cn, ErrClosed)
p._badConnError.Store(BadConnError{wrapped: nil})
default:
return errors.New("redis: StickyConnPool does not have a Conn")
}
if !atomic.CompareAndSwapUint32(&p.state, stateInited, stateDefault) {
state := atomic.LoadUint32(&p.state)
return fmt.Errorf("redis: invalid StickyConnPool state: %d", state)
}
return nil
}
func (p *StickyConnPool) badConnError() error {
if v := p._badConnError.Load(); v != nil {
if err := v.(BadConnError); err.wrapped != nil {
return err
}
}
return nil
}
func (p *StickyConnPool) Len() int {
switch atomic.LoadUint32(&p.state) {
case stateDefault:
return 0
case stateInited:
return 1
case stateClosed:
return 0
default:
panic("not reached")
}
}
func (p *StickyConnPool) IdleLen() int {
return len(p.ch)
}
func (p *StickyConnPool) Stats() *Stats {
return &Stats{}
}

356
internal/pool/pool_test.go Normal file
View File

@ -0,0 +1,356 @@
package pool_test
import (
"context"
"net"
"sync"
"testing"
"time"
. "github.com/bsm/ginkgo/v2"
. "github.com/bsm/gomega"
"github.com/redis/go-redis/v9/internal/pool"
)
var _ = Describe("ConnPool", func() {
ctx := context.Background()
var connPool *pool.ConnPool
BeforeEach(func() {
connPool = pool.NewConnPool(&pool.Options{
Dialer: dummyDialer,
PoolSize: 10,
PoolTimeout: time.Hour,
ConnMaxIdleTime: time.Millisecond,
})
})
AfterEach(func() {
connPool.Close()
})
It("should safe close", func() {
const minIdleConns = 10
var (
wg sync.WaitGroup
closedChan = make(chan struct{})
)
wg.Add(minIdleConns)
connPool = pool.NewConnPool(&pool.Options{
Dialer: func(ctx context.Context) (net.Conn, error) {
wg.Done()
<-closedChan
return &net.TCPConn{}, nil
},
PoolSize: 10,
PoolTimeout: time.Hour,
ConnMaxIdleTime: time.Millisecond,
MinIdleConns: minIdleConns,
})
wg.Wait()
Expect(connPool.Close()).NotTo(HaveOccurred())
close(closedChan)
// We wait for 1 second and believe that checkMinIdleConns has been executed.
time.Sleep(time.Second)
Expect(connPool.Stats()).To(Equal(&pool.Stats{
Hits: 0,
Misses: 0,
Timeouts: 0,
TotalConns: 0,
IdleConns: 0,
StaleConns: 0,
}))
})
It("should unblock client when conn is removed", func() {
// Reserve one connection.
cn, err := connPool.Get(ctx)
Expect(err).NotTo(HaveOccurred())
// Reserve all other connections.
var cns []*pool.Conn
for i := 0; i < 9; i++ {
cn, err := connPool.Get(ctx)
Expect(err).NotTo(HaveOccurred())
cns = append(cns, cn)
}
started := make(chan bool, 1)
done := make(chan bool, 1)
go func() {
defer GinkgoRecover()
started <- true
_, err := connPool.Get(ctx)
Expect(err).NotTo(HaveOccurred())
done <- true
connPool.Put(ctx, cn)
}()
<-started
// Check that Get is blocked.
select {
case <-done:
Fail("Get is not blocked")
case <-time.After(time.Millisecond):
// ok
}
connPool.Remove(ctx, cn, nil)
// Check that Get is unblocked.
select {
case <-done:
// ok
case <-time.After(time.Second):
Fail("Get is not unblocked")
}
for _, cn := range cns {
connPool.Put(ctx, cn)
}
})
})
var _ = Describe("MinIdleConns", func() {
const poolSize = 100
ctx := context.Background()
var minIdleConns int
var connPool *pool.ConnPool
newConnPool := func() *pool.ConnPool {
connPool := pool.NewConnPool(&pool.Options{
Dialer: dummyDialer,
PoolSize: poolSize,
MinIdleConns: minIdleConns,
PoolTimeout: 100 * time.Millisecond,
ConnMaxIdleTime: -1,
})
Eventually(func() int {
return connPool.Len()
}).Should(Equal(minIdleConns))
return connPool
}
assert := func() {
It("has idle connections when created", func() {
Expect(connPool.Len()).To(Equal(minIdleConns))
Expect(connPool.IdleLen()).To(Equal(minIdleConns))
})
Context("after Get", func() {
var cn *pool.Conn
BeforeEach(func() {
var err error
cn, err = connPool.Get(ctx)
Expect(err).NotTo(HaveOccurred())
Eventually(func() int {
return connPool.Len()
}).Should(Equal(minIdleConns + 1))
})
It("has idle connections", func() {
Expect(connPool.Len()).To(Equal(minIdleConns + 1))
Expect(connPool.IdleLen()).To(Equal(minIdleConns))
})
Context("after Remove", func() {
BeforeEach(func() {
connPool.Remove(ctx, cn, nil)
})
It("has idle connections", func() {
Expect(connPool.Len()).To(Equal(minIdleConns))
Expect(connPool.IdleLen()).To(Equal(minIdleConns))
})
})
})
Describe("Get does not exceed pool size", func() {
var mu sync.RWMutex
var cns []*pool.Conn
BeforeEach(func() {
cns = make([]*pool.Conn, 0)
perform(poolSize, func(_ int) {
defer GinkgoRecover()
cn, err := connPool.Get(ctx)
Expect(err).NotTo(HaveOccurred())
mu.Lock()
cns = append(cns, cn)
mu.Unlock()
})
Eventually(func() int {
return connPool.Len()
}).Should(BeNumerically(">=", poolSize))
})
It("Get is blocked", func() {
done := make(chan struct{})
go func() {
connPool.Get(ctx)
close(done)
}()
select {
case <-done:
Fail("Get is not blocked")
case <-time.After(time.Millisecond):
// ok
}
select {
case <-done:
// ok
case <-time.After(time.Second):
Fail("Get is not unblocked")
}
})
Context("after Put", func() {
BeforeEach(func() {
perform(len(cns), func(i int) {
mu.RLock()
connPool.Put(ctx, cns[i])
mu.RUnlock()
})
Eventually(func() int {
return connPool.Len()
}).Should(Equal(poolSize))
})
It("pool.Len is back to normal", func() {
Expect(connPool.Len()).To(Equal(poolSize))
Expect(connPool.IdleLen()).To(Equal(poolSize))
})
})
Context("after Remove", func() {
BeforeEach(func() {
perform(len(cns), func(i int) {
mu.RLock()
connPool.Remove(ctx, cns[i], nil)
mu.RUnlock()
})
Eventually(func() int {
return connPool.Len()
}).Should(Equal(minIdleConns))
})
It("has idle connections", func() {
Expect(connPool.Len()).To(Equal(minIdleConns))
Expect(connPool.IdleLen()).To(Equal(minIdleConns))
})
})
})
}
Context("minIdleConns = 1", func() {
BeforeEach(func() {
minIdleConns = 1
connPool = newConnPool()
})
AfterEach(func() {
connPool.Close()
})
assert()
})
Context("minIdleConns = 32", func() {
BeforeEach(func() {
minIdleConns = 32
connPool = newConnPool()
})
AfterEach(func() {
connPool.Close()
})
assert()
})
})
var _ = Describe("race", func() {
ctx := context.Background()
var connPool *pool.ConnPool
var C, N int
BeforeEach(func() {
C, N = 10, 1000
if testing.Short() {
C = 4
N = 100
}
})
AfterEach(func() {
connPool.Close()
})
It("does not happen on Get, Put, and Remove", func() {
connPool = pool.NewConnPool(&pool.Options{
Dialer: dummyDialer,
PoolSize: 10,
PoolTimeout: time.Minute,
ConnMaxIdleTime: time.Millisecond,
})
perform(C, func(id int) {
for i := 0; i < N; i++ {
cn, err := connPool.Get(ctx)
Expect(err).NotTo(HaveOccurred())
if err == nil {
connPool.Put(ctx, cn)
}
}
}, func(id int) {
for i := 0; i < N; i++ {
cn, err := connPool.Get(ctx)
Expect(err).NotTo(HaveOccurred())
if err == nil {
connPool.Remove(ctx, cn, nil)
}
}
})
})
It("limit the number of connections", func() {
opt := &pool.Options{
Dialer: func(ctx context.Context) (net.Conn, error) {
return &net.TCPConn{}, nil
},
PoolSize: 1000,
MinIdleConns: 50,
PoolTimeout: 3 * time.Second,
}
p := pool.NewConnPool(opt)
var wg sync.WaitGroup
for i := 0; i < opt.PoolSize; i++ {
wg.Add(1)
go func() {
defer wg.Done()
_, _ = p.Get(ctx)
}()
}
wg.Wait()
stats := p.Stats()
Expect(stats.IdleConns).To(Equal(uint32(0)))
Expect(stats.TotalConns).To(Equal(uint32(opt.PoolSize)))
})
})

View File

@ -0,0 +1,13 @@
package proto_test
import (
"testing"
. "github.com/bsm/ginkgo/v2"
. "github.com/bsm/gomega"
)
func TestGinkgoSuite(t *testing.T) {
RegisterFailHandler(Fail)
RunSpecs(t, "proto")
}

552
internal/proto/reader.go Normal file
View File

@ -0,0 +1,552 @@
package proto
import (
"bufio"
"errors"
"fmt"
"io"
"math"
"math/big"
"strconv"
"github.com/redis/go-redis/v9/internal/util"
)
// redis resp protocol data type.
const (
RespStatus = '+' // +<string>\r\n
RespError = '-' // -<string>\r\n
RespString = '$' // $<length>\r\n<bytes>\r\n
RespInt = ':' // :<number>\r\n
RespNil = '_' // _\r\n
RespFloat = ',' // ,<floating-point-number>\r\n (golang float)
RespBool = '#' // true: #t\r\n false: #f\r\n
RespBlobError = '!' // !<length>\r\n<bytes>\r\n
RespVerbatim = '=' // =<length>\r\nFORMAT:<bytes>\r\n
RespBigInt = '(' // (<big number>\r\n
RespArray = '*' // *<len>\r\n... (same as resp2)
RespMap = '%' // %<len>\r\n(key)\r\n(value)\r\n... (golang map)
RespSet = '~' // ~<len>\r\n... (same as Array)
RespAttr = '|' // |<len>\r\n(key)\r\n(value)\r\n... + command reply
RespPush = '>' // ><len>\r\n... (same as Array)
)
// Not used temporarily.
// Redis has not used these two data types for the time being, and will implement them later.
// Streamed = "EOF:"
// StreamedAggregated = '?'
//------------------------------------------------------------------------------
const Nil = RedisError("redis: nil") // nolint:errname
type RedisError string
func (e RedisError) Error() string { return string(e) }
func (RedisError) RedisError() {}
func ParseErrorReply(line []byte) error {
return RedisError(line[1:])
}
//------------------------------------------------------------------------------
type Reader struct {
rd *bufio.Reader
}
func NewReader(rd io.Reader) *Reader {
return &Reader{
rd: bufio.NewReader(rd),
}
}
func (r *Reader) Buffered() int {
return r.rd.Buffered()
}
func (r *Reader) Peek(n int) ([]byte, error) {
return r.rd.Peek(n)
}
func (r *Reader) Reset(rd io.Reader) {
r.rd.Reset(rd)
}
// PeekReplyType returns the data type of the next response without advancing the Reader,
// and discard the attribute type.
func (r *Reader) PeekReplyType() (byte, error) {
b, err := r.rd.Peek(1)
if err != nil {
return 0, err
}
if b[0] == RespAttr {
if err = r.DiscardNext(); err != nil {
return 0, err
}
return r.PeekReplyType()
}
return b[0], nil
}
// ReadLine Return a valid reply, it will check the protocol or redis error,
// and discard the attribute type.
func (r *Reader) ReadLine() ([]byte, error) {
line, err := r.readLine()
if err != nil {
return nil, err
}
switch line[0] {
case RespError:
return nil, ParseErrorReply(line)
case RespNil:
return nil, Nil
case RespBlobError:
var blobErr string
blobErr, err = r.readStringReply(line)
if err == nil {
err = RedisError(blobErr)
}
return nil, err
case RespAttr:
if err = r.Discard(line); err != nil {
return nil, err
}
return r.ReadLine()
}
// Compatible with RESP2
if IsNilReply(line) {
return nil, Nil
}
return line, nil
}
// readLine returns an error if:
// - there is a pending read error;
// - or line does not end with \r\n.
func (r *Reader) readLine() ([]byte, error) {
b, err := r.rd.ReadSlice('\n')
if err != nil {
if err != bufio.ErrBufferFull {
return nil, err
}
full := make([]byte, len(b))
copy(full, b)
b, err = r.rd.ReadBytes('\n')
if err != nil {
return nil, err
}
full = append(full, b...) //nolint:makezero
b = full
}
if len(b) <= 2 || b[len(b)-1] != '\n' || b[len(b)-2] != '\r' {
return nil, fmt.Errorf("redis: invalid reply: %q", b)
}
return b[:len(b)-2], nil
}
func (r *Reader) ReadReply() (interface{}, error) {
line, err := r.ReadLine()
if err != nil {
return nil, err
}
switch line[0] {
case RespStatus:
return string(line[1:]), nil
case RespInt:
return util.ParseInt(line[1:], 10, 64)
case RespFloat:
return r.readFloat(line)
case RespBool:
return r.readBool(line)
case RespBigInt:
return r.readBigInt(line)
case RespString:
return r.readStringReply(line)
case RespVerbatim:
return r.readVerb(line)
case RespArray, RespSet, RespPush:
return r.readSlice(line)
case RespMap:
return r.readMap(line)
}
return nil, fmt.Errorf("redis: can't parse %.100q", line)
}
func (r *Reader) readFloat(line []byte) (float64, error) {
v := string(line[1:])
switch string(line[1:]) {
case "inf":
return math.Inf(1), nil
case "-inf":
return math.Inf(-1), nil
case "nan", "-nan":
return math.NaN(), nil
}
return strconv.ParseFloat(v, 64)
}
func (r *Reader) readBool(line []byte) (bool, error) {
switch string(line[1:]) {
case "t":
return true, nil
case "f":
return false, nil
}
return false, fmt.Errorf("redis: can't parse bool reply: %q", line)
}
func (r *Reader) readBigInt(line []byte) (*big.Int, error) {
i := new(big.Int)
if i, ok := i.SetString(string(line[1:]), 10); ok {
return i, nil
}
return nil, fmt.Errorf("redis: can't parse bigInt reply: %q", line)
}
func (r *Reader) readStringReply(line []byte) (string, error) {
n, err := replyLen(line)
if err != nil {
return "", err
}
b := make([]byte, n+2)
_, err = io.ReadFull(r.rd, b)
if err != nil {
return "", err
}
return util.BytesToString(b[:n]), nil
}
func (r *Reader) readVerb(line []byte) (string, error) {
s, err := r.readStringReply(line)
if err != nil {
return "", err
}
if len(s) < 4 || s[3] != ':' {
return "", fmt.Errorf("redis: can't parse verbatim string reply: %q", line)
}
return s[4:], nil
}
func (r *Reader) readSlice(line []byte) ([]interface{}, error) {
n, err := replyLen(line)
if err != nil {
return nil, err
}
val := make([]interface{}, n)
for i := 0; i < len(val); i++ {
v, err := r.ReadReply()
if err != nil {
if err == Nil {
val[i] = nil
continue
}
if err, ok := err.(RedisError); ok {
val[i] = err
continue
}
return nil, err
}
val[i] = v
}
return val, nil
}
func (r *Reader) readMap(line []byte) (map[interface{}]interface{}, error) {
n, err := replyLen(line)
if err != nil {
return nil, err
}
m := make(map[interface{}]interface{}, n)
for i := 0; i < n; i++ {
k, err := r.ReadReply()
if err != nil {
return nil, err
}
v, err := r.ReadReply()
if err != nil {
if err == Nil {
m[k] = nil
continue
}
if err, ok := err.(RedisError); ok {
m[k] = err
continue
}
return nil, err
}
m[k] = v
}
return m, nil
}
// -------------------------------
func (r *Reader) ReadInt() (int64, error) {
line, err := r.ReadLine()
if err != nil {
return 0, err
}
switch line[0] {
case RespInt, RespStatus:
return util.ParseInt(line[1:], 10, 64)
case RespString:
s, err := r.readStringReply(line)
if err != nil {
return 0, err
}
return util.ParseInt([]byte(s), 10, 64)
case RespBigInt:
b, err := r.readBigInt(line)
if err != nil {
return 0, err
}
if !b.IsInt64() {
return 0, fmt.Errorf("bigInt(%s) value out of range", b.String())
}
return b.Int64(), nil
}
return 0, fmt.Errorf("redis: can't parse int reply: %.100q", line)
}
func (r *Reader) ReadUint() (uint64, error) {
line, err := r.ReadLine()
if err != nil {
return 0, err
}
switch line[0] {
case RespInt, RespStatus:
return util.ParseUint(line[1:], 10, 64)
case RespString:
s, err := r.readStringReply(line)
if err != nil {
return 0, err
}
return util.ParseUint([]byte(s), 10, 64)
case RespBigInt:
b, err := r.readBigInt(line)
if err != nil {
return 0, err
}
if !b.IsUint64() {
return 0, fmt.Errorf("bigInt(%s) value out of range", b.String())
}
return b.Uint64(), nil
}
return 0, fmt.Errorf("redis: can't parse uint reply: %.100q", line)
}
func (r *Reader) ReadFloat() (float64, error) {
line, err := r.ReadLine()
if err != nil {
return 0, err
}
switch line[0] {
case RespFloat:
return r.readFloat(line)
case RespStatus:
return strconv.ParseFloat(string(line[1:]), 64)
case RespString:
s, err := r.readStringReply(line)
if err != nil {
return 0, err
}
return strconv.ParseFloat(s, 64)
}
return 0, fmt.Errorf("redis: can't parse float reply: %.100q", line)
}
func (r *Reader) ReadString() (string, error) {
line, err := r.ReadLine()
if err != nil {
return "", err
}
switch line[0] {
case RespStatus, RespInt, RespFloat:
return string(line[1:]), nil
case RespString:
return r.readStringReply(line)
case RespBool:
b, err := r.readBool(line)
return strconv.FormatBool(b), err
case RespVerbatim:
return r.readVerb(line)
case RespBigInt:
b, err := r.readBigInt(line)
if err != nil {
return "", err
}
return b.String(), nil
}
return "", fmt.Errorf("redis: can't parse reply=%.100q reading string", line)
}
func (r *Reader) ReadBool() (bool, error) {
s, err := r.ReadString()
if err != nil {
return false, err
}
return s == "OK" || s == "1" || s == "true", nil
}
func (r *Reader) ReadSlice() ([]interface{}, error) {
line, err := r.ReadLine()
if err != nil {
return nil, err
}
return r.readSlice(line)
}
// ReadFixedArrayLen read fixed array length.
func (r *Reader) ReadFixedArrayLen(fixedLen int) error {
n, err := r.ReadArrayLen()
if err != nil {
return err
}
if n != fixedLen {
return fmt.Errorf("redis: got %d elements in the array, wanted %d", n, fixedLen)
}
return nil
}
// ReadArrayLen Read and return the length of the array.
func (r *Reader) ReadArrayLen() (int, error) {
line, err := r.ReadLine()
if err != nil {
return 0, err
}
switch line[0] {
case RespArray, RespSet, RespPush:
return replyLen(line)
default:
return 0, fmt.Errorf("redis: can't parse array/set/push reply: %.100q", line)
}
}
// ReadFixedMapLen reads fixed map length.
func (r *Reader) ReadFixedMapLen(fixedLen int) error {
n, err := r.ReadMapLen()
if err != nil {
return err
}
if n != fixedLen {
return fmt.Errorf("redis: got %d elements in the map, wanted %d", n, fixedLen)
}
return nil
}
// ReadMapLen reads the length of the map type.
// If responding to the array type (RespArray/RespSet/RespPush),
// it must be a multiple of 2 and return n/2.
// Other types will return an error.
func (r *Reader) ReadMapLen() (int, error) {
line, err := r.ReadLine()
if err != nil {
return 0, err
}
switch line[0] {
case RespMap:
return replyLen(line)
case RespArray, RespSet, RespPush:
// Some commands and RESP2 protocol may respond to array types.
n, err := replyLen(line)
if err != nil {
return 0, err
}
if n%2 != 0 {
return 0, fmt.Errorf("redis: the length of the array must be a multiple of 2, got: %d", n)
}
return n / 2, nil
default:
return 0, fmt.Errorf("redis: can't parse map reply: %.100q", line)
}
}
// DiscardNext read and discard the data represented by the next line.
func (r *Reader) DiscardNext() error {
line, err := r.readLine()
if err != nil {
return err
}
return r.Discard(line)
}
// Discard the data represented by line.
func (r *Reader) Discard(line []byte) (err error) {
if len(line) == 0 {
return errors.New("redis: invalid line")
}
switch line[0] {
case RespStatus, RespError, RespInt, RespNil, RespFloat, RespBool, RespBigInt:
return nil
}
n, err := replyLen(line)
if err != nil && err != Nil {
return err
}
switch line[0] {
case RespBlobError, RespString, RespVerbatim:
// +\r\n
_, err = r.rd.Discard(n + 2)
return err
case RespArray, RespSet, RespPush:
for i := 0; i < n; i++ {
if err = r.DiscardNext(); err != nil {
return err
}
}
return nil
case RespMap, RespAttr:
// Read key & value.
for i := 0; i < n*2; i++ {
if err = r.DiscardNext(); err != nil {
return err
}
}
return nil
}
return fmt.Errorf("redis: can't parse %.100q", line)
}
func replyLen(line []byte) (n int, err error) {
n, err = util.Atoi(line[1:])
if err != nil {
return 0, err
}
if n < -1 {
return 0, fmt.Errorf("redis: invalid reply: %q", line)
}
switch line[0] {
case RespString, RespVerbatim, RespBlobError,
RespArray, RespSet, RespPush, RespMap, RespAttr:
if n == -1 {
return 0, Nil
}
}
return n, nil
}
// IsNilReply detects redis.Nil of RESP2.
func IsNilReply(line []byte) bool {
return len(line) == 3 &&
(line[0] == RespString || line[0] == RespArray) &&
line[1] == '-' && line[2] == '1'
}

View File

@ -0,0 +1,100 @@
package proto_test
import (
"bytes"
"io"
"testing"
"github.com/redis/go-redis/v9/internal/proto"
)
func BenchmarkReader_ParseReply_Status(b *testing.B) {
benchmarkParseReply(b, "+OK\r\n", false)
}
func BenchmarkReader_ParseReply_Int(b *testing.B) {
benchmarkParseReply(b, ":1\r\n", false)
}
func BenchmarkReader_ParseReply_Float(b *testing.B) {
benchmarkParseReply(b, ",123.456\r\n", false)
}
func BenchmarkReader_ParseReply_Bool(b *testing.B) {
benchmarkParseReply(b, "#t\r\n", false)
}
func BenchmarkReader_ParseReply_BigInt(b *testing.B) {
benchmarkParseReply(b, "(3492890328409238509324850943850943825024385\r\n", false)
}
func BenchmarkReader_ParseReply_Error(b *testing.B) {
benchmarkParseReply(b, "-Error message\r\n", true)
}
func BenchmarkReader_ParseReply_Nil(b *testing.B) {
benchmarkParseReply(b, "_\r\n", true)
}
func BenchmarkReader_ParseReply_BlobError(b *testing.B) {
benchmarkParseReply(b, "!21\r\nSYNTAX invalid syntax", true)
}
func BenchmarkReader_ParseReply_String(b *testing.B) {
benchmarkParseReply(b, "$5\r\nhello\r\n", false)
}
func BenchmarkReader_ParseReply_Verb(b *testing.B) {
benchmarkParseReply(b, "$9\r\ntxt:hello\r\n", false)
}
func BenchmarkReader_ParseReply_Slice(b *testing.B) {
benchmarkParseReply(b, "*2\r\n$5\r\nhello\r\n$5\r\nworld\r\n", false)
}
func BenchmarkReader_ParseReply_Set(b *testing.B) {
benchmarkParseReply(b, "~2\r\n$5\r\nhello\r\n$5\r\nworld\r\n", false)
}
func BenchmarkReader_ParseReply_Push(b *testing.B) {
benchmarkParseReply(b, ">2\r\n$5\r\nhello\r\n$5\r\nworld\r\n", false)
}
func BenchmarkReader_ParseReply_Map(b *testing.B) {
benchmarkParseReply(b, "%2\r\n$5\r\nhello\r\n$5\r\nworld\r\n+key\r\n+value\r\n", false)
}
func BenchmarkReader_ParseReply_Attr(b *testing.B) {
benchmarkParseReply(b, "%1\r\n+key\r\n+value\r\n+hello\r\n", false)
}
func TestReader_ReadLine(t *testing.T) {
original := bytes.Repeat([]byte("a"), 8192)
original[len(original)-2] = '\r'
original[len(original)-1] = '\n'
r := proto.NewReader(bytes.NewReader(original))
read, err := r.ReadLine()
if err != nil && err != io.EOF {
t.Errorf("Should be able to read the full buffer: %v", err)
}
if !bytes.Equal(read, original[:len(original)-2]) {
t.Errorf("Values must be equal: %d expected %d", len(read), len(original[:len(original)-2]))
}
}
func benchmarkParseReply(b *testing.B, reply string, wanterr bool) {
buf := new(bytes.Buffer)
for i := 0; i < b.N; i++ {
buf.WriteString(reply)
}
p := proto.NewReader(buf)
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, err := p.ReadReply()
if !wanterr && err != nil {
b.Fatal(err)
}
}
}

185
internal/proto/scan.go Normal file
View File

@ -0,0 +1,185 @@
package proto
import (
"encoding"
"fmt"
"net"
"reflect"
"time"
"github.com/redis/go-redis/v9/internal/util"
)
// Scan parses bytes `b` to `v` with appropriate type.
//
//nolint:gocyclo
func Scan(b []byte, v interface{}) error {
switch v := v.(type) {
case nil:
return fmt.Errorf("redis: Scan(nil)")
case *string:
*v = util.BytesToString(b)
return nil
case *[]byte:
*v = b
return nil
case *int:
var err error
*v, err = util.Atoi(b)
return err
case *int8:
n, err := util.ParseInt(b, 10, 8)
if err != nil {
return err
}
*v = int8(n)
return nil
case *int16:
n, err := util.ParseInt(b, 10, 16)
if err != nil {
return err
}
*v = int16(n)
return nil
case *int32:
n, err := util.ParseInt(b, 10, 32)
if err != nil {
return err
}
*v = int32(n)
return nil
case *int64:
n, err := util.ParseInt(b, 10, 64)
if err != nil {
return err
}
*v = n
return nil
case *uint:
n, err := util.ParseUint(b, 10, 64)
if err != nil {
return err
}
*v = uint(n)
return nil
case *uint8:
n, err := util.ParseUint(b, 10, 8)
if err != nil {
return err
}
*v = uint8(n)
return nil
case *uint16:
n, err := util.ParseUint(b, 10, 16)
if err != nil {
return err
}
*v = uint16(n)
return nil
case *uint32:
n, err := util.ParseUint(b, 10, 32)
if err != nil {
return err
}
*v = uint32(n)
return nil
case *uint64:
n, err := util.ParseUint(b, 10, 64)
if err != nil {
return err
}
*v = n
return nil
case *float32:
n, err := util.ParseFloat(b, 32)
if err != nil {
return err
}
*v = float32(n)
return err
case *float64:
var err error
*v, err = util.ParseFloat(b, 64)
return err
case *bool:
*v = len(b) == 1 && b[0] == '1'
return nil
case *time.Time:
var err error
*v, err = time.Parse(time.RFC3339Nano, util.BytesToString(b))
return err
case *time.Duration:
n, err := util.ParseInt(b, 10, 64)
if err != nil {
return err
}
*v = time.Duration(n)
return nil
case encoding.BinaryUnmarshaler:
return v.UnmarshalBinary(b)
case *net.IP:
*v = b
return nil
default:
return fmt.Errorf(
"redis: can't unmarshal %T (consider implementing BinaryUnmarshaler)", v)
}
}
func ScanSlice(data []string, slice interface{}) error {
v := reflect.ValueOf(slice)
if !v.IsValid() {
return fmt.Errorf("redis: ScanSlice(nil)")
}
if v.Kind() != reflect.Ptr {
return fmt.Errorf("redis: ScanSlice(non-pointer %T)", slice)
}
v = v.Elem()
if v.Kind() != reflect.Slice {
return fmt.Errorf("redis: ScanSlice(non-slice %T)", slice)
}
next := makeSliceNextElemFunc(v)
for i, s := range data {
elem := next()
if err := Scan([]byte(s), elem.Addr().Interface()); err != nil {
err = fmt.Errorf("redis: ScanSlice index=%d value=%q failed: %w", i, s, err)
return err
}
}
return nil
}
func makeSliceNextElemFunc(v reflect.Value) func() reflect.Value {
elemType := v.Type().Elem()
if elemType.Kind() == reflect.Ptr {
elemType = elemType.Elem()
return func() reflect.Value {
if v.Len() < v.Cap() {
v.Set(v.Slice(0, v.Len()+1))
elem := v.Index(v.Len() - 1)
if elem.IsNil() {
elem.Set(reflect.New(elemType))
}
return elem.Elem()
}
elem := reflect.New(elemType)
v.Set(reflect.Append(v, elem))
return elem.Elem()
}
}
zero := reflect.Zero(elemType)
return func() reflect.Value {
if v.Len() < v.Cap() {
v.Set(v.Slice(0, v.Len()+1))
return v.Index(v.Len() - 1)
}
v.Set(reflect.Append(v, zero))
return v.Index(v.Len() - 1)
}
}

View File

@ -0,0 +1,50 @@
package proto_test
import (
"encoding/json"
. "github.com/bsm/ginkgo/v2"
. "github.com/bsm/gomega"
"github.com/redis/go-redis/v9/internal/proto"
)
type testScanSliceStruct struct {
ID int
Name string
}
func (s *testScanSliceStruct) MarshalBinary() ([]byte, error) {
return json.Marshal(s)
}
func (s *testScanSliceStruct) UnmarshalBinary(b []byte) error {
return json.Unmarshal(b, s)
}
var _ = Describe("ScanSlice", func() {
data := []string{
`{"ID":-1,"Name":"Back Yu"}`,
`{"ID":1,"Name":"szyhf"}`,
}
It("[]testScanSliceStruct", func() {
var slice []testScanSliceStruct
err := proto.ScanSlice(data, &slice)
Expect(err).NotTo(HaveOccurred())
Expect(slice).To(Equal([]testScanSliceStruct{
{-1, "Back Yu"},
{1, "szyhf"},
}))
})
It("var testContainer []*testScanSliceStruct", func() {
var slice []*testScanSliceStruct
err := proto.ScanSlice(data, &slice)
Expect(err).NotTo(HaveOccurred())
Expect(slice).To(Equal([]*testScanSliceStruct{
{-1, "Back Yu"},
{1, "szyhf"},
}))
})
})

189
internal/proto/writer.go Normal file
View File

@ -0,0 +1,189 @@
package proto
import (
"encoding"
"fmt"
"io"
"net"
"strconv"
"time"
"github.com/redis/go-redis/v9/internal/util"
)
type writer interface {
io.Writer
io.ByteWriter
// WriteString implement io.StringWriter.
WriteString(s string) (n int, err error)
}
type Writer struct {
writer
lenBuf []byte
numBuf []byte
}
func NewWriter(wr writer) *Writer {
return &Writer{
writer: wr,
lenBuf: make([]byte, 64),
numBuf: make([]byte, 64),
}
}
func (w *Writer) WriteArgs(args []interface{}) error {
if err := w.WriteByte(RespArray); err != nil {
return err
}
if err := w.writeLen(len(args)); err != nil {
return err
}
for _, arg := range args {
if err := w.WriteArg(arg); err != nil {
return err
}
}
return nil
}
func (w *Writer) writeLen(n int) error {
w.lenBuf = strconv.AppendUint(w.lenBuf[:0], uint64(n), 10)
w.lenBuf = append(w.lenBuf, '\r', '\n')
_, err := w.Write(w.lenBuf)
return err
}
func (w *Writer) WriteArg(v interface{}) error {
switch v := v.(type) {
case nil:
return w.string("")
case string:
return w.string(v)
case *string:
return w.string(*v)
case []byte:
return w.bytes(v)
case int:
return w.int(int64(v))
case *int:
return w.int(int64(*v))
case int8:
return w.int(int64(v))
case *int8:
return w.int(int64(*v))
case int16:
return w.int(int64(v))
case *int16:
return w.int(int64(*v))
case int32:
return w.int(int64(v))
case *int32:
return w.int(int64(*v))
case int64:
return w.int(v)
case *int64:
return w.int(*v)
case uint:
return w.uint(uint64(v))
case *uint:
return w.uint(uint64(*v))
case uint8:
return w.uint(uint64(v))
case *uint8:
return w.uint(uint64(*v))
case uint16:
return w.uint(uint64(v))
case *uint16:
return w.uint(uint64(*v))
case uint32:
return w.uint(uint64(v))
case *uint32:
return w.uint(uint64(*v))
case uint64:
return w.uint(v)
case *uint64:
return w.uint(*v)
case float32:
return w.float(float64(v))
case *float32:
return w.float(float64(*v))
case float64:
return w.float(v)
case *float64:
return w.float(*v)
case bool:
if v {
return w.int(1)
}
return w.int(0)
case *bool:
if *v {
return w.int(1)
}
return w.int(0)
case time.Time:
w.numBuf = v.AppendFormat(w.numBuf[:0], time.RFC3339Nano)
return w.bytes(w.numBuf)
case time.Duration:
return w.int(v.Nanoseconds())
case encoding.BinaryMarshaler:
b, err := v.MarshalBinary()
if err != nil {
return err
}
return w.bytes(b)
case net.IP:
return w.bytes(v)
default:
return fmt.Errorf(
"redis: can't marshal %T (implement encoding.BinaryMarshaler)", v)
}
}
func (w *Writer) bytes(b []byte) error {
if err := w.WriteByte(RespString); err != nil {
return err
}
if err := w.writeLen(len(b)); err != nil {
return err
}
if _, err := w.Write(b); err != nil {
return err
}
return w.crlf()
}
func (w *Writer) string(s string) error {
return w.bytes(util.StringToBytes(s))
}
func (w *Writer) uint(n uint64) error {
w.numBuf = strconv.AppendUint(w.numBuf[:0], n, 10)
return w.bytes(w.numBuf)
}
func (w *Writer) int(n int64) error {
w.numBuf = strconv.AppendInt(w.numBuf[:0], n, 10)
return w.bytes(w.numBuf)
}
func (w *Writer) float(f float64) error {
w.numBuf = strconv.AppendFloat(w.numBuf[:0], f, 'f', -1, 64)
return w.bytes(w.numBuf)
}
func (w *Writer) crlf() error {
if err := w.WriteByte('\r'); err != nil {
return err
}
return w.WriteByte('\n')
}

View File

@ -0,0 +1,154 @@
package proto_test
import (
"bytes"
"encoding"
"fmt"
"net"
"testing"
"time"
. "github.com/bsm/ginkgo/v2"
. "github.com/bsm/gomega"
"github.com/redis/go-redis/v9/internal/proto"
"github.com/redis/go-redis/v9/internal/util"
)
type MyType struct{}
var _ encoding.BinaryMarshaler = (*MyType)(nil)
func (t *MyType) MarshalBinary() ([]byte, error) {
return []byte("hello"), nil
}
var _ = Describe("WriteBuffer", func() {
var buf *bytes.Buffer
var wr *proto.Writer
BeforeEach(func() {
buf = new(bytes.Buffer)
wr = proto.NewWriter(buf)
})
It("should write args", func() {
err := wr.WriteArgs([]interface{}{
"string",
12,
34.56,
[]byte{'b', 'y', 't', 'e', 's'},
true,
nil,
})
Expect(err).NotTo(HaveOccurred())
Expect(buf.Bytes()).To(Equal([]byte("*6\r\n" +
"$6\r\nstring\r\n" +
"$2\r\n12\r\n" +
"$5\r\n34.56\r\n" +
"$5\r\nbytes\r\n" +
"$1\r\n1\r\n" +
"$0\r\n" +
"\r\n")))
})
It("should append time", func() {
tm := time.Date(2019, 1, 1, 9, 45, 10, 222125, time.UTC)
err := wr.WriteArgs([]interface{}{tm})
Expect(err).NotTo(HaveOccurred())
Expect(buf.Len()).To(Equal(41))
})
It("should append marshalable args", func() {
err := wr.WriteArgs([]interface{}{&MyType{}})
Expect(err).NotTo(HaveOccurred())
Expect(buf.Len()).To(Equal(15))
})
It("should append net.IP", func() {
ip := net.ParseIP("192.168.1.1")
err := wr.WriteArgs([]interface{}{ip})
Expect(err).NotTo(HaveOccurred())
Expect(buf.String()).To(Equal(fmt.Sprintf("*1\r\n$16\r\n%s\r\n", bytes.NewBuffer(ip))))
})
})
type discard struct{}
func (discard) Write(b []byte) (int, error) {
return len(b), nil
}
func (discard) WriteString(s string) (int, error) {
return len(s), nil
}
func (discard) WriteByte(c byte) error {
return nil
}
func BenchmarkWriteBuffer_Append(b *testing.B) {
buf := proto.NewWriter(discard{})
args := []interface{}{"hello", "world", "foo", "bar"}
for i := 0; i < b.N; i++ {
err := buf.WriteArgs(args)
if err != nil {
b.Fatal(err)
}
}
}
var _ = Describe("WriteArg", func() {
var buf *bytes.Buffer
var wr *proto.Writer
BeforeEach(func() {
buf = new(bytes.Buffer)
wr = proto.NewWriter(buf)
})
args := map[any]string{
"hello": "$5\r\nhello\r\n",
int(10): "$2\r\n10\r\n",
util.ToPtr(int(10)): "$2\r\n10\r\n",
int8(10): "$2\r\n10\r\n",
util.ToPtr(int8(10)): "$2\r\n10\r\n",
int16(10): "$2\r\n10\r\n",
util.ToPtr(int16(10)): "$2\r\n10\r\n",
int32(10): "$2\r\n10\r\n",
util.ToPtr(int32(10)): "$2\r\n10\r\n",
int64(10): "$2\r\n10\r\n",
util.ToPtr(int64(10)): "$2\r\n10\r\n",
uint(10): "$2\r\n10\r\n",
util.ToPtr(uint(10)): "$2\r\n10\r\n",
uint8(10): "$2\r\n10\r\n",
util.ToPtr(uint8(10)): "$2\r\n10\r\n",
uint16(10): "$2\r\n10\r\n",
util.ToPtr(uint16(10)): "$2\r\n10\r\n",
uint32(10): "$2\r\n10\r\n",
util.ToPtr(uint32(10)): "$2\r\n10\r\n",
uint64(10): "$2\r\n10\r\n",
util.ToPtr(uint64(10)): "$2\r\n10\r\n",
float32(10.3): "$18\r\n10.300000190734863\r\n",
util.ToPtr(float32(10.3)): "$18\r\n10.300000190734863\r\n",
float64(10.3): "$4\r\n10.3\r\n",
util.ToPtr(float64(10.3)): "$4\r\n10.3\r\n",
bool(true): "$1\r\n1\r\n",
bool(false): "$1\r\n0\r\n",
util.ToPtr(bool(true)): "$1\r\n1\r\n",
util.ToPtr(bool(false)): "$1\r\n0\r\n",
}
for arg, expect := range args {
arg, expect := arg, expect
It(fmt.Sprintf("should write arg of type %T", arg), func() {
err := wr.WriteArg(arg)
Expect(err).NotTo(HaveOccurred())
Expect(buf.String()).To(Equal(expect))
})
}
})

50
internal/rand/rand.go Normal file
View File

@ -0,0 +1,50 @@
package rand
import (
"math/rand"
"sync"
)
// Int returns a non-negative pseudo-random int.
func Int() int { return pseudo.Int() }
// Intn returns, as an int, a non-negative pseudo-random number in [0,n).
// It panics if n <= 0.
func Intn(n int) int { return pseudo.Intn(n) }
// Int63n returns, as an int64, a non-negative pseudo-random number in [0,n).
// It panics if n <= 0.
func Int63n(n int64) int64 { return pseudo.Int63n(n) }
// Perm returns, as a slice of n ints, a pseudo-random permutation of the integers [0,n).
func Perm(n int) []int { return pseudo.Perm(n) }
// Seed uses the provided seed value to initialize the default Source to a
// deterministic state. If Seed is not called, the generator behaves as if
// seeded by Seed(1).
func Seed(n int64) { pseudo.Seed(n) }
var pseudo = rand.New(&source{src: rand.NewSource(1)})
type source struct {
src rand.Source
mu sync.Mutex
}
func (s *source) Int63() int64 {
s.mu.Lock()
n := s.src.Int63()
s.mu.Unlock()
return n
}
func (s *source) Seed(seed int64) {
s.mu.Lock()
s.src.Seed(seed)
s.mu.Unlock()
}
// Shuffle pseudo-randomizes the order of elements.
// n is the number of elements.
// swap swaps the elements with indexes i and j.
func Shuffle(n int, swap func(i, j int)) { pseudo.Shuffle(n, swap) }

83
internal/util.go Normal file
View File

@ -0,0 +1,83 @@
package internal
import (
"context"
"net"
"strings"
"time"
"github.com/redis/go-redis/v9/internal/util"
)
func Sleep(ctx context.Context, dur time.Duration) error {
t := time.NewTimer(dur)
defer t.Stop()
select {
case <-t.C:
return nil
case <-ctx.Done():
return ctx.Err()
}
}
func ToLower(s string) string {
if isLower(s) {
return s
}
b := make([]byte, len(s))
for i := range b {
c := s[i]
if c >= 'A' && c <= 'Z' {
c += 'a' - 'A'
}
b[i] = c
}
return util.BytesToString(b)
}
func isLower(s string) bool {
for i := 0; i < len(s); i++ {
c := s[i]
if c >= 'A' && c <= 'Z' {
return false
}
}
return true
}
func ReplaceSpaces(s string) string {
// Pre-allocate a builder with the same length as s to minimize allocations.
// This is a basic optimization; adjust the initial size based on your use case.
var builder strings.Builder
builder.Grow(len(s))
for _, char := range s {
if char == ' ' {
// Replace space with a hyphen.
builder.WriteRune('-')
} else {
// Copy the character as-is.
builder.WriteRune(char)
}
}
return builder.String()
}
func GetAddr(addr string) string {
ind := strings.LastIndexByte(addr, ':')
if ind == -1 {
return ""
}
if strings.IndexByte(addr, '.') != -1 {
return addr
}
if addr[0] == '[' {
return addr
}
return net.JoinHostPort(addr[:ind], addr[ind+1:])
}

11
internal/util/safe.go Normal file
View File

@ -0,0 +1,11 @@
//go:build appengine
package util
func BytesToString(b []byte) string {
return string(b)
}
func StringToBytes(s string) []byte {
return []byte(s)
}

19
internal/util/strconv.go Normal file
View File

@ -0,0 +1,19 @@
package util
import "strconv"
func Atoi(b []byte) (int, error) {
return strconv.Atoi(BytesToString(b))
}
func ParseInt(b []byte, base int, bitSize int) (int64, error) {
return strconv.ParseInt(BytesToString(b), base, bitSize)
}
func ParseUint(b []byte, base int, bitSize int) (uint64, error) {
return strconv.ParseUint(BytesToString(b), base, bitSize)
}
func ParseFloat(b []byte, bitSize int) (float64, error) {
return strconv.ParseFloat(BytesToString(b), bitSize)
}

5
internal/util/type.go Normal file
View File

@ -0,0 +1,5 @@
package util
func ToPtr[T any](v T) *T {
return &v
}

22
internal/util/unsafe.go Normal file
View File

@ -0,0 +1,22 @@
//go:build !appengine
package util
import (
"unsafe"
)
// BytesToString converts byte slice to string.
func BytesToString(b []byte) string {
return *(*string)(unsafe.Pointer(&b))
}
// StringToBytes converts string to byte slice.
func StringToBytes(s string) []byte {
return *(*[]byte)(unsafe.Pointer(
&struct {
string
Cap int
}{s, len(s)},
))
}

74
internal/util_test.go Normal file
View File

@ -0,0 +1,74 @@
package internal
import (
"strings"
"testing"
. "github.com/bsm/ginkgo/v2"
. "github.com/bsm/gomega"
)
func BenchmarkToLowerStd(b *testing.B) {
str := "AaBbCcDdEeFfGgHhIiJjKk"
for i := 0; i < b.N; i++ {
_ = strings.ToLower(str)
}
}
// util.ToLower is 3x faster than strings.ToLower.
func BenchmarkToLowerInternal(b *testing.B) {
str := "AaBbCcDdEeFfGgHhIiJjKk"
for i := 0; i < b.N; i++ {
_ = ToLower(str)
}
}
func TestToLower(t *testing.T) {
It("toLower", func() {
str := "AaBbCcDdEeFfGg"
Expect(ToLower(str)).To(Equal(strings.ToLower(str)))
str = "ABCDE"
Expect(ToLower(str)).To(Equal(strings.ToLower(str)))
str = "ABCDE"
Expect(ToLower(str)).To(Equal(strings.ToLower(str)))
str = "abced"
Expect(ToLower(str)).To(Equal(strings.ToLower(str)))
})
}
func TestIsLower(t *testing.T) {
It("isLower", func() {
str := "AaBbCcDdEeFfGg"
Expect(isLower(str)).To(BeFalse())
str = "ABCDE"
Expect(isLower(str)).To(BeFalse())
str = "abcdefg"
Expect(isLower(str)).To(BeTrue())
})
}
func TestGetAddr(t *testing.T) {
It("getAddr", func() {
str := "127.0.0.1:1234"
Expect(GetAddr(str)).To(Equal(str))
str = "[::1]:1234"
Expect(GetAddr(str)).To(Equal(str))
str = "[fd01:abcd::7d03]:6379"
Expect(GetAddr(str)).To(Equal(str))
Expect(GetAddr("::1:1234")).To(Equal("[::1]:1234"))
Expect(GetAddr("fd01:abcd::7d03:6379")).To(Equal("[fd01:abcd::7d03]:6379"))
Expect(GetAddr("127.0.0.1")).To(Equal(""))
Expect(GetAddr("127")).To(Equal(""))
})
}