1
0
mirror of https://github.com/tensorchord/pgvecto.rs.git synced 2025-07-29 08:21:12 +03:00

feat: add more ruff rules (#138)

* feat: add more ruff rules

Signed-off-by: 盐粒 Yanli <mail@yanli.one>

* chore: modified readme

Signed-off-by: 盐粒 Yanli <mail@yanli.one>

* rename error class

Signed-off-by: 盐粒 Yanli <mail@yanli.one>

---------

Signed-off-by: 盐粒 Yanli <mail@yanli.one>
This commit is contained in:
盐粒 Yanli
2023-11-17 17:47:28 +08:00
committed by GitHub
parent f8344dd039
commit f6e382d0fc
16 changed files with 138 additions and 86 deletions

View File

@ -106,6 +106,7 @@ pdm sync
Run lint:
```bash
pdm run format
pdm run fix
pdm run check
```

View File

@ -6,7 +6,7 @@ import psycopg
from pgvecto_rs.psycopg import register_vector
URL = "postgresql://{username}:{password}@{host}:{port}/{db_name}".format(
port=os.getenv("DB_PORT", 5432),
port=os.getenv("DB_PORT", "5432"),
host=os.getenv("DB_HOST", "localhost"),
username=os.getenv("DB_USER", "postgres"),
password=os.getenv("DB_PASS", "mysecretpassword"),
@ -18,7 +18,7 @@ with psycopg.connect(URL) as conn:
conn.execute("CREATE EXTENSION IF NOT EXISTS vectors;")
register_vector(conn)
conn.execute(
"CREATE TABLE documents (id SERIAL PRIMARY KEY, text TEXT NOT NULL, embedding vector(3) NOT NULL);"
"CREATE TABLE documents (id SERIAL PRIMARY KEY, text TEXT NOT NULL, embedding vector(3) NOT NULL);",
)
conn.commit()
try:
@ -39,7 +39,8 @@ with psycopg.connect(URL) as conn:
# Select the row "hello pgvecto.rs"
cur = conn.execute(
"SELECT * FROM documents WHERE text = %s;", ("hello pgvecto.rs",)
"SELECT * FROM documents WHERE text = %s;",
("hello pgvecto.rs",),
)
target = cur.fetchone()[2]

View File

@ -5,7 +5,7 @@ from openai import OpenAI
from pgvecto_rs.sdk import PGVectoRs, Record, filters
URL = "postgresql+psycopg://{username}:{password}@{host}:{port}/{db_name}".format(
port=os.getenv("DB_PORT", 5432),
port=os.getenv("DB_PORT", "5432"),
host=os.getenv("DB_HOST", "localhost"),
username=os.getenv("DB_USER", "postgres"),
password=os.getenv("DB_PASS", "mysecretpassword"),
@ -43,7 +43,8 @@ try:
# Query (With a filter from the filters module)
print("#################### First Query ####################")
for record, dis in client.search(
target, filter=filters.meta_contains({"src": "one"})
target,
filter=filters.meta_contains({"src": "one"}),
):
print(f"DISTANCE SCORE: {dis}")
print(record)
@ -51,7 +52,8 @@ try:
# Another Query (Equivalent to the first one, but with a lambda filter written by hand)
print("#################### Second Query ####################")
for record, dis in client.search(
target, filter=lambda r: r.meta.contains({"src": "one"})
target,
filter=lambda r: r.meta.contains({"src": "one"}),
):
print(f"DISTANCE SCORE: {dis}")
print(record)

View File

@ -7,7 +7,7 @@ from sqlalchemy.orm import DeclarativeBase, Mapped, Session, mapped_column
from pgvecto_rs.sqlalchemy import Vector
URL = "postgresql+psycopg://{username}:{password}@{host}:{port}/{db_name}".format(
port=os.getenv("DB_PORT", 5432),
port=os.getenv("DB_PORT", "5432"),
host=os.getenv("DB_HOST", "localhost"),
username=os.getenv("DB_USER", "postgres"),
password=os.getenv("DB_PASS", "mysecretpassword"),
@ -53,7 +53,7 @@ with Session(engine) as session:
stmt = select(
Document.text,
Document.embedding.squared_euclidean_distance(target.embedding).label(
"distance"
"distance",
),
).order_by("distance")
for doc in session.execute(stmt):

View File

@ -5,7 +5,7 @@
groups = ["default", "lint", "psycopg3", "sdk", "sqlalchemy", "test"]
strategy = ["cross_platform", "direct_minimal_versions"]
lock_version = "4.4"
content_hash = "sha256:f65e8d98636d7592453753c6ba60e73a6912f0e21205dff0e0f92bb148befce3"
content_hash = "sha256:a7e2c999c870cd4bac136205366ffe7d592d8ca043fd0946c3df631f7326282f"
[[package]]
name = "annotated-types"
@ -506,27 +506,27 @@ files = [
[[package]]
name = "ruff"
version = "0.1.1"
version = "0.1.5"
requires_python = ">=3.7"
summary = "An extremely fast Python linter, written in Rust."
summary = "An extremely fast Python linter and code formatter, written in Rust."
files = [
{file = "ruff-0.1.1-py3-none-macosx_10_7_x86_64.whl", hash = "sha256:b7cdc893aef23ccc14c54bd79a8109a82a2c527e11d030b62201d86f6c2b81c5"},
{file = "ruff-0.1.1-py3-none-macosx_10_9_x86_64.macosx_11_0_arm64.macosx_10_9_universal2.whl", hash = "sha256:620d4b34302538dbd8bbbe8fdb8e8f98d72d29bd47e972e2b59ce6c1e8862257"},
{file = "ruff-0.1.1-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2a909d3930afdbc2e9fd893b0034479e90e7981791879aab50ce3d9f55205bd6"},
{file = "ruff-0.1.1-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:3305d1cb4eb8ff6d3e63a48d1659d20aab43b49fe987b3ca4900528342367145"},
{file = "ruff-0.1.1-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c34ae501d0ec71acf19ee5d4d889e379863dcc4b796bf8ce2934a9357dc31db7"},
{file = "ruff-0.1.1-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:6aa7e63c3852cf8fe62698aef31e563e97143a4b801b57f920012d0e07049a8d"},
{file = "ruff-0.1.1-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2d68367d1379a6b47e61bc9de144a47bcdb1aad7903bbf256e4c3d31f11a87ae"},
{file = "ruff-0.1.1-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:bc11955f6ce3398d2afe81ad7e49d0ebf0a581d8bcb27b8c300281737735e3a3"},
{file = "ruff-0.1.1-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cbbd8eead88ea83a250499074e2a8e9d80975f0b324b1e2e679e4594da318c25"},
{file = "ruff-0.1.1-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:f4780e2bb52f3863a565ec3f699319d3493b83ff95ebbb4993e59c62aaf6e75e"},
{file = "ruff-0.1.1-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:8f5b24daddf35b6c207619301170cae5d2699955829cda77b6ce1e5fc69340df"},
{file = "ruff-0.1.1-py3-none-musllinux_1_2_i686.whl", hash = "sha256:d3f9ac658ba29e07b95c80fa742b059a55aefffa8b1e078bc3c08768bdd4b11a"},
{file = "ruff-0.1.1-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:3521bf910104bf781e6753282282acc145cbe3eff79a1ce6b920404cd756075a"},
{file = "ruff-0.1.1-py3-none-win32.whl", hash = "sha256:ba3208543ab91d3e4032db2652dcb6c22a25787b85b8dc3aeff084afdc612e5c"},
{file = "ruff-0.1.1-py3-none-win_amd64.whl", hash = "sha256:3ff3006c97d9dc396b87fb46bb65818e614ad0181f059322df82bbfe6944e264"},
{file = "ruff-0.1.1-py3-none-win_arm64.whl", hash = "sha256:e140bd717c49164c8feb4f65c644046fe929c46f42493672853e3213d7bdbce2"},
{file = "ruff-0.1.1.tar.gz", hash = "sha256:c90461ae4abec261609e5ea436de4a4b5f2822921cf04c16d2cc9327182dbbcc"},
{file = "ruff-0.1.5-py3-none-macosx_10_7_x86_64.whl", hash = "sha256:32d47fc69261c21a4c48916f16ca272bf2f273eb635d91c65d5cd548bf1f3d96"},
{file = "ruff-0.1.5-py3-none-macosx_10_9_x86_64.macosx_11_0_arm64.macosx_10_9_universal2.whl", hash = "sha256:171276c1df6c07fa0597fb946139ced1c2978f4f0b8254f201281729981f3c17"},
{file = "ruff-0.1.5-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:17ef33cd0bb7316ca65649fc748acc1406dfa4da96a3d0cde6d52f2e866c7b39"},
{file = "ruff-0.1.5-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:b2c205827b3f8c13b4a432e9585750b93fd907986fe1aec62b2a02cf4401eee6"},
{file = "ruff-0.1.5-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:bb408e3a2ad8f6881d0f2e7ad70cddb3ed9f200eb3517a91a245bbe27101d379"},
{file = "ruff-0.1.5-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:f20dc5e5905ddb407060ca27267c7174f532375c08076d1a953cf7bb016f5a24"},
{file = "ruff-0.1.5-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:aafb9d2b671ed934998e881e2c0f5845a4295e84e719359c71c39a5363cccc91"},
{file = "ruff-0.1.5-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a4894dddb476597a0ba4473d72a23151b8b3b0b5f958f2cf4d3f1c572cdb7af7"},
{file = "ruff-0.1.5-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a00a7ec893f665ed60008c70fe9eeb58d210e6b4d83ec6654a9904871f982a2a"},
{file = "ruff-0.1.5-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:a8c11206b47f283cbda399a654fd0178d7a389e631f19f51da15cbe631480c5b"},
{file = "ruff-0.1.5-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:fa29e67b3284b9a79b1a85ee66e293a94ac6b7bb068b307a8a373c3d343aa8ec"},
{file = "ruff-0.1.5-py3-none-musllinux_1_2_i686.whl", hash = "sha256:9b97fd6da44d6cceb188147b68db69a5741fbc736465b5cea3928fdac0bc1aeb"},
{file = "ruff-0.1.5-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:721f4b9d3b4161df8dc9f09aa8562e39d14e55a4dbaa451a8e55bdc9590e20f4"},
{file = "ruff-0.1.5-py3-none-win32.whl", hash = "sha256:f80c73bba6bc69e4fdc73b3991db0b546ce641bdcd5b07210b8ad6f64c79f1ab"},
{file = "ruff-0.1.5-py3-none-win_amd64.whl", hash = "sha256:c21fe20ee7d76206d290a76271c1af7a5096bc4c73ab9383ed2ad35f852a0087"},
{file = "ruff-0.1.5-py3-none-win_arm64.whl", hash = "sha256:82bfcb9927e88c1ed50f49ac6c9728dab3ea451212693fe40d08d314663e412f"},
{file = "ruff-0.1.5.tar.gz", hash = "sha256:5cbec0ef2ae1748fb194f420fb03fb2c25c3258c86129af7172ff8f198f125ab"},
]
[[package]]

View File

@ -3,8 +3,8 @@ name = "pgvecto-rs"
version = "0.1.3"
description = "Python binding for pgvecto.rs"
authors = [
{ name = "TensorChord", email = "envd-maintainers@tensorchord.ai" },
{ name = "盐粒 Yanli", email = "mail@yanli.one" },
{ name = "TensorChord", email = "envd-maintainers@tensorchord.ai" },
{ name = "盐粒 Yanli", email = "mail@yanli.one" },
]
dependencies = [
"numpy>=1.23",
@ -23,15 +23,15 @@ classifiers = [
[build-system]
build-backend = "pdm.backend"
requires = [
requires = [
"pdm-backend",
]
[project.optional-dependencies]
psycopg3 = [
psycopg3 = [
"psycopg[binary]>=3.1.12",
]
sdk = [
sdk = [
"openai>=1.2.2",
"pgvecto_rs[sqlalchemy]",
]
@ -40,19 +40,34 @@ sqlalchemy = [
"SQLAlchemy>=2.0.23",
]
[tool.pdm.dev-dependencies]
lint = ["ruff>=0.1.1"]
lint = ["ruff>=0.1.5"]
test = ["pytest>=7.4.3"]
[tool.pdm.scripts]
test = "pytest tests/"
test = "pytest tests/"
format = "ruff format ."
fix = "ruff --fix ."
check = { composite = ["ruff format . --check", "ruff ."] }
fix = "ruff --fix ."
check = { composite = ["ruff format . --check", "ruff ."] }
[tool.ruff]
select = ["E", "F", "I", "TID"]
ignore = ["E731", "E501"]
src = ["src"]
select = [
"E", #https://docs.astral.sh/ruff/rules/#error-e
"F", #https://docs.astral.sh/ruff/rules/#pyflakes-f
"I", #https://docs.astral.sh/ruff/rules/#isort-i
"TID", #https://docs.astral.sh/ruff/rules/#flake8-tidy-imports-tid
"S", #https://docs.astral.sh/ruff/rules/#flake8-bandit-s
"B", #https://docs.astral.sh/ruff/rules/#flake8-bugbear-b
"SIM", #https://docs.astral.sh/ruff/rules/#flake8-simplify-sim
"N", #https://docs.astral.sh/ruff/rules/#pep8-naming-n
"PT", #https://docs.astral.sh/ruff/rules/#flake8-pytest-style-pt
"TRY", #https://docs.astral.sh/ruff/rules/#tryceratops-try
"FLY", #https://docs.astral.sh/ruff/rules/#flynt-fly
"PL", #https://docs.astral.sh/ruff/rules/#pylint-pl
"NPY", #https://docs.astral.sh/ruff/rules/#numpy-specific-rules-npy
"RUF", #https://docs.astral.sh/ruff/rules/#ruff-specific-rules-ruf
]
ignore = ["S101", "E731", "E501"]
src = ["src"]
[tool.pytest.ini_options]
addopts = "-r aR"

View File

@ -0,0 +1,25 @@
import numpy as np
class PGVectoRsError(ValueError):
pass
class NDArrayDimensionError(PGVectoRsError):
def __init__(self, dim: int) -> None:
super().__init__(f"ndarray must be 1D for vector, got {dim}D")
class NDArrayDtypeError(PGVectoRsError):
def __init__(self, dtype: np.dtype) -> None:
super().__init__(f"ndarray data type must be numeric for vector, got {dtype}")
class BuiltinListTypeError(PGVectoRsError):
def __init__(self) -> None:
super().__init__("list data type must be numeric for vector")
class VectorDimensionError(PGVectoRsError):
def __init__(self, dim: int) -> None:
super().__init__(f"vector dimension must be > 0, got {dim}")

View File

@ -37,7 +37,7 @@ async def register_vector_async(context: Connection):
def register_vector_info(context: Connection, info: TypeInfo):
if info is None:
raise ProgrammingError("vector type not found in the database")
raise ProgrammingError(info="vector type not found in the database")
info.register(context)
class VectorTextDumper(VectorDumper):

View File

@ -28,6 +28,7 @@ class PGVectoRs:
"""Connect to an existing table or create a new empty one.
Args:
----
db_url (str): url to the database.
table_name (str): name of the table.
dimension (int): dimension of the embeddings.
@ -36,7 +37,8 @@ class PGVectoRs:
class _Table(RecordORM):
__tablename__ = f"collection_{collection_name}"
id: Mapped[UUID] = mapped_column(
postgresql.UUID(as_uuid=True), primary_key=True
postgresql.UUID(as_uuid=True),
primary_key=True,
)
text: Mapped[str] = mapped_column(String)
meta: Mapped[dict] = mapped_column(postgresql.JSONB)
@ -59,7 +61,7 @@ class PGVectoRs:
text=record.text,
meta=record.meta,
embedding=record.embedding,
)
),
)
session.commit()
@ -73,6 +75,7 @@ class PGVectoRs:
"""Search for the nearest records.
Args:
----
embedding : Target embedding.
distance_op : Distance op.
top_k : Max records to return. Defaults to 4.
@ -80,6 +83,7 @@ class PGVectoRs:
order_by_dis : Order by distance. Defaults to True.
Returns:
-------
List of records and coresponding distances.
"""
@ -88,7 +92,7 @@ class PGVectoRs:
select(
self._table,
self._table.embedding.op(distance_op, return_type=Float)(
embedding
embedding,
).label("distance"),
)
.limit(top_k)

View File

@ -1,5 +1,6 @@
import sqlalchemy.types as types
from sqlalchemy import types
from pgvecto_rs.errors import VectorDimensionError
from pgvecto_rs.utils import serializer
@ -8,13 +9,13 @@ class Vector(types.UserDefinedType):
def __init__(self, dim):
if dim < 0:
raise ValueError("negative dim is not allowed")
raise VectorDimensionError(dim)
self.dim = dim
def get_col_spec(self, **kw):
if self.dim is None or self.dim == 0:
return "VECTOR"
return "VECTOR({})".format(self.dim)
return f"VECTOR({self.dim})"
def bind_processor(self, dialect):
def _processor(value):
@ -28,7 +29,7 @@ class Vector(types.UserDefinedType):
return _processor
class comparator_factory(types.UserDefinedType.Comparator):
class comparator_factory(types.UserDefinedType.Comparator): # noqa: N801
def squared_euclidean_distance(self, other):
return self.op("<->", return_type=types.Float)(other)

View File

@ -2,6 +2,12 @@ from functools import wraps
import numpy as np
from pgvecto_rs.errors import (
BuiltinListTypeError,
NDArrayDimensionError,
NDArrayDtypeError,
)
def ignore_none(func):
@wraps(func)
@ -26,9 +32,9 @@ def validate_ndarray(func):
def _func(value: np.ndarray, *args, **kwargs):
if isinstance(value, np.ndarray):
if value.ndim != 1:
raise ValueError("ndarray must be 1D for vector")
raise NDArrayDimensionError(value.ndim)
if not np.issubdtype(value.dtype, np.number):
raise ValueError("ndarray data type must be numeric for vector")
raise NDArrayDtypeError(value.dtype)
return func(value, *args, **kwargs)
return _func
@ -41,7 +47,7 @@ def validate_builtin_list(func):
def _func(value: list, *args, **kwargs):
if isinstance(value, list):
if not all(isinstance(x, (int, float)) for x in value):
raise ValueError("list data type must be numeric for vector")
raise BuiltinListTypeError()
value = np.array(value, dtype=np.float32)
return func(value, *args, **kwargs)

View File

@ -3,7 +3,7 @@ import os
import numpy as np
import toml
PORT = os.getenv("DB_PORT", 5432)
PORT = os.getenv("DB_PORT", "5432")
HOST = os.getenv("DB_HOST", "localhost")
USER = os.getenv("DB_USER", "postgres")
PASS = os.getenv("DB_PASS", "mysecretpassword")
@ -11,13 +11,7 @@ DB_NAME = os.getenv("DB_NAME", "postgres")
# Run tests with shell:
# DB_HOST=localhost DB_USER=postgres DB_PASS=password DB_NAME=postgres python3 -m pytest bindings/python/tests/
URL = "postgresql://{username}:{password}@{host}:{port}/{db_name}".format(
port=PORT,
host=HOST,
username=USER,
password=PASS,
db_name=DB_NAME,
)
URL = f"postgresql://{USER}:{PASS}@{HOST}:{PORT}/{DB_NAME}"
# ==== test_create_index ====
@ -27,13 +21,13 @@ TOML_SETTINGS = {
{
"capacity": 2097152,
"algorithm": {"flat": {}},
}
},
),
"hnsw": toml.dumps(
{
"capacity": 2097152,
"algorithm": {"hnsw": {}},
}
},
),
}

View File

@ -25,7 +25,7 @@ def conn():
register_vector(conn)
conn.execute("DROP TABLE IF EXISTS tb_test_item;")
conn.execute(
"CREATE TABLE tb_test_item (id bigserial PRIMARY KEY, embedding vector(3) NOT NULL);"
"CREATE TABLE tb_test_item (id bigserial PRIMARY KEY, embedding vector(3) NOT NULL);",
)
conn.commit()
try:
@ -35,7 +35,7 @@ def conn():
conn.commit()
@pytest.mark.parametrize("index_name,index_setting", TOML_SETTINGS.items())
@pytest.mark.parametrize(("index_name", "index_setting"), TOML_SETTINGS.items())
def test_create_index(conn: Connection, index_name: str, index_setting: str):
stat = sql.SQL(
"CREATE INDEX {} ON tb_test_item USING vectors (embedding l2_ops) WITH (options={});",
@ -68,7 +68,8 @@ def test_create_index(conn: Connection, index_name: str, index_setting: str):
def test_insert(conn: Connection):
with conn.cursor() as cur:
cur.executemany(
"INSERT INTO tb_test_item (embedding) VALUES (%s);", [(e,) for e in VECTORS]
"INSERT INTO tb_test_item (embedding) VALUES (%s);",
[(e,) for e in VECTORS],
)
cur.execute("SELECT * FROM tb_test_item;")
conn.commit()
@ -80,7 +81,8 @@ def test_insert(conn: Connection):
def test_squared_euclidean_distance(conn: Connection):
cur = conn.execute(
"SELECT embedding <-> %s FROM tb_test_item;", (OP_SQRT_EUCLID_DIS,)
"SELECT embedding <-> %s FROM tb_test_item;",
(OP_SQRT_EUCLID_DIS,),
)
for i, row in enumerate(cur.fetchall()):
assert np.allclose(EXPECTED_SQRT_EUCLID_DIS[i], row[0], atol=1e-10)
@ -88,7 +90,8 @@ def test_squared_euclidean_distance(conn: Connection):
def test_negative_dot_product_distance(conn: Connection):
cur = conn.execute(
"SELECT embedding <#> %s FROM tb_test_item;", (OP_NEG_DOT_PROD_DIS,)
"SELECT embedding <#> %s FROM tb_test_item;",
(OP_NEG_DOT_PROD_DIS,),
)
for i, row in enumerate(cur.fetchall()):
assert np.allclose(EXPECTED_NEG_DOT_PROD_DIS[i], row[0], atol=1e-10)

View File

@ -16,7 +16,7 @@ from tests import (
)
URL = URL.replace("postgresql", "postgresql+psycopg")
mockTexts = {
MockTexts = {
"text0": VECTORS[0],
"text1": VECTORS[1],
"text2": VECTORS[2],
@ -25,9 +25,9 @@ mockTexts = {
class MockEmbedder:
def embed(self, text: str) -> np.ndarray:
if isinstance(mockTexts[text], list):
return np.array(mockTexts[text], dtype=np.float32)
return mockTexts[text]
if isinstance(MockTexts[text], list):
return np.array(MockTexts[text], dtype=np.float32)
return MockTexts[text]
@pytest.fixture(scope="module")
@ -35,10 +35,10 @@ def client():
client = PGVectoRs(db_url=URL, collection_name="empty", dimension=3)
try:
records1 = [
Record.from_text(t, v, {"src": "src1"}) for t, v in mockTexts.items()
Record.from_text(t, v, {"src": "src1"}) for t, v in MockTexts.items()
]
records2 = [
Record.from_text(t, v, {"src": "src2"}) for t, v in mockTexts.items()
Record.from_text(t, v, {"src": "src2"}) for t, v in MockTexts.items()
]
client.insert(records1)
client.insert(records2)
@ -53,7 +53,7 @@ filter_src2: Filter = lambda r: r.meta.contains({"src": "src2"})
@pytest.mark.parametrize("filter", [filter_src1, filter_src2])
@pytest.mark.parametrize(
"dis_op, dis_oprand, dis_expected",
("dis_op", "dis_oprand", "dis_expected"),
zip(
["<->", "<#>", "<=>"],
[OP_SQRT_EUCLID_DIS, OP_NEG_DOT_PROD_DIS, OP_NEG_COS_DIS],
@ -77,7 +77,7 @@ def test_search_filter_and_op(
@pytest.mark.parametrize(
"dis_op, dis_oprand, dis_expected",
("dis_op", "dis_oprand", "dis_expected"),
zip(
["<->", "<#>", "<=>"],
[OP_SQRT_EUCLID_DIS, OP_NEG_DOT_PROD_DIS, OP_NEG_COS_DIS],
@ -92,5 +92,5 @@ def test_search_order_and_limit(
):
dis_expected = dis_expected.copy()
dis_expected.sort()
for i, (rec, dis) in enumerate(client.search(dis_oprand, dis_op, top_k=4)):
for i, (_rec, dis) in enumerate(client.search(dis_oprand, dis_op, top_k=4)):
assert np.allclose(dis, dis_expected[i // 2])

View File

@ -36,11 +36,9 @@ class Document(Base):
@pytest.fixture(scope="module")
def session():
"""
Connect to the test db pointed by the URL. Can check more details
"""Connect to the test db pointed by the URL. Can check more details
in `tests/__init__.py`
"""
engine = create_engine(URL.replace("postgresql", "postgresql+psycopg"))
# ensure that we have installed pgvector.rs extension
@ -63,7 +61,7 @@ def session():
# =================================
@pytest.mark.parametrize("index_name,index_setting", TOML_SETTINGS.items())
@pytest.mark.parametrize(("index_name", "index_setting"), TOML_SETTINGS.items())
def test_create_index(session: Session, index_name: str, index_setting: str):
index = Index(
index_name,
@ -76,15 +74,15 @@ def test_create_index(session: Session, index_name: str, index_setting: str):
session.commit()
@pytest.mark.parametrize("i,e", enumerate(INVALID_VECTORS))
@pytest.mark.parametrize(("i", "e"), enumerate(INVALID_VECTORS))
def test_invalid_insert(session: Session, i: int, e: np.ndarray):
try:
session.execute(insert(Document).values(id=i, embedding=e))
except StatementError:
pass
else:
raise AssertionError(
"failed to raise invalid value error for {}th vector {}".format(i, e),
raise AssertionError( # noqa: TRY003
f"failed to raise invalid value error for {i}th vector {e}",
)
finally:
session.rollback()
@ -110,7 +108,7 @@ def test_squared_euclidean_distance(session: Session):
select(
Document.id,
Document.embedding.squared_euclidean_distance(OP_SQRT_EUCLID_DIS),
)
),
):
(i, res) = row
assert np.allclose(EXPECTED_SQRT_EUCLID_DIS[i], res, atol=1e-10)
@ -121,7 +119,7 @@ def test_negative_dot_product_distance(session: Session):
select(
Document.id,
Document.embedding.negative_dot_product_distance(OP_NEG_DOT_PROD_DIS),
)
),
):
(i, res) = row
assert np.allclose(EXPECTED_NEG_DOT_PROD_DIS[i], res, atol=1e-10)
@ -129,7 +127,9 @@ def test_negative_dot_product_distance(session: Session):
def test_negative_cosine_distance(session: Session):
for row in session.execute(
select(Document.id, Document.embedding.negative_cosine_distance(OP_NEG_COS_DIS))
select(
Document.id, Document.embedding.negative_cosine_distance(OP_NEG_COS_DIS)
),
):
(i, res) = row
assert np.allclose(EXPECTED_NEG_COS_DIS[i], res, atol=1e-10)

View File

@ -4,6 +4,6 @@ env_list = py3{10, 11}
[testenv]
deps = pdm
commands =
pdm sync -d -G lint
pdm sync -d -G :all
pdm run -v check
pdm run -v test