You've already forked pgvecto.rs
mirror of
https://github.com/tensorchord/pgvecto.rs.git
synced 2025-07-29 08:21:12 +03:00
feat: add Python bindings by psycopg 3 (#102)
* feat: support psycopg Signed-off-by: 盐粒 Yanli <mail@yanli.one> * chore: lint && add comment Signed-off-by: 盐粒 Yanli <mail@yanli.one> * test: update tests Signed-off-by: 盐粒 Yanli <mail@yanli.one> test: update tests Signed-off-by: 盐粒 Yanli <mail@yanli.one> * test: fix test of psycopg Signed-off-by: 盐粒 Yanli <mail@yanli.one> * chore: update readme Signed-off-by: 盐粒 Yanli <mail@yanli.one> * chore: write examples && modify readme Signed-off-by: 盐粒 Yanli <mail@yanli.one> * chore: bump version no. Signed-off-by: 盐粒 Yanli <mail@yanli.one> * feat: use normal defined class for Dumper Signed-off-by: 盐粒 Yanli <mail@yanli.one> --------- Signed-off-by: 盐粒 Yanli <mail@yanli.one>
This commit is contained in:
61
bindings/python/examples/psycopg_example.py
Normal file
61
bindings/python/examples/psycopg_example.py
Normal file
@ -0,0 +1,61 @@
|
||||
import os
|
||||
import psycopg
|
||||
import numpy as np
|
||||
from pgvecto_rs.psycopg import register_vector
|
||||
|
||||
URL = "postgresql://{username}:{password}@{host}:{port}/{db_name}".format(
|
||||
port=os.getenv("DB_PORT", 5432),
|
||||
host=os.getenv("DB_HOST", "localhost"),
|
||||
username=os.getenv("DB_USER", "postgres"),
|
||||
password=os.getenv("DB_PASS", "mysecretpassword"),
|
||||
db_name=os.getenv("DB_NAME", "postgres"),
|
||||
)
|
||||
|
||||
# Connect to the DB and init things
|
||||
with psycopg.connect(URL) as conn:
|
||||
conn.execute("CREATE EXTENSION IF NOT EXISTS vectors;")
|
||||
register_vector(conn)
|
||||
conn.execute(
|
||||
"CREATE TABLE documents (id SERIAL PRIMARY KEY, text TEXT NOT NULL, embedding vector(3) NOT NULL);"
|
||||
)
|
||||
conn.commit()
|
||||
try:
|
||||
# Insert 3 rows into the table
|
||||
conn.execute(
|
||||
"INSERT INTO documents (text, embedding) VALUES (%s, %s);",
|
||||
("hello world", [1, 2, 3]),
|
||||
)
|
||||
conn.execute(
|
||||
"INSERT INTO documents (text, embedding) VALUES (%s, %s);",
|
||||
("hello postgres", [1.0, 2.0, 4.0]),
|
||||
)
|
||||
conn.execute(
|
||||
"INSERT INTO documents (text, embedding) VALUES (%s, %s);",
|
||||
("hello pgvecto.rs", np.array([1, 3, 4])),
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
# Select the row "hello pgvecto.rs"
|
||||
cur = conn.execute(
|
||||
"SELECT * FROM documents WHERE text = %s;", ("hello pgvecto.rs",)
|
||||
)
|
||||
target = cur.fetchone()[2]
|
||||
|
||||
# Select all the rows and sort them
|
||||
# by the squared_euclidean_distance to "hello pgvecto.rs"
|
||||
cur = conn.execute(
|
||||
"SELECT text, embedding <-> %s AS distance FROM documents ORDER BY distance;",
|
||||
(target,),
|
||||
)
|
||||
for row in cur.fetchall():
|
||||
print(row)
|
||||
# The output will be:
|
||||
# ```
|
||||
# ('hello pgvecto.rs', 0.0)
|
||||
# ('hello postgres', 1.0)
|
||||
# ('hello world', 2.0)
|
||||
# ```
|
||||
finally:
|
||||
# Drop the table
|
||||
conn.execute("DROP TABLE IF EXISTS documents;")
|
||||
conn.commit()
|
69
bindings/python/examples/sqlalchemy_example.py
Normal file
69
bindings/python/examples/sqlalchemy_example.py
Normal file
@ -0,0 +1,69 @@
|
||||
import os
|
||||
import numpy as np
|
||||
from sqlalchemy import create_engine, select, insert
|
||||
from sqlalchemy import Integer, String
|
||||
from pgvecto_rs.sqlalchemy import Vector
|
||||
from sqlalchemy.orm import Session, DeclarativeBase, mapped_column, Mapped
|
||||
|
||||
URL = "postgresql+psycopg://{username}:{password}@{host}:{port}/{db_name}".format(
|
||||
port=os.getenv("DB_PORT", 5432),
|
||||
host=os.getenv("DB_HOST", "localhost"),
|
||||
username=os.getenv("DB_USER", "postgres"),
|
||||
password=os.getenv("DB_PASS", "mysecretpassword"),
|
||||
db_name=os.getenv("DB_NAME", "postgres"),
|
||||
)
|
||||
|
||||
|
||||
# Define the ORM model
|
||||
class Base(DeclarativeBase):
|
||||
pass
|
||||
|
||||
|
||||
class Document(Base):
|
||||
__tablename__ = "documents"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||
text: Mapped[str] = mapped_column(String)
|
||||
embedding: Mapped[np.ndarray] = mapped_column(Vector(3))
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.text}: {self.embedding}"
|
||||
|
||||
|
||||
# Connect to the DB and create the table
|
||||
engine = create_engine(URL)
|
||||
Document.metadata.create_all(engine)
|
||||
|
||||
with Session(engine) as session:
|
||||
# Insert 3 rows into the table
|
||||
t1 = insert(Document).values(text="hello world", embedding=[1, 2, 3])
|
||||
t2 = insert(Document).values(text="hello postgres", embedding=[1.0, 2.0, 4.0])
|
||||
t3 = insert(Document).values(text="hello pgvecto.rs", embedding=np.array([1, 3, 4]))
|
||||
for t in [t1, t2, t3]:
|
||||
session.execute(t)
|
||||
session.commit()
|
||||
|
||||
# Select the row "hello pgvecto.rs"
|
||||
stmt = select(Document).where(Document.text == "hello pgvecto.rs")
|
||||
target = session.scalar(stmt)
|
||||
|
||||
# Select all the rows and sort them
|
||||
# by the squared_euclidean_distance to "hello pgvecto.rs"
|
||||
stmt = select(
|
||||
Document.text,
|
||||
Document.embedding.squared_euclidean_distance(target.embedding).label(
|
||||
"distance"
|
||||
),
|
||||
).order_by("distance")
|
||||
for doc in session.execute(stmt):
|
||||
print(doc)
|
||||
|
||||
# The output will be:
|
||||
# ```
|
||||
# ('hello pgvecto.rs', 0.0)
|
||||
# ('hello postgres', 1.0)
|
||||
# ('hello world', 2.0)
|
||||
# ```
|
||||
|
||||
# Drop the table
|
||||
Document.metadata.drop_all(engine)
|
Reference in New Issue
Block a user