1
0
mirror of https://github.com/tensorchord/pgvecto.rs.git synced 2025-04-18 21:44:00 +03:00
pgvecto.rs/bindings/python/tests/test_sqlalchemy.py
Usamoi 5c0450274d
feat: fp16 vector (#178)
* feat: fp16 vector

Signed-off-by: usamoi <usamoi@outlook.com>

* feat: detect avx512fp16

Signed-off-by: usamoi <usamoi@outlook.com>

* fix: install clang-16 for ci

Signed-off-by: usamoi <usamoi@outlook.com>

* fix: clippy

Signed-off-by: usamoi <usamoi@outlook.com>

* fix: rename c to pgvectorsc

Signed-off-by: usamoi <usamoi@outlook.com>

* feat: hand-writing avx512fp16

Signed-off-by: usamoi <usamoi@outlook.com>

* fix: index on fp16

Signed-off-by: usamoi <usamoi@outlook.com>

* feat: hand-writing avx2

Signed-off-by: usamoi <usamoi@outlook.com>

* fix: clippy

Signed-off-by: usamoi <usamoi@outlook.com>

* fix: add rerun in build script

Signed-off-by: usamoi <usamoi@outlook.com>

* fix: cross compilation

Signed-off-by: usamoi <usamoi@outlook.com>

* fix: do not leave uninitialized bytes in datatype input function

Signed-off-by: usamoi <usamoi@outlook.com>

* fix: compiler built-in function calling convention workaround

Signed-off-by: usamoi <usamoi@outlook.com>

* fix: cross compile on aarch64

Signed-off-by: usamoi <usamoi@outlook.com>

* fix: fix detect avx512fp16

Signed-off-by: usamoi <usamoi@outlook.com>

* fix: avx512 codegen by multiversion

Signed-off-by: usamoi <usamoi@outlook.com>

* fix: enable more target features for c

Signed-off-by: usamoi <usamoi@outlook.com>

* fix: use tensorchord/stdarch

Signed-off-by: usamoi <usamoi@outlook.com>

* fix: ci

Signed-off-by: usamoi <usamoi@outlook.com>

* fix: remove no-run cross test

Signed-off-by: usamoi <usamoi@outlook.com>

* fix: vbase

Signed-off-by: usamoi <usamoi@outlook.com>

* fix: error and document

Signed-off-by: usamoi <usamoi@outlook.com>

* [skip ci]

Signed-off-by: usamoi <usamoi@outlook.com>

---------

Signed-off-by: usamoi <usamoi@outlook.com>
2023-12-14 09:50:52 +00:00

148 lines
4.1 KiB
Python

import numpy as np
import pytest
from sqlalchemy import Index, Integer, create_engine, delete, insert, select, text
from sqlalchemy.exc import StatementError
from sqlalchemy.orm import DeclarativeBase, Mapped, Session, mapped_column
from pgvecto_rs.sqlalchemy import Vector
from tests import (
EXPECTED_NEG_COS_DIS,
EXPECTED_NEG_DOT_PROD_DIS,
EXPECTED_SQRT_EUCLID_DIS,
INVALID_VECTORS,
LEN_AFT_DEL,
OP_NEG_COS_DIS,
OP_NEG_DOT_PROD_DIS,
OP_SQRT_EUCLID_DIS,
TOML_SETTINGS,
URL,
VECTORS,
)
class Base(DeclarativeBase):
pass
class Document(Base):
__tablename__ = "tb_test_item"
id: Mapped[int] = mapped_column(Integer, primary_key=True)
embedding: Mapped[np.ndarray] = mapped_column(Vector(3))
def __repr__(self) -> str:
return f"{self.embedding}"
@pytest.fixture(scope="module")
def session():
"""Connect to the test db pointed by the URL. Can check more details
in `tests/__init__.py`
"""
engine = create_engine(URL.replace("postgresql", "postgresql+psycopg"))
# ensure that we have installed pgvector.rs extension
with engine.connect() as conn:
conn.execute(text("CREATE EXTENSION IF NOT EXISTS vectors"))
conn.execute(text("DROP TABLE IF EXISTS tb_test_item"))
conn.commit()
with Session(engine) as session:
Document.metadata.create_all(engine)
try:
yield session
finally:
session.rollback()
Document.metadata.drop_all(engine)
# =================================
# Prefix functional tests
# =================================
@pytest.mark.parametrize(("index_name", "index_setting"), TOML_SETTINGS.items())
def test_create_index(session: Session, index_name: str, index_setting: str):
index = Index(
index_name,
Document.embedding,
postgresql_using="vectors",
postgresql_with={"options": f"$${index_setting}$$"},
postgresql_ops={"embedding": "vector_l2_ops"},
)
index.create(session.bind)
session.commit()
@pytest.mark.parametrize(("i", "e"), enumerate(INVALID_VECTORS))
def test_invalid_insert(session: Session, i: int, e: np.ndarray):
try:
session.execute(insert(Document).values(id=i, embedding=e))
except StatementError:
pass
else:
raise AssertionError( # noqa: TRY003
f"failed to raise invalid value error for {i}th vector {e}",
)
finally:
session.rollback()
# =================================
# Semetic search tests
# =================================
def test_insert(session: Session):
for stat in [
insert(Document).values(id=i, embedding=e) for i, e in enumerate(VECTORS)
]:
session.execute(stat)
session.commit()
for row in session.scalars(select(Document)):
assert np.allclose(row.embedding, VECTORS[row.id], atol=1e-10)
def test_squared_euclidean_distance(session: Session):
for row in session.execute(
select(
Document.id,
Document.embedding.squared_euclidean_distance(OP_SQRT_EUCLID_DIS),
),
):
(i, res) = row
assert np.allclose(EXPECTED_SQRT_EUCLID_DIS[i], res, atol=1e-10)
def test_negative_dot_product_distance(session: Session):
for row in session.execute(
select(
Document.id,
Document.embedding.negative_dot_product_distance(OP_NEG_DOT_PROD_DIS),
),
):
(i, res) = row
assert np.allclose(EXPECTED_NEG_DOT_PROD_DIS[i], res, atol=1e-10)
def test_negative_cosine_distance(session: Session):
for row in session.execute(
select(
Document.id, Document.embedding.negative_cosine_distance(OP_NEG_COS_DIS)
),
):
(i, res) = row
assert np.allclose(EXPECTED_NEG_COS_DIS[i], res, atol=1e-10)
# # =================================
# # Suffix functional tests
# # =================================
def test_delete(session: Session):
session.execute(delete(Document).where(Document.embedding == VECTORS[0]))
session.commit()
res = session.execute(select(Document))
assert len(list(res)) == LEN_AFT_DEL