mirror of
https://github.com/redis/go-redis.git
synced 2025-05-31 21:01:13 +03:00
utils: export ParseFloat and MustParseFloat wrapping internal utils (#3371)
* utils: expose ParseFloat via new public utils package * add tests for special float values in vector search
This commit is contained in:
parent
f174acba52
commit
42c32846e6
11
helper/helper.go
Normal file
11
helper/helper.go
Normal file
@ -0,0 +1,11 @@
|
||||
package helper
|
||||
|
||||
import "github.com/redis/go-redis/v9/internal/util"
|
||||
|
||||
func ParseFloat(s string) (float64, error) {
|
||||
return util.ParseStringToFloat(s)
|
||||
}
|
||||
|
||||
func MustParseFloat(s string) float64 {
|
||||
return util.MustParseFloat(s)
|
||||
}
|
30
internal/util/convert.go
Normal file
30
internal/util/convert.go
Normal file
@ -0,0 +1,30 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
// ParseFloat parses a Redis RESP3 float reply into a Go float64,
|
||||
// handling "inf", "-inf", "nan" per Redis conventions.
|
||||
func ParseStringToFloat(s string) (float64, error) {
|
||||
switch s {
|
||||
case "inf":
|
||||
return math.Inf(1), nil
|
||||
case "-inf":
|
||||
return math.Inf(-1), nil
|
||||
case "nan", "-nan":
|
||||
return math.NaN(), nil
|
||||
}
|
||||
return strconv.ParseFloat(s, 64)
|
||||
}
|
||||
|
||||
// MustParseFloat is like ParseFloat but panics on parse errors.
|
||||
func MustParseFloat(s string) float64 {
|
||||
f, err := ParseStringToFloat(s)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("redis: failed to parse float %q: %v", s, err))
|
||||
}
|
||||
return f
|
||||
}
|
40
internal/util/convert_test.go
Normal file
40
internal/util/convert_test.go
Normal file
@ -0,0 +1,40 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"math"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestParseStringToFloat(t *testing.T) {
|
||||
tests := []struct {
|
||||
in string
|
||||
want float64
|
||||
ok bool
|
||||
}{
|
||||
{"1.23", 1.23, true},
|
||||
{"inf", math.Inf(1), true},
|
||||
{"-inf", math.Inf(-1), true},
|
||||
{"nan", math.NaN(), true},
|
||||
{"oops", 0, false},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
got, err := ParseStringToFloat(tc.in)
|
||||
if tc.ok {
|
||||
if err != nil {
|
||||
t.Fatalf("ParseFloat(%q) error: %v", tc.in, err)
|
||||
}
|
||||
if math.IsNaN(tc.want) {
|
||||
if !math.IsNaN(got) {
|
||||
t.Errorf("ParseFloat(%q) = %v; want NaN", tc.in, got)
|
||||
}
|
||||
} else if got != tc.want {
|
||||
t.Errorf("ParseFloat(%q) = %v; want %v", tc.in, got, tc.want)
|
||||
}
|
||||
} else {
|
||||
if err == nil {
|
||||
t.Errorf("ParseFloat(%q) expected error, got nil", tc.in)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
@ -1,15 +1,18 @@
|
||||
package redis_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"math"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
. "github.com/bsm/ginkgo/v2"
|
||||
. "github.com/bsm/gomega"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"github.com/redis/go-redis/v9/helper"
|
||||
)
|
||||
|
||||
func WaitForIndexing(c *redis.Client, index string) {
|
||||
@ -27,6 +30,14 @@ func WaitForIndexing(c *redis.Client, index string) {
|
||||
}
|
||||
}
|
||||
|
||||
func encodeFloat32Vector(vec []float32) []byte {
|
||||
buf := new(bytes.Buffer)
|
||||
for _, v := range vec {
|
||||
binary.Write(buf, binary.LittleEndian, v)
|
||||
}
|
||||
return buf.Bytes()
|
||||
}
|
||||
|
||||
var _ = Describe("RediSearch commands Resp 2", Label("search"), func() {
|
||||
ctx := context.TODO()
|
||||
var client *redis.Client
|
||||
@ -693,9 +704,9 @@ var _ = Describe("RediSearch commands Resp 2", Label("search"), func() {
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(res).ToNot(BeNil())
|
||||
Expect(len(res.Rows)).To(BeEquivalentTo(2))
|
||||
score1, err := strconv.ParseFloat(fmt.Sprintf("%s", res.Rows[0].Fields["__score"]), 64)
|
||||
score1, err := helper.ParseFloat(fmt.Sprintf("%s", res.Rows[0].Fields["__score"]))
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
score2, err := strconv.ParseFloat(fmt.Sprintf("%s", res.Rows[1].Fields["__score"]), 64)
|
||||
score2, err := helper.ParseFloat(fmt.Sprintf("%s", res.Rows[1].Fields["__score"]))
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(score1).To(BeNumerically(">", score2))
|
||||
|
||||
@ -712,9 +723,9 @@ var _ = Describe("RediSearch commands Resp 2", Label("search"), func() {
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(resDM).ToNot(BeNil())
|
||||
Expect(len(resDM.Rows)).To(BeEquivalentTo(2))
|
||||
score1DM, err := strconv.ParseFloat(fmt.Sprintf("%s", resDM.Rows[0].Fields["__score"]), 64)
|
||||
score1DM, err := helper.ParseFloat(fmt.Sprintf("%s", resDM.Rows[0].Fields["__score"]))
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
score2DM, err := strconv.ParseFloat(fmt.Sprintf("%s", resDM.Rows[1].Fields["__score"]), 64)
|
||||
score2DM, err := helper.ParseFloat(fmt.Sprintf("%s", resDM.Rows[1].Fields["__score"]))
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(score1DM).To(BeNumerically(">", score2DM))
|
||||
|
||||
@ -1684,6 +1695,56 @@ var _ = Describe("RediSearch commands Resp 2", Label("search"), func() {
|
||||
Expect(resUint8.Docs[0].ID).To(BeEquivalentTo("doc1"))
|
||||
})
|
||||
|
||||
It("should return special float scores in FT.SEARCH vecsim", Label("search", "ftsearch", "vecsim"), func() {
|
||||
SkipBeforeRedisVersion(7.4, "doesn't work with older redis stack images")
|
||||
|
||||
vecField := &redis.FTFlatOptions{
|
||||
Type: "FLOAT32",
|
||||
Dim: 2,
|
||||
DistanceMetric: "IP",
|
||||
}
|
||||
_, err := client.FTCreate(ctx, "idx_vec",
|
||||
&redis.FTCreateOptions{OnHash: true, Prefix: []interface{}{"doc:"}},
|
||||
&redis.FieldSchema{FieldName: "vector", FieldType: redis.SearchFieldTypeVector, VectorArgs: &redis.FTVectorArgs{FlatOptions: vecField}}).Result()
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
WaitForIndexing(client, "idx_vec")
|
||||
|
||||
bigPos := []float32{1e38, 1e38}
|
||||
bigNeg := []float32{-1e38, -1e38}
|
||||
nanVec := []float32{float32(math.NaN()), 0}
|
||||
negNanVec := []float32{float32(math.Copysign(math.NaN(), -1)), 0}
|
||||
|
||||
client.HSet(ctx, "doc:1", "vector", encodeFloat32Vector(bigPos))
|
||||
client.HSet(ctx, "doc:2", "vector", encodeFloat32Vector(bigNeg))
|
||||
client.HSet(ctx, "doc:3", "vector", encodeFloat32Vector(nanVec))
|
||||
client.HSet(ctx, "doc:4", "vector", encodeFloat32Vector(negNanVec))
|
||||
|
||||
searchOptions := &redis.FTSearchOptions{WithScores: true, Params: map[string]interface{}{"vec": encodeFloat32Vector(bigPos)}}
|
||||
res, err := client.FTSearchWithArgs(ctx, "idx_vec", "*=>[KNN 4 @vector $vec]", searchOptions).Result()
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(res.Total).To(BeEquivalentTo(4))
|
||||
|
||||
var scores []float64
|
||||
for _, row := range res.Docs {
|
||||
raw := fmt.Sprintf("%v", row.Fields["__vector_score"])
|
||||
f, err := helper.ParseFloat(raw)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
scores = append(scores, f)
|
||||
}
|
||||
|
||||
Expect(scores).To(ContainElement(BeNumerically("==", math.Inf(1))))
|
||||
Expect(scores).To(ContainElement(BeNumerically("==", math.Inf(-1))))
|
||||
|
||||
// For NaN values, use a custom check since NaN != NaN in floating point math
|
||||
nanCount := 0
|
||||
for _, score := range scores {
|
||||
if math.IsNaN(score) {
|
||||
nanCount++
|
||||
}
|
||||
}
|
||||
Expect(nanCount).To(Equal(2))
|
||||
})
|
||||
|
||||
It("should fail when using a non-zero offset with a zero limit", Label("search", "ftsearch"), func() {
|
||||
SkipBeforeRedisVersion(7.9, "requires Redis 8.x")
|
||||
val, err := client.FTCreate(ctx, "testIdx", &redis.FTCreateOptions{}, &redis.FieldSchema{
|
||||
|
Loading…
x
Reference in New Issue
Block a user