You've already forked pgvecto.rs
mirror of
https://github.com/tensorchord/pgvecto.rs.git
synced 2025-08-08 14:22:07 +03:00
feat: Add high-level API for Python (#123)
* feat: init high level api
Signed-off-by: 盐粒 Yanli <mail@yanli.one>
* feat: pretify things
Signed-off-by: 盐粒 Yanli <mail@yanli.one>
* feat: add test && filter subpackage
Signed-off-by: 盐粒 Yanli <mail@yanli.one>
* fix: dependency
Signed-off-by: 盐粒 Yanli <mail@yanli.one>
* test: fix Action
Signed-off-by: 盐粒 Yanli <mail@yanli.one>
* feat: add isort for format
Signed-off-by: 盐粒 Yanli <mail@yanli.one>
* fix: create extension with init client
Signed-off-by: 盐粒 Yanli <mail@yanli.one>
* docs: add readme
Signed-off-by: 盐粒 Yanli <mail@yanli.one>
* chore: bump version
Signed-off-by: 盐粒 Yanli <mail@yanli.one>
* feat: rename things
Signed-off-by: 盐粒 Yanli <mail@yanli.one>
* feat: delete embedder
Signed-off-by: 盐粒 Yanli <mail@yanli.one>
* feat: simplify filter
Signed-off-by: 盐粒 Yanli <mail@yanli.one>
* feat: config ruff
Signed-off-by: 盐粒 Yanli <mail@yanli.one>
* feat: clean up client.py
Signed-off-by: 盐粒 Yanli <mail@yanli.one>
* feat: modify PGVectoRs interfaces
Signed-off-by: 盐粒 Yanli <mail@yanli.one>
* chore: add docs
Signed-off-by: 盐粒 Yanli <mail@yanli.one>
* feat: delete text column
Signed-off-by: 盐粒 Yanli <mail@yanli.one>
* rename things
Signed-off-by: 盐粒 Yanli <mail@yanli.one>
* Revert "feat: delete text column"
This reverts commit df5452b9ad
.
Signed-off-by: 盐粒 Yanli <mail@yanli.one>
* feat: rename insert
Signed-off-by: 盐粒 Yanli <mail@yanli.one>
* chore: delete __all__ for filters.py
Signed-off-by: 盐粒 Yanli <mail@yanli.one>
* chore: update things
Signed-off-by: 盐粒 Yanli <mail@yanli.one>
* chore: update lint config
Signed-off-by: 盐粒 Yanli <mail@yanli.one>
* pretify things
Signed-off-by: 盐粒 Yanli <mail@yanli.one>
* pdm lock -G :all -S direct_minimal_versions
Signed-off-by: 盐粒 Yanli <mail@yanli.one>
* replace relative import
Signed-off-by: 盐粒 Yanli <mail@yanli.one>
* change Record.from_text
Signed-off-by: 盐粒 Yanli <mail@yanli.one>
* make lint happ
Signed-off-by: 盐粒 Yanli <mail@yanli.one>
* fix Record.from_text
Signed-off-by: 盐粒 Yanli <mail@yanli.one>
---------
Signed-off-by: 盐粒 Yanli <mail@yanli.one>
This commit is contained in:
96
bindings/python/tests/test_sdk.py
Normal file
96
bindings/python/tests/test_sdk.py
Normal file
@@ -0,0 +1,96 @@
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from pgvecto_rs.sdk import Filter, PGVectoRs, Record, filters
|
||||
from tests import (
|
||||
EXPECTED_NEG_COS_DIS,
|
||||
EXPECTED_NEG_DOT_PROD_DIS,
|
||||
EXPECTED_SQRT_EUCLID_DIS,
|
||||
OP_NEG_COS_DIS,
|
||||
OP_NEG_DOT_PROD_DIS,
|
||||
OP_SQRT_EUCLID_DIS,
|
||||
URL,
|
||||
VECTORS,
|
||||
)
|
||||
|
||||
URL = URL.replace("postgresql", "postgresql+psycopg")
|
||||
mockTexts = {
|
||||
"text0": VECTORS[0],
|
||||
"text1": VECTORS[1],
|
||||
"text2": VECTORS[2],
|
||||
}
|
||||
|
||||
|
||||
class MockEmbedder:
|
||||
def embed(self, text: str) -> np.ndarray:
|
||||
if isinstance(mockTexts[text], list):
|
||||
return np.array(mockTexts[text], dtype=np.float32)
|
||||
return mockTexts[text]
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def client():
|
||||
client = PGVectoRs(db_url=URL, collection_name="empty", dimension=3)
|
||||
try:
|
||||
records1 = [
|
||||
Record.from_text(t, v, {"src": "src1"}) for t, v in mockTexts.items()
|
||||
]
|
||||
records2 = [
|
||||
Record.from_text(t, v, {"src": "src2"}) for t, v in mockTexts.items()
|
||||
]
|
||||
client.insert(records1)
|
||||
client.insert(records2)
|
||||
yield client
|
||||
finally:
|
||||
client.drop()
|
||||
|
||||
|
||||
filter_src1 = filters.meta_contains({"src": "src1"})
|
||||
filter_src2: Filter = lambda r: r.meta.contains({"src": "src2"})
|
||||
|
||||
|
||||
@pytest.mark.parametrize("filter", [filter_src1, filter_src2])
|
||||
@pytest.mark.parametrize(
|
||||
"dis_op, dis_oprand, dis_expected",
|
||||
zip(
|
||||
["<->", "<#>", "<=>"],
|
||||
[OP_SQRT_EUCLID_DIS, OP_NEG_DOT_PROD_DIS, OP_NEG_COS_DIS],
|
||||
[EXPECTED_SQRT_EUCLID_DIS, EXPECTED_NEG_DOT_PROD_DIS, EXPECTED_NEG_COS_DIS],
|
||||
),
|
||||
)
|
||||
def test_search_filter_and_op(
|
||||
client: PGVectoRs,
|
||||
filter: Filter,
|
||||
dis_op: str,
|
||||
dis_oprand: List[float],
|
||||
dis_expected: List[float],
|
||||
):
|
||||
for rec, dis in client.search(dis_oprand, dis_op, top_k=99, filter=filter):
|
||||
cnt = None
|
||||
for i in range(len(VECTORS)):
|
||||
if np.allclose(rec.embedding, VECTORS[i]):
|
||||
cnt = i
|
||||
break
|
||||
assert np.allclose(dis, dis_expected[cnt])
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"dis_op, dis_oprand, dis_expected",
|
||||
zip(
|
||||
["<->", "<#>", "<=>"],
|
||||
[OP_SQRT_EUCLID_DIS, OP_NEG_DOT_PROD_DIS, OP_NEG_COS_DIS],
|
||||
[EXPECTED_SQRT_EUCLID_DIS, EXPECTED_NEG_DOT_PROD_DIS, EXPECTED_NEG_COS_DIS],
|
||||
),
|
||||
)
|
||||
def test_search_order_and_limit(
|
||||
client: PGVectoRs,
|
||||
dis_op: str,
|
||||
dis_oprand: List[float],
|
||||
dis_expected: List[float],
|
||||
):
|
||||
dis_expected = dis_expected.copy()
|
||||
dis_expected.sort()
|
||||
for i, (rec, dis) in enumerate(client.search(dis_oprand, dis_op, top_k=4)):
|
||||
assert np.allclose(dis, dis_expected[i // 2])
|
Reference in New Issue
Block a user