diff --git a/helper/helper.go b/helper/helper.go new file mode 100644 index 00000000..7047c8ae --- /dev/null +++ b/helper/helper.go @@ -0,0 +1,11 @@ +package helper + +import "github.com/redis/go-redis/v9/internal/util" + +func ParseFloat(s string) (float64, error) { + return util.ParseStringToFloat(s) +} + +func MustParseFloat(s string) float64 { + return util.MustParseFloat(s) +} diff --git a/internal/util/convert.go b/internal/util/convert.go new file mode 100644 index 00000000..d326d50d --- /dev/null +++ b/internal/util/convert.go @@ -0,0 +1,30 @@ +package util + +import ( + "fmt" + "math" + "strconv" +) + +// ParseFloat parses a Redis RESP3 float reply into a Go float64, +// handling "inf", "-inf", "nan" per Redis conventions. +func ParseStringToFloat(s string) (float64, error) { + switch s { + case "inf": + return math.Inf(1), nil + case "-inf": + return math.Inf(-1), nil + case "nan", "-nan": + return math.NaN(), nil + } + return strconv.ParseFloat(s, 64) +} + +// MustParseFloat is like ParseFloat but panics on parse errors. +func MustParseFloat(s string) float64 { + f, err := ParseStringToFloat(s) + if err != nil { + panic(fmt.Sprintf("redis: failed to parse float %q: %v", s, err)) + } + return f +} diff --git a/internal/util/convert_test.go b/internal/util/convert_test.go new file mode 100644 index 00000000..ffa3ee9f --- /dev/null +++ b/internal/util/convert_test.go @@ -0,0 +1,40 @@ +package util + +import ( + "math" + "testing" +) + +func TestParseStringToFloat(t *testing.T) { + tests := []struct { + in string + want float64 + ok bool + }{ + {"1.23", 1.23, true}, + {"inf", math.Inf(1), true}, + {"-inf", math.Inf(-1), true}, + {"nan", math.NaN(), true}, + {"oops", 0, false}, + } + + for _, tc := range tests { + got, err := ParseStringToFloat(tc.in) + if tc.ok { + if err != nil { + t.Fatalf("ParseFloat(%q) error: %v", tc.in, err) + } + if math.IsNaN(tc.want) { + if !math.IsNaN(got) { + t.Errorf("ParseFloat(%q) = %v; want NaN", tc.in, got) + } + } else if got != tc.want { + t.Errorf("ParseFloat(%q) = %v; want %v", tc.in, got, tc.want) + } + } else { + if err == nil { + t.Errorf("ParseFloat(%q) expected error, got nil", tc.in) + } + } + } +} diff --git a/search_test.go b/search_test.go index 019acbe3..fdcd0d24 100644 --- a/search_test.go +++ b/search_test.go @@ -1,15 +1,18 @@ package redis_test import ( + "bytes" "context" + "encoding/binary" "fmt" - "strconv" + "math" "strings" "time" . "github.com/bsm/ginkgo/v2" . "github.com/bsm/gomega" "github.com/redis/go-redis/v9" + "github.com/redis/go-redis/v9/helper" ) func WaitForIndexing(c *redis.Client, index string) { @@ -27,6 +30,14 @@ func WaitForIndexing(c *redis.Client, index string) { } } +func encodeFloat32Vector(vec []float32) []byte { + buf := new(bytes.Buffer) + for _, v := range vec { + binary.Write(buf, binary.LittleEndian, v) + } + return buf.Bytes() +} + var _ = Describe("RediSearch commands Resp 2", Label("search"), func() { ctx := context.TODO() var client *redis.Client @@ -693,9 +704,9 @@ var _ = Describe("RediSearch commands Resp 2", Label("search"), func() { Expect(err).NotTo(HaveOccurred()) Expect(res).ToNot(BeNil()) Expect(len(res.Rows)).To(BeEquivalentTo(2)) - score1, err := strconv.ParseFloat(fmt.Sprintf("%s", res.Rows[0].Fields["__score"]), 64) + score1, err := helper.ParseFloat(fmt.Sprintf("%s", res.Rows[0].Fields["__score"])) Expect(err).NotTo(HaveOccurred()) - score2, err := strconv.ParseFloat(fmt.Sprintf("%s", res.Rows[1].Fields["__score"]), 64) + score2, err := helper.ParseFloat(fmt.Sprintf("%s", res.Rows[1].Fields["__score"])) Expect(err).NotTo(HaveOccurred()) Expect(score1).To(BeNumerically(">", score2)) @@ -712,9 +723,9 @@ var _ = Describe("RediSearch commands Resp 2", Label("search"), func() { Expect(err).NotTo(HaveOccurred()) Expect(resDM).ToNot(BeNil()) Expect(len(resDM.Rows)).To(BeEquivalentTo(2)) - score1DM, err := strconv.ParseFloat(fmt.Sprintf("%s", resDM.Rows[0].Fields["__score"]), 64) + score1DM, err := helper.ParseFloat(fmt.Sprintf("%s", resDM.Rows[0].Fields["__score"])) Expect(err).NotTo(HaveOccurred()) - score2DM, err := strconv.ParseFloat(fmt.Sprintf("%s", resDM.Rows[1].Fields["__score"]), 64) + score2DM, err := helper.ParseFloat(fmt.Sprintf("%s", resDM.Rows[1].Fields["__score"])) Expect(err).NotTo(HaveOccurred()) Expect(score1DM).To(BeNumerically(">", score2DM)) @@ -1684,6 +1695,56 @@ var _ = Describe("RediSearch commands Resp 2", Label("search"), func() { Expect(resUint8.Docs[0].ID).To(BeEquivalentTo("doc1")) }) + It("should return special float scores in FT.SEARCH vecsim", Label("search", "ftsearch", "vecsim"), func() { + SkipBeforeRedisVersion(7.4, "doesn't work with older redis stack images") + + vecField := &redis.FTFlatOptions{ + Type: "FLOAT32", + Dim: 2, + DistanceMetric: "IP", + } + _, err := client.FTCreate(ctx, "idx_vec", + &redis.FTCreateOptions{OnHash: true, Prefix: []interface{}{"doc:"}}, + &redis.FieldSchema{FieldName: "vector", FieldType: redis.SearchFieldTypeVector, VectorArgs: &redis.FTVectorArgs{FlatOptions: vecField}}).Result() + Expect(err).NotTo(HaveOccurred()) + WaitForIndexing(client, "idx_vec") + + bigPos := []float32{1e38, 1e38} + bigNeg := []float32{-1e38, -1e38} + nanVec := []float32{float32(math.NaN()), 0} + negNanVec := []float32{float32(math.Copysign(math.NaN(), -1)), 0} + + client.HSet(ctx, "doc:1", "vector", encodeFloat32Vector(bigPos)) + client.HSet(ctx, "doc:2", "vector", encodeFloat32Vector(bigNeg)) + client.HSet(ctx, "doc:3", "vector", encodeFloat32Vector(nanVec)) + client.HSet(ctx, "doc:4", "vector", encodeFloat32Vector(negNanVec)) + + searchOptions := &redis.FTSearchOptions{WithScores: true, Params: map[string]interface{}{"vec": encodeFloat32Vector(bigPos)}} + res, err := client.FTSearchWithArgs(ctx, "idx_vec", "*=>[KNN 4 @vector $vec]", searchOptions).Result() + Expect(err).NotTo(HaveOccurred()) + Expect(res.Total).To(BeEquivalentTo(4)) + + var scores []float64 + for _, row := range res.Docs { + raw := fmt.Sprintf("%v", row.Fields["__vector_score"]) + f, err := helper.ParseFloat(raw) + Expect(err).NotTo(HaveOccurred()) + scores = append(scores, f) + } + + Expect(scores).To(ContainElement(BeNumerically("==", math.Inf(1)))) + Expect(scores).To(ContainElement(BeNumerically("==", math.Inf(-1)))) + + // For NaN values, use a custom check since NaN != NaN in floating point math + nanCount := 0 + for _, score := range scores { + if math.IsNaN(score) { + nanCount++ + } + } + Expect(nanCount).To(Equal(2)) + }) + It("should fail when using a non-zero offset with a zero limit", Label("search", "ftsearch"), func() { SkipBeforeRedisVersion(7.9, "requires Redis 8.x") val, err := client.FTCreate(ctx, "testIdx", &redis.FTCreateOptions{}, &redis.FieldSchema{