1
0
mirror of https://github.com/redis/go-redis.git synced 2025-08-10 11:03:00 +03:00

Allow setting and scaning interface{} values.

This commit is contained in:
Vladimir Mihailenco
2015-05-28 15:51:19 +03:00
parent 90888881b4
commit 3c1f2bd45a
10 changed files with 726 additions and 230 deletions

241
parser.go
View File

@@ -17,18 +17,201 @@ var (
//------------------------------------------------------------------------------
func appendArgs(buf []byte, args []string) []byte {
buf = append(buf, '*')
buf = strconv.AppendUint(buf, uint64(len(args)), 10)
buf = append(buf, '\r', '\n')
for _, arg := range args {
buf = append(buf, '$')
buf = strconv.AppendUint(buf, uint64(len(arg)), 10)
buf = append(buf, '\r', '\n')
buf = append(buf, arg...)
buf = append(buf, '\r', '\n')
// Copy of encoding.BinaryMarshaler.
type binaryMarshaler interface {
MarshalBinary() (data []byte, err error)
}
// Copy of encoding.BinaryUnmarshaler.
type binaryUnmarshaler interface {
UnmarshalBinary(data []byte) error
}
func appendString(b []byte, s string) []byte {
b = append(b, '$')
b = strconv.AppendUint(b, uint64(len(s)), 10)
b = append(b, '\r', '\n')
b = append(b, s...)
b = append(b, '\r', '\n')
return b
}
func appendBytes(b, bb []byte) []byte {
b = append(b, '$')
b = strconv.AppendUint(b, uint64(len(bb)), 10)
b = append(b, '\r', '\n')
b = append(b, bb...)
b = append(b, '\r', '\n')
return b
}
func appendArg(b []byte, val interface{}) ([]byte, error) {
switch v := val.(type) {
case nil:
b = appendString(b, "")
case string:
b = appendString(b, v)
case []byte:
b = appendBytes(b, v)
case int:
b = appendString(b, formatInt(int64(v)))
case int8:
b = appendString(b, formatInt(int64(v)))
case int16:
b = appendString(b, formatInt(int64(v)))
case int32:
b = appendString(b, formatInt(int64(v)))
case int64:
b = appendString(b, formatInt(v))
case uint:
b = appendString(b, formatUint(uint64(v)))
case uint8:
b = appendString(b, formatUint(uint64(v)))
case uint16:
b = appendString(b, formatUint(uint64(v)))
case uint32:
b = appendString(b, formatUint(uint64(v)))
case uint64:
b = appendString(b, formatUint(v))
case float32:
b = appendString(b, formatFloat(float64(v)))
case float64:
b = appendString(b, formatFloat(v))
case bool:
if v {
b = appendString(b, "1")
} else {
b = appendString(b, "0")
}
default:
if bm, ok := val.(binaryMarshaler); ok {
bb, err := bm.MarshalBinary()
if err != nil {
return nil, err
}
b = appendBytes(b, bb)
} else {
err := fmt.Errorf(
"redis: can't marshal %T (consider implementing BinaryMarshaler)", val)
return nil, err
}
}
return b, nil
}
func appendArgs(b []byte, args []interface{}) ([]byte, error) {
b = append(b, '*')
b = strconv.AppendUint(b, uint64(len(args)), 10)
b = append(b, '\r', '\n')
for _, arg := range args {
var err error
b, err = appendArg(b, arg)
if err != nil {
return nil, err
}
}
return b, nil
}
func scan(b []byte, val interface{}) error {
switch v := val.(type) {
case nil:
return errorf("redis: Scan(nil)")
case *string:
*v = string(b)
return nil
case *[]byte:
*v = b
return nil
case *int:
var err error
*v, err = strconv.Atoi(string(b))
return err
case *int8:
n, err := strconv.ParseInt(string(b), 10, 8)
if err != nil {
return err
}
*v = int8(n)
return nil
case *int16:
n, err := strconv.ParseInt(string(b), 10, 16)
if err != nil {
return err
}
*v = int16(n)
return nil
case *int32:
n, err := strconv.ParseInt(string(b), 10, 16)
if err != nil {
return err
}
*v = int32(n)
return nil
case *int64:
n, err := strconv.ParseInt(string(b), 10, 64)
if err != nil {
return err
}
*v = n
return nil
case *uint:
n, err := strconv.ParseUint(string(b), 10, 64)
if err != nil {
return err
}
*v = uint(n)
return nil
case *uint8:
n, err := strconv.ParseUint(string(b), 10, 8)
if err != nil {
return err
}
*v = uint8(n)
return nil
case *uint16:
n, err := strconv.ParseUint(string(b), 10, 16)
if err != nil {
return err
}
*v = uint16(n)
return nil
case *uint32:
n, err := strconv.ParseUint(string(b), 10, 32)
if err != nil {
return err
}
*v = uint32(n)
return nil
case *uint64:
n, err := strconv.ParseUint(string(b), 10, 64)
if err != nil {
return err
}
*v = n
return nil
case *float32:
n, err := strconv.ParseFloat(string(b), 32)
if err != nil {
return err
}
*v = float32(n)
return err
case *float64:
var err error
*v, err = strconv.ParseFloat(string(b), 64)
return err
case *bool:
*v = len(b) == 1 && b[0] == '1'
return nil
default:
if bu, ok := val.(binaryUnmarshaler); ok {
return bu.UnmarshalBinary(b)
}
err := fmt.Errorf(
"redis: can't unmarshal %T (consider implementing BinaryUnmarshaler)", val)
return err
}
return buf
}
//------------------------------------------------------------------------------
@@ -120,7 +303,7 @@ func parseReply(rd *bufio.Reader, p multiBulkParser) (interface{}, error) {
case '-':
return nil, errorf(string(line[1:]))
case '+':
return string(line[1:]), nil
return line[1:], nil
case ':':
v, err := strconv.ParseInt(string(line[1:]), 10, 64)
if err != nil {
@@ -141,7 +324,7 @@ func parseReply(rd *bufio.Reader, p multiBulkParser) (interface{}, error) {
if err != nil {
return nil, err
}
return string(b[:replyLen]), nil
return b[:replyLen], nil
case '*':
if len(line) == 3 && line[1] == '-' && line[2] == '1' {
return nil, Nil
@@ -166,7 +349,12 @@ func parseSlice(rd *bufio.Reader, n int64) (interface{}, error) {
} else if err != nil {
return nil, err
} else {
vals = append(vals, v)
switch vv := v.(type) {
case []byte:
vals = append(vals, string(vv))
default:
vals = append(vals, v)
}
}
}
return vals, nil
@@ -179,11 +367,11 @@ func parseStringSlice(rd *bufio.Reader, n int64) (interface{}, error) {
if err != nil {
return nil, err
}
v, ok := viface.(string)
v, ok := viface.([]byte)
if !ok {
return nil, fmt.Errorf("got %T, expected string", viface)
}
vals = append(vals, v)
vals = append(vals, string(v))
}
return vals, nil
}
@@ -211,7 +399,7 @@ func parseStringStringMap(rd *bufio.Reader, n int64) (interface{}, error) {
if err != nil {
return nil, err
}
key, ok := keyiface.(string)
key, ok := keyiface.([]byte)
if !ok {
return nil, fmt.Errorf("got %T, expected string", keyiface)
}
@@ -220,12 +408,12 @@ func parseStringStringMap(rd *bufio.Reader, n int64) (interface{}, error) {
if err != nil {
return nil, err
}
value, ok := valueiface.(string)
value, ok := valueiface.([]byte)
if !ok {
return nil, fmt.Errorf("got %T, expected string", valueiface)
}
m[key] = value
m[string(key)] = string(value)
}
return m, nil
}
@@ -237,7 +425,7 @@ func parseStringIntMap(rd *bufio.Reader, n int64) (interface{}, error) {
if err != nil {
return nil, err
}
key, ok := keyiface.(string)
key, ok := keyiface.([]byte)
if !ok {
return nil, fmt.Errorf("got %T, expected string", keyiface)
}
@@ -248,15 +436,14 @@ func parseStringIntMap(rd *bufio.Reader, n int64) (interface{}, error) {
}
switch value := valueiface.(type) {
case int64:
m[key] = value
m[string(key)] = value
case string:
m[key], err = strconv.ParseInt(value, 10, 64)
m[string(key)], err = strconv.ParseInt(value, 10, 64)
if err != nil {
return nil, fmt.Errorf("got %v, expected number", value)
}
default:
return nil, fmt.Errorf("got %T, expected number or string", valueiface)
}
}
return m, nil
@@ -271,21 +458,21 @@ func parseZSlice(rd *bufio.Reader, n int64) (interface{}, error) {
if err != nil {
return nil, err
}
member, ok := memberiface.(string)
member, ok := memberiface.([]byte)
if !ok {
return nil, fmt.Errorf("got %T, expected string", memberiface)
}
z.Member = member
z.Member = string(member)
scoreiface, err := parseReply(rd, nil)
if err != nil {
return nil, err
}
scorestr, ok := scoreiface.(string)
scoreb, ok := scoreiface.([]byte)
if !ok {
return nil, fmt.Errorf("got %T, expected string", scoreiface)
}
score, err := strconv.ParseFloat(scorestr, 64)
score, err := strconv.ParseFloat(string(scoreb), 64)
if err != nil {
return nil, err
}