You've already forked pgvecto.rs
mirror of
https://github.com/tensorchord/pgvecto.rs.git
synced 2025-07-30 19:23:05 +03:00
feat: add more ruff rules (#138)
* feat: add more ruff rules Signed-off-by: 盐粒 Yanli <mail@yanli.one> * chore: modified readme Signed-off-by: 盐粒 Yanli <mail@yanli.one> * rename error class Signed-off-by: 盐粒 Yanli <mail@yanli.one> --------- Signed-off-by: 盐粒 Yanli <mail@yanli.one>
This commit is contained in:
@ -3,7 +3,7 @@ import os
|
||||
import numpy as np
|
||||
import toml
|
||||
|
||||
PORT = os.getenv("DB_PORT", 5432)
|
||||
PORT = os.getenv("DB_PORT", "5432")
|
||||
HOST = os.getenv("DB_HOST", "localhost")
|
||||
USER = os.getenv("DB_USER", "postgres")
|
||||
PASS = os.getenv("DB_PASS", "mysecretpassword")
|
||||
@ -11,13 +11,7 @@ DB_NAME = os.getenv("DB_NAME", "postgres")
|
||||
|
||||
# Run tests with shell:
|
||||
# DB_HOST=localhost DB_USER=postgres DB_PASS=password DB_NAME=postgres python3 -m pytest bindings/python/tests/
|
||||
URL = "postgresql://{username}:{password}@{host}:{port}/{db_name}".format(
|
||||
port=PORT,
|
||||
host=HOST,
|
||||
username=USER,
|
||||
password=PASS,
|
||||
db_name=DB_NAME,
|
||||
)
|
||||
URL = f"postgresql://{USER}:{PASS}@{HOST}:{PORT}/{DB_NAME}"
|
||||
|
||||
|
||||
# ==== test_create_index ====
|
||||
@ -27,13 +21,13 @@ TOML_SETTINGS = {
|
||||
{
|
||||
"capacity": 2097152,
|
||||
"algorithm": {"flat": {}},
|
||||
}
|
||||
},
|
||||
),
|
||||
"hnsw": toml.dumps(
|
||||
{
|
||||
"capacity": 2097152,
|
||||
"algorithm": {"hnsw": {}},
|
||||
}
|
||||
},
|
||||
),
|
||||
}
|
||||
|
||||
|
@ -25,7 +25,7 @@ def conn():
|
||||
register_vector(conn)
|
||||
conn.execute("DROP TABLE IF EXISTS tb_test_item;")
|
||||
conn.execute(
|
||||
"CREATE TABLE tb_test_item (id bigserial PRIMARY KEY, embedding vector(3) NOT NULL);"
|
||||
"CREATE TABLE tb_test_item (id bigserial PRIMARY KEY, embedding vector(3) NOT NULL);",
|
||||
)
|
||||
conn.commit()
|
||||
try:
|
||||
@ -35,7 +35,7 @@ def conn():
|
||||
conn.commit()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("index_name,index_setting", TOML_SETTINGS.items())
|
||||
@pytest.mark.parametrize(("index_name", "index_setting"), TOML_SETTINGS.items())
|
||||
def test_create_index(conn: Connection, index_name: str, index_setting: str):
|
||||
stat = sql.SQL(
|
||||
"CREATE INDEX {} ON tb_test_item USING vectors (embedding l2_ops) WITH (options={});",
|
||||
@ -68,7 +68,8 @@ def test_create_index(conn: Connection, index_name: str, index_setting: str):
|
||||
def test_insert(conn: Connection):
|
||||
with conn.cursor() as cur:
|
||||
cur.executemany(
|
||||
"INSERT INTO tb_test_item (embedding) VALUES (%s);", [(e,) for e in VECTORS]
|
||||
"INSERT INTO tb_test_item (embedding) VALUES (%s);",
|
||||
[(e,) for e in VECTORS],
|
||||
)
|
||||
cur.execute("SELECT * FROM tb_test_item;")
|
||||
conn.commit()
|
||||
@ -80,7 +81,8 @@ def test_insert(conn: Connection):
|
||||
|
||||
def test_squared_euclidean_distance(conn: Connection):
|
||||
cur = conn.execute(
|
||||
"SELECT embedding <-> %s FROM tb_test_item;", (OP_SQRT_EUCLID_DIS,)
|
||||
"SELECT embedding <-> %s FROM tb_test_item;",
|
||||
(OP_SQRT_EUCLID_DIS,),
|
||||
)
|
||||
for i, row in enumerate(cur.fetchall()):
|
||||
assert np.allclose(EXPECTED_SQRT_EUCLID_DIS[i], row[0], atol=1e-10)
|
||||
@ -88,7 +90,8 @@ def test_squared_euclidean_distance(conn: Connection):
|
||||
|
||||
def test_negative_dot_product_distance(conn: Connection):
|
||||
cur = conn.execute(
|
||||
"SELECT embedding <#> %s FROM tb_test_item;", (OP_NEG_DOT_PROD_DIS,)
|
||||
"SELECT embedding <#> %s FROM tb_test_item;",
|
||||
(OP_NEG_DOT_PROD_DIS,),
|
||||
)
|
||||
for i, row in enumerate(cur.fetchall()):
|
||||
assert np.allclose(EXPECTED_NEG_DOT_PROD_DIS[i], row[0], atol=1e-10)
|
||||
|
@ -16,7 +16,7 @@ from tests import (
|
||||
)
|
||||
|
||||
URL = URL.replace("postgresql", "postgresql+psycopg")
|
||||
mockTexts = {
|
||||
MockTexts = {
|
||||
"text0": VECTORS[0],
|
||||
"text1": VECTORS[1],
|
||||
"text2": VECTORS[2],
|
||||
@ -25,9 +25,9 @@ mockTexts = {
|
||||
|
||||
class MockEmbedder:
|
||||
def embed(self, text: str) -> np.ndarray:
|
||||
if isinstance(mockTexts[text], list):
|
||||
return np.array(mockTexts[text], dtype=np.float32)
|
||||
return mockTexts[text]
|
||||
if isinstance(MockTexts[text], list):
|
||||
return np.array(MockTexts[text], dtype=np.float32)
|
||||
return MockTexts[text]
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
@ -35,10 +35,10 @@ def client():
|
||||
client = PGVectoRs(db_url=URL, collection_name="empty", dimension=3)
|
||||
try:
|
||||
records1 = [
|
||||
Record.from_text(t, v, {"src": "src1"}) for t, v in mockTexts.items()
|
||||
Record.from_text(t, v, {"src": "src1"}) for t, v in MockTexts.items()
|
||||
]
|
||||
records2 = [
|
||||
Record.from_text(t, v, {"src": "src2"}) for t, v in mockTexts.items()
|
||||
Record.from_text(t, v, {"src": "src2"}) for t, v in MockTexts.items()
|
||||
]
|
||||
client.insert(records1)
|
||||
client.insert(records2)
|
||||
@ -53,7 +53,7 @@ filter_src2: Filter = lambda r: r.meta.contains({"src": "src2"})
|
||||
|
||||
@pytest.mark.parametrize("filter", [filter_src1, filter_src2])
|
||||
@pytest.mark.parametrize(
|
||||
"dis_op, dis_oprand, dis_expected",
|
||||
("dis_op", "dis_oprand", "dis_expected"),
|
||||
zip(
|
||||
["<->", "<#>", "<=>"],
|
||||
[OP_SQRT_EUCLID_DIS, OP_NEG_DOT_PROD_DIS, OP_NEG_COS_DIS],
|
||||
@ -77,7 +77,7 @@ def test_search_filter_and_op(
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"dis_op, dis_oprand, dis_expected",
|
||||
("dis_op", "dis_oprand", "dis_expected"),
|
||||
zip(
|
||||
["<->", "<#>", "<=>"],
|
||||
[OP_SQRT_EUCLID_DIS, OP_NEG_DOT_PROD_DIS, OP_NEG_COS_DIS],
|
||||
@ -92,5 +92,5 @@ def test_search_order_and_limit(
|
||||
):
|
||||
dis_expected = dis_expected.copy()
|
||||
dis_expected.sort()
|
||||
for i, (rec, dis) in enumerate(client.search(dis_oprand, dis_op, top_k=4)):
|
||||
for i, (_rec, dis) in enumerate(client.search(dis_oprand, dis_op, top_k=4)):
|
||||
assert np.allclose(dis, dis_expected[i // 2])
|
||||
|
@ -36,11 +36,9 @@ class Document(Base):
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def session():
|
||||
"""
|
||||
Connect to the test db pointed by the URL. Can check more details
|
||||
"""Connect to the test db pointed by the URL. Can check more details
|
||||
in `tests/__init__.py`
|
||||
"""
|
||||
|
||||
engine = create_engine(URL.replace("postgresql", "postgresql+psycopg"))
|
||||
|
||||
# ensure that we have installed pgvector.rs extension
|
||||
@ -63,7 +61,7 @@ def session():
|
||||
# =================================
|
||||
|
||||
|
||||
@pytest.mark.parametrize("index_name,index_setting", TOML_SETTINGS.items())
|
||||
@pytest.mark.parametrize(("index_name", "index_setting"), TOML_SETTINGS.items())
|
||||
def test_create_index(session: Session, index_name: str, index_setting: str):
|
||||
index = Index(
|
||||
index_name,
|
||||
@ -76,15 +74,15 @@ def test_create_index(session: Session, index_name: str, index_setting: str):
|
||||
session.commit()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("i,e", enumerate(INVALID_VECTORS))
|
||||
@pytest.mark.parametrize(("i", "e"), enumerate(INVALID_VECTORS))
|
||||
def test_invalid_insert(session: Session, i: int, e: np.ndarray):
|
||||
try:
|
||||
session.execute(insert(Document).values(id=i, embedding=e))
|
||||
except StatementError:
|
||||
pass
|
||||
else:
|
||||
raise AssertionError(
|
||||
"failed to raise invalid value error for {}th vector {}".format(i, e),
|
||||
raise AssertionError( # noqa: TRY003
|
||||
f"failed to raise invalid value error for {i}th vector {e}",
|
||||
)
|
||||
finally:
|
||||
session.rollback()
|
||||
@ -110,7 +108,7 @@ def test_squared_euclidean_distance(session: Session):
|
||||
select(
|
||||
Document.id,
|
||||
Document.embedding.squared_euclidean_distance(OP_SQRT_EUCLID_DIS),
|
||||
)
|
||||
),
|
||||
):
|
||||
(i, res) = row
|
||||
assert np.allclose(EXPECTED_SQRT_EUCLID_DIS[i], res, atol=1e-10)
|
||||
@ -121,7 +119,7 @@ def test_negative_dot_product_distance(session: Session):
|
||||
select(
|
||||
Document.id,
|
||||
Document.embedding.negative_dot_product_distance(OP_NEG_DOT_PROD_DIS),
|
||||
)
|
||||
),
|
||||
):
|
||||
(i, res) = row
|
||||
assert np.allclose(EXPECTED_NEG_DOT_PROD_DIS[i], res, atol=1e-10)
|
||||
@ -129,7 +127,9 @@ def test_negative_dot_product_distance(session: Session):
|
||||
|
||||
def test_negative_cosine_distance(session: Session):
|
||||
for row in session.execute(
|
||||
select(Document.id, Document.embedding.negative_cosine_distance(OP_NEG_COS_DIS))
|
||||
select(
|
||||
Document.id, Document.embedding.negative_cosine_distance(OP_NEG_COS_DIS)
|
||||
),
|
||||
):
|
||||
(i, res) = row
|
||||
assert np.allclose(EXPECTED_NEG_COS_DIS[i], res, atol=1e-10)
|
||||
|
Reference in New Issue
Block a user