From 20b6e0081f9a89efa4a8f75c843e0c6684a9660e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=9B=90=E7=B2=92=20Yanli?= Date: Thu, 23 Nov 2023 15:32:07 +0800 Subject: [PATCH] feat: enhance the __init__ of client (#164) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * enhance the __init__ of client Signed-off-by: 盐粒 Yanli * bump version Signed-off-by: 盐粒 Yanli --------- Signed-off-by: 盐粒 Yanli --- bindings/python/examples/sdk_example.py | 64 ++++++++++---------- bindings/python/pyproject.toml | 2 +- bindings/python/src/pgvecto_rs/sdk/client.py | 12 ++-- bindings/python/tests/test_sdk.py | 19 ++---- 4 files changed, 45 insertions(+), 52 deletions(-) diff --git a/bindings/python/examples/sdk_example.py b/bindings/python/examples/sdk_example.py index d4cff4a..dede330 100644 --- a/bindings/python/examples/sdk_example.py +++ b/bindings/python/examples/sdk_example.py @@ -34,43 +34,41 @@ client = PGVectoRs( db_url=URL, collection_name="example", dimension=1536, + recreate=True, ) -try: - # Add some records - client.insert(records1) - client.insert(records2) +# Add some records +client.insert(records1) +client.insert(records2) - # Query (With a filter from the filters module) - print("#################### First Query ####################") - for record, dis in client.search( - target, - filter=filters.meta_contains({"src": "one"}), - ): - print(f"DISTANCE SCORE: {dis}") - print(record) +# Query (With a filter from the filters module) +print("#################### First Query ####################") +for record, dis in client.search( + target, + filter=filters.meta_contains({"src": "one"}), +): + print(f"DISTANCE SCORE: {dis}") + print(record) - # Another Query (Equivalent to the first one, but with a lambda filter written by hand) - print("#################### Second Query ####################") - for record, dis in client.search( - target, - filter=lambda r: r.meta.contains({"src": "one"}), - ): - print(f"DISTANCE SCORE: {dis}") - print(record) +# Another Query (Equivalent to the first one, but with a lambda filter written by hand) +print("#################### Second Query ####################") +for record, dis in client.search( + target, + filter=lambda r: r.meta.contains({"src": "one"}), +): + print(f"DISTANCE SCORE: {dis}") + print(record) - # Yet Another Query (With a more complex filter) - print("#################### Third Query ####################") +# Yet Another Query (With a more complex filter) +print("#################### Third Query ####################") - def complex_filter(r: filters.FilterInput) -> filters.FilterOutput: - t1 = r.text.endswith("!") == False # noqa: E712 - t2 = r.meta.contains({"src": "two"}) - t = t1 & t2 - return t - for record, dis in client.search(target, filter=complex_filter): - print(f"DISTANCE SCORE: {dis}") - print(record) +def complex_filter(r: filters.FilterInput) -> filters.FilterOutput: + t1 = r.text.endswith("!") == False # noqa: E712 + t2 = r.meta.contains({"src": "two"}) + t = t1 & t2 + return t -finally: - # Clean up - client.drop() + +for record, dis in client.search(target, filter=complex_filter): + print(f"DISTANCE SCORE: {dis}") + print(record) diff --git a/bindings/python/pyproject.toml b/bindings/python/pyproject.toml index ca92105..5a119e4 100644 --- a/bindings/python/pyproject.toml +++ b/bindings/python/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "pgvecto-rs" -version = "0.1.3" +version = "0.1.4" description = "Python binding for pgvecto.rs" authors = [ { name = "TensorChord", email = "envd-maintainers@tensorchord.ai" }, diff --git a/bindings/python/src/pgvecto_rs/sdk/client.py b/bindings/python/src/pgvecto_rs/sdk/client.py index 17f1f8c..aed155b 100644 --- a/bindings/python/src/pgvecto_rs/sdk/client.py +++ b/bindings/python/src/pgvecto_rs/sdk/client.py @@ -20,22 +20,22 @@ class PGVectoRs: dimension: int def __init__( - self, - db_url: str, - collection_name: str, - dimension: int, + self, db_url: str, collection_name: str, dimension: int, recreate: bool = False ) -> None: """Connect to an existing table or create a new empty one. + If the `recreate=True`, the table will be dropped if it exists. Args: ---- db_url (str): url to the database. table_name (str): name of the table. dimension (int): dimension of the embeddings. + recreate (bool): drop the table if it exists. Defaults to False. """ class _Table(RecordORM): __tablename__ = f"collection_{collection_name}" + __table_args__ = {"extend_existing": True} # noqa: RUF012 id: Mapped[UUID] = mapped_column( postgresql.UUID(as_uuid=True), primary_key=True, @@ -47,9 +47,11 @@ class PGVectoRs: self._engine = create_engine(db_url) with Session(self._engine) as session: session.execute(text("CREATE EXTENSION IF NOT EXISTS vectors")) + if recreate: + session.execute(text(f"DROP TABLE IF EXISTS {_Table.__tablename__}")) session.commit() self._table = _Table - self._table.__table__.create(self._engine) + self._table.__table__.create(self._engine, checkfirst=True) self.dimension = dimension def insert(self, records: List[Record]) -> None: diff --git a/bindings/python/tests/test_sdk.py b/bindings/python/tests/test_sdk.py index 162604b..0e62a86 100644 --- a/bindings/python/tests/test_sdk.py +++ b/bindings/python/tests/test_sdk.py @@ -32,19 +32,12 @@ class MockEmbedder: @pytest.fixture(scope="module") def client(): - client = PGVectoRs(db_url=URL, collection_name="empty", dimension=3) - try: - records1 = [ - Record.from_text(t, v, {"src": "src1"}) for t, v in MockTexts.items() - ] - records2 = [ - Record.from_text(t, v, {"src": "src2"}) for t, v in MockTexts.items() - ] - client.insert(records1) - client.insert(records2) - yield client - finally: - client.drop() + client = PGVectoRs(db_url=URL, collection_name="empty", dimension=3, recreate=True) + records1 = [Record.from_text(t, v, {"src": "src1"}) for t, v in MockTexts.items()] + records2 = [Record.from_text(t, v, {"src": "src2"}) for t, v in MockTexts.items()] + client.insert(records1) + client.insert(records2) + return client filter_src1 = filters.meta_contains({"src": "src1"})