1
0
mirror of https://github.com/tensorchord/pgvecto.rs.git synced 2025-07-29 08:21:12 +03:00

feat: enhance the __init__ of client (#164)

* enhance the __init__ of client

Signed-off-by: 盐粒 Yanli <mail@yanli.one>

* bump version

Signed-off-by: 盐粒 Yanli <mail@yanli.one>

---------

Signed-off-by: 盐粒 Yanli <mail@yanli.one>
This commit is contained in:
盐粒 Yanli
2023-11-23 15:32:07 +08:00
committed by GitHub
parent eafb2f66f2
commit 20b6e0081f
4 changed files with 45 additions and 52 deletions

View File

@ -34,43 +34,41 @@ client = PGVectoRs(
db_url=URL,
collection_name="example",
dimension=1536,
recreate=True,
)
try:
# Add some records
client.insert(records1)
client.insert(records2)
# Add some records
client.insert(records1)
client.insert(records2)
# Query (With a filter from the filters module)
print("#################### First Query ####################")
for record, dis in client.search(
target,
filter=filters.meta_contains({"src": "one"}),
):
print(f"DISTANCE SCORE: {dis}")
print(record)
# Query (With a filter from the filters module)
print("#################### First Query ####################")
for record, dis in client.search(
target,
filter=filters.meta_contains({"src": "one"}),
):
print(f"DISTANCE SCORE: {dis}")
print(record)
# Another Query (Equivalent to the first one, but with a lambda filter written by hand)
print("#################### Second Query ####################")
for record, dis in client.search(
target,
filter=lambda r: r.meta.contains({"src": "one"}),
):
print(f"DISTANCE SCORE: {dis}")
print(record)
# Another Query (Equivalent to the first one, but with a lambda filter written by hand)
print("#################### Second Query ####################")
for record, dis in client.search(
target,
filter=lambda r: r.meta.contains({"src": "one"}),
):
print(f"DISTANCE SCORE: {dis}")
print(record)
# Yet Another Query (With a more complex filter)
print("#################### Third Query ####################")
# Yet Another Query (With a more complex filter)
print("#################### Third Query ####################")
def complex_filter(r: filters.FilterInput) -> filters.FilterOutput:
t1 = r.text.endswith("!") == False # noqa: E712
t2 = r.meta.contains({"src": "two"})
t = t1 & t2
return t
for record, dis in client.search(target, filter=complex_filter):
print(f"DISTANCE SCORE: {dis}")
print(record)
def complex_filter(r: filters.FilterInput) -> filters.FilterOutput:
t1 = r.text.endswith("!") == False # noqa: E712
t2 = r.meta.contains({"src": "two"})
t = t1 & t2
return t
finally:
# Clean up
client.drop()
for record, dis in client.search(target, filter=complex_filter):
print(f"DISTANCE SCORE: {dis}")
print(record)

View File

@ -1,6 +1,6 @@
[project]
name = "pgvecto-rs"
version = "0.1.3"
version = "0.1.4"
description = "Python binding for pgvecto.rs"
authors = [
{ name = "TensorChord", email = "envd-maintainers@tensorchord.ai" },

View File

@ -20,22 +20,22 @@ class PGVectoRs:
dimension: int
def __init__(
self,
db_url: str,
collection_name: str,
dimension: int,
self, db_url: str, collection_name: str, dimension: int, recreate: bool = False
) -> None:
"""Connect to an existing table or create a new empty one.
If the `recreate=True`, the table will be dropped if it exists.
Args:
----
db_url (str): url to the database.
table_name (str): name of the table.
dimension (int): dimension of the embeddings.
recreate (bool): drop the table if it exists. Defaults to False.
"""
class _Table(RecordORM):
__tablename__ = f"collection_{collection_name}"
__table_args__ = {"extend_existing": True} # noqa: RUF012
id: Mapped[UUID] = mapped_column(
postgresql.UUID(as_uuid=True),
primary_key=True,
@ -47,9 +47,11 @@ class PGVectoRs:
self._engine = create_engine(db_url)
with Session(self._engine) as session:
session.execute(text("CREATE EXTENSION IF NOT EXISTS vectors"))
if recreate:
session.execute(text(f"DROP TABLE IF EXISTS {_Table.__tablename__}"))
session.commit()
self._table = _Table
self._table.__table__.create(self._engine)
self._table.__table__.create(self._engine, checkfirst=True)
self.dimension = dimension
def insert(self, records: List[Record]) -> None:

View File

@ -32,19 +32,12 @@ class MockEmbedder:
@pytest.fixture(scope="module")
def client():
client = PGVectoRs(db_url=URL, collection_name="empty", dimension=3)
try:
records1 = [
Record.from_text(t, v, {"src": "src1"}) for t, v in MockTexts.items()
]
records2 = [
Record.from_text(t, v, {"src": "src2"}) for t, v in MockTexts.items()
]
client.insert(records1)
client.insert(records2)
yield client
finally:
client.drop()
client = PGVectoRs(db_url=URL, collection_name="empty", dimension=3, recreate=True)
records1 = [Record.from_text(t, v, {"src": "src1"}) for t, v in MockTexts.items()]
records2 = [Record.from_text(t, v, {"src": "src2"}) for t, v in MockTexts.items()]
client.insert(records1)
client.insert(records2)
return client
filter_src1 = filters.meta_contains({"src": "src1"})