1
0
mirror of https://github.com/tensorchord/pgvecto.rs.git synced 2025-07-30 19:23:05 +03:00

feat: enhance the __init__ of client (#164)

* enhance the __init__ of client

Signed-off-by: 盐粒 Yanli <mail@yanli.one>

* bump version

Signed-off-by: 盐粒 Yanli <mail@yanli.one>

---------

Signed-off-by: 盐粒 Yanli <mail@yanli.one>
This commit is contained in:
盐粒 Yanli
2023-11-23 15:32:07 +08:00
committed by GitHub
parent eafb2f66f2
commit 20b6e0081f
4 changed files with 45 additions and 52 deletions

View File

@ -34,43 +34,41 @@ client = PGVectoRs(
db_url=URL, db_url=URL,
collection_name="example", collection_name="example",
dimension=1536, dimension=1536,
recreate=True,
) )
try: # Add some records
# Add some records client.insert(records1)
client.insert(records1) client.insert(records2)
client.insert(records2)
# Query (With a filter from the filters module) # Query (With a filter from the filters module)
print("#################### First Query ####################") print("#################### First Query ####################")
for record, dis in client.search( for record, dis in client.search(
target, target,
filter=filters.meta_contains({"src": "one"}), filter=filters.meta_contains({"src": "one"}),
): ):
print(f"DISTANCE SCORE: {dis}") print(f"DISTANCE SCORE: {dis}")
print(record) print(record)
# Another Query (Equivalent to the first one, but with a lambda filter written by hand) # Another Query (Equivalent to the first one, but with a lambda filter written by hand)
print("#################### Second Query ####################") print("#################### Second Query ####################")
for record, dis in client.search( for record, dis in client.search(
target, target,
filter=lambda r: r.meta.contains({"src": "one"}), filter=lambda r: r.meta.contains({"src": "one"}),
): ):
print(f"DISTANCE SCORE: {dis}") print(f"DISTANCE SCORE: {dis}")
print(record) print(record)
# Yet Another Query (With a more complex filter) # Yet Another Query (With a more complex filter)
print("#################### Third Query ####################") print("#################### Third Query ####################")
def complex_filter(r: filters.FilterInput) -> filters.FilterOutput:
t1 = r.text.endswith("!") == False # noqa: E712
t2 = r.meta.contains({"src": "two"})
t = t1 & t2
return t
for record, dis in client.search(target, filter=complex_filter): def complex_filter(r: filters.FilterInput) -> filters.FilterOutput:
print(f"DISTANCE SCORE: {dis}") t1 = r.text.endswith("!") == False # noqa: E712
print(record) t2 = r.meta.contains({"src": "two"})
t = t1 & t2
return t
finally:
# Clean up for record, dis in client.search(target, filter=complex_filter):
client.drop() print(f"DISTANCE SCORE: {dis}")
print(record)

View File

@ -1,6 +1,6 @@
[project] [project]
name = "pgvecto-rs" name = "pgvecto-rs"
version = "0.1.3" version = "0.1.4"
description = "Python binding for pgvecto.rs" description = "Python binding for pgvecto.rs"
authors = [ authors = [
{ name = "TensorChord", email = "envd-maintainers@tensorchord.ai" }, { name = "TensorChord", email = "envd-maintainers@tensorchord.ai" },

View File

@ -20,22 +20,22 @@ class PGVectoRs:
dimension: int dimension: int
def __init__( def __init__(
self, self, db_url: str, collection_name: str, dimension: int, recreate: bool = False
db_url: str,
collection_name: str,
dimension: int,
) -> None: ) -> None:
"""Connect to an existing table or create a new empty one. """Connect to an existing table or create a new empty one.
If the `recreate=True`, the table will be dropped if it exists.
Args: Args:
---- ----
db_url (str): url to the database. db_url (str): url to the database.
table_name (str): name of the table. table_name (str): name of the table.
dimension (int): dimension of the embeddings. dimension (int): dimension of the embeddings.
recreate (bool): drop the table if it exists. Defaults to False.
""" """
class _Table(RecordORM): class _Table(RecordORM):
__tablename__ = f"collection_{collection_name}" __tablename__ = f"collection_{collection_name}"
__table_args__ = {"extend_existing": True} # noqa: RUF012
id: Mapped[UUID] = mapped_column( id: Mapped[UUID] = mapped_column(
postgresql.UUID(as_uuid=True), postgresql.UUID(as_uuid=True),
primary_key=True, primary_key=True,
@ -47,9 +47,11 @@ class PGVectoRs:
self._engine = create_engine(db_url) self._engine = create_engine(db_url)
with Session(self._engine) as session: with Session(self._engine) as session:
session.execute(text("CREATE EXTENSION IF NOT EXISTS vectors")) session.execute(text("CREATE EXTENSION IF NOT EXISTS vectors"))
if recreate:
session.execute(text(f"DROP TABLE IF EXISTS {_Table.__tablename__}"))
session.commit() session.commit()
self._table = _Table self._table = _Table
self._table.__table__.create(self._engine) self._table.__table__.create(self._engine, checkfirst=True)
self.dimension = dimension self.dimension = dimension
def insert(self, records: List[Record]) -> None: def insert(self, records: List[Record]) -> None:

View File

@ -32,19 +32,12 @@ class MockEmbedder:
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def client(): def client():
client = PGVectoRs(db_url=URL, collection_name="empty", dimension=3) client = PGVectoRs(db_url=URL, collection_name="empty", dimension=3, recreate=True)
try: records1 = [Record.from_text(t, v, {"src": "src1"}) for t, v in MockTexts.items()]
records1 = [ records2 = [Record.from_text(t, v, {"src": "src2"}) for t, v in MockTexts.items()]
Record.from_text(t, v, {"src": "src1"}) for t, v in MockTexts.items() client.insert(records1)
] client.insert(records2)
records2 = [ return client
Record.from_text(t, v, {"src": "src2"}) for t, v in MockTexts.items()
]
client.insert(records1)
client.insert(records2)
yield client
finally:
client.drop()
filter_src1 = filters.meta_contains({"src": "src1"}) filter_src1 = filters.meta_contains({"src": "src1"})