1
0
mirror of https://github.com/tensorchord/pgvecto.rs.git synced 2025-08-08 14:22:07 +03:00

feat: Add high-level API for Python (#123)

* feat: init high level api

Signed-off-by: 盐粒 Yanli <mail@yanli.one>

* feat: pretify things

Signed-off-by: 盐粒 Yanli <mail@yanli.one>

* feat: add test && filter subpackage

Signed-off-by: 盐粒 Yanli <mail@yanli.one>

* fix: dependency

Signed-off-by: 盐粒 Yanli <mail@yanli.one>

* test: fix Action

Signed-off-by: 盐粒 Yanli <mail@yanli.one>

* feat: add isort for format

Signed-off-by: 盐粒 Yanli <mail@yanli.one>

* fix: create extension with init client

Signed-off-by: 盐粒 Yanli <mail@yanli.one>

* docs: add readme

Signed-off-by: 盐粒 Yanli <mail@yanli.one>

* chore: bump version

Signed-off-by: 盐粒 Yanli <mail@yanli.one>

* feat: rename things

Signed-off-by: 盐粒 Yanli <mail@yanli.one>

* feat: delete embedder

Signed-off-by: 盐粒 Yanli <mail@yanli.one>

* feat: simplify filter

Signed-off-by: 盐粒 Yanli <mail@yanli.one>

* feat: config ruff

Signed-off-by: 盐粒 Yanli <mail@yanli.one>

* feat: clean up client.py

Signed-off-by: 盐粒 Yanli <mail@yanli.one>

* feat: modify PGVectoRs interfaces

Signed-off-by: 盐粒 Yanli <mail@yanli.one>

* chore: add docs

Signed-off-by: 盐粒 Yanli <mail@yanli.one>

* feat: delete text column

Signed-off-by: 盐粒 Yanli <mail@yanli.one>

* rename things

Signed-off-by: 盐粒 Yanli <mail@yanli.one>

* Revert "feat: delete text column"

This reverts commit df5452b9ad.

Signed-off-by: 盐粒 Yanli <mail@yanli.one>

* feat: rename insert

Signed-off-by: 盐粒 Yanli <mail@yanli.one>

* chore: delete __all__ for filters.py

Signed-off-by: 盐粒 Yanli <mail@yanli.one>

* chore: update things

Signed-off-by: 盐粒 Yanli <mail@yanli.one>

* chore: update lint config

Signed-off-by: 盐粒 Yanli <mail@yanli.one>

* pretify things

Signed-off-by: 盐粒 Yanli <mail@yanli.one>

* pdm lock -G :all -S direct_minimal_versions

Signed-off-by: 盐粒 Yanli <mail@yanli.one>

* replace relative import

Signed-off-by: 盐粒 Yanli <mail@yanli.one>

* change Record.from_text

Signed-off-by: 盐粒 Yanli <mail@yanli.one>

* make lint happ

Signed-off-by: 盐粒 Yanli <mail@yanli.one>

* fix Record.from_text

Signed-off-by: 盐粒 Yanli <mail@yanli.one>

---------

Signed-off-by: 盐粒 Yanli <mail@yanli.one>
This commit is contained in:
盐粒 Yanli
2023-11-16 20:52:15 +08:00
committed by GitHub
parent 9ce6c3b4cb
commit f8344dd039
12 changed files with 844 additions and 174 deletions

View File

@@ -0,0 +1,96 @@
from typing import List
import numpy as np
import pytest
from pgvecto_rs.sdk import Filter, PGVectoRs, Record, filters
from tests import (
EXPECTED_NEG_COS_DIS,
EXPECTED_NEG_DOT_PROD_DIS,
EXPECTED_SQRT_EUCLID_DIS,
OP_NEG_COS_DIS,
OP_NEG_DOT_PROD_DIS,
OP_SQRT_EUCLID_DIS,
URL,
VECTORS,
)
URL = URL.replace("postgresql", "postgresql+psycopg")
mockTexts = {
"text0": VECTORS[0],
"text1": VECTORS[1],
"text2": VECTORS[2],
}
class MockEmbedder:
def embed(self, text: str) -> np.ndarray:
if isinstance(mockTexts[text], list):
return np.array(mockTexts[text], dtype=np.float32)
return mockTexts[text]
@pytest.fixture(scope="module")
def client():
client = PGVectoRs(db_url=URL, collection_name="empty", dimension=3)
try:
records1 = [
Record.from_text(t, v, {"src": "src1"}) for t, v in mockTexts.items()
]
records2 = [
Record.from_text(t, v, {"src": "src2"}) for t, v in mockTexts.items()
]
client.insert(records1)
client.insert(records2)
yield client
finally:
client.drop()
filter_src1 = filters.meta_contains({"src": "src1"})
filter_src2: Filter = lambda r: r.meta.contains({"src": "src2"})
@pytest.mark.parametrize("filter", [filter_src1, filter_src2])
@pytest.mark.parametrize(
"dis_op, dis_oprand, dis_expected",
zip(
["<->", "<#>", "<=>"],
[OP_SQRT_EUCLID_DIS, OP_NEG_DOT_PROD_DIS, OP_NEG_COS_DIS],
[EXPECTED_SQRT_EUCLID_DIS, EXPECTED_NEG_DOT_PROD_DIS, EXPECTED_NEG_COS_DIS],
),
)
def test_search_filter_and_op(
client: PGVectoRs,
filter: Filter,
dis_op: str,
dis_oprand: List[float],
dis_expected: List[float],
):
for rec, dis in client.search(dis_oprand, dis_op, top_k=99, filter=filter):
cnt = None
for i in range(len(VECTORS)):
if np.allclose(rec.embedding, VECTORS[i]):
cnt = i
break
assert np.allclose(dis, dis_expected[cnt])
@pytest.mark.parametrize(
"dis_op, dis_oprand, dis_expected",
zip(
["<->", "<#>", "<=>"],
[OP_SQRT_EUCLID_DIS, OP_NEG_DOT_PROD_DIS, OP_NEG_COS_DIS],
[EXPECTED_SQRT_EUCLID_DIS, EXPECTED_NEG_DOT_PROD_DIS, EXPECTED_NEG_COS_DIS],
),
)
def test_search_order_and_limit(
client: PGVectoRs,
dis_op: str,
dis_oprand: List[float],
dis_expected: List[float],
):
dis_expected = dis_expected.copy()
dis_expected.sort()
for i, (rec, dis) in enumerate(client.search(dis_oprand, dis_op, top_k=4)):
assert np.allclose(dis, dis_expected[i // 2])