You've already forked pgvecto.rs
mirror of
https://github.com/tensorchord/pgvecto.rs.git
synced 2025-07-30 19:23:05 +03:00
feat: add Python bindings by psycopg 3 (#102)
* feat: support psycopg Signed-off-by: 盐粒 Yanli <mail@yanli.one> * chore: lint && add comment Signed-off-by: 盐粒 Yanli <mail@yanli.one> * test: update tests Signed-off-by: 盐粒 Yanli <mail@yanli.one> test: update tests Signed-off-by: 盐粒 Yanli <mail@yanli.one> * test: fix test of psycopg Signed-off-by: 盐粒 Yanli <mail@yanli.one> * chore: update readme Signed-off-by: 盐粒 Yanli <mail@yanli.one> * chore: write examples && modify readme Signed-off-by: 盐粒 Yanli <mail@yanli.one> * chore: bump version no. Signed-off-by: 盐粒 Yanli <mail@yanli.one> * feat: use normal defined class for Dumper Signed-off-by: 盐粒 Yanli <mail@yanli.one> --------- Signed-off-by: 盐粒 Yanli <mail@yanli.one>
This commit is contained in:
@ -22,38 +22,18 @@ URL = "postgresql://{username}:{password}@{host}:{port}/{db_name}".format(
|
||||
# ==== test_create_index ====
|
||||
|
||||
TOML_SETTINGS = {
|
||||
"flat": "$${}$$".format(
|
||||
toml.dumps(
|
||||
{
|
||||
"capacity": 2097152,
|
||||
"algorithm": {"flat": {}},
|
||||
}
|
||||
)
|
||||
"flat": toml.dumps(
|
||||
{
|
||||
"capacity": 2097152,
|
||||
"algorithm": {"flat": {}},
|
||||
}
|
||||
),
|
||||
"hnsw": "$${}$$".format(
|
||||
toml.dumps(
|
||||
{
|
||||
"capacity": 2097152,
|
||||
"algorithm": {"hnsw": {}},
|
||||
}
|
||||
)
|
||||
"hnsw": toml.dumps(
|
||||
{
|
||||
"capacity": 2097152,
|
||||
"algorithm": {"hnsw": {}},
|
||||
}
|
||||
),
|
||||
# "ivf": "$${}$$".format(
|
||||
# toml.dumps(
|
||||
# {
|
||||
# "capacity": 2097152,
|
||||
# "algorithm": {"ivf": {}},
|
||||
# }
|
||||
# )
|
||||
# ),
|
||||
# "vamana": "$${}$$".format(
|
||||
# toml.dumps(
|
||||
# {
|
||||
# "capacity": 2097152,
|
||||
# "algorithm": {"vamana": {}},
|
||||
# }
|
||||
# )
|
||||
# ),
|
||||
}
|
||||
|
||||
# ==== test_invalid_insert ====
|
||||
|
111
bindings/python/tests/test_psycopg.py
Normal file
111
bindings/python/tests/test_psycopg.py
Normal file
@ -0,0 +1,111 @@
|
||||
import pytest
|
||||
import psycopg
|
||||
import numpy as np
|
||||
from psycopg import Connection, sql
|
||||
from pgvecto_rs.psycopg import register_vector
|
||||
from tests import (
|
||||
URL,
|
||||
TOML_SETTINGS,
|
||||
VECTORS,
|
||||
OP_SQRT_EUCLID_DIS,
|
||||
OP_NEG_DOT_PROD_DIS,
|
||||
OP_NEG_COS_DIS,
|
||||
EXPECTED_SQRT_EUCLID_DIS,
|
||||
EXPECTED_NEG_DOT_PROD_DIS,
|
||||
EXPECTED_NEG_COS_DIS,
|
||||
LEN_AFT_DEL,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def conn():
|
||||
with psycopg.connect(URL) as conn:
|
||||
conn.execute("CREATE EXTENSION IF NOT EXISTS vectors;")
|
||||
register_vector(conn)
|
||||
conn.execute("DROP TABLE IF EXISTS tb_test_item;")
|
||||
conn.execute(
|
||||
"CREATE TABLE tb_test_item (id bigserial PRIMARY KEY, embedding vector(3) NOT NULL);"
|
||||
)
|
||||
conn.commit()
|
||||
try:
|
||||
yield conn
|
||||
finally:
|
||||
conn.execute("DROP TABLE IF EXISTS tb_test_item;")
|
||||
conn.commit()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("index_name,index_setting", TOML_SETTINGS.items())
|
||||
def test_create_index(conn: Connection, index_name: str, index_setting: str):
|
||||
stat = sql.SQL(
|
||||
"CREATE INDEX {} ON tb_test_item USING vectors (embedding l2_ops) WITH (options={});",
|
||||
).format(sql.Identifier(index_name), index_setting)
|
||||
|
||||
conn.execute(stat)
|
||||
conn.commit()
|
||||
|
||||
|
||||
# The server cannot handle invalid vectors curently, see https://github.com/tensorchord/pgvecto.rs/issues/96
|
||||
# def test_invalid_insert(conn: Connection):
|
||||
# for i, e in enumerate(INVALID_VECTORS):
|
||||
# try:
|
||||
# conn.execute("INSERT INTO tb_test_item (embedding) VALUES (%s);", (e, ) )
|
||||
# pass
|
||||
# except:
|
||||
# conn.rollback()
|
||||
# else:
|
||||
# conn.rollback()
|
||||
# raise AssertionError(
|
||||
# 'failed to raise invalid value error for {}th vector {}'
|
||||
# .format(i, e),
|
||||
# )
|
||||
|
||||
# =================================
|
||||
# Semetic search tests
|
||||
# =================================
|
||||
|
||||
|
||||
def test_insert(conn: Connection):
|
||||
with conn.cursor() as cur:
|
||||
cur.executemany(
|
||||
"INSERT INTO tb_test_item (embedding) VALUES (%s);", [(e,) for e in VECTORS]
|
||||
)
|
||||
cur.execute("SELECT * FROM tb_test_item;")
|
||||
conn.commit()
|
||||
rows = cur.fetchall()
|
||||
assert len(rows) == len(VECTORS)
|
||||
for i, e in enumerate(rows):
|
||||
assert np.allclose(e[1], VECTORS[i], atol=1e-10)
|
||||
|
||||
|
||||
def test_squared_euclidean_distance(conn: Connection):
|
||||
cur = conn.execute(
|
||||
"SELECT embedding <-> %s FROM tb_test_item;", (OP_SQRT_EUCLID_DIS,)
|
||||
)
|
||||
for i, row in enumerate(cur.fetchall()):
|
||||
assert np.allclose(EXPECTED_SQRT_EUCLID_DIS[i], row[0], atol=1e-10)
|
||||
|
||||
|
||||
def test_negative_dot_product_distance(conn: Connection):
|
||||
cur = conn.execute(
|
||||
"SELECT embedding <#> %s FROM tb_test_item;", (OP_NEG_DOT_PROD_DIS,)
|
||||
)
|
||||
for i, row in enumerate(cur.fetchall()):
|
||||
assert np.allclose(EXPECTED_NEG_DOT_PROD_DIS[i], row[0], atol=1e-10)
|
||||
|
||||
|
||||
def test_negative_cosine_distance(conn: Connection):
|
||||
cur = conn.execute("SELECT embedding <=> %s FROM tb_test_item;", (OP_NEG_COS_DIS,))
|
||||
for i, row in enumerate(cur.fetchall()):
|
||||
assert np.allclose(EXPECTED_NEG_COS_DIS[i], row[0], atol=1e-10)
|
||||
|
||||
|
||||
# # =================================
|
||||
# # Suffix functional tests
|
||||
# # =================================
|
||||
|
||||
|
||||
def test_delete(conn: Connection):
|
||||
conn.execute("DELETE FROM tb_test_item WHERE embedding = %s;", (VECTORS[0],))
|
||||
conn.commit()
|
||||
cur = conn.execute("SELECT * FROM tb_test_item;")
|
||||
assert len(cur.fetchall()) == LEN_AFT_DEL
|
@ -69,11 +69,11 @@ def test_create_index(session: Session, index_name: str, index_setting: str):
|
||||
index_name,
|
||||
Document.embedding,
|
||||
postgresql_using="vectors",
|
||||
postgresql_with={"options": index_setting},
|
||||
postgresql_with={"options": f"$${index_setting}$$"},
|
||||
postgresql_ops={"embedding": "l2_ops"},
|
||||
)
|
||||
index.create(session.bind)
|
||||
session.rollback()
|
||||
session.commit()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("i,e", enumerate(INVALID_VECTORS))
|
||||
|
Reference in New Issue
Block a user