1
0
mirror of https://github.com/redis/go-redis.git synced 2025-05-31 21:01:13 +03:00

utils: export ParseFloat and MustParseFloat wrapping internal utils (#3371)

* utils: expose ParseFloat via new public utils package

* add tests for special float values in vector search
This commit is contained in:
ofekshenawa 2025-05-09 12:24:36 +03:00 committed by GitHub
parent f174acba52
commit 42c32846e6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 147 additions and 5 deletions

11
helper/helper.go Normal file
View File

@ -0,0 +1,11 @@
package helper
import "github.com/redis/go-redis/v9/internal/util"
func ParseFloat(s string) (float64, error) {
return util.ParseStringToFloat(s)
}
func MustParseFloat(s string) float64 {
return util.MustParseFloat(s)
}

30
internal/util/convert.go Normal file
View File

@ -0,0 +1,30 @@
package util
import (
"fmt"
"math"
"strconv"
)
// ParseFloat parses a Redis RESP3 float reply into a Go float64,
// handling "inf", "-inf", "nan" per Redis conventions.
func ParseStringToFloat(s string) (float64, error) {
switch s {
case "inf":
return math.Inf(1), nil
case "-inf":
return math.Inf(-1), nil
case "nan", "-nan":
return math.NaN(), nil
}
return strconv.ParseFloat(s, 64)
}
// MustParseFloat is like ParseFloat but panics on parse errors.
func MustParseFloat(s string) float64 {
f, err := ParseStringToFloat(s)
if err != nil {
panic(fmt.Sprintf("redis: failed to parse float %q: %v", s, err))
}
return f
}

View File

@ -0,0 +1,40 @@
package util
import (
"math"
"testing"
)
func TestParseStringToFloat(t *testing.T) {
tests := []struct {
in string
want float64
ok bool
}{
{"1.23", 1.23, true},
{"inf", math.Inf(1), true},
{"-inf", math.Inf(-1), true},
{"nan", math.NaN(), true},
{"oops", 0, false},
}
for _, tc := range tests {
got, err := ParseStringToFloat(tc.in)
if tc.ok {
if err != nil {
t.Fatalf("ParseFloat(%q) error: %v", tc.in, err)
}
if math.IsNaN(tc.want) {
if !math.IsNaN(got) {
t.Errorf("ParseFloat(%q) = %v; want NaN", tc.in, got)
}
} else if got != tc.want {
t.Errorf("ParseFloat(%q) = %v; want %v", tc.in, got, tc.want)
}
} else {
if err == nil {
t.Errorf("ParseFloat(%q) expected error, got nil", tc.in)
}
}
}
}

View File

@ -1,15 +1,18 @@
package redis_test
import (
"bytes"
"context"
"encoding/binary"
"fmt"
"strconv"
"math"
"strings"
"time"
. "github.com/bsm/ginkgo/v2"
. "github.com/bsm/gomega"
"github.com/redis/go-redis/v9"
"github.com/redis/go-redis/v9/helper"
)
func WaitForIndexing(c *redis.Client, index string) {
@ -27,6 +30,14 @@ func WaitForIndexing(c *redis.Client, index string) {
}
}
func encodeFloat32Vector(vec []float32) []byte {
buf := new(bytes.Buffer)
for _, v := range vec {
binary.Write(buf, binary.LittleEndian, v)
}
return buf.Bytes()
}
var _ = Describe("RediSearch commands Resp 2", Label("search"), func() {
ctx := context.TODO()
var client *redis.Client
@ -693,9 +704,9 @@ var _ = Describe("RediSearch commands Resp 2", Label("search"), func() {
Expect(err).NotTo(HaveOccurred())
Expect(res).ToNot(BeNil())
Expect(len(res.Rows)).To(BeEquivalentTo(2))
score1, err := strconv.ParseFloat(fmt.Sprintf("%s", res.Rows[0].Fields["__score"]), 64)
score1, err := helper.ParseFloat(fmt.Sprintf("%s", res.Rows[0].Fields["__score"]))
Expect(err).NotTo(HaveOccurred())
score2, err := strconv.ParseFloat(fmt.Sprintf("%s", res.Rows[1].Fields["__score"]), 64)
score2, err := helper.ParseFloat(fmt.Sprintf("%s", res.Rows[1].Fields["__score"]))
Expect(err).NotTo(HaveOccurred())
Expect(score1).To(BeNumerically(">", score2))
@ -712,9 +723,9 @@ var _ = Describe("RediSearch commands Resp 2", Label("search"), func() {
Expect(err).NotTo(HaveOccurred())
Expect(resDM).ToNot(BeNil())
Expect(len(resDM.Rows)).To(BeEquivalentTo(2))
score1DM, err := strconv.ParseFloat(fmt.Sprintf("%s", resDM.Rows[0].Fields["__score"]), 64)
score1DM, err := helper.ParseFloat(fmt.Sprintf("%s", resDM.Rows[0].Fields["__score"]))
Expect(err).NotTo(HaveOccurred())
score2DM, err := strconv.ParseFloat(fmt.Sprintf("%s", resDM.Rows[1].Fields["__score"]), 64)
score2DM, err := helper.ParseFloat(fmt.Sprintf("%s", resDM.Rows[1].Fields["__score"]))
Expect(err).NotTo(HaveOccurred())
Expect(score1DM).To(BeNumerically(">", score2DM))
@ -1684,6 +1695,56 @@ var _ = Describe("RediSearch commands Resp 2", Label("search"), func() {
Expect(resUint8.Docs[0].ID).To(BeEquivalentTo("doc1"))
})
It("should return special float scores in FT.SEARCH vecsim", Label("search", "ftsearch", "vecsim"), func() {
SkipBeforeRedisVersion(7.4, "doesn't work with older redis stack images")
vecField := &redis.FTFlatOptions{
Type: "FLOAT32",
Dim: 2,
DistanceMetric: "IP",
}
_, err := client.FTCreate(ctx, "idx_vec",
&redis.FTCreateOptions{OnHash: true, Prefix: []interface{}{"doc:"}},
&redis.FieldSchema{FieldName: "vector", FieldType: redis.SearchFieldTypeVector, VectorArgs: &redis.FTVectorArgs{FlatOptions: vecField}}).Result()
Expect(err).NotTo(HaveOccurred())
WaitForIndexing(client, "idx_vec")
bigPos := []float32{1e38, 1e38}
bigNeg := []float32{-1e38, -1e38}
nanVec := []float32{float32(math.NaN()), 0}
negNanVec := []float32{float32(math.Copysign(math.NaN(), -1)), 0}
client.HSet(ctx, "doc:1", "vector", encodeFloat32Vector(bigPos))
client.HSet(ctx, "doc:2", "vector", encodeFloat32Vector(bigNeg))
client.HSet(ctx, "doc:3", "vector", encodeFloat32Vector(nanVec))
client.HSet(ctx, "doc:4", "vector", encodeFloat32Vector(negNanVec))
searchOptions := &redis.FTSearchOptions{WithScores: true, Params: map[string]interface{}{"vec": encodeFloat32Vector(bigPos)}}
res, err := client.FTSearchWithArgs(ctx, "idx_vec", "*=>[KNN 4 @vector $vec]", searchOptions).Result()
Expect(err).NotTo(HaveOccurred())
Expect(res.Total).To(BeEquivalentTo(4))
var scores []float64
for _, row := range res.Docs {
raw := fmt.Sprintf("%v", row.Fields["__vector_score"])
f, err := helper.ParseFloat(raw)
Expect(err).NotTo(HaveOccurred())
scores = append(scores, f)
}
Expect(scores).To(ContainElement(BeNumerically("==", math.Inf(1))))
Expect(scores).To(ContainElement(BeNumerically("==", math.Inf(-1))))
// For NaN values, use a custom check since NaN != NaN in floating point math
nanCount := 0
for _, score := range scores {
if math.IsNaN(score) {
nanCount++
}
}
Expect(nanCount).To(Equal(2))
})
It("should fail when using a non-zero offset with a zero limit", Label("search", "ftsearch"), func() {
SkipBeforeRedisVersion(7.9, "requires Redis 8.x")
val, err := client.FTCreate(ctx, "testIdx", &redis.FTCreateOptions{}, &redis.FieldSchema{