You've already forked pgvecto.rs
mirror of
https://github.com/tensorchord/pgvecto.rs.git
synced 2025-08-08 14:22:07 +03:00
feat: add Python bindings by SQLAlchemy (#95)
* feat: init py bindings Signed-off-by: Aurutus <emslhy@hotmail.com> * feat: basic binding for sqlalchemy Signed-off-by: Aurutus <emslhy@hotmail.com> * fix: return value for decorator Signed-off-by: Aurutus <emslhy@hotmail.com> * test: impl basic db test order Signed-off-by: Aurutus <emslhy@hotmail.com> * test: add invalid value checker Signed-off-by: Aurutus <emslhy@hotmail.com> * test: fix insert value check Signed-off-by: Aurutus <emslhy@hotmail.com> * feat: impl vector operator Signed-off-by: Aurutus <emslhy@hotmail.com> * test: finish op tests Signed-off-by: Aurutus <emslhy@hotmail.com> * test: add test for creating index Signed-off-by: Aurutus <emslhy@hotmail.com> * docs: complete basic docs Signed-off-by: Aurutus <emslhy@hotmail.com> * chore: add python requirements Signed-off-by: Aurutus <emslhy@hotmail.com> * chore: change requirement.txt Signed-off-by: 盐粒 Yanli <mail@yanli.one> * feat: change the structure of code Signed-off-by: 盐粒 Yanli <mail@yanli.one> * test: update the test Signed-off-by: 盐粒 Yanli <mail@yanli.one> * chore: rewrite the readme with SQLAlchemy ORM Signed-off-by: 盐粒 Yanli <mail@yanli.one> chore: fix readme Signed-off-by: 盐粒 Yanli <mail@yanli.one> chore: fix readme Signed-off-by: 盐粒 Yanli <mail@yanli.one> * test: rewrite tests using Alchemy ORM Signed-off-by: 盐粒 Yanli <mail@yanli.one> * feat: delete serializer for binary (since it's not available for now) Signed-off-by: 盐粒 Yanli <mail@yanli.one> * feat: use psycopg 3 for the SQLALchemy Signed-off-by: 盐粒 Yanli <mail@yanli.one> * test: update tests Signed-off-by: 盐粒 Yanli <mail@yanli.one> * test: comment ivf and vamana index due to #97 Signed-off-by: 盐粒 Yanli <mail@yanli.one> * chore: format code Signed-off-by: 盐粒 Yanli <mail@yanli.one> * test: update test_invalid_insert Signed-off-by: 盐粒 Yanli <mail@yanli.one> * fix: rename pgvector_rs to pgvecto_rs Signed-off-by: 盐粒 Yanli <mail@yanli.one> * feat: re-construct to use PDM Signed-off-by: 盐粒 Yanli <mail@yanli.one> * chore: fix readme and add LICENSE Signed-off-by: 盐粒 Yanli <mail@yanli.one> * fix: tox.ini and pyproject.toml Signed-off-by: 盐粒 Yanli <mail@yanli.one> * feat: add Github Action (for example) Signed-off-by: 盐粒 Yanli <mail@yanli.one> feat: fix Action * feat: support python_check in Action Signed-off-by: 盐粒 Yanli <mail@yanli.one> * feat: delete Action example Signed-off-by: 盐粒 Yanli <mail@yanli.one> * feat: enhance lint check && fix Signed-off-by: 盐粒 Yanli <mail@yanli.one> * fix: test problem Signed-off-by: 盐粒 Yanli <mail@yanli.one> * feat: try to add python_release for CI Signed-off-by: 盐粒 Yanli <mail@yanli.one> * feat: Complete Python Release CI Signed-off-by: 盐粒 Yanli <mail@yanli.one> * test: try to test the package in more platform Signed-off-by: 盐粒 Yanli <mail@yanli.one> * test: try to fix test platforms Signed-off-by: 盐粒 Yanli <mail@yanli.one> * test: fix dependencies for multi-platform Signed-off-by: 盐粒 Yanli <mail@yanli.one> * fix: update lock file Signed-off-by: 盐粒 Yanli <mail@yanli.one> * test: fix test for macOS Signed-off-by: 盐粒 Yanli <mail@yanli.one> * test: delete test on macOS and Windows, since no docker img is provided Signed-off-by: 盐粒 Yanli <mail@yanli.one> * feat: use workflow_dispatch to manually trigger Python Release Signed-off-by: 盐粒 Yanli <mail@yanli.one> * feat: enhance lint && simplify its Action Signed-off-by: 盐粒 Yanli <mail@yanli.one> * chore: update readme Signed-off-by: 盐粒 Yanli <mail@yanli.one> --------- Signed-off-by: Aurutus <emslhy@hotmail.com> Signed-off-by: 盐粒 Yanli <mail@yanli.one> Co-authored-by: Aurutus <emslhy@hotmail.com>
This commit is contained in:
99
bindings/python/tests/__init__.py
Normal file
99
bindings/python/tests/__init__.py
Normal file
@@ -0,0 +1,99 @@
|
||||
import os
|
||||
import toml
|
||||
import numpy as np
|
||||
|
||||
PORT = os.getenv("DB_PORT", 5432)
|
||||
HOST = os.getenv("DB_HOST", "localhost")
|
||||
USER = os.getenv("DB_USER", "postgres")
|
||||
PASS = os.getenv("DB_PASS", "mysecretpassword")
|
||||
DB_NAME = os.getenv("DB_NAME", "postgres")
|
||||
|
||||
# Run tests with shell:
|
||||
# DB_HOST=localhost DB_USER=postgres DB_PASS=password DB_NAME=postgres python3 -m pytest bindings/python/tests/
|
||||
URL = "postgresql://{username}:{password}@{host}:{port}/{db_name}".format(
|
||||
port=PORT,
|
||||
host=HOST,
|
||||
username=USER,
|
||||
password=PASS,
|
||||
db_name=DB_NAME,
|
||||
)
|
||||
|
||||
|
||||
# ==== test_create_index ====
|
||||
|
||||
TOML_SETTINGS = {
|
||||
"flat": "$${}$$".format(
|
||||
toml.dumps(
|
||||
{
|
||||
"capacity": 2097152,
|
||||
"algorithm": {"flat": {}},
|
||||
}
|
||||
)
|
||||
),
|
||||
"hnsw": "$${}$$".format(
|
||||
toml.dumps(
|
||||
{
|
||||
"capacity": 2097152,
|
||||
"algorithm": {"hnsw": {}},
|
||||
}
|
||||
)
|
||||
),
|
||||
# "ivf": "$${}$$".format(
|
||||
# toml.dumps(
|
||||
# {
|
||||
# "capacity": 2097152,
|
||||
# "algorithm": {"ivf": {}},
|
||||
# }
|
||||
# )
|
||||
# ),
|
||||
# "vamana": "$${}$$".format(
|
||||
# toml.dumps(
|
||||
# {
|
||||
# "capacity": 2097152,
|
||||
# "algorithm": {"vamana": {}},
|
||||
# }
|
||||
# )
|
||||
# ),
|
||||
}
|
||||
|
||||
# ==== test_invalid_insert ====
|
||||
INVALID_VECTORS = [
|
||||
[1, 2, 3, 4],
|
||||
[
|
||||
1,
|
||||
],
|
||||
[[1, 2], [3, 4], [5, 6]],
|
||||
["123.", "123", "a"],
|
||||
np.array([1, 2, 3, 4]),
|
||||
np.array([1, "3", 3]),
|
||||
np.zeros(shape=(1, 2)),
|
||||
]
|
||||
|
||||
# =================================
|
||||
# Semetic search tests
|
||||
# =================================
|
||||
VECTORS = [
|
||||
[1, 2, 3],
|
||||
[0.0, -45, 2.34],
|
||||
np.ones(shape=(3)),
|
||||
]
|
||||
OP_SQRT_EUCLID_DIS = [0, 0, 0]
|
||||
EXPECTED_SQRT_EUCLID_DIS = [14.0, 2030.4756, 3.0]
|
||||
OP_NEG_DOT_PROD_DIS = [1, 2, 4]
|
||||
EXPECTED_NEG_DOT_PROD_DIS = [-17.0, 80.64, -7.0]
|
||||
OP_NEG_COS_DIS = [3, 2, 1]
|
||||
EXPECTED_NEG_COS_DIS = [-0.7142857, 0.5199225, -0.92582005]
|
||||
|
||||
# ==== test_delete ====
|
||||
LEN_AFT_DEL = 2
|
||||
|
||||
__all__ = [
|
||||
"URL",
|
||||
"TOML_SETTINGS",
|
||||
"INVALID_VECTORS",
|
||||
"VECTORS",
|
||||
"EXPECTED_SQRT_EUCLID_DIS",
|
||||
"EXPECTED_NEG_DOT_PROD_DIS",
|
||||
"EXPECTED_NEG_COS_DIS",
|
||||
"LEN_AFT_DEL",
|
||||
]
|
147
bindings/python/tests/test_sqlalchemy.py
Normal file
147
bindings/python/tests/test_sqlalchemy.py
Normal file
@@ -0,0 +1,147 @@
|
||||
import pytest
|
||||
import numpy as np
|
||||
from tests import (
|
||||
URL,
|
||||
TOML_SETTINGS,
|
||||
VECTORS,
|
||||
INVALID_VECTORS,
|
||||
OP_NEG_DOT_PROD_DIS,
|
||||
EXPECTED_SQRT_EUCLID_DIS,
|
||||
OP_SQRT_EUCLID_DIS,
|
||||
EXPECTED_NEG_DOT_PROD_DIS,
|
||||
OP_NEG_COS_DIS,
|
||||
EXPECTED_NEG_COS_DIS,
|
||||
LEN_AFT_DEL,
|
||||
)
|
||||
from sqlalchemy import create_engine, select, text, insert, delete
|
||||
from sqlalchemy import Integer, Index
|
||||
from pgvecto_rs.sqlalchemy import Vector
|
||||
from sqlalchemy.orm import Session, DeclarativeBase, mapped_column, Mapped
|
||||
from sqlalchemy.exc import StatementError
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
pass
|
||||
|
||||
|
||||
class Document(Base):
|
||||
__tablename__ = "tb_test_item"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True)
|
||||
embedding: Mapped[np.ndarray] = mapped_column(Vector(3))
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.embedding}"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def session():
|
||||
"""
|
||||
Connect to the test db pointed by the URL. Can check more details
|
||||
in `tests/__init__.py`
|
||||
"""
|
||||
|
||||
engine = create_engine(URL.replace("postgresql", "postgresql+psycopg"))
|
||||
|
||||
# ensure that we have installed pgvector.rs extension
|
||||
with engine.connect() as conn:
|
||||
conn.execute(text("CREATE EXTENSION IF NOT EXISTS vectors"))
|
||||
conn.execute(text("DROP TABLE IF EXISTS tb_test_item"))
|
||||
conn.commit()
|
||||
|
||||
with Session(engine) as session:
|
||||
Document.metadata.create_all(engine)
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
session.rollback()
|
||||
Document.metadata.drop_all(engine)
|
||||
|
||||
|
||||
# =================================
|
||||
# Prefix functional tests
|
||||
# =================================
|
||||
|
||||
|
||||
@pytest.mark.parametrize("index_name,index_setting", TOML_SETTINGS.items())
|
||||
def test_create_index(session: Session, index_name: str, index_setting: str):
|
||||
index = Index(
|
||||
index_name,
|
||||
Document.embedding,
|
||||
postgresql_using="vectors",
|
||||
postgresql_with={"options": index_setting},
|
||||
postgresql_ops={"embedding": "l2_ops"},
|
||||
)
|
||||
index.create(session.bind)
|
||||
session.rollback()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("i,e", enumerate(INVALID_VECTORS))
|
||||
def test_invalid_insert(session: Session, i: int, e: np.ndarray):
|
||||
try:
|
||||
session.execute(insert(Document).values(id=i, embedding=e))
|
||||
except StatementError:
|
||||
pass
|
||||
else:
|
||||
raise AssertionError(
|
||||
"failed to raise invalid value error for {}th vector {}".format(i, e),
|
||||
)
|
||||
finally:
|
||||
session.rollback()
|
||||
|
||||
|
||||
# =================================
|
||||
# Semetic search tests
|
||||
# =================================
|
||||
|
||||
|
||||
def test_insert(session: Session):
|
||||
for stat in [
|
||||
insert(Document).values(id=i, embedding=e) for i, e in enumerate(VECTORS)
|
||||
]:
|
||||
session.execute(stat)
|
||||
session.commit()
|
||||
for row in session.scalars(select(Document)):
|
||||
assert np.allclose(row.embedding, VECTORS[row.id], atol=1e-10)
|
||||
|
||||
|
||||
def test_squared_euclidean_distance(session: Session):
|
||||
for row in session.execute(
|
||||
select(
|
||||
Document.id,
|
||||
Document.embedding.squared_euclidean_distance(OP_SQRT_EUCLID_DIS),
|
||||
)
|
||||
):
|
||||
(i, res) = row
|
||||
assert np.allclose(EXPECTED_SQRT_EUCLID_DIS[i], res, atol=1e-10)
|
||||
|
||||
|
||||
def test_negative_dot_product_distance(session: Session):
|
||||
for row in session.execute(
|
||||
select(
|
||||
Document.id,
|
||||
Document.embedding.negative_dot_product_distance(OP_NEG_DOT_PROD_DIS),
|
||||
)
|
||||
):
|
||||
(i, res) = row
|
||||
assert np.allclose(EXPECTED_NEG_DOT_PROD_DIS[i], res, atol=1e-10)
|
||||
|
||||
|
||||
def test_negative_cosine_distance(session: Session):
|
||||
for row in session.execute(
|
||||
select(Document.id, Document.embedding.negative_cosine_distance(OP_NEG_COS_DIS))
|
||||
):
|
||||
(i, res) = row
|
||||
assert np.allclose(EXPECTED_NEG_COS_DIS[i], res, atol=1e-10)
|
||||
|
||||
|
||||
# # =================================
|
||||
# # Suffix functional tests
|
||||
# # =================================
|
||||
|
||||
|
||||
def test_delete(session: Session):
|
||||
session.execute(delete(Document).where(Document.embedding == VECTORS[0]))
|
||||
session.commit()
|
||||
res = session.execute(select(Document))
|
||||
assert len(list(res)) == LEN_AFT_DEL
|
Reference in New Issue
Block a user