1
0
mirror of https://github.com/tensorchord/pgvecto.rs.git synced 2025-04-18 21:44:00 +03:00
pgvecto.rs/bindings/python/tests/test_psycopg.py
Keming eb44c263b6
chore: fix typos (#228)
Signed-off-by: Keming <kemingyang@tensorchord.ai>
2024-01-04 06:24:45 +00:00

116 lines
3.5 KiB
Python

import numpy as np
import psycopg
import pytest
from psycopg import Connection, sql
from pgvecto_rs.psycopg import register_vector
from tests import (
EXPECTED_NEG_COS_DIS,
EXPECTED_NEG_DOT_PROD_DIS,
EXPECTED_SQRT_EUCLID_DIS,
LEN_AFT_DEL,
OP_NEG_COS_DIS,
OP_NEG_DOT_PROD_DIS,
OP_SQRT_EUCLID_DIS,
TOML_SETTINGS,
URL,
VECTORS,
)
@pytest.fixture(scope="module")
def conn():
with psycopg.connect(URL) as conn:
conn.execute("CREATE EXTENSION IF NOT EXISTS vectors;")
register_vector(conn)
conn.execute("DROP TABLE IF EXISTS tb_test_item;")
conn.execute(
"CREATE TABLE tb_test_item (id bigserial PRIMARY KEY, embedding vector(3) NOT NULL);",
)
conn.commit()
try:
yield conn
finally:
conn.execute("DROP TABLE IF EXISTS tb_test_item;")
conn.commit()
@pytest.mark.parametrize(("index_name", "index_setting"), TOML_SETTINGS.items())
def test_create_index(conn: Connection, index_name: str, index_setting: str):
stat = sql.SQL(
"CREATE INDEX {} ON tb_test_item USING vectors (embedding vector_l2_ops) WITH (options={});",
).format(sql.Identifier(index_name), index_setting)
conn.execute(stat)
conn.commit()
# The server cannot handle invalid vectors currently, see https://github.com/tensorchord/pgvecto.rs/issues/96
# def test_invalid_insert(conn: Connection):
# for i, e in enumerate(INVALID_VECTORS):
# try:
# conn.execute("INSERT INTO tb_test_item (embedding) VALUES (%s);", (e, ) )
# pass
# except:
# conn.rollback()
# else:
# conn.rollback()
# raise AssertionError(
# 'failed to raise invalid value error for {}th vector {}'
# .format(i, e),
# )
# =================================
# Semetic search tests
# =================================
def test_insert(conn: Connection):
with conn.cursor() as cur:
cur.executemany(
"INSERT INTO tb_test_item (embedding) VALUES (%s);",
[(e,) for e in VECTORS],
)
cur.execute("SELECT * FROM tb_test_item;")
conn.commit()
rows = cur.fetchall()
assert len(rows) == len(VECTORS)
for i, e in enumerate(rows):
assert np.allclose(e[1], VECTORS[i], atol=1e-10)
def test_squared_euclidean_distance(conn: Connection):
cur = conn.execute(
"SELECT embedding <-> %s FROM tb_test_item;",
(OP_SQRT_EUCLID_DIS,),
)
for i, row in enumerate(cur.fetchall()):
assert np.allclose(EXPECTED_SQRT_EUCLID_DIS[i], row[0], atol=1e-10)
def test_negative_dot_product_distance(conn: Connection):
cur = conn.execute(
"SELECT embedding <#> %s FROM tb_test_item;",
(OP_NEG_DOT_PROD_DIS,),
)
for i, row in enumerate(cur.fetchall()):
assert np.allclose(EXPECTED_NEG_DOT_PROD_DIS[i], row[0], atol=1e-10)
def test_negative_cosine_distance(conn: Connection):
cur = conn.execute("SELECT embedding <=> %s FROM tb_test_item;", (OP_NEG_COS_DIS,))
for i, row in enumerate(cur.fetchall()):
assert np.allclose(EXPECTED_NEG_COS_DIS[i], row[0], atol=1e-10)
# # =================================
# # Suffix functional tests
# # =================================
def test_delete(conn: Connection):
conn.execute("DELETE FROM tb_test_item WHERE embedding = %s;", (VECTORS[0],))
conn.commit()
cur = conn.execute("SELECT * FROM tb_test_item;")
assert len(cur.fetchall()) == LEN_AFT_DEL