You've already forked pgvecto.rs
mirror of
https://github.com/tensorchord/pgvecto.rs.git
synced 2025-07-29 08:21:12 +03:00
feat: add more ruff rules (#138)
* feat: add more ruff rules Signed-off-by: 盐粒 Yanli <mail@yanli.one> * chore: modified readme Signed-off-by: 盐粒 Yanli <mail@yanli.one> * rename error class Signed-off-by: 盐粒 Yanli <mail@yanli.one> --------- Signed-off-by: 盐粒 Yanli <mail@yanli.one>
This commit is contained in:
@ -106,6 +106,7 @@ pdm sync
|
||||
Run lint:
|
||||
```bash
|
||||
pdm run format
|
||||
pdm run fix
|
||||
pdm run check
|
||||
```
|
||||
|
||||
|
@ -6,7 +6,7 @@ import psycopg
|
||||
from pgvecto_rs.psycopg import register_vector
|
||||
|
||||
URL = "postgresql://{username}:{password}@{host}:{port}/{db_name}".format(
|
||||
port=os.getenv("DB_PORT", 5432),
|
||||
port=os.getenv("DB_PORT", "5432"),
|
||||
host=os.getenv("DB_HOST", "localhost"),
|
||||
username=os.getenv("DB_USER", "postgres"),
|
||||
password=os.getenv("DB_PASS", "mysecretpassword"),
|
||||
@ -18,7 +18,7 @@ with psycopg.connect(URL) as conn:
|
||||
conn.execute("CREATE EXTENSION IF NOT EXISTS vectors;")
|
||||
register_vector(conn)
|
||||
conn.execute(
|
||||
"CREATE TABLE documents (id SERIAL PRIMARY KEY, text TEXT NOT NULL, embedding vector(3) NOT NULL);"
|
||||
"CREATE TABLE documents (id SERIAL PRIMARY KEY, text TEXT NOT NULL, embedding vector(3) NOT NULL);",
|
||||
)
|
||||
conn.commit()
|
||||
try:
|
||||
@ -39,7 +39,8 @@ with psycopg.connect(URL) as conn:
|
||||
|
||||
# Select the row "hello pgvecto.rs"
|
||||
cur = conn.execute(
|
||||
"SELECT * FROM documents WHERE text = %s;", ("hello pgvecto.rs",)
|
||||
"SELECT * FROM documents WHERE text = %s;",
|
||||
("hello pgvecto.rs",),
|
||||
)
|
||||
target = cur.fetchone()[2]
|
||||
|
||||
|
@ -5,7 +5,7 @@ from openai import OpenAI
|
||||
from pgvecto_rs.sdk import PGVectoRs, Record, filters
|
||||
|
||||
URL = "postgresql+psycopg://{username}:{password}@{host}:{port}/{db_name}".format(
|
||||
port=os.getenv("DB_PORT", 5432),
|
||||
port=os.getenv("DB_PORT", "5432"),
|
||||
host=os.getenv("DB_HOST", "localhost"),
|
||||
username=os.getenv("DB_USER", "postgres"),
|
||||
password=os.getenv("DB_PASS", "mysecretpassword"),
|
||||
@ -43,7 +43,8 @@ try:
|
||||
# Query (With a filter from the filters module)
|
||||
print("#################### First Query ####################")
|
||||
for record, dis in client.search(
|
||||
target, filter=filters.meta_contains({"src": "one"})
|
||||
target,
|
||||
filter=filters.meta_contains({"src": "one"}),
|
||||
):
|
||||
print(f"DISTANCE SCORE: {dis}")
|
||||
print(record)
|
||||
@ -51,7 +52,8 @@ try:
|
||||
# Another Query (Equivalent to the first one, but with a lambda filter written by hand)
|
||||
print("#################### Second Query ####################")
|
||||
for record, dis in client.search(
|
||||
target, filter=lambda r: r.meta.contains({"src": "one"})
|
||||
target,
|
||||
filter=lambda r: r.meta.contains({"src": "one"}),
|
||||
):
|
||||
print(f"DISTANCE SCORE: {dis}")
|
||||
print(record)
|
||||
|
@ -7,7 +7,7 @@ from sqlalchemy.orm import DeclarativeBase, Mapped, Session, mapped_column
|
||||
from pgvecto_rs.sqlalchemy import Vector
|
||||
|
||||
URL = "postgresql+psycopg://{username}:{password}@{host}:{port}/{db_name}".format(
|
||||
port=os.getenv("DB_PORT", 5432),
|
||||
port=os.getenv("DB_PORT", "5432"),
|
||||
host=os.getenv("DB_HOST", "localhost"),
|
||||
username=os.getenv("DB_USER", "postgres"),
|
||||
password=os.getenv("DB_PASS", "mysecretpassword"),
|
||||
@ -53,7 +53,7 @@ with Session(engine) as session:
|
||||
stmt = select(
|
||||
Document.text,
|
||||
Document.embedding.squared_euclidean_distance(target.embedding).label(
|
||||
"distance"
|
||||
"distance",
|
||||
),
|
||||
).order_by("distance")
|
||||
for doc in session.execute(stmt):
|
||||
|
40
bindings/python/pdm.lock
generated
40
bindings/python/pdm.lock
generated
@ -5,7 +5,7 @@
|
||||
groups = ["default", "lint", "psycopg3", "sdk", "sqlalchemy", "test"]
|
||||
strategy = ["cross_platform", "direct_minimal_versions"]
|
||||
lock_version = "4.4"
|
||||
content_hash = "sha256:f65e8d98636d7592453753c6ba60e73a6912f0e21205dff0e0f92bb148befce3"
|
||||
content_hash = "sha256:a7e2c999c870cd4bac136205366ffe7d592d8ca043fd0946c3df631f7326282f"
|
||||
|
||||
[[package]]
|
||||
name = "annotated-types"
|
||||
@ -506,27 +506,27 @@ files = [
|
||||
|
||||
[[package]]
|
||||
name = "ruff"
|
||||
version = "0.1.1"
|
||||
version = "0.1.5"
|
||||
requires_python = ">=3.7"
|
||||
summary = "An extremely fast Python linter, written in Rust."
|
||||
summary = "An extremely fast Python linter and code formatter, written in Rust."
|
||||
files = [
|
||||
{file = "ruff-0.1.1-py3-none-macosx_10_7_x86_64.whl", hash = "sha256:b7cdc893aef23ccc14c54bd79a8109a82a2c527e11d030b62201d86f6c2b81c5"},
|
||||
{file = "ruff-0.1.1-py3-none-macosx_10_9_x86_64.macosx_11_0_arm64.macosx_10_9_universal2.whl", hash = "sha256:620d4b34302538dbd8bbbe8fdb8e8f98d72d29bd47e972e2b59ce6c1e8862257"},
|
||||
{file = "ruff-0.1.1-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2a909d3930afdbc2e9fd893b0034479e90e7981791879aab50ce3d9f55205bd6"},
|
||||
{file = "ruff-0.1.1-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:3305d1cb4eb8ff6d3e63a48d1659d20aab43b49fe987b3ca4900528342367145"},
|
||||
{file = "ruff-0.1.1-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c34ae501d0ec71acf19ee5d4d889e379863dcc4b796bf8ce2934a9357dc31db7"},
|
||||
{file = "ruff-0.1.1-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:6aa7e63c3852cf8fe62698aef31e563e97143a4b801b57f920012d0e07049a8d"},
|
||||
{file = "ruff-0.1.1-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2d68367d1379a6b47e61bc9de144a47bcdb1aad7903bbf256e4c3d31f11a87ae"},
|
||||
{file = "ruff-0.1.1-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:bc11955f6ce3398d2afe81ad7e49d0ebf0a581d8bcb27b8c300281737735e3a3"},
|
||||
{file = "ruff-0.1.1-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cbbd8eead88ea83a250499074e2a8e9d80975f0b324b1e2e679e4594da318c25"},
|
||||
{file = "ruff-0.1.1-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:f4780e2bb52f3863a565ec3f699319d3493b83ff95ebbb4993e59c62aaf6e75e"},
|
||||
{file = "ruff-0.1.1-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:8f5b24daddf35b6c207619301170cae5d2699955829cda77b6ce1e5fc69340df"},
|
||||
{file = "ruff-0.1.1-py3-none-musllinux_1_2_i686.whl", hash = "sha256:d3f9ac658ba29e07b95c80fa742b059a55aefffa8b1e078bc3c08768bdd4b11a"},
|
||||
{file = "ruff-0.1.1-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:3521bf910104bf781e6753282282acc145cbe3eff79a1ce6b920404cd756075a"},
|
||||
{file = "ruff-0.1.1-py3-none-win32.whl", hash = "sha256:ba3208543ab91d3e4032db2652dcb6c22a25787b85b8dc3aeff084afdc612e5c"},
|
||||
{file = "ruff-0.1.1-py3-none-win_amd64.whl", hash = "sha256:3ff3006c97d9dc396b87fb46bb65818e614ad0181f059322df82bbfe6944e264"},
|
||||
{file = "ruff-0.1.1-py3-none-win_arm64.whl", hash = "sha256:e140bd717c49164c8feb4f65c644046fe929c46f42493672853e3213d7bdbce2"},
|
||||
{file = "ruff-0.1.1.tar.gz", hash = "sha256:c90461ae4abec261609e5ea436de4a4b5f2822921cf04c16d2cc9327182dbbcc"},
|
||||
{file = "ruff-0.1.5-py3-none-macosx_10_7_x86_64.whl", hash = "sha256:32d47fc69261c21a4c48916f16ca272bf2f273eb635d91c65d5cd548bf1f3d96"},
|
||||
{file = "ruff-0.1.5-py3-none-macosx_10_9_x86_64.macosx_11_0_arm64.macosx_10_9_universal2.whl", hash = "sha256:171276c1df6c07fa0597fb946139ced1c2978f4f0b8254f201281729981f3c17"},
|
||||
{file = "ruff-0.1.5-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:17ef33cd0bb7316ca65649fc748acc1406dfa4da96a3d0cde6d52f2e866c7b39"},
|
||||
{file = "ruff-0.1.5-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:b2c205827b3f8c13b4a432e9585750b93fd907986fe1aec62b2a02cf4401eee6"},
|
||||
{file = "ruff-0.1.5-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:bb408e3a2ad8f6881d0f2e7ad70cddb3ed9f200eb3517a91a245bbe27101d379"},
|
||||
{file = "ruff-0.1.5-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:f20dc5e5905ddb407060ca27267c7174f532375c08076d1a953cf7bb016f5a24"},
|
||||
{file = "ruff-0.1.5-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:aafb9d2b671ed934998e881e2c0f5845a4295e84e719359c71c39a5363cccc91"},
|
||||
{file = "ruff-0.1.5-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a4894dddb476597a0ba4473d72a23151b8b3b0b5f958f2cf4d3f1c572cdb7af7"},
|
||||
{file = "ruff-0.1.5-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a00a7ec893f665ed60008c70fe9eeb58d210e6b4d83ec6654a9904871f982a2a"},
|
||||
{file = "ruff-0.1.5-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:a8c11206b47f283cbda399a654fd0178d7a389e631f19f51da15cbe631480c5b"},
|
||||
{file = "ruff-0.1.5-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:fa29e67b3284b9a79b1a85ee66e293a94ac6b7bb068b307a8a373c3d343aa8ec"},
|
||||
{file = "ruff-0.1.5-py3-none-musllinux_1_2_i686.whl", hash = "sha256:9b97fd6da44d6cceb188147b68db69a5741fbc736465b5cea3928fdac0bc1aeb"},
|
||||
{file = "ruff-0.1.5-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:721f4b9d3b4161df8dc9f09aa8562e39d14e55a4dbaa451a8e55bdc9590e20f4"},
|
||||
{file = "ruff-0.1.5-py3-none-win32.whl", hash = "sha256:f80c73bba6bc69e4fdc73b3991db0b546ce641bdcd5b07210b8ad6f64c79f1ab"},
|
||||
{file = "ruff-0.1.5-py3-none-win_amd64.whl", hash = "sha256:c21fe20ee7d76206d290a76271c1af7a5096bc4c73ab9383ed2ad35f852a0087"},
|
||||
{file = "ruff-0.1.5-py3-none-win_arm64.whl", hash = "sha256:82bfcb9927e88c1ed50f49ac6c9728dab3ea451212693fe40d08d314663e412f"},
|
||||
{file = "ruff-0.1.5.tar.gz", hash = "sha256:5cbec0ef2ae1748fb194f420fb03fb2c25c3258c86129af7172ff8f198f125ab"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
@ -3,8 +3,8 @@ name = "pgvecto-rs"
|
||||
version = "0.1.3"
|
||||
description = "Python binding for pgvecto.rs"
|
||||
authors = [
|
||||
{ name = "TensorChord", email = "envd-maintainers@tensorchord.ai" },
|
||||
{ name = "盐粒 Yanli", email = "mail@yanli.one" },
|
||||
{ name = "TensorChord", email = "envd-maintainers@tensorchord.ai" },
|
||||
{ name = "盐粒 Yanli", email = "mail@yanli.one" },
|
||||
]
|
||||
dependencies = [
|
||||
"numpy>=1.23",
|
||||
@ -23,15 +23,15 @@ classifiers = [
|
||||
|
||||
[build-system]
|
||||
build-backend = "pdm.backend"
|
||||
requires = [
|
||||
requires = [
|
||||
"pdm-backend",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
psycopg3 = [
|
||||
psycopg3 = [
|
||||
"psycopg[binary]>=3.1.12",
|
||||
]
|
||||
sdk = [
|
||||
sdk = [
|
||||
"openai>=1.2.2",
|
||||
"pgvecto_rs[sqlalchemy]",
|
||||
]
|
||||
@ -40,19 +40,34 @@ sqlalchemy = [
|
||||
"SQLAlchemy>=2.0.23",
|
||||
]
|
||||
[tool.pdm.dev-dependencies]
|
||||
lint = ["ruff>=0.1.1"]
|
||||
lint = ["ruff>=0.1.5"]
|
||||
test = ["pytest>=7.4.3"]
|
||||
|
||||
[tool.pdm.scripts]
|
||||
test = "pytest tests/"
|
||||
test = "pytest tests/"
|
||||
format = "ruff format ."
|
||||
fix = "ruff --fix ."
|
||||
check = { composite = ["ruff format . --check", "ruff ."] }
|
||||
fix = "ruff --fix ."
|
||||
check = { composite = ["ruff format . --check", "ruff ."] }
|
||||
|
||||
[tool.ruff]
|
||||
select = ["E", "F", "I", "TID"]
|
||||
ignore = ["E731", "E501"]
|
||||
src = ["src"]
|
||||
select = [
|
||||
"E", #https://docs.astral.sh/ruff/rules/#error-e
|
||||
"F", #https://docs.astral.sh/ruff/rules/#pyflakes-f
|
||||
"I", #https://docs.astral.sh/ruff/rules/#isort-i
|
||||
"TID", #https://docs.astral.sh/ruff/rules/#flake8-tidy-imports-tid
|
||||
"S", #https://docs.astral.sh/ruff/rules/#flake8-bandit-s
|
||||
"B", #https://docs.astral.sh/ruff/rules/#flake8-bugbear-b
|
||||
"SIM", #https://docs.astral.sh/ruff/rules/#flake8-simplify-sim
|
||||
"N", #https://docs.astral.sh/ruff/rules/#pep8-naming-n
|
||||
"PT", #https://docs.astral.sh/ruff/rules/#flake8-pytest-style-pt
|
||||
"TRY", #https://docs.astral.sh/ruff/rules/#tryceratops-try
|
||||
"FLY", #https://docs.astral.sh/ruff/rules/#flynt-fly
|
||||
"PL", #https://docs.astral.sh/ruff/rules/#pylint-pl
|
||||
"NPY", #https://docs.astral.sh/ruff/rules/#numpy-specific-rules-npy
|
||||
"RUF", #https://docs.astral.sh/ruff/rules/#ruff-specific-rules-ruf
|
||||
]
|
||||
ignore = ["S101", "E731", "E501"]
|
||||
src = ["src"]
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
addopts = "-r aR"
|
||||
|
25
bindings/python/src/pgvecto_rs/errors.py
Normal file
25
bindings/python/src/pgvecto_rs/errors.py
Normal file
@ -0,0 +1,25 @@
|
||||
import numpy as np
|
||||
|
||||
|
||||
class PGVectoRsError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
class NDArrayDimensionError(PGVectoRsError):
|
||||
def __init__(self, dim: int) -> None:
|
||||
super().__init__(f"ndarray must be 1D for vector, got {dim}D")
|
||||
|
||||
|
||||
class NDArrayDtypeError(PGVectoRsError):
|
||||
def __init__(self, dtype: np.dtype) -> None:
|
||||
super().__init__(f"ndarray data type must be numeric for vector, got {dtype}")
|
||||
|
||||
|
||||
class BuiltinListTypeError(PGVectoRsError):
|
||||
def __init__(self) -> None:
|
||||
super().__init__("list data type must be numeric for vector")
|
||||
|
||||
|
||||
class VectorDimensionError(PGVectoRsError):
|
||||
def __init__(self, dim: int) -> None:
|
||||
super().__init__(f"vector dimension must be > 0, got {dim}")
|
@ -37,7 +37,7 @@ async def register_vector_async(context: Connection):
|
||||
|
||||
def register_vector_info(context: Connection, info: TypeInfo):
|
||||
if info is None:
|
||||
raise ProgrammingError("vector type not found in the database")
|
||||
raise ProgrammingError(info="vector type not found in the database")
|
||||
info.register(context)
|
||||
|
||||
class VectorTextDumper(VectorDumper):
|
||||
|
@ -28,6 +28,7 @@ class PGVectoRs:
|
||||
"""Connect to an existing table or create a new empty one.
|
||||
|
||||
Args:
|
||||
----
|
||||
db_url (str): url to the database.
|
||||
table_name (str): name of the table.
|
||||
dimension (int): dimension of the embeddings.
|
||||
@ -36,7 +37,8 @@ class PGVectoRs:
|
||||
class _Table(RecordORM):
|
||||
__tablename__ = f"collection_{collection_name}"
|
||||
id: Mapped[UUID] = mapped_column(
|
||||
postgresql.UUID(as_uuid=True), primary_key=True
|
||||
postgresql.UUID(as_uuid=True),
|
||||
primary_key=True,
|
||||
)
|
||||
text: Mapped[str] = mapped_column(String)
|
||||
meta: Mapped[dict] = mapped_column(postgresql.JSONB)
|
||||
@ -59,7 +61,7 @@ class PGVectoRs:
|
||||
text=record.text,
|
||||
meta=record.meta,
|
||||
embedding=record.embedding,
|
||||
)
|
||||
),
|
||||
)
|
||||
session.commit()
|
||||
|
||||
@ -73,6 +75,7 @@ class PGVectoRs:
|
||||
"""Search for the nearest records.
|
||||
|
||||
Args:
|
||||
----
|
||||
embedding : Target embedding.
|
||||
distance_op : Distance op.
|
||||
top_k : Max records to return. Defaults to 4.
|
||||
@ -80,6 +83,7 @@ class PGVectoRs:
|
||||
order_by_dis : Order by distance. Defaults to True.
|
||||
|
||||
Returns:
|
||||
-------
|
||||
List of records and coresponding distances.
|
||||
|
||||
"""
|
||||
@ -88,7 +92,7 @@ class PGVectoRs:
|
||||
select(
|
||||
self._table,
|
||||
self._table.embedding.op(distance_op, return_type=Float)(
|
||||
embedding
|
||||
embedding,
|
||||
).label("distance"),
|
||||
)
|
||||
.limit(top_k)
|
||||
|
@ -1,5 +1,6 @@
|
||||
import sqlalchemy.types as types
|
||||
from sqlalchemy import types
|
||||
|
||||
from pgvecto_rs.errors import VectorDimensionError
|
||||
from pgvecto_rs.utils import serializer
|
||||
|
||||
|
||||
@ -8,13 +9,13 @@ class Vector(types.UserDefinedType):
|
||||
|
||||
def __init__(self, dim):
|
||||
if dim < 0:
|
||||
raise ValueError("negative dim is not allowed")
|
||||
raise VectorDimensionError(dim)
|
||||
self.dim = dim
|
||||
|
||||
def get_col_spec(self, **kw):
|
||||
if self.dim is None or self.dim == 0:
|
||||
return "VECTOR"
|
||||
return "VECTOR({})".format(self.dim)
|
||||
return f"VECTOR({self.dim})"
|
||||
|
||||
def bind_processor(self, dialect):
|
||||
def _processor(value):
|
||||
@ -28,7 +29,7 @@ class Vector(types.UserDefinedType):
|
||||
|
||||
return _processor
|
||||
|
||||
class comparator_factory(types.UserDefinedType.Comparator):
|
||||
class comparator_factory(types.UserDefinedType.Comparator): # noqa: N801
|
||||
def squared_euclidean_distance(self, other):
|
||||
return self.op("<->", return_type=types.Float)(other)
|
||||
|
||||
|
@ -2,6 +2,12 @@ from functools import wraps
|
||||
|
||||
import numpy as np
|
||||
|
||||
from pgvecto_rs.errors import (
|
||||
BuiltinListTypeError,
|
||||
NDArrayDimensionError,
|
||||
NDArrayDtypeError,
|
||||
)
|
||||
|
||||
|
||||
def ignore_none(func):
|
||||
@wraps(func)
|
||||
@ -26,9 +32,9 @@ def validate_ndarray(func):
|
||||
def _func(value: np.ndarray, *args, **kwargs):
|
||||
if isinstance(value, np.ndarray):
|
||||
if value.ndim != 1:
|
||||
raise ValueError("ndarray must be 1D for vector")
|
||||
raise NDArrayDimensionError(value.ndim)
|
||||
if not np.issubdtype(value.dtype, np.number):
|
||||
raise ValueError("ndarray data type must be numeric for vector")
|
||||
raise NDArrayDtypeError(value.dtype)
|
||||
return func(value, *args, **kwargs)
|
||||
|
||||
return _func
|
||||
@ -41,7 +47,7 @@ def validate_builtin_list(func):
|
||||
def _func(value: list, *args, **kwargs):
|
||||
if isinstance(value, list):
|
||||
if not all(isinstance(x, (int, float)) for x in value):
|
||||
raise ValueError("list data type must be numeric for vector")
|
||||
raise BuiltinListTypeError()
|
||||
value = np.array(value, dtype=np.float32)
|
||||
return func(value, *args, **kwargs)
|
||||
|
||||
|
@ -3,7 +3,7 @@ import os
|
||||
import numpy as np
|
||||
import toml
|
||||
|
||||
PORT = os.getenv("DB_PORT", 5432)
|
||||
PORT = os.getenv("DB_PORT", "5432")
|
||||
HOST = os.getenv("DB_HOST", "localhost")
|
||||
USER = os.getenv("DB_USER", "postgres")
|
||||
PASS = os.getenv("DB_PASS", "mysecretpassword")
|
||||
@ -11,13 +11,7 @@ DB_NAME = os.getenv("DB_NAME", "postgres")
|
||||
|
||||
# Run tests with shell:
|
||||
# DB_HOST=localhost DB_USER=postgres DB_PASS=password DB_NAME=postgres python3 -m pytest bindings/python/tests/
|
||||
URL = "postgresql://{username}:{password}@{host}:{port}/{db_name}".format(
|
||||
port=PORT,
|
||||
host=HOST,
|
||||
username=USER,
|
||||
password=PASS,
|
||||
db_name=DB_NAME,
|
||||
)
|
||||
URL = f"postgresql://{USER}:{PASS}@{HOST}:{PORT}/{DB_NAME}"
|
||||
|
||||
|
||||
# ==== test_create_index ====
|
||||
@ -27,13 +21,13 @@ TOML_SETTINGS = {
|
||||
{
|
||||
"capacity": 2097152,
|
||||
"algorithm": {"flat": {}},
|
||||
}
|
||||
},
|
||||
),
|
||||
"hnsw": toml.dumps(
|
||||
{
|
||||
"capacity": 2097152,
|
||||
"algorithm": {"hnsw": {}},
|
||||
}
|
||||
},
|
||||
),
|
||||
}
|
||||
|
||||
|
@ -25,7 +25,7 @@ def conn():
|
||||
register_vector(conn)
|
||||
conn.execute("DROP TABLE IF EXISTS tb_test_item;")
|
||||
conn.execute(
|
||||
"CREATE TABLE tb_test_item (id bigserial PRIMARY KEY, embedding vector(3) NOT NULL);"
|
||||
"CREATE TABLE tb_test_item (id bigserial PRIMARY KEY, embedding vector(3) NOT NULL);",
|
||||
)
|
||||
conn.commit()
|
||||
try:
|
||||
@ -35,7 +35,7 @@ def conn():
|
||||
conn.commit()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("index_name,index_setting", TOML_SETTINGS.items())
|
||||
@pytest.mark.parametrize(("index_name", "index_setting"), TOML_SETTINGS.items())
|
||||
def test_create_index(conn: Connection, index_name: str, index_setting: str):
|
||||
stat = sql.SQL(
|
||||
"CREATE INDEX {} ON tb_test_item USING vectors (embedding l2_ops) WITH (options={});",
|
||||
@ -68,7 +68,8 @@ def test_create_index(conn: Connection, index_name: str, index_setting: str):
|
||||
def test_insert(conn: Connection):
|
||||
with conn.cursor() as cur:
|
||||
cur.executemany(
|
||||
"INSERT INTO tb_test_item (embedding) VALUES (%s);", [(e,) for e in VECTORS]
|
||||
"INSERT INTO tb_test_item (embedding) VALUES (%s);",
|
||||
[(e,) for e in VECTORS],
|
||||
)
|
||||
cur.execute("SELECT * FROM tb_test_item;")
|
||||
conn.commit()
|
||||
@ -80,7 +81,8 @@ def test_insert(conn: Connection):
|
||||
|
||||
def test_squared_euclidean_distance(conn: Connection):
|
||||
cur = conn.execute(
|
||||
"SELECT embedding <-> %s FROM tb_test_item;", (OP_SQRT_EUCLID_DIS,)
|
||||
"SELECT embedding <-> %s FROM tb_test_item;",
|
||||
(OP_SQRT_EUCLID_DIS,),
|
||||
)
|
||||
for i, row in enumerate(cur.fetchall()):
|
||||
assert np.allclose(EXPECTED_SQRT_EUCLID_DIS[i], row[0], atol=1e-10)
|
||||
@ -88,7 +90,8 @@ def test_squared_euclidean_distance(conn: Connection):
|
||||
|
||||
def test_negative_dot_product_distance(conn: Connection):
|
||||
cur = conn.execute(
|
||||
"SELECT embedding <#> %s FROM tb_test_item;", (OP_NEG_DOT_PROD_DIS,)
|
||||
"SELECT embedding <#> %s FROM tb_test_item;",
|
||||
(OP_NEG_DOT_PROD_DIS,),
|
||||
)
|
||||
for i, row in enumerate(cur.fetchall()):
|
||||
assert np.allclose(EXPECTED_NEG_DOT_PROD_DIS[i], row[0], atol=1e-10)
|
||||
|
@ -16,7 +16,7 @@ from tests import (
|
||||
)
|
||||
|
||||
URL = URL.replace("postgresql", "postgresql+psycopg")
|
||||
mockTexts = {
|
||||
MockTexts = {
|
||||
"text0": VECTORS[0],
|
||||
"text1": VECTORS[1],
|
||||
"text2": VECTORS[2],
|
||||
@ -25,9 +25,9 @@ mockTexts = {
|
||||
|
||||
class MockEmbedder:
|
||||
def embed(self, text: str) -> np.ndarray:
|
||||
if isinstance(mockTexts[text], list):
|
||||
return np.array(mockTexts[text], dtype=np.float32)
|
||||
return mockTexts[text]
|
||||
if isinstance(MockTexts[text], list):
|
||||
return np.array(MockTexts[text], dtype=np.float32)
|
||||
return MockTexts[text]
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
@ -35,10 +35,10 @@ def client():
|
||||
client = PGVectoRs(db_url=URL, collection_name="empty", dimension=3)
|
||||
try:
|
||||
records1 = [
|
||||
Record.from_text(t, v, {"src": "src1"}) for t, v in mockTexts.items()
|
||||
Record.from_text(t, v, {"src": "src1"}) for t, v in MockTexts.items()
|
||||
]
|
||||
records2 = [
|
||||
Record.from_text(t, v, {"src": "src2"}) for t, v in mockTexts.items()
|
||||
Record.from_text(t, v, {"src": "src2"}) for t, v in MockTexts.items()
|
||||
]
|
||||
client.insert(records1)
|
||||
client.insert(records2)
|
||||
@ -53,7 +53,7 @@ filter_src2: Filter = lambda r: r.meta.contains({"src": "src2"})
|
||||
|
||||
@pytest.mark.parametrize("filter", [filter_src1, filter_src2])
|
||||
@pytest.mark.parametrize(
|
||||
"dis_op, dis_oprand, dis_expected",
|
||||
("dis_op", "dis_oprand", "dis_expected"),
|
||||
zip(
|
||||
["<->", "<#>", "<=>"],
|
||||
[OP_SQRT_EUCLID_DIS, OP_NEG_DOT_PROD_DIS, OP_NEG_COS_DIS],
|
||||
@ -77,7 +77,7 @@ def test_search_filter_and_op(
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"dis_op, dis_oprand, dis_expected",
|
||||
("dis_op", "dis_oprand", "dis_expected"),
|
||||
zip(
|
||||
["<->", "<#>", "<=>"],
|
||||
[OP_SQRT_EUCLID_DIS, OP_NEG_DOT_PROD_DIS, OP_NEG_COS_DIS],
|
||||
@ -92,5 +92,5 @@ def test_search_order_and_limit(
|
||||
):
|
||||
dis_expected = dis_expected.copy()
|
||||
dis_expected.sort()
|
||||
for i, (rec, dis) in enumerate(client.search(dis_oprand, dis_op, top_k=4)):
|
||||
for i, (_rec, dis) in enumerate(client.search(dis_oprand, dis_op, top_k=4)):
|
||||
assert np.allclose(dis, dis_expected[i // 2])
|
||||
|
@ -36,11 +36,9 @@ class Document(Base):
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def session():
|
||||
"""
|
||||
Connect to the test db pointed by the URL. Can check more details
|
||||
"""Connect to the test db pointed by the URL. Can check more details
|
||||
in `tests/__init__.py`
|
||||
"""
|
||||
|
||||
engine = create_engine(URL.replace("postgresql", "postgresql+psycopg"))
|
||||
|
||||
# ensure that we have installed pgvector.rs extension
|
||||
@ -63,7 +61,7 @@ def session():
|
||||
# =================================
|
||||
|
||||
|
||||
@pytest.mark.parametrize("index_name,index_setting", TOML_SETTINGS.items())
|
||||
@pytest.mark.parametrize(("index_name", "index_setting"), TOML_SETTINGS.items())
|
||||
def test_create_index(session: Session, index_name: str, index_setting: str):
|
||||
index = Index(
|
||||
index_name,
|
||||
@ -76,15 +74,15 @@ def test_create_index(session: Session, index_name: str, index_setting: str):
|
||||
session.commit()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("i,e", enumerate(INVALID_VECTORS))
|
||||
@pytest.mark.parametrize(("i", "e"), enumerate(INVALID_VECTORS))
|
||||
def test_invalid_insert(session: Session, i: int, e: np.ndarray):
|
||||
try:
|
||||
session.execute(insert(Document).values(id=i, embedding=e))
|
||||
except StatementError:
|
||||
pass
|
||||
else:
|
||||
raise AssertionError(
|
||||
"failed to raise invalid value error for {}th vector {}".format(i, e),
|
||||
raise AssertionError( # noqa: TRY003
|
||||
f"failed to raise invalid value error for {i}th vector {e}",
|
||||
)
|
||||
finally:
|
||||
session.rollback()
|
||||
@ -110,7 +108,7 @@ def test_squared_euclidean_distance(session: Session):
|
||||
select(
|
||||
Document.id,
|
||||
Document.embedding.squared_euclidean_distance(OP_SQRT_EUCLID_DIS),
|
||||
)
|
||||
),
|
||||
):
|
||||
(i, res) = row
|
||||
assert np.allclose(EXPECTED_SQRT_EUCLID_DIS[i], res, atol=1e-10)
|
||||
@ -121,7 +119,7 @@ def test_negative_dot_product_distance(session: Session):
|
||||
select(
|
||||
Document.id,
|
||||
Document.embedding.negative_dot_product_distance(OP_NEG_DOT_PROD_DIS),
|
||||
)
|
||||
),
|
||||
):
|
||||
(i, res) = row
|
||||
assert np.allclose(EXPECTED_NEG_DOT_PROD_DIS[i], res, atol=1e-10)
|
||||
@ -129,7 +127,9 @@ def test_negative_dot_product_distance(session: Session):
|
||||
|
||||
def test_negative_cosine_distance(session: Session):
|
||||
for row in session.execute(
|
||||
select(Document.id, Document.embedding.negative_cosine_distance(OP_NEG_COS_DIS))
|
||||
select(
|
||||
Document.id, Document.embedding.negative_cosine_distance(OP_NEG_COS_DIS)
|
||||
),
|
||||
):
|
||||
(i, res) = row
|
||||
assert np.allclose(EXPECTED_NEG_COS_DIS[i], res, atol=1e-10)
|
||||
|
@ -4,6 +4,6 @@ env_list = py3{10, 11}
|
||||
[testenv]
|
||||
deps = pdm
|
||||
commands =
|
||||
pdm sync -d -G lint
|
||||
pdm sync -d -G :all
|
||||
pdm run -v check
|
||||
pdm run -v test
|
||||
|
Reference in New Issue
Block a user