You've already forked pgvecto.rs
mirror of
https://github.com/tensorchord/pgvecto.rs.git
synced 2025-07-30 19:23:05 +03:00
feat: enhance the __init__ of client (#164)
* enhance the __init__ of client Signed-off-by: 盐粒 Yanli <mail@yanli.one> * bump version Signed-off-by: 盐粒 Yanli <mail@yanli.one> --------- Signed-off-by: 盐粒 Yanli <mail@yanli.one>
This commit is contained in:
@ -34,43 +34,41 @@ client = PGVectoRs(
|
|||||||
db_url=URL,
|
db_url=URL,
|
||||||
collection_name="example",
|
collection_name="example",
|
||||||
dimension=1536,
|
dimension=1536,
|
||||||
|
recreate=True,
|
||||||
)
|
)
|
||||||
try:
|
# Add some records
|
||||||
# Add some records
|
client.insert(records1)
|
||||||
client.insert(records1)
|
client.insert(records2)
|
||||||
client.insert(records2)
|
|
||||||
|
|
||||||
# Query (With a filter from the filters module)
|
# Query (With a filter from the filters module)
|
||||||
print("#################### First Query ####################")
|
print("#################### First Query ####################")
|
||||||
for record, dis in client.search(
|
for record, dis in client.search(
|
||||||
target,
|
target,
|
||||||
filter=filters.meta_contains({"src": "one"}),
|
filter=filters.meta_contains({"src": "one"}),
|
||||||
):
|
):
|
||||||
print(f"DISTANCE SCORE: {dis}")
|
print(f"DISTANCE SCORE: {dis}")
|
||||||
print(record)
|
print(record)
|
||||||
|
|
||||||
# Another Query (Equivalent to the first one, but with a lambda filter written by hand)
|
# Another Query (Equivalent to the first one, but with a lambda filter written by hand)
|
||||||
print("#################### Second Query ####################")
|
print("#################### Second Query ####################")
|
||||||
for record, dis in client.search(
|
for record, dis in client.search(
|
||||||
target,
|
target,
|
||||||
filter=lambda r: r.meta.contains({"src": "one"}),
|
filter=lambda r: r.meta.contains({"src": "one"}),
|
||||||
):
|
):
|
||||||
print(f"DISTANCE SCORE: {dis}")
|
print(f"DISTANCE SCORE: {dis}")
|
||||||
print(record)
|
print(record)
|
||||||
|
|
||||||
# Yet Another Query (With a more complex filter)
|
# Yet Another Query (With a more complex filter)
|
||||||
print("#################### Third Query ####################")
|
print("#################### Third Query ####################")
|
||||||
|
|
||||||
def complex_filter(r: filters.FilterInput) -> filters.FilterOutput:
|
|
||||||
t1 = r.text.endswith("!") == False # noqa: E712
|
|
||||||
t2 = r.meta.contains({"src": "two"})
|
|
||||||
t = t1 & t2
|
|
||||||
return t
|
|
||||||
|
|
||||||
for record, dis in client.search(target, filter=complex_filter):
|
def complex_filter(r: filters.FilterInput) -> filters.FilterOutput:
|
||||||
print(f"DISTANCE SCORE: {dis}")
|
t1 = r.text.endswith("!") == False # noqa: E712
|
||||||
print(record)
|
t2 = r.meta.contains({"src": "two"})
|
||||||
|
t = t1 & t2
|
||||||
|
return t
|
||||||
|
|
||||||
finally:
|
|
||||||
# Clean up
|
for record, dis in client.search(target, filter=complex_filter):
|
||||||
client.drop()
|
print(f"DISTANCE SCORE: {dis}")
|
||||||
|
print(record)
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "pgvecto-rs"
|
name = "pgvecto-rs"
|
||||||
version = "0.1.3"
|
version = "0.1.4"
|
||||||
description = "Python binding for pgvecto.rs"
|
description = "Python binding for pgvecto.rs"
|
||||||
authors = [
|
authors = [
|
||||||
{ name = "TensorChord", email = "envd-maintainers@tensorchord.ai" },
|
{ name = "TensorChord", email = "envd-maintainers@tensorchord.ai" },
|
||||||
|
@ -20,22 +20,22 @@ class PGVectoRs:
|
|||||||
dimension: int
|
dimension: int
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self, db_url: str, collection_name: str, dimension: int, recreate: bool = False
|
||||||
db_url: str,
|
|
||||||
collection_name: str,
|
|
||||||
dimension: int,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Connect to an existing table or create a new empty one.
|
"""Connect to an existing table or create a new empty one.
|
||||||
|
If the `recreate=True`, the table will be dropped if it exists.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
----
|
----
|
||||||
db_url (str): url to the database.
|
db_url (str): url to the database.
|
||||||
table_name (str): name of the table.
|
table_name (str): name of the table.
|
||||||
dimension (int): dimension of the embeddings.
|
dimension (int): dimension of the embeddings.
|
||||||
|
recreate (bool): drop the table if it exists. Defaults to False.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
class _Table(RecordORM):
|
class _Table(RecordORM):
|
||||||
__tablename__ = f"collection_{collection_name}"
|
__tablename__ = f"collection_{collection_name}"
|
||||||
|
__table_args__ = {"extend_existing": True} # noqa: RUF012
|
||||||
id: Mapped[UUID] = mapped_column(
|
id: Mapped[UUID] = mapped_column(
|
||||||
postgresql.UUID(as_uuid=True),
|
postgresql.UUID(as_uuid=True),
|
||||||
primary_key=True,
|
primary_key=True,
|
||||||
@ -47,9 +47,11 @@ class PGVectoRs:
|
|||||||
self._engine = create_engine(db_url)
|
self._engine = create_engine(db_url)
|
||||||
with Session(self._engine) as session:
|
with Session(self._engine) as session:
|
||||||
session.execute(text("CREATE EXTENSION IF NOT EXISTS vectors"))
|
session.execute(text("CREATE EXTENSION IF NOT EXISTS vectors"))
|
||||||
|
if recreate:
|
||||||
|
session.execute(text(f"DROP TABLE IF EXISTS {_Table.__tablename__}"))
|
||||||
session.commit()
|
session.commit()
|
||||||
self._table = _Table
|
self._table = _Table
|
||||||
self._table.__table__.create(self._engine)
|
self._table.__table__.create(self._engine, checkfirst=True)
|
||||||
self.dimension = dimension
|
self.dimension = dimension
|
||||||
|
|
||||||
def insert(self, records: List[Record]) -> None:
|
def insert(self, records: List[Record]) -> None:
|
||||||
|
@ -32,19 +32,12 @@ class MockEmbedder:
|
|||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def client():
|
def client():
|
||||||
client = PGVectoRs(db_url=URL, collection_name="empty", dimension=3)
|
client = PGVectoRs(db_url=URL, collection_name="empty", dimension=3, recreate=True)
|
||||||
try:
|
records1 = [Record.from_text(t, v, {"src": "src1"}) for t, v in MockTexts.items()]
|
||||||
records1 = [
|
records2 = [Record.from_text(t, v, {"src": "src2"}) for t, v in MockTexts.items()]
|
||||||
Record.from_text(t, v, {"src": "src1"}) for t, v in MockTexts.items()
|
client.insert(records1)
|
||||||
]
|
client.insert(records2)
|
||||||
records2 = [
|
return client
|
||||||
Record.from_text(t, v, {"src": "src2"}) for t, v in MockTexts.items()
|
|
||||||
]
|
|
||||||
client.insert(records1)
|
|
||||||
client.insert(records2)
|
|
||||||
yield client
|
|
||||||
finally:
|
|
||||||
client.drop()
|
|
||||||
|
|
||||||
|
|
||||||
filter_src1 = filters.meta_contains({"src": "src1"})
|
filter_src1 = filters.meta_contains({"src": "src1"})
|
||||||
|
Reference in New Issue
Block a user