diff --git a/bindings/python/README.md b/bindings/python/README.md index a8dd25d..94fb775 100644 --- a/bindings/python/README.md +++ b/bindings/python/README.md @@ -106,6 +106,7 @@ pdm sync Run lint: ```bash pdm run format +pdm run fix pdm run check ``` diff --git a/bindings/python/examples/psycopg_example.py b/bindings/python/examples/psycopg_example.py index 148c7f6..542fc08 100644 --- a/bindings/python/examples/psycopg_example.py +++ b/bindings/python/examples/psycopg_example.py @@ -6,7 +6,7 @@ import psycopg from pgvecto_rs.psycopg import register_vector URL = "postgresql://{username}:{password}@{host}:{port}/{db_name}".format( - port=os.getenv("DB_PORT", 5432), + port=os.getenv("DB_PORT", "5432"), host=os.getenv("DB_HOST", "localhost"), username=os.getenv("DB_USER", "postgres"), password=os.getenv("DB_PASS", "mysecretpassword"), @@ -18,7 +18,7 @@ with psycopg.connect(URL) as conn: conn.execute("CREATE EXTENSION IF NOT EXISTS vectors;") register_vector(conn) conn.execute( - "CREATE TABLE documents (id SERIAL PRIMARY KEY, text TEXT NOT NULL, embedding vector(3) NOT NULL);" + "CREATE TABLE documents (id SERIAL PRIMARY KEY, text TEXT NOT NULL, embedding vector(3) NOT NULL);", ) conn.commit() try: @@ -39,7 +39,8 @@ with psycopg.connect(URL) as conn: # Select the row "hello pgvecto.rs" cur = conn.execute( - "SELECT * FROM documents WHERE text = %s;", ("hello pgvecto.rs",) + "SELECT * FROM documents WHERE text = %s;", + ("hello pgvecto.rs",), ) target = cur.fetchone()[2] diff --git a/bindings/python/examples/sdk_example.py b/bindings/python/examples/sdk_example.py index 8791eff..d4cff4a 100644 --- a/bindings/python/examples/sdk_example.py +++ b/bindings/python/examples/sdk_example.py @@ -5,7 +5,7 @@ from openai import OpenAI from pgvecto_rs.sdk import PGVectoRs, Record, filters URL = "postgresql+psycopg://{username}:{password}@{host}:{port}/{db_name}".format( - port=os.getenv("DB_PORT", 5432), + port=os.getenv("DB_PORT", "5432"), host=os.getenv("DB_HOST", "localhost"), username=os.getenv("DB_USER", "postgres"), password=os.getenv("DB_PASS", "mysecretpassword"), @@ -43,7 +43,8 @@ try: # Query (With a filter from the filters module) print("#################### First Query ####################") for record, dis in client.search( - target, filter=filters.meta_contains({"src": "one"}) + target, + filter=filters.meta_contains({"src": "one"}), ): print(f"DISTANCE SCORE: {dis}") print(record) @@ -51,7 +52,8 @@ try: # Another Query (Equivalent to the first one, but with a lambda filter written by hand) print("#################### Second Query ####################") for record, dis in client.search( - target, filter=lambda r: r.meta.contains({"src": "one"}) + target, + filter=lambda r: r.meta.contains({"src": "one"}), ): print(f"DISTANCE SCORE: {dis}") print(record) diff --git a/bindings/python/examples/sqlalchemy_example.py b/bindings/python/examples/sqlalchemy_example.py index d8e6c26..2894367 100644 --- a/bindings/python/examples/sqlalchemy_example.py +++ b/bindings/python/examples/sqlalchemy_example.py @@ -7,7 +7,7 @@ from sqlalchemy.orm import DeclarativeBase, Mapped, Session, mapped_column from pgvecto_rs.sqlalchemy import Vector URL = "postgresql+psycopg://{username}:{password}@{host}:{port}/{db_name}".format( - port=os.getenv("DB_PORT", 5432), + port=os.getenv("DB_PORT", "5432"), host=os.getenv("DB_HOST", "localhost"), username=os.getenv("DB_USER", "postgres"), password=os.getenv("DB_PASS", "mysecretpassword"), @@ -53,7 +53,7 @@ with Session(engine) as session: stmt = select( Document.text, Document.embedding.squared_euclidean_distance(target.embedding).label( - "distance" + "distance", ), ).order_by("distance") for doc in session.execute(stmt): diff --git a/bindings/python/pdm.lock b/bindings/python/pdm.lock index 4c5b0d4..568ccb4 100644 --- a/bindings/python/pdm.lock +++ b/bindings/python/pdm.lock @@ -5,7 +5,7 @@ groups = ["default", "lint", "psycopg3", "sdk", "sqlalchemy", "test"] strategy = ["cross_platform", "direct_minimal_versions"] lock_version = "4.4" -content_hash = "sha256:f65e8d98636d7592453753c6ba60e73a6912f0e21205dff0e0f92bb148befce3" +content_hash = "sha256:a7e2c999c870cd4bac136205366ffe7d592d8ca043fd0946c3df631f7326282f" [[package]] name = "annotated-types" @@ -506,27 +506,27 @@ files = [ [[package]] name = "ruff" -version = "0.1.1" +version = "0.1.5" requires_python = ">=3.7" -summary = "An extremely fast Python linter, written in Rust." +summary = "An extremely fast Python linter and code formatter, written in Rust." files = [ - {file = "ruff-0.1.1-py3-none-macosx_10_7_x86_64.whl", hash = "sha256:b7cdc893aef23ccc14c54bd79a8109a82a2c527e11d030b62201d86f6c2b81c5"}, - {file = "ruff-0.1.1-py3-none-macosx_10_9_x86_64.macosx_11_0_arm64.macosx_10_9_universal2.whl", hash = "sha256:620d4b34302538dbd8bbbe8fdb8e8f98d72d29bd47e972e2b59ce6c1e8862257"}, - {file = "ruff-0.1.1-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2a909d3930afdbc2e9fd893b0034479e90e7981791879aab50ce3d9f55205bd6"}, - {file = "ruff-0.1.1-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:3305d1cb4eb8ff6d3e63a48d1659d20aab43b49fe987b3ca4900528342367145"}, - {file = "ruff-0.1.1-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c34ae501d0ec71acf19ee5d4d889e379863dcc4b796bf8ce2934a9357dc31db7"}, - {file = "ruff-0.1.1-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:6aa7e63c3852cf8fe62698aef31e563e97143a4b801b57f920012d0e07049a8d"}, - {file = "ruff-0.1.1-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2d68367d1379a6b47e61bc9de144a47bcdb1aad7903bbf256e4c3d31f11a87ae"}, - {file = "ruff-0.1.1-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:bc11955f6ce3398d2afe81ad7e49d0ebf0a581d8bcb27b8c300281737735e3a3"}, - {file = "ruff-0.1.1-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cbbd8eead88ea83a250499074e2a8e9d80975f0b324b1e2e679e4594da318c25"}, - {file = "ruff-0.1.1-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:f4780e2bb52f3863a565ec3f699319d3493b83ff95ebbb4993e59c62aaf6e75e"}, - {file = "ruff-0.1.1-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:8f5b24daddf35b6c207619301170cae5d2699955829cda77b6ce1e5fc69340df"}, - {file = "ruff-0.1.1-py3-none-musllinux_1_2_i686.whl", hash = "sha256:d3f9ac658ba29e07b95c80fa742b059a55aefffa8b1e078bc3c08768bdd4b11a"}, - {file = "ruff-0.1.1-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:3521bf910104bf781e6753282282acc145cbe3eff79a1ce6b920404cd756075a"}, - {file = "ruff-0.1.1-py3-none-win32.whl", hash = "sha256:ba3208543ab91d3e4032db2652dcb6c22a25787b85b8dc3aeff084afdc612e5c"}, - {file = "ruff-0.1.1-py3-none-win_amd64.whl", hash = "sha256:3ff3006c97d9dc396b87fb46bb65818e614ad0181f059322df82bbfe6944e264"}, - {file = "ruff-0.1.1-py3-none-win_arm64.whl", hash = "sha256:e140bd717c49164c8feb4f65c644046fe929c46f42493672853e3213d7bdbce2"}, - {file = "ruff-0.1.1.tar.gz", hash = "sha256:c90461ae4abec261609e5ea436de4a4b5f2822921cf04c16d2cc9327182dbbcc"}, + {file = "ruff-0.1.5-py3-none-macosx_10_7_x86_64.whl", hash = "sha256:32d47fc69261c21a4c48916f16ca272bf2f273eb635d91c65d5cd548bf1f3d96"}, + {file = "ruff-0.1.5-py3-none-macosx_10_9_x86_64.macosx_11_0_arm64.macosx_10_9_universal2.whl", hash = "sha256:171276c1df6c07fa0597fb946139ced1c2978f4f0b8254f201281729981f3c17"}, + {file = "ruff-0.1.5-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:17ef33cd0bb7316ca65649fc748acc1406dfa4da96a3d0cde6d52f2e866c7b39"}, + {file = "ruff-0.1.5-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:b2c205827b3f8c13b4a432e9585750b93fd907986fe1aec62b2a02cf4401eee6"}, + {file = "ruff-0.1.5-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:bb408e3a2ad8f6881d0f2e7ad70cddb3ed9f200eb3517a91a245bbe27101d379"}, + {file = "ruff-0.1.5-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:f20dc5e5905ddb407060ca27267c7174f532375c08076d1a953cf7bb016f5a24"}, + {file = "ruff-0.1.5-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:aafb9d2b671ed934998e881e2c0f5845a4295e84e719359c71c39a5363cccc91"}, + {file = "ruff-0.1.5-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a4894dddb476597a0ba4473d72a23151b8b3b0b5f958f2cf4d3f1c572cdb7af7"}, + {file = "ruff-0.1.5-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a00a7ec893f665ed60008c70fe9eeb58d210e6b4d83ec6654a9904871f982a2a"}, + {file = "ruff-0.1.5-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:a8c11206b47f283cbda399a654fd0178d7a389e631f19f51da15cbe631480c5b"}, + {file = "ruff-0.1.5-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:fa29e67b3284b9a79b1a85ee66e293a94ac6b7bb068b307a8a373c3d343aa8ec"}, + {file = "ruff-0.1.5-py3-none-musllinux_1_2_i686.whl", hash = "sha256:9b97fd6da44d6cceb188147b68db69a5741fbc736465b5cea3928fdac0bc1aeb"}, + {file = "ruff-0.1.5-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:721f4b9d3b4161df8dc9f09aa8562e39d14e55a4dbaa451a8e55bdc9590e20f4"}, + {file = "ruff-0.1.5-py3-none-win32.whl", hash = "sha256:f80c73bba6bc69e4fdc73b3991db0b546ce641bdcd5b07210b8ad6f64c79f1ab"}, + {file = "ruff-0.1.5-py3-none-win_amd64.whl", hash = "sha256:c21fe20ee7d76206d290a76271c1af7a5096bc4c73ab9383ed2ad35f852a0087"}, + {file = "ruff-0.1.5-py3-none-win_arm64.whl", hash = "sha256:82bfcb9927e88c1ed50f49ac6c9728dab3ea451212693fe40d08d314663e412f"}, + {file = "ruff-0.1.5.tar.gz", hash = "sha256:5cbec0ef2ae1748fb194f420fb03fb2c25c3258c86129af7172ff8f198f125ab"}, ] [[package]] diff --git a/bindings/python/pyproject.toml b/bindings/python/pyproject.toml index dff01ce..ca92105 100644 --- a/bindings/python/pyproject.toml +++ b/bindings/python/pyproject.toml @@ -3,8 +3,8 @@ name = "pgvecto-rs" version = "0.1.3" description = "Python binding for pgvecto.rs" authors = [ - { name = "TensorChord", email = "envd-maintainers@tensorchord.ai" }, - { name = "盐粒 Yanli", email = "mail@yanli.one" }, +{ name = "TensorChord", email = "envd-maintainers@tensorchord.ai" }, +{ name = "盐粒 Yanli", email = "mail@yanli.one" }, ] dependencies = [ "numpy>=1.23", @@ -23,15 +23,15 @@ classifiers = [ [build-system] build-backend = "pdm.backend" -requires = [ +requires = [ "pdm-backend", ] [project.optional-dependencies] -psycopg3 = [ +psycopg3 = [ "psycopg[binary]>=3.1.12", ] -sdk = [ +sdk = [ "openai>=1.2.2", "pgvecto_rs[sqlalchemy]", ] @@ -40,19 +40,34 @@ sqlalchemy = [ "SQLAlchemy>=2.0.23", ] [tool.pdm.dev-dependencies] -lint = ["ruff>=0.1.1"] +lint = ["ruff>=0.1.5"] test = ["pytest>=7.4.3"] [tool.pdm.scripts] -test = "pytest tests/" +test = "pytest tests/" format = "ruff format ." -fix = "ruff --fix ." -check = { composite = ["ruff format . --check", "ruff ."] } +fix = "ruff --fix ." +check = { composite = ["ruff format . --check", "ruff ."] } [tool.ruff] -select = ["E", "F", "I", "TID"] -ignore = ["E731", "E501"] -src = ["src"] +select = [ + "E", #https://docs.astral.sh/ruff/rules/#error-e + "F", #https://docs.astral.sh/ruff/rules/#pyflakes-f + "I", #https://docs.astral.sh/ruff/rules/#isort-i + "TID", #https://docs.astral.sh/ruff/rules/#flake8-tidy-imports-tid + "S", #https://docs.astral.sh/ruff/rules/#flake8-bandit-s + "B", #https://docs.astral.sh/ruff/rules/#flake8-bugbear-b + "SIM", #https://docs.astral.sh/ruff/rules/#flake8-simplify-sim + "N", #https://docs.astral.sh/ruff/rules/#pep8-naming-n + "PT", #https://docs.astral.sh/ruff/rules/#flake8-pytest-style-pt + "TRY", #https://docs.astral.sh/ruff/rules/#tryceratops-try + "FLY", #https://docs.astral.sh/ruff/rules/#flynt-fly + "PL", #https://docs.astral.sh/ruff/rules/#pylint-pl + "NPY", #https://docs.astral.sh/ruff/rules/#numpy-specific-rules-npy + "RUF", #https://docs.astral.sh/ruff/rules/#ruff-specific-rules-ruf +] +ignore = ["S101", "E731", "E501"] +src = ["src"] [tool.pytest.ini_options] addopts = "-r aR" diff --git a/bindings/python/src/pgvecto_rs/errors.py b/bindings/python/src/pgvecto_rs/errors.py new file mode 100644 index 0000000..7487182 --- /dev/null +++ b/bindings/python/src/pgvecto_rs/errors.py @@ -0,0 +1,25 @@ +import numpy as np + + +class PGVectoRsError(ValueError): + pass + + +class NDArrayDimensionError(PGVectoRsError): + def __init__(self, dim: int) -> None: + super().__init__(f"ndarray must be 1D for vector, got {dim}D") + + +class NDArrayDtypeError(PGVectoRsError): + def __init__(self, dtype: np.dtype) -> None: + super().__init__(f"ndarray data type must be numeric for vector, got {dtype}") + + +class BuiltinListTypeError(PGVectoRsError): + def __init__(self) -> None: + super().__init__("list data type must be numeric for vector") + + +class VectorDimensionError(PGVectoRsError): + def __init__(self, dim: int) -> None: + super().__init__(f"vector dimension must be > 0, got {dim}") diff --git a/bindings/python/src/pgvecto_rs/psycopg/__init__.py b/bindings/python/src/pgvecto_rs/psycopg/__init__.py index 7298d5c..9f04978 100644 --- a/bindings/python/src/pgvecto_rs/psycopg/__init__.py +++ b/bindings/python/src/pgvecto_rs/psycopg/__init__.py @@ -37,7 +37,7 @@ async def register_vector_async(context: Connection): def register_vector_info(context: Connection, info: TypeInfo): if info is None: - raise ProgrammingError("vector type not found in the database") + raise ProgrammingError(info="vector type not found in the database") info.register(context) class VectorTextDumper(VectorDumper): diff --git a/bindings/python/src/pgvecto_rs/sdk/client.py b/bindings/python/src/pgvecto_rs/sdk/client.py index 3e51da6..17f1f8c 100644 --- a/bindings/python/src/pgvecto_rs/sdk/client.py +++ b/bindings/python/src/pgvecto_rs/sdk/client.py @@ -28,6 +28,7 @@ class PGVectoRs: """Connect to an existing table or create a new empty one. Args: + ---- db_url (str): url to the database. table_name (str): name of the table. dimension (int): dimension of the embeddings. @@ -36,7 +37,8 @@ class PGVectoRs: class _Table(RecordORM): __tablename__ = f"collection_{collection_name}" id: Mapped[UUID] = mapped_column( - postgresql.UUID(as_uuid=True), primary_key=True + postgresql.UUID(as_uuid=True), + primary_key=True, ) text: Mapped[str] = mapped_column(String) meta: Mapped[dict] = mapped_column(postgresql.JSONB) @@ -59,7 +61,7 @@ class PGVectoRs: text=record.text, meta=record.meta, embedding=record.embedding, - ) + ), ) session.commit() @@ -73,6 +75,7 @@ class PGVectoRs: """Search for the nearest records. Args: + ---- embedding : Target embedding. distance_op : Distance op. top_k : Max records to return. Defaults to 4. @@ -80,6 +83,7 @@ class PGVectoRs: order_by_dis : Order by distance. Defaults to True. Returns: + ------- List of records and coresponding distances. """ @@ -88,7 +92,7 @@ class PGVectoRs: select( self._table, self._table.embedding.op(distance_op, return_type=Float)( - embedding + embedding, ).label("distance"), ) .limit(top_k) diff --git a/bindings/python/src/pgvecto_rs/sqlalchemy/__init__.py b/bindings/python/src/pgvecto_rs/sqlalchemy/__init__.py index eae4f7e..83fe58d 100644 --- a/bindings/python/src/pgvecto_rs/sqlalchemy/__init__.py +++ b/bindings/python/src/pgvecto_rs/sqlalchemy/__init__.py @@ -1,5 +1,6 @@ -import sqlalchemy.types as types +from sqlalchemy import types +from pgvecto_rs.errors import VectorDimensionError from pgvecto_rs.utils import serializer @@ -8,13 +9,13 @@ class Vector(types.UserDefinedType): def __init__(self, dim): if dim < 0: - raise ValueError("negative dim is not allowed") + raise VectorDimensionError(dim) self.dim = dim def get_col_spec(self, **kw): if self.dim is None or self.dim == 0: return "VECTOR" - return "VECTOR({})".format(self.dim) + return f"VECTOR({self.dim})" def bind_processor(self, dialect): def _processor(value): @@ -28,7 +29,7 @@ class Vector(types.UserDefinedType): return _processor - class comparator_factory(types.UserDefinedType.Comparator): + class comparator_factory(types.UserDefinedType.Comparator): # noqa: N801 def squared_euclidean_distance(self, other): return self.op("<->", return_type=types.Float)(other) diff --git a/bindings/python/src/pgvecto_rs/utils/decorators.py b/bindings/python/src/pgvecto_rs/utils/decorators.py index c483c46..1e812b1 100644 --- a/bindings/python/src/pgvecto_rs/utils/decorators.py +++ b/bindings/python/src/pgvecto_rs/utils/decorators.py @@ -2,6 +2,12 @@ from functools import wraps import numpy as np +from pgvecto_rs.errors import ( + BuiltinListTypeError, + NDArrayDimensionError, + NDArrayDtypeError, +) + def ignore_none(func): @wraps(func) @@ -26,9 +32,9 @@ def validate_ndarray(func): def _func(value: np.ndarray, *args, **kwargs): if isinstance(value, np.ndarray): if value.ndim != 1: - raise ValueError("ndarray must be 1D for vector") + raise NDArrayDimensionError(value.ndim) if not np.issubdtype(value.dtype, np.number): - raise ValueError("ndarray data type must be numeric for vector") + raise NDArrayDtypeError(value.dtype) return func(value, *args, **kwargs) return _func @@ -41,7 +47,7 @@ def validate_builtin_list(func): def _func(value: list, *args, **kwargs): if isinstance(value, list): if not all(isinstance(x, (int, float)) for x in value): - raise ValueError("list data type must be numeric for vector") + raise BuiltinListTypeError() value = np.array(value, dtype=np.float32) return func(value, *args, **kwargs) diff --git a/bindings/python/tests/__init__.py b/bindings/python/tests/__init__.py index 17e682a..1e28cc7 100644 --- a/bindings/python/tests/__init__.py +++ b/bindings/python/tests/__init__.py @@ -3,7 +3,7 @@ import os import numpy as np import toml -PORT = os.getenv("DB_PORT", 5432) +PORT = os.getenv("DB_PORT", "5432") HOST = os.getenv("DB_HOST", "localhost") USER = os.getenv("DB_USER", "postgres") PASS = os.getenv("DB_PASS", "mysecretpassword") @@ -11,13 +11,7 @@ DB_NAME = os.getenv("DB_NAME", "postgres") # Run tests with shell: # DB_HOST=localhost DB_USER=postgres DB_PASS=password DB_NAME=postgres python3 -m pytest bindings/python/tests/ -URL = "postgresql://{username}:{password}@{host}:{port}/{db_name}".format( - port=PORT, - host=HOST, - username=USER, - password=PASS, - db_name=DB_NAME, -) +URL = f"postgresql://{USER}:{PASS}@{HOST}:{PORT}/{DB_NAME}" # ==== test_create_index ==== @@ -27,13 +21,13 @@ TOML_SETTINGS = { { "capacity": 2097152, "algorithm": {"flat": {}}, - } + }, ), "hnsw": toml.dumps( { "capacity": 2097152, "algorithm": {"hnsw": {}}, - } + }, ), } diff --git a/bindings/python/tests/test_psycopg.py b/bindings/python/tests/test_psycopg.py index 68b3db0..f79b4e2 100644 --- a/bindings/python/tests/test_psycopg.py +++ b/bindings/python/tests/test_psycopg.py @@ -25,7 +25,7 @@ def conn(): register_vector(conn) conn.execute("DROP TABLE IF EXISTS tb_test_item;") conn.execute( - "CREATE TABLE tb_test_item (id bigserial PRIMARY KEY, embedding vector(3) NOT NULL);" + "CREATE TABLE tb_test_item (id bigserial PRIMARY KEY, embedding vector(3) NOT NULL);", ) conn.commit() try: @@ -35,7 +35,7 @@ def conn(): conn.commit() -@pytest.mark.parametrize("index_name,index_setting", TOML_SETTINGS.items()) +@pytest.mark.parametrize(("index_name", "index_setting"), TOML_SETTINGS.items()) def test_create_index(conn: Connection, index_name: str, index_setting: str): stat = sql.SQL( "CREATE INDEX {} ON tb_test_item USING vectors (embedding l2_ops) WITH (options={});", @@ -68,7 +68,8 @@ def test_create_index(conn: Connection, index_name: str, index_setting: str): def test_insert(conn: Connection): with conn.cursor() as cur: cur.executemany( - "INSERT INTO tb_test_item (embedding) VALUES (%s);", [(e,) for e in VECTORS] + "INSERT INTO tb_test_item (embedding) VALUES (%s);", + [(e,) for e in VECTORS], ) cur.execute("SELECT * FROM tb_test_item;") conn.commit() @@ -80,7 +81,8 @@ def test_insert(conn: Connection): def test_squared_euclidean_distance(conn: Connection): cur = conn.execute( - "SELECT embedding <-> %s FROM tb_test_item;", (OP_SQRT_EUCLID_DIS,) + "SELECT embedding <-> %s FROM tb_test_item;", + (OP_SQRT_EUCLID_DIS,), ) for i, row in enumerate(cur.fetchall()): assert np.allclose(EXPECTED_SQRT_EUCLID_DIS[i], row[0], atol=1e-10) @@ -88,7 +90,8 @@ def test_squared_euclidean_distance(conn: Connection): def test_negative_dot_product_distance(conn: Connection): cur = conn.execute( - "SELECT embedding <#> %s FROM tb_test_item;", (OP_NEG_DOT_PROD_DIS,) + "SELECT embedding <#> %s FROM tb_test_item;", + (OP_NEG_DOT_PROD_DIS,), ) for i, row in enumerate(cur.fetchall()): assert np.allclose(EXPECTED_NEG_DOT_PROD_DIS[i], row[0], atol=1e-10) diff --git a/bindings/python/tests/test_sdk.py b/bindings/python/tests/test_sdk.py index 6409e90..162604b 100644 --- a/bindings/python/tests/test_sdk.py +++ b/bindings/python/tests/test_sdk.py @@ -16,7 +16,7 @@ from tests import ( ) URL = URL.replace("postgresql", "postgresql+psycopg") -mockTexts = { +MockTexts = { "text0": VECTORS[0], "text1": VECTORS[1], "text2": VECTORS[2], @@ -25,9 +25,9 @@ mockTexts = { class MockEmbedder: def embed(self, text: str) -> np.ndarray: - if isinstance(mockTexts[text], list): - return np.array(mockTexts[text], dtype=np.float32) - return mockTexts[text] + if isinstance(MockTexts[text], list): + return np.array(MockTexts[text], dtype=np.float32) + return MockTexts[text] @pytest.fixture(scope="module") @@ -35,10 +35,10 @@ def client(): client = PGVectoRs(db_url=URL, collection_name="empty", dimension=3) try: records1 = [ - Record.from_text(t, v, {"src": "src1"}) for t, v in mockTexts.items() + Record.from_text(t, v, {"src": "src1"}) for t, v in MockTexts.items() ] records2 = [ - Record.from_text(t, v, {"src": "src2"}) for t, v in mockTexts.items() + Record.from_text(t, v, {"src": "src2"}) for t, v in MockTexts.items() ] client.insert(records1) client.insert(records2) @@ -53,7 +53,7 @@ filter_src2: Filter = lambda r: r.meta.contains({"src": "src2"}) @pytest.mark.parametrize("filter", [filter_src1, filter_src2]) @pytest.mark.parametrize( - "dis_op, dis_oprand, dis_expected", + ("dis_op", "dis_oprand", "dis_expected"), zip( ["<->", "<#>", "<=>"], [OP_SQRT_EUCLID_DIS, OP_NEG_DOT_PROD_DIS, OP_NEG_COS_DIS], @@ -77,7 +77,7 @@ def test_search_filter_and_op( @pytest.mark.parametrize( - "dis_op, dis_oprand, dis_expected", + ("dis_op", "dis_oprand", "dis_expected"), zip( ["<->", "<#>", "<=>"], [OP_SQRT_EUCLID_DIS, OP_NEG_DOT_PROD_DIS, OP_NEG_COS_DIS], @@ -92,5 +92,5 @@ def test_search_order_and_limit( ): dis_expected = dis_expected.copy() dis_expected.sort() - for i, (rec, dis) in enumerate(client.search(dis_oprand, dis_op, top_k=4)): + for i, (_rec, dis) in enumerate(client.search(dis_oprand, dis_op, top_k=4)): assert np.allclose(dis, dis_expected[i // 2]) diff --git a/bindings/python/tests/test_sqlalchemy.py b/bindings/python/tests/test_sqlalchemy.py index 9156368..d51e58b 100644 --- a/bindings/python/tests/test_sqlalchemy.py +++ b/bindings/python/tests/test_sqlalchemy.py @@ -36,11 +36,9 @@ class Document(Base): @pytest.fixture(scope="module") def session(): - """ - Connect to the test db pointed by the URL. Can check more details + """Connect to the test db pointed by the URL. Can check more details in `tests/__init__.py` """ - engine = create_engine(URL.replace("postgresql", "postgresql+psycopg")) # ensure that we have installed pgvector.rs extension @@ -63,7 +61,7 @@ def session(): # ================================= -@pytest.mark.parametrize("index_name,index_setting", TOML_SETTINGS.items()) +@pytest.mark.parametrize(("index_name", "index_setting"), TOML_SETTINGS.items()) def test_create_index(session: Session, index_name: str, index_setting: str): index = Index( index_name, @@ -76,15 +74,15 @@ def test_create_index(session: Session, index_name: str, index_setting: str): session.commit() -@pytest.mark.parametrize("i,e", enumerate(INVALID_VECTORS)) +@pytest.mark.parametrize(("i", "e"), enumerate(INVALID_VECTORS)) def test_invalid_insert(session: Session, i: int, e: np.ndarray): try: session.execute(insert(Document).values(id=i, embedding=e)) except StatementError: pass else: - raise AssertionError( - "failed to raise invalid value error for {}th vector {}".format(i, e), + raise AssertionError( # noqa: TRY003 + f"failed to raise invalid value error for {i}th vector {e}", ) finally: session.rollback() @@ -110,7 +108,7 @@ def test_squared_euclidean_distance(session: Session): select( Document.id, Document.embedding.squared_euclidean_distance(OP_SQRT_EUCLID_DIS), - ) + ), ): (i, res) = row assert np.allclose(EXPECTED_SQRT_EUCLID_DIS[i], res, atol=1e-10) @@ -121,7 +119,7 @@ def test_negative_dot_product_distance(session: Session): select( Document.id, Document.embedding.negative_dot_product_distance(OP_NEG_DOT_PROD_DIS), - ) + ), ): (i, res) = row assert np.allclose(EXPECTED_NEG_DOT_PROD_DIS[i], res, atol=1e-10) @@ -129,7 +127,9 @@ def test_negative_dot_product_distance(session: Session): def test_negative_cosine_distance(session: Session): for row in session.execute( - select(Document.id, Document.embedding.negative_cosine_distance(OP_NEG_COS_DIS)) + select( + Document.id, Document.embedding.negative_cosine_distance(OP_NEG_COS_DIS) + ), ): (i, res) = row assert np.allclose(EXPECTED_NEG_COS_DIS[i], res, atol=1e-10) diff --git a/bindings/python/tox.ini b/bindings/python/tox.ini index 2ea13b2..03575dc 100644 --- a/bindings/python/tox.ini +++ b/bindings/python/tox.ini @@ -4,6 +4,6 @@ env_list = py3{10, 11} [testenv] deps = pdm commands = - pdm sync -d -G lint + pdm sync -d -G :all pdm run -v check pdm run -v test