diff --git a/example/hset-struct/README.md b/example/hset-struct/README.md new file mode 100644 index 00000000..e6cb4523 --- /dev/null +++ b/example/hset-struct/README.md @@ -0,0 +1,7 @@ +# Example for setting struct fields as hash fields + +To run this example: + +```shell +go run . +``` diff --git a/example/hset-struct/go.mod b/example/hset-struct/go.mod new file mode 100644 index 00000000..fca1a597 --- /dev/null +++ b/example/hset-struct/go.mod @@ -0,0 +1,15 @@ +module github.com/redis/go-redis/example/scan-struct + +go 1.18 + +replace github.com/redis/go-redis/v9 => ../.. + +require ( + github.com/davecgh/go-spew v1.1.1 + github.com/redis/go-redis/v9 v9.6.2 +) + +require ( + github.com/cespare/xxhash/v2 v2.2.0 // indirect + github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect +) diff --git a/example/hset-struct/go.sum b/example/hset-struct/go.sum new file mode 100644 index 00000000..1602e702 --- /dev/null +++ b/example/hset-struct/go.sum @@ -0,0 +1,10 @@ +github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs= +github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA= +github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44= +github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= +github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= diff --git a/example/hset-struct/main.go b/example/hset-struct/main.go new file mode 100644 index 00000000..2e08f542 --- /dev/null +++ b/example/hset-struct/main.go @@ -0,0 +1,129 @@ +package main + +import ( + "context" + "time" + + "github.com/davecgh/go-spew/spew" + + "github.com/redis/go-redis/v9" +) + +type Model struct { + Str1 string `redis:"str1"` + Str2 string `redis:"str2"` + Str3 *string `redis:"str3"` + Str4 *string `redis:"str4"` + Bytes []byte `redis:"bytes"` + Int int `redis:"int"` + Int2 *int `redis:"int2"` + Int3 *int `redis:"int3"` + Bool bool `redis:"bool"` + Bool2 *bool `redis:"bool2"` + Bool3 *bool `redis:"bool3"` + Bool4 *bool `redis:"bool4,omitempty"` + Time time.Time `redis:"time"` + Time2 *time.Time `redis:"time2"` + Time3 *time.Time `redis:"time3"` + Ignored struct{} `redis:"-"` +} + +func main() { + ctx := context.Background() + + rdb := redis.NewClient(&redis.Options{ + Addr: ":6379", + }) + + _ = rdb.FlushDB(ctx).Err() + + t := time.Date(2025, 02, 8, 0, 0, 0, 0, time.UTC) + + data := Model{ + Str1: "hello", + Str2: "world", + Str3: ToPtr("hello"), + Str4: nil, + Bytes: []byte("this is bytes !"), + Int: 123, + Int2: ToPtr(0), + Int3: nil, + Bool: true, + Bool2: ToPtr(false), + Bool3: nil, + Time: t, + Time2: ToPtr(t), + Time3: nil, + Ignored: struct{}{}, + } + + // Set some fields. + if _, err := rdb.Pipelined(ctx, func(rdb redis.Pipeliner) error { + rdb.HMSet(ctx, "key", data) + return nil + }); err != nil { + panic(err) + } + + var model1, model2 Model + + // Scan all fields into the model. + if err := rdb.HGetAll(ctx, "key").Scan(&model1); err != nil { + panic(err) + } + + // Or scan a subset of the fields. + if err := rdb.HMGet(ctx, "key", "str1", "int").Scan(&model2); err != nil { + panic(err) + } + + spew.Dump(model1) + // Output: + // (main.Model) { + // Str1: (string) (len=5) "hello", + // Str2: (string) (len=5) "world", + // Str3: (*string)(0xc000016970)((len=5) "hello"), + // Str4: (*string)(0xc000016980)(""), + // Bytes: ([]uint8) (len=15 cap=16) { + // 00000000 74 68 69 73 20 69 73 20 62 79 74 65 73 20 21 |this is bytes !| + // }, + // Int: (int) 123, + // Int2: (*int)(0xc000014568)(0), + // Int3: (*int)(0xc000014560)(0), + // Bool: (bool) true, + // Bool2: (*bool)(0xc000014570)(false), + // Bool3: (*bool)(0xc000014548)(false), + // Bool4: (*bool)(), + // Time: (time.Time) 2025-02-08 00:00:00 +0000 UTC, + // Time2: (*time.Time)(0xc0000122a0)(2025-02-08 00:00:00 +0000 UTC), + // Time3: (*time.Time)(0xc000012288)(0001-01-01 00:00:00 +0000 UTC), + // Ignored: (struct {}) { + // } + // } + + spew.Dump(model2) + // Output: + // (main.Model) { + // Str1: (string) (len=5) "hello", + // Str2: (string) "", + // Str3: (*string)(), + // Str4: (*string)(), + // Bytes: ([]uint8) , + // Int: (int) 123, + // Int2: (*int)(), + // Int3: (*int)(), + // Bool: (bool) false, + // Bool2: (*bool)(), + // Bool3: (*bool)(), + // Bool4: (*bool)(), + // Time: (time.Time) 0001-01-01 00:00:00 +0000 UTC, + // Time2: (*time.Time)(), + // Time3: (*time.Time)(), + // Ignored: (struct {}) { + // } + // } +} + +func ToPtr[T any](v T) *T { + return &v +} diff --git a/example/scan-struct/main.go b/example/scan-struct/main.go index cc877b84..2dc5b85c 100644 --- a/example/scan-struct/main.go +++ b/example/scan-struct/main.go @@ -11,9 +11,12 @@ import ( type Model struct { Str1 string `redis:"str1"` Str2 string `redis:"str2"` + Str3 *string `redis:"str3"` Bytes []byte `redis:"bytes"` Int int `redis:"int"` + Int2 *int `redis:"int2"` Bool bool `redis:"bool"` + Bool2 *bool `redis:"bool2"` Ignored struct{} `redis:"-"` } @@ -29,8 +32,11 @@ func main() { if _, err := rdb.Pipelined(ctx, func(rdb redis.Pipeliner) error { rdb.HSet(ctx, "key", "str1", "hello") rdb.HSet(ctx, "key", "str2", "world") + rdb.HSet(ctx, "key", "str3", "") rdb.HSet(ctx, "key", "int", 123) + rdb.HSet(ctx, "key", "int2", 0) rdb.HSet(ctx, "key", "bool", 1) + rdb.HSet(ctx, "key", "bool2", 0) rdb.HSet(ctx, "key", "bytes", []byte("this is bytes !")) return nil }); err != nil { diff --git a/internal/proto/writer.go b/internal/proto/writer.go index 78595cc4..38e66c68 100644 --- a/internal/proto/writer.go +++ b/internal/proto/writer.go @@ -66,56 +66,95 @@ func (w *Writer) WriteArg(v interface{}) error { case string: return w.string(v) case *string: + if v == nil { + return w.string("") + } return w.string(*v) case []byte: return w.bytes(v) case int: return w.int(int64(v)) case *int: + if v == nil { + return w.int(0) + } return w.int(int64(*v)) case int8: return w.int(int64(v)) case *int8: + if v == nil { + return w.int(0) + } return w.int(int64(*v)) case int16: return w.int(int64(v)) case *int16: + if v == nil { + return w.int(0) + } return w.int(int64(*v)) case int32: return w.int(int64(v)) case *int32: + if v == nil { + return w.int(0) + } return w.int(int64(*v)) case int64: return w.int(v) case *int64: + if v == nil { + return w.int(0) + } return w.int(*v) case uint: return w.uint(uint64(v)) case *uint: + if v == nil { + return w.uint(0) + } return w.uint(uint64(*v)) case uint8: return w.uint(uint64(v)) case *uint8: + if v == nil { + return w.string("") + } return w.uint(uint64(*v)) case uint16: return w.uint(uint64(v)) case *uint16: + if v == nil { + return w.uint(0) + } return w.uint(uint64(*v)) case uint32: return w.uint(uint64(v)) case *uint32: + if v == nil { + return w.uint(0) + } return w.uint(uint64(*v)) case uint64: return w.uint(v) case *uint64: + if v == nil { + return w.uint(0) + } return w.uint(*v) case float32: return w.float(float64(v)) case *float32: + if v == nil { + return w.float(0) + } return w.float(float64(*v)) case float64: return w.float(v) case *float64: + if v == nil { + return w.float(0) + } return w.float(*v) case bool: if v { @@ -123,6 +162,9 @@ func (w *Writer) WriteArg(v interface{}) error { } return w.int(0) case *bool: + if v == nil { + return w.int(0) + } if *v { return w.int(1) } @@ -130,8 +172,19 @@ func (w *Writer) WriteArg(v interface{}) error { case time.Time: w.numBuf = v.AppendFormat(w.numBuf[:0], time.RFC3339Nano) return w.bytes(w.numBuf) + case *time.Time: + if v == nil { + v = &time.Time{} + } + w.numBuf = v.AppendFormat(w.numBuf[:0], time.RFC3339Nano) + return w.bytes(w.numBuf) case time.Duration: return w.int(v.Nanoseconds()) + case *time.Duration: + if v == nil { + return w.int(0) + } + return w.int(v.Nanoseconds()) case encoding.BinaryMarshaler: b, err := v.MarshalBinary() if err != nil { diff --git a/internal/proto/writer_test.go b/internal/proto/writer_test.go index 7c9d2088..1d5152dc 100644 --- a/internal/proto/writer_test.go +++ b/internal/proto/writer_test.go @@ -111,36 +111,61 @@ var _ = Describe("WriteArg", func() { wr = proto.NewWriter(buf) }) + t := time.Date(2025, 2, 8, 00, 00, 00, 0, time.UTC) + args := map[any]string{ - "hello": "$5\r\nhello\r\n", - int(10): "$2\r\n10\r\n", - util.ToPtr(int(10)): "$2\r\n10\r\n", - int8(10): "$2\r\n10\r\n", - util.ToPtr(int8(10)): "$2\r\n10\r\n", - int16(10): "$2\r\n10\r\n", - util.ToPtr(int16(10)): "$2\r\n10\r\n", - int32(10): "$2\r\n10\r\n", - util.ToPtr(int32(10)): "$2\r\n10\r\n", - int64(10): "$2\r\n10\r\n", - util.ToPtr(int64(10)): "$2\r\n10\r\n", - uint(10): "$2\r\n10\r\n", - util.ToPtr(uint(10)): "$2\r\n10\r\n", - uint8(10): "$2\r\n10\r\n", - util.ToPtr(uint8(10)): "$2\r\n10\r\n", - uint16(10): "$2\r\n10\r\n", - util.ToPtr(uint16(10)): "$2\r\n10\r\n", - uint32(10): "$2\r\n10\r\n", - util.ToPtr(uint32(10)): "$2\r\n10\r\n", - uint64(10): "$2\r\n10\r\n", - util.ToPtr(uint64(10)): "$2\r\n10\r\n", - float32(10.3): "$18\r\n10.300000190734863\r\n", - util.ToPtr(float32(10.3)): "$18\r\n10.300000190734863\r\n", - float64(10.3): "$4\r\n10.3\r\n", - util.ToPtr(float64(10.3)): "$4\r\n10.3\r\n", - bool(true): "$1\r\n1\r\n", - bool(false): "$1\r\n0\r\n", - util.ToPtr(bool(true)): "$1\r\n1\r\n", - util.ToPtr(bool(false)): "$1\r\n0\r\n", + "hello": "$5\r\nhello\r\n", + util.ToPtr("hello"): "$5\r\nhello\r\n", + (*string)(nil): "$0\r\n\r\n", + int(10): "$2\r\n10\r\n", + util.ToPtr(int(10)): "$2\r\n10\r\n", + (*int)(nil): "$1\r\n0\r\n", + int8(10): "$2\r\n10\r\n", + util.ToPtr(int8(10)): "$2\r\n10\r\n", + (*int8)(nil): "$1\r\n0\r\n", + int16(10): "$2\r\n10\r\n", + util.ToPtr(int16(10)): "$2\r\n10\r\n", + (*int16)(nil): "$1\r\n0\r\n", + int32(10): "$2\r\n10\r\n", + util.ToPtr(int32(10)): "$2\r\n10\r\n", + (*int32)(nil): "$1\r\n0\r\n", + int64(10): "$2\r\n10\r\n", + util.ToPtr(int64(10)): "$2\r\n10\r\n", + (*int64)(nil): "$1\r\n0\r\n", + uint(10): "$2\r\n10\r\n", + util.ToPtr(uint(10)): "$2\r\n10\r\n", + (*uint)(nil): "$1\r\n0\r\n", + uint8(10): "$2\r\n10\r\n", + util.ToPtr(uint8(10)): "$2\r\n10\r\n", + (*uint8)(nil): "$0\r\n\r\n", + uint16(10): "$2\r\n10\r\n", + util.ToPtr(uint16(10)): "$2\r\n10\r\n", + (*uint16)(nil): "$1\r\n0\r\n", + uint32(10): "$2\r\n10\r\n", + util.ToPtr(uint32(10)): "$2\r\n10\r\n", + (*uint32)(nil): "$1\r\n0\r\n", + uint64(10): "$2\r\n10\r\n", + util.ToPtr(uint64(10)): "$2\r\n10\r\n", + (*uint64)(nil): "$1\r\n0\r\n", + float32(10.3): "$18\r\n10.300000190734863\r\n", + util.ToPtr(float32(10.3)): "$18\r\n10.300000190734863\r\n", + (*float32)(nil): "$1\r\n0\r\n", + float64(10.3): "$4\r\n10.3\r\n", + util.ToPtr(float64(10.3)): "$4\r\n10.3\r\n", + (*float64)(nil): "$1\r\n0\r\n", + bool(true): "$1\r\n1\r\n", + bool(false): "$1\r\n0\r\n", + util.ToPtr(bool(true)): "$1\r\n1\r\n", + util.ToPtr(bool(false)): "$1\r\n0\r\n", + (*bool)(nil): "$1\r\n0\r\n", + time.Time(t): "$20\r\n2025-02-08T00:00:00Z\r\n", + util.ToPtr(time.Time(t)): "$20\r\n2025-02-08T00:00:00Z\r\n", + (*time.Time)(nil): "$20\r\n0001-01-01T00:00:00Z\r\n", + time.Duration(time.Second): "$10\r\n1000000000\r\n", + util.ToPtr(time.Duration(time.Second)): "$10\r\n1000000000\r\n", + (*time.Duration)(nil): "$1\r\n0\r\n", + (encoding.BinaryMarshaler)(&MyType{}): "$5\r\nhello\r\n", + (encoding.BinaryMarshaler)(nil): "$0\r\n\r\n", } for arg, expect := range args {