1
0
mirror of https://github.com/tensorchord/pgvecto.rs.git synced 2025-07-30 19:23:05 +03:00

feat: add more ruff rules (#138)

* feat: add more ruff rules

Signed-off-by: 盐粒 Yanli <mail@yanli.one>

* chore: modified readme

Signed-off-by: 盐粒 Yanli <mail@yanli.one>

* rename error class

Signed-off-by: 盐粒 Yanli <mail@yanli.one>

---------

Signed-off-by: 盐粒 Yanli <mail@yanli.one>
This commit is contained in:
盐粒 Yanli
2023-11-17 17:47:28 +08:00
committed by GitHub
parent f8344dd039
commit f6e382d0fc
16 changed files with 138 additions and 86 deletions

View File

@ -3,7 +3,7 @@ import os
import numpy as np
import toml
PORT = os.getenv("DB_PORT", 5432)
PORT = os.getenv("DB_PORT", "5432")
HOST = os.getenv("DB_HOST", "localhost")
USER = os.getenv("DB_USER", "postgres")
PASS = os.getenv("DB_PASS", "mysecretpassword")
@ -11,13 +11,7 @@ DB_NAME = os.getenv("DB_NAME", "postgres")
# Run tests with shell:
# DB_HOST=localhost DB_USER=postgres DB_PASS=password DB_NAME=postgres python3 -m pytest bindings/python/tests/
URL = "postgresql://{username}:{password}@{host}:{port}/{db_name}".format(
port=PORT,
host=HOST,
username=USER,
password=PASS,
db_name=DB_NAME,
)
URL = f"postgresql://{USER}:{PASS}@{HOST}:{PORT}/{DB_NAME}"
# ==== test_create_index ====
@ -27,13 +21,13 @@ TOML_SETTINGS = {
{
"capacity": 2097152,
"algorithm": {"flat": {}},
}
},
),
"hnsw": toml.dumps(
{
"capacity": 2097152,
"algorithm": {"hnsw": {}},
}
},
),
}

View File

@ -25,7 +25,7 @@ def conn():
register_vector(conn)
conn.execute("DROP TABLE IF EXISTS tb_test_item;")
conn.execute(
"CREATE TABLE tb_test_item (id bigserial PRIMARY KEY, embedding vector(3) NOT NULL);"
"CREATE TABLE tb_test_item (id bigserial PRIMARY KEY, embedding vector(3) NOT NULL);",
)
conn.commit()
try:
@ -35,7 +35,7 @@ def conn():
conn.commit()
@pytest.mark.parametrize("index_name,index_setting", TOML_SETTINGS.items())
@pytest.mark.parametrize(("index_name", "index_setting"), TOML_SETTINGS.items())
def test_create_index(conn: Connection, index_name: str, index_setting: str):
stat = sql.SQL(
"CREATE INDEX {} ON tb_test_item USING vectors (embedding l2_ops) WITH (options={});",
@ -68,7 +68,8 @@ def test_create_index(conn: Connection, index_name: str, index_setting: str):
def test_insert(conn: Connection):
with conn.cursor() as cur:
cur.executemany(
"INSERT INTO tb_test_item (embedding) VALUES (%s);", [(e,) for e in VECTORS]
"INSERT INTO tb_test_item (embedding) VALUES (%s);",
[(e,) for e in VECTORS],
)
cur.execute("SELECT * FROM tb_test_item;")
conn.commit()
@ -80,7 +81,8 @@ def test_insert(conn: Connection):
def test_squared_euclidean_distance(conn: Connection):
cur = conn.execute(
"SELECT embedding <-> %s FROM tb_test_item;", (OP_SQRT_EUCLID_DIS,)
"SELECT embedding <-> %s FROM tb_test_item;",
(OP_SQRT_EUCLID_DIS,),
)
for i, row in enumerate(cur.fetchall()):
assert np.allclose(EXPECTED_SQRT_EUCLID_DIS[i], row[0], atol=1e-10)
@ -88,7 +90,8 @@ def test_squared_euclidean_distance(conn: Connection):
def test_negative_dot_product_distance(conn: Connection):
cur = conn.execute(
"SELECT embedding <#> %s FROM tb_test_item;", (OP_NEG_DOT_PROD_DIS,)
"SELECT embedding <#> %s FROM tb_test_item;",
(OP_NEG_DOT_PROD_DIS,),
)
for i, row in enumerate(cur.fetchall()):
assert np.allclose(EXPECTED_NEG_DOT_PROD_DIS[i], row[0], atol=1e-10)

View File

@ -16,7 +16,7 @@ from tests import (
)
URL = URL.replace("postgresql", "postgresql+psycopg")
mockTexts = {
MockTexts = {
"text0": VECTORS[0],
"text1": VECTORS[1],
"text2": VECTORS[2],
@ -25,9 +25,9 @@ mockTexts = {
class MockEmbedder:
def embed(self, text: str) -> np.ndarray:
if isinstance(mockTexts[text], list):
return np.array(mockTexts[text], dtype=np.float32)
return mockTexts[text]
if isinstance(MockTexts[text], list):
return np.array(MockTexts[text], dtype=np.float32)
return MockTexts[text]
@pytest.fixture(scope="module")
@ -35,10 +35,10 @@ def client():
client = PGVectoRs(db_url=URL, collection_name="empty", dimension=3)
try:
records1 = [
Record.from_text(t, v, {"src": "src1"}) for t, v in mockTexts.items()
Record.from_text(t, v, {"src": "src1"}) for t, v in MockTexts.items()
]
records2 = [
Record.from_text(t, v, {"src": "src2"}) for t, v in mockTexts.items()
Record.from_text(t, v, {"src": "src2"}) for t, v in MockTexts.items()
]
client.insert(records1)
client.insert(records2)
@ -53,7 +53,7 @@ filter_src2: Filter = lambda r: r.meta.contains({"src": "src2"})
@pytest.mark.parametrize("filter", [filter_src1, filter_src2])
@pytest.mark.parametrize(
"dis_op, dis_oprand, dis_expected",
("dis_op", "dis_oprand", "dis_expected"),
zip(
["<->", "<#>", "<=>"],
[OP_SQRT_EUCLID_DIS, OP_NEG_DOT_PROD_DIS, OP_NEG_COS_DIS],
@ -77,7 +77,7 @@ def test_search_filter_and_op(
@pytest.mark.parametrize(
"dis_op, dis_oprand, dis_expected",
("dis_op", "dis_oprand", "dis_expected"),
zip(
["<->", "<#>", "<=>"],
[OP_SQRT_EUCLID_DIS, OP_NEG_DOT_PROD_DIS, OP_NEG_COS_DIS],
@ -92,5 +92,5 @@ def test_search_order_and_limit(
):
dis_expected = dis_expected.copy()
dis_expected.sort()
for i, (rec, dis) in enumerate(client.search(dis_oprand, dis_op, top_k=4)):
for i, (_rec, dis) in enumerate(client.search(dis_oprand, dis_op, top_k=4)):
assert np.allclose(dis, dis_expected[i // 2])

View File

@ -36,11 +36,9 @@ class Document(Base):
@pytest.fixture(scope="module")
def session():
"""
Connect to the test db pointed by the URL. Can check more details
"""Connect to the test db pointed by the URL. Can check more details
in `tests/__init__.py`
"""
engine = create_engine(URL.replace("postgresql", "postgresql+psycopg"))
# ensure that we have installed pgvector.rs extension
@ -63,7 +61,7 @@ def session():
# =================================
@pytest.mark.parametrize("index_name,index_setting", TOML_SETTINGS.items())
@pytest.mark.parametrize(("index_name", "index_setting"), TOML_SETTINGS.items())
def test_create_index(session: Session, index_name: str, index_setting: str):
index = Index(
index_name,
@ -76,15 +74,15 @@ def test_create_index(session: Session, index_name: str, index_setting: str):
session.commit()
@pytest.mark.parametrize("i,e", enumerate(INVALID_VECTORS))
@pytest.mark.parametrize(("i", "e"), enumerate(INVALID_VECTORS))
def test_invalid_insert(session: Session, i: int, e: np.ndarray):
try:
session.execute(insert(Document).values(id=i, embedding=e))
except StatementError:
pass
else:
raise AssertionError(
"failed to raise invalid value error for {}th vector {}".format(i, e),
raise AssertionError( # noqa: TRY003
f"failed to raise invalid value error for {i}th vector {e}",
)
finally:
session.rollback()
@ -110,7 +108,7 @@ def test_squared_euclidean_distance(session: Session):
select(
Document.id,
Document.embedding.squared_euclidean_distance(OP_SQRT_EUCLID_DIS),
)
),
):
(i, res) = row
assert np.allclose(EXPECTED_SQRT_EUCLID_DIS[i], res, atol=1e-10)
@ -121,7 +119,7 @@ def test_negative_dot_product_distance(session: Session):
select(
Document.id,
Document.embedding.negative_dot_product_distance(OP_NEG_DOT_PROD_DIS),
)
),
):
(i, res) = row
assert np.allclose(EXPECTED_NEG_DOT_PROD_DIS[i], res, atol=1e-10)
@ -129,7 +127,9 @@ def test_negative_dot_product_distance(session: Session):
def test_negative_cosine_distance(session: Session):
for row in session.execute(
select(Document.id, Document.embedding.negative_cosine_distance(OP_NEG_COS_DIS))
select(
Document.id, Document.embedding.negative_cosine_distance(OP_NEG_COS_DIS)
),
):
(i, res) = row
assert np.allclose(EXPECTED_NEG_COS_DIS[i], res, atol=1e-10)