1
0
mirror of https://github.com/tensorchord/pgvecto.rs.git synced 2025-08-07 03:22:55 +03:00

feat: mmap transport for macos (#137)

* feat: mmap transport for macos

Signed-off-by: usamoi <usamoi@outlook.com>

* fix: building with feature pg12, pg13

Signed-off-by: usamoi <usamoi@outlook.com>

* fix: call unlink for shmem

Signed-off-by: usamoi <usamoi@outlook.com>

* fix: reduce shmem filename on macos

Signed-off-by: usamoi <usamoi@outlook.com>

* chore: enable testing on all Postgresql versions

Signed-off-by: usamoi <usamoi@outlook.com>

* fix: use file instead of shmem for macos

Signed-off-by: usamoi <usamoi@outlook.com>

* chore: select simpler matrix for pull requests in CI

Signed-off-by: usamoi <usamoi@outlook.com>

* fix: remove macos-latest-16

Signed-off-by: usamoi <usamoi@outlook.com>

* chore: reduce use of nightly features

Signed-off-by: usamoi <usamoi@outlook.com>

---------

Signed-off-by: usamoi <usamoi@outlook.com>
This commit is contained in:
Usamoi
2023-11-18 00:26:15 +08:00
committed by GitHub
parent f6e382d0fc
commit 39e8ee9797
26 changed files with 483 additions and 305 deletions

View File

@@ -1,7 +1,7 @@
[build]
rustdocflags = ["--document-private-items"]
[target.'cfg(target_arch="x86_64")']
[target.'cfg(all(target_os="linux", target_arch="x86_64"))']
rustflags = ["-Ctarget-cpu=x86-64-v3"]
[target.'cfg(target_os="macos")']

View File

@@ -4,21 +4,27 @@ on:
push:
branches: ["main"]
paths:
- '.github/workflows/check.yml'
- 'src/**'
- 'Cargo.toml'
- 'Cargo.lock'
- 'rust-toolchain.toml'
- 'tests/**'
- ".cargo/**"
- ".github/**"
- "scripts/**"
- "src/**"
- "tests/**"
- "Cargo.lock"
- "Cargo.toml"
- "rust-toolchain.toml"
- "vectors.control"
pull_request:
branches: ["main"]
paths:
- '.github/workflows/check.yml'
- 'src/**'
- 'Cargo.toml'
- 'Cargo.lock'
- 'rust-toolchain.toml'
- 'tests/**'
- ".cargo/**"
- ".github/**"
- "scripts/**"
- "src/**"
- "tests/**"
- "Cargo.lock"
- "Cargo.toml"
- "rust-toolchain.toml"
- "vectors.control"
merge_group:
workflow_dispatch:
@@ -28,15 +34,49 @@ concurrency:
env:
CARGO_TERM_COLOR: always
RUST_BACKTRACE: 1
SCCACHE_GHA_ENABLED: true
RUSTC_WRAPPER: sccache
jobs:
lint:
matrix:
runs-on: ubuntu-latest
outputs:
matrix: ${{ steps.main.outputs.matrix }}
steps:
- uses: actions/github-script@v7
id: main
with:
script: |
let matrix;
if("${{ github.event_name }}" == "push" || "${{ github.event_name }}" == "pull_request"){
matrix = [
{ version: 15, os: "ubuntu-latest" },
];
}
if("${{ github.event_name }}" == "merge_group" || "${{ github.event_name }}" == "workflow_dispatch"){
matrix = [
{ version: 12, os: "ubuntu-latest" },
{ version: 13, os: "ubuntu-latest" },
{ version: 14, os: "ubuntu-latest" },
{ version: 15, os: "ubuntu-latest" },
{ version: 16, os: "ubuntu-latest" },
{ version: 12, os: "macos-latest" },
{ version: 13, os: "macos-latest" },
{ version: 14, os: "macos-latest" },
{ version: 15, os: "macos-latest" },
];
}
core.setOutput('matrix', JSON.stringify(matrix));
check:
needs: matrix
strategy:
matrix:
version: [15]
runs-on: ubuntu-latest
include: ${{ fromJson(needs.matrix.outputs.matrix) }}
runs-on: ${{ matrix.os }}
env:
VERSION: ${{ matrix.version }}
OS: ${{ matrix.os }}
steps:
- uses: actions/checkout@v3
- uses: actions/cache/restore@v3
@@ -45,99 +85,30 @@ jobs:
~/.cargo/registry/index/
~/.cargo/registry/cache/
~/.cargo/git/db/
key: cargo-${{ runner.os }}-pg${{ matrix.version }}-${{ hashFiles('./Cargo.toml') }}
restore-keys: cargo-${{ runner.os }}-pg${{ matrix.version }}
key: cargo-${{ matrix.os }}-pg${{ matrix.version }}-${{ hashFiles('./Cargo.toml') }}
restore-keys: cargo-${{ matrix.os }}-pg${{ matrix.version }}
- uses: mozilla-actions/sccache-action@v0.0.3
- name: Prepare
run: |
sudo sh -c 'echo "deb http://apt.postgresql.org/pub/repos/apt $(lsb_release -cs)-pgdg main" > /etc/apt/sources.list.d/pgdg.list'
wget --quiet -O - https://www.postgresql.org/media/keys/ACCC4CF8.asc | sudo apt-key add -
sudo apt-get update
sudo apt-get -y install libpq-dev postgresql-${{ matrix.version }} postgresql-server-dev-${{ matrix.version }}
cargo install cargo-pgrx --version $(grep '^pgrx ' Cargo.toml | awk -F'\"' '{print $2}')
cargo pgrx init --pg${{ matrix.version }}=/usr/lib/postgresql/${{ matrix.version }}/bin/pg_config
- name: Setup
shell: bash
run: ./scripts/ci_setup.sh
- name: Format check
run: cargo fmt --check
- name: Semantic check
run: cargo clippy
build:
strategy:
matrix:
version: [15]
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: actions/cache/restore@v3
with:
path: |
~/.cargo/registry/index/
~/.cargo/registry/cache/
~/.cargo/git/db/
key: cargo-${{ runner.os }}-pg${{ matrix.version }}-${{ hashFiles('./Cargo.toml') }}
restore-keys: cargo-${{ runner.os }}-pg${{ matrix.version }}
- uses: mozilla-actions/sccache-action@v0.0.3
- name: Prepare
run: |
sudo sh -c 'echo "deb http://apt.postgresql.org/pub/repos/apt $(lsb_release -cs)-pgdg main" > /etc/apt/sources.list.d/pgdg.list'
wget --quiet -O - https://www.postgresql.org/media/keys/ACCC4CF8.asc | sudo apt-key add -
sudo apt-get update
sudo apt-get -y install libpq-dev postgresql-${{ matrix.version }} postgresql-server-dev-${{ matrix.version }}
cargo install cargo-pgrx --version $(grep '^pgrx ' Cargo.toml | awk -F'\"' '{print $2}')
cargo pgrx init --pg${{ matrix.version }}=/usr/lib/postgresql/${{ matrix.version }}/bin/pg_config
- name: Build
run: cargo build --verbose
run: cargo clippy --no-default-features --features "pg${{ matrix.version }} pg_test"
- name: Debug build
run: cargo build --no-default-features --features "pg${{ matrix.version }} pg_test"
- name: Test
env:
RUST_BACKTRACE: 1
run: cargo test --all --no-default-features --features "pg${{ matrix.version }} pg_test" -- --nocapture
test:
strategy:
matrix:
version: [15]
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: actions/cache/restore@v3
with:
path: |
~/.cargo/registry/index/
~/.cargo/registry/cache/
~/.cargo/git/db/
key: cargo-${{ runner.os }}-pg${{ matrix.version }}-${{ hashFiles('./Cargo.toml') }}
restore-keys: cargo-${{ runner.os }}-pg${{ matrix.version }}
- uses: mozilla-actions/sccache-action@v0.0.3
- name: Prepare
- name: Install release
run: ./scripts/ci_install.sh
- name: Sqllogictest
run: |
sudo pg_dropcluster 14 main
sudo sh -c 'echo "deb http://apt.postgresql.org/pub/repos/apt $(lsb_release -cs)-pgdg main" > /etc/apt/sources.list.d/pgdg.list'
wget --quiet -O - https://www.postgresql.org/media/keys/ACCC4CF8.asc | sudo apt-key add -
sudo apt-get update
sudo apt-get -y install libpq-dev postgresql-${{ matrix.version }} postgresql-server-dev-${{ matrix.version }}
cargo install cargo-pgrx --version $(grep '^pgrx ' Cargo.toml | awk -F'\"' '{print $2}')
cargo pgrx init --pg${{ matrix.version }}=/usr/lib/postgresql/${{ matrix.version }}/bin/pg_config
cargo install sqllogictest-bin
psql -f ./tests/init.sql
sqllogictest -u runner -d runner './tests/**/*.slt'
- uses: actions/cache/save@v3
with:
path: |
~/.cargo/registry/index/
~/.cargo/registry/cache/
~/.cargo/git/db/
key: cargo-${{ runner.os }}-pg${{ matrix.version }}-${{ hashFiles('./Cargo.toml') }}
- name: Build
run: |
sudo chmod -R 777 /usr/share/postgresql/${{ matrix.version }}/extension
sudo chmod -R 777 /usr/lib/postgresql/${{ matrix.version }}/lib
cargo pgrx install --release
sudo systemctl start postgresql@${{ matrix.version }}-main
sudo -u postgres psql -c "CREATE USER $USER LOGIN SUPERUSER"
sudo -u postgres psql -c "CREATE DATABASE $USER OWNER $USER"
psql -c 'ALTER SYSTEM SET shared_preload_libraries = "vectors.so"'
sudo systemctl restart postgresql@${{ matrix.version }}-main
- name: Sqllogictest
run: |
export password=$(openssl rand -base64 32)
psql -c "ALTER USER $USER WITH PASSWORD '$password'"
psql -f ./tests/init.sql
sqllogictest -u "$USER" -w "$password" -d "$USER" './tests/**/*.slt'
key: cargo-${{ matrix.os }}-pg${{ matrix.version }}-${{ hashFiles('./Cargo.toml') }}

View File

@@ -41,7 +41,7 @@ validator = { version = "0.16.1", features = ["derive"] }
toml = "0.7.6"
rayon = "1.6.1"
uuid = { version = "1.4.1", features = ["serde"] }
rustix = { version = "0.38.20", features = ["net", "mm"] }
rustix = { version = "0.38.20", features = ["net", "mm", "shm"] }
arc-swap = "1.6.0"
bytemuck = "1.14.0"
serde_with = "3.4.0"
@@ -53,6 +53,9 @@ pgrx-tests = "0.11.0"
httpmock = "0.6"
mockall = "0.11.4"
[target.'cfg(target_os = "macos")'.dependencies]
ulock-sys = "0.1.0"
[profile.dev]
panic = "unwind"

View File

@@ -1,4 +1,9 @@
[toolchain]
channel = "nightly-2023-08-03"
channel = "nightly-2023-11-15"
components = ["rustfmt", "clippy", "miri"]
targets = ["x86_64-unknown-linux-gnu"]
targets = [
"x86_64-unknown-linux-gnu",
"x86_64-apple-darwin",
"aarch64-unknown-linux-gnu",
"aarch64-apple-darwin",
]

13
scripts/ci_install.sh Executable file
View File

@@ -0,0 +1,13 @@
#!/usr/bin/env bash
set -e
cargo pgrx install --no-default-features --features "pg$VERSION" --release
psql -c 'ALTER SYSTEM SET shared_preload_libraries = "vectors.so"'
if [ "$OS" == "ubuntu-latest" ]; then
sudo systemctl restart postgresql
pg_lsclusters
fi
if [ "$OS" == "macos-latest" ]; then
brew services restart postgresql@$VERSION
fi

39
scripts/ci_setup.sh Executable file
View File

@@ -0,0 +1,39 @@
#!/usr/bin/env bash
set -e
if [ "$OS" == "ubuntu-latest" ]; then
if [ $VERSION != 14 ]; then
sudo pg_dropcluster 14 main
fi
sudo apt-get remove -y '^postgres.*' '^libpq.*' '^clang.*' '^llvm.*' '^libclang.*' '^libllvm.*' '^mono-llvm.*'
sudo sh -c 'echo "deb http://apt.postgresql.org/pub/repos/apt $(lsb_release -cs)-pgdg main" > /etc/apt/sources.list.d/pgdg.list'
wget --quiet -O - https://www.postgresql.org/media/keys/ACCC4CF8.asc | sudo apt-key add -
sudo apt-get update
sudo apt-get -y install build-essential libpq-dev postgresql-$VERSION postgresql-server-dev-$VERSION
echo "local all all trust" | sudo tee /etc/postgresql/$VERSION/main/pg_hba.conf
echo "host all all 127.0.0.1/32 trust" | sudo tee -a /etc/postgresql/$VERSION/main/pg_hba.conf
echo "host all all ::1/128 trust" | sudo tee -a /etc/postgresql/$VERSION/main/pg_hba.conf
pg_lsclusters
sudo systemctl restart postgresql
pg_lsclusters
sudo -iu postgres createuser -s -r runner
createdb
fi
if [ "$OS" == "macos-latest" ]; then
brew uninstall postgresql
brew install postgresql@$VERSION
export PATH="$PATH:$(brew --prefix postgresql@$VERSION)/bin"
echo "$(brew --prefix postgresql@$VERSION)/bin" >> $GITHUB_PATH
brew services start postgresql@$VERSION
sleep 30
createdb
fi
sudo chmod -R 777 `pg_config --pkglibdir`
sudo chmod -R 777 `pg_config --sharedir`/extension
curl -L --proto '=https' --tlsv1.2 -sSf https://raw.githubusercontent.com/cargo-bins/cargo-binstall/main/install-from-binstall-release.sh | bash
cargo binstall sqllogictest-bin -y --force
cargo install cargo-pgrx --version $(grep '^pgrx ' Cargo.toml | awk -F'\"' '{print $2}') --debug
cargo pgrx init --pg$VERSION=$(which pg_config)

View File

@@ -7,13 +7,13 @@ use crate::index::segments::sealed::SealedSegment;
use crate::index::IndexOptions;
use crate::index::VectorOptions;
use crate::prelude::*;
use crate::utils::cells::SyncUnsafeCell;
use crate::utils::dir_ops::sync_dir;
use crate::utils::mmap_array::MmapArray;
use crate::utils::vec2::Vec2;
use rand::seq::index::sample;
use rand::thread_rng;
use rayon::prelude::{IntoParallelIterator, ParallelIterator};
use std::cell::SyncUnsafeCell;
use std::fs::create_dir;
use std::path::PathBuf;
use std::sync::atomic::AtomicU32;

View File

@@ -8,12 +8,12 @@ use crate::index::segments::sealed::SealedSegment;
use crate::index::IndexOptions;
use crate::index::VectorOptions;
use crate::prelude::*;
use crate::utils::cells::SyncUnsafeCell;
use crate::utils::dir_ops::sync_dir;
use crate::utils::mmap_array::MmapArray;
use crate::utils::vec2::Vec2;
use rand::seq::index::sample;
use rand::thread_rng;
use std::cell::SyncUnsafeCell;
use std::fs::create_dir;
use std::path::PathBuf;
use std::sync::atomic::AtomicU32;

View File

@@ -4,7 +4,7 @@ use self::bgworker::Bgworker;
use crate::ipc::server::RpcHandler;
use crate::ipc::IpcError;
use std::fs::OpenOptions;
use std::path::PathBuf;
use std::path::{Path, PathBuf};
use std::sync::Arc;
pub fn main() {
@@ -39,7 +39,7 @@ pub fn main() {
log::error!("Panickied. Info: {:?}. Backtrace: {}.", info, backtrace);
}));
let bgworker;
if std::fs::try_exists("pg_vectors").unwrap() {
if Path::new("pg_vectors").try_exists().unwrap() {
bgworker = Bgworker::open(PathBuf::from("pg_vectors"));
} else {
bgworker = Bgworker::create(PathBuf::from("pg_vectors"));

View File

@@ -26,18 +26,11 @@ pub fn listen_unix() -> impl Iterator<Item = RpcHandler> {
}
pub fn listen_mmap() -> impl Iterator<Item = RpcHandler> {
#[cfg(target_os = "linux")]
{
std::iter::from_fn(move || {
let socket = self::transport::Socket::Mmap(self::transport::mmap::accept());
Some(self::server::RpcHandler::new(socket))
})
}
#[cfg(not(target_os = "linux"))]
{
std::iter::empty()
}
}
pub fn connect_unix() -> Rpc {
let socket = self::transport::Socket::Unix(self::transport::unix::connect());
@@ -45,14 +38,6 @@ pub fn connect_unix() -> Rpc {
}
pub fn connect_mmap() -> Rpc {
#[cfg(target_os = "linux")]
{
let socket = self::transport::Socket::Mmap(self::transport::mmap::connect());
self::client::Rpc::new(socket)
}
#[cfg(not(target_os = "linux"))]
{
use crate::prelude::FriendlyError;
FriendlyError::MmapTransportNotSupported.friendly();
}
}

View File

@@ -1,21 +1,16 @@
use crate::ipc::IpcError;
use crate::utils::file_socket::FileSocket;
use crate::utils::os::{futex_wait, futex_wake, memfd_create, mmap_populate};
use rustix::fd::{AsFd, OwnedFd};
use rustix::fs::{FlockOperation, MemfdFlags};
use rustix::mm::{MapFlags, ProtFlags};
use rustix::fs::FlockOperation;
use serde::{Deserialize, Serialize};
use std::cell::UnsafeCell;
use std::io::ErrorKind;
use std::ptr::null_mut;
use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::OnceLock;
const BUFFER_SIZE: usize = 512 * 1024;
const SPIN_LIMIT: usize = 8;
const FUTEX_TIMEOUT: libc::timespec = libc::timespec {
tv_sec: 15,
tv_nsec: 0,
};
static CHANNEL: OnceLock<FileSocket> = OnceLock::new();
@@ -26,18 +21,7 @@ pub fn init() {
pub fn accept() -> Socket {
let memfd = CHANNEL.get().unwrap().recv().unwrap();
rustix::fs::fcntl_lock(&memfd, FlockOperation::NonBlockingLockShared).unwrap();
let addr;
unsafe {
addr = rustix::mm::mmap(
null_mut(),
BUFFER_SIZE,
ProtFlags::READ | ProtFlags::WRITE,
MapFlags::POPULATE | MapFlags::SHARED,
&memfd,
0,
)
.unwrap();
}
let addr = unsafe { mmap_populate(BUFFER_SIZE, &memfd).unwrap() };
Socket {
is_server: true,
addr: addr as _,
@@ -46,22 +30,11 @@ pub fn accept() -> Socket {
}
pub fn connect() -> Socket {
let memfd = rustix::fs::memfd_create("transport", MemfdFlags::empty()).unwrap();
let memfd = memfd_create().unwrap();
rustix::fs::ftruncate(&memfd, BUFFER_SIZE as u64).unwrap();
rustix::fs::fcntl_lock(&memfd, FlockOperation::NonBlockingLockShared).unwrap();
CHANNEL.get().unwrap().send(memfd.as_fd()).unwrap();
let addr;
unsafe {
addr = rustix::mm::mmap(
null_mut(),
BUFFER_SIZE,
ProtFlags::READ | ProtFlags::WRITE,
MapFlags::POPULATE | MapFlags::SHARED,
&memfd,
0,
)
.unwrap();
}
let addr = unsafe { mmap_populate(BUFFER_SIZE, &memfd).unwrap() };
Socket {
is_server: false,
addr: addr as _,
@@ -159,25 +132,13 @@ impl Channel {
{
break;
}
libc::syscall(
libc::SYS_futex,
self.futex.as_ptr(),
libc::FUTEX_WAIT,
Y,
&FUTEX_TIMEOUT,
);
futex_wait(&self.futex, Y);
}
Y => {
if !test() {
return Err(IpcError::Closed);
}
libc::syscall(
libc::SYS_futex,
self.futex.as_ptr(),
libc::FUTEX_WAIT,
Y,
&FUTEX_TIMEOUT,
);
futex_wait(&self.futex, Y);
}
_ => std::hint::unreachable_unchecked(),
}
@@ -193,17 +154,8 @@ impl Channel {
debug_assert!(matches!(self.futex.load(Ordering::Relaxed), S | X));
*self.len.get() = data.len() as u32;
(*self.bytes.get())[0..data.len()].copy_from_slice(data);
match self.futex.swap(T, Ordering::Release) {
S => (),
X => {
libc::syscall(
libc::SYS_futex,
self.futex.as_ptr(),
libc::FUTEX_WAKE,
i32::MAX,
);
}
_ => std::hint::unreachable_unchecked(),
if X == self.futex.swap(T, Ordering::Release) {
futex_wake(&self.futex);
}
}
unsafe fn server_recv(&self, test: impl Fn() -> bool) -> Result<Vec<u8>, IpcError> {
@@ -229,25 +181,13 @@ impl Channel {
{
break;
}
libc::syscall(
libc::SYS_futex,
self.futex.as_ptr(),
libc::FUTEX_WAIT,
Y,
&FUTEX_TIMEOUT,
);
futex_wait(&self.futex, Y);
}
Y => {
if !test() {
return Err(IpcError::Closed);
}
libc::syscall(
libc::SYS_futex,
self.futex.as_ptr(),
libc::FUTEX_WAIT,
Y,
&FUTEX_TIMEOUT,
);
futex_wait(&self.futex, Y);
}
_ => std::hint::unreachable_unchecked(),
}
@@ -263,17 +203,8 @@ impl Channel {
debug_assert!(matches!(self.futex.load(Ordering::Relaxed), S | X));
*self.len.get() = data.len() as u32;
(*self.bytes.get())[0..data.len()].copy_from_slice(data);
match self.futex.swap(T, Ordering::Release) {
S => (),
X => {
libc::syscall(
libc::SYS_futex,
self.futex.as_ptr(),
libc::FUTEX_WAKE,
i32::MAX,
);
}
_ => std::hint::unreachable_unchecked(),
if X == self.futex.swap(T, Ordering::Release) {
futex_wake(&self.futex);
}
}
}

View File

@@ -1,4 +1,3 @@
#[cfg(target_os = "linux")]
pub mod mmap;
pub mod unix;
@@ -7,7 +6,6 @@ use serde::{Deserialize, Serialize};
pub enum Socket {
Unix(unix::Socket),
#[cfg(target_os = "linux")]
Mmap(mmap::Socket),
}
@@ -15,14 +13,12 @@ impl Socket {
pub fn send<T: Serialize>(&mut self, packet: T) -> Result<(), IpcError> {
match self {
Socket::Unix(x) => x.send(packet),
#[cfg(target_os = "linux")]
Socket::Mmap(x) => x.send(packet),
}
}
pub fn recv<T: for<'a> Deserialize<'a>>(&mut self) -> Result<T, IpcError> {
match self {
Socket::Unix(x) => x.recv(),
#[cfg(target_os = "linux")]
Socket::Mmap(x) => x.recv(),
}
}

View File

@@ -3,17 +3,7 @@
//! Provides an easy-to-use extension for vector similarity search.
#![feature(core_intrinsics)]
#![feature(allocator_api)]
#![feature(thread_local)]
#![feature(auto_traits)]
#![feature(negative_impls)]
#![feature(ptr_metadata)]
#![feature(new_uninit)]
#![feature(int_roundings)]
#![feature(never_type)]
#![feature(lazy_cell)]
#![feature(const_maybe_uninit_zeroed)]
#![feature(fs_try_exists)]
#![feature(sync_unsafe_cell)]
#![allow(clippy::complexity)]
#![allow(clippy::style)]
@@ -48,7 +38,6 @@ pub unsafe extern "C" fn _PG_init() {
.load();
self::postgres::init();
self::ipc::transport::unix::init();
#[cfg(target_os = "linux")]
self::ipc::transport::mmap::init();
}

View File

@@ -10,15 +10,8 @@ pub enum Transport {
impl Transport {
pub const fn default() -> Transport {
#[cfg(target_os = "linux")]
{
Transport::mmap
}
#[cfg(not(target_os = "linux"))]
{
Transport::unix
}
}
}
pub static OPENAI_API_KEY_GUC: GucSetting<Option<&'static CStr>> =

View File

@@ -126,7 +126,7 @@ unsafe extern "C" fn rewrite_plan_state(
_ => (),
}
let walker = std::mem::transmute::<PlanstateTreeWalker, _>(rewrite_plan_state);
#[cfg(not(feature = "pg16"))]
#[cfg(any(feature = "pg12", feature = "pg13", feature = "pg14", feature = "pg15"))]
{
pgrx::pg_sys::planstate_tree_walker(node, Some(walker), context)
}

View File

@@ -3,17 +3,14 @@ use super::gucs::TRANSPORT;
use crate::ipc::client::Rpc;
use crate::ipc::{connect_mmap, connect_unix};
use crate::prelude::*;
use std::cell::RefCell;
use crate::utils::cells::PgRefCell;
use std::collections::BTreeSet;
#[thread_local]
static FLUSH_IF_COMMIT: RefCell<BTreeSet<Id>> = RefCell::new(BTreeSet::new());
static FLUSH_IF_COMMIT: PgRefCell<BTreeSet<Id>> = unsafe { PgRefCell::new(BTreeSet::new()) };
#[thread_local]
static DROP_IF_COMMIT: RefCell<BTreeSet<Id>> = RefCell::new(BTreeSet::new());
static DROP_IF_COMMIT: PgRefCell<BTreeSet<Id>> = unsafe { PgRefCell::new(BTreeSet::new()) };
#[thread_local]
static CLIENT: RefCell<Option<Rpc>> = RefCell::new(None);
static CLIENT: PgRefCell<Option<Rpc>> = unsafe { PgRefCell::new(None) };
pub fn aborting() {
*FLUSH_IF_COMMIT.borrow_mut() = BTreeSet::new();

View File

@@ -14,6 +14,34 @@ unsafe extern "C" fn vectors_executor_start(
super::hook_executor::post_executor_start(query_desc);
}
#[cfg(any(feature = "pg12", feature = "pg13"))]
#[pgrx::pg_guard]
unsafe extern "C" fn vectors_process_utility(
pstmt: *mut pgrx::pg_sys::PlannedStmt,
query_string: *const ::std::os::raw::c_char,
context: pgrx::pg_sys::ProcessUtilityContext,
params: pgrx::pg_sys::ParamListInfo,
query_env: *mut pgrx::pg_sys::QueryEnvironment,
dest: *mut pgrx::pg_sys::DestReceiver,
qc: *mut pgrx::pg_sys::QueryCompletion,
) {
super::hook_executor::pre_process_utility(pstmt);
if let Some(prev_process_utility) = PREV_PROCESS_UTILITY {
prev_process_utility(pstmt, query_string, context, params, query_env, dest, qc);
} else {
pgrx::pg_sys::standard_ProcessUtility(
pstmt,
query_string,
context,
params,
query_env,
dest,
qc,
);
}
}
#[cfg(any(feature = "pg14", feature = "pg15", feature = "pg16"))]
#[pgrx::pg_guard]
unsafe extern "C" fn vectors_process_utility(
pstmt: *mut pgrx::pg_sys::PlannedStmt,

View File

@@ -5,10 +5,9 @@ use super::index_update;
use crate::postgres::datatype::VectorInput;
use crate::postgres::gucs::ENABLE_VECTOR_INDEX;
use crate::prelude::*;
use std::cell::Cell;
use crate::utils::cells::PgCell;
#[thread_local]
static RELOPT_KIND: Cell<pgrx::pg_sys::relopt_kind> = Cell::new(0);
static RELOPT_KIND: PgCell<pgrx::pg_sys::relopt_kind> = unsafe { PgCell::new(0) };
pub unsafe fn init() {
use pgrx::pg_sys::AsPgCStr;
@@ -50,7 +49,6 @@ const AM_HANDLER: pgrx::pg_sys::IndexAmRoutine = {
am_routine.amstrategies = 1;
am_routine.amsupport = 0;
am_routine.amoptsprocnum = 0;
am_routine.amcanorder = false;
am_routine.amcanorderbyop = true;
@@ -64,7 +62,6 @@ const AM_HANDLER: pgrx::pg_sys::IndexAmRoutine = {
am_routine.amclusterable = false;
am_routine.ampredlocks = false;
am_routine.amcaninclude = false;
am_routine.amusemaintenanceworkmem = false;
am_routine.amkeytype = pgrx::pg_sys::InvalidOid;
am_routine.amvalidate = Some(amvalidate);
@@ -95,25 +92,27 @@ pub unsafe extern "C" fn amvalidate(opclass_oid: pgrx::pg_sys::Oid) -> bool {
#[cfg(feature = "pg12")]
#[pgrx::pg_guard]
pub unsafe extern "C" fn amoptions(
reloptions: pg_sys::Datum,
reloptions: pgrx::pg_sys::Datum,
validate: bool,
) -> *mut pg_sys::bytea {
use pg_sys::AsPgCStr;
let tab: &[pg_sys::relopt_parse_elt] = &[pg_sys::relopt_parse_elt {
) -> *mut pgrx::pg_sys::bytea {
use pgrx::pg_sys::AsPgCStr;
let tab: &[pgrx::pg_sys::relopt_parse_elt] = &[pgrx::pg_sys::relopt_parse_elt {
optname: "options".as_pg_cstr(),
opttype: pg_sys::relopt_type_RELOPT_TYPE_STRING,
opttype: pgrx::pg_sys::relopt_type_RELOPT_TYPE_STRING,
offset: index_setup::helper_offset() as i32,
}];
let mut noptions = 0;
let options = pg_sys::parseRelOptions(reloptions, validate, RELOPT_KIND.get(), &mut noptions);
let options =
pgrx::pg_sys::parseRelOptions(reloptions, validate, RELOPT_KIND.get(), &mut noptions);
if noptions == 0 {
return std::ptr::null_mut();
}
for relopt in std::slice::from_raw_parts_mut(options, noptions as usize) {
relopt.gen.as_mut().unwrap().lockmode = pg_sys::AccessExclusiveLock as pg_sys::LOCKMODE;
relopt.gen.as_mut().unwrap().lockmode =
pgrx::pg_sys::AccessExclusiveLock as pgrx::pg_sys::LOCKMODE;
}
let rdopts = pg_sys::allocateReloptStruct(index_setup::helper_size(), options, noptions);
pg_sys::fillRelOptions(
let rdopts = pgrx::pg_sys::allocateReloptStruct(index_setup::helper_size(), options, noptions);
pgrx::pg_sys::fillRelOptions(
rdopts,
index_setup::helper_size(),
options,
@@ -122,8 +121,8 @@ pub unsafe extern "C" fn amoptions(
tab.as_ptr(),
tab.len() as i32,
);
pg_sys::pfree(options as pgrx::void_mut_ptr);
rdopts as *mut pg_sys::bytea
pgrx::pg_sys::pfree(options as pgrx::void_mut_ptr);
rdopts as *mut pgrx::pg_sys::bytea
}
#[cfg(any(feature = "pg13", feature = "pg14", feature = "pg15", feature = "pg16"))]
@@ -196,21 +195,21 @@ pub unsafe extern "C" fn ambuildempty(index_relation: pgrx::pg_sys::Relation) {
}
#[cfg(any(feature = "pg12", feature = "pg13"))]
#[pg_guard]
#[pgrx::pg_guard]
pub unsafe extern "C" fn aminsert(
index_relation: pg_sys::Relation,
values: *mut pg_sys::Datum,
index_relation: pgrx::pg_sys::Relation,
values: *mut pgrx::pg_sys::Datum,
is_null: *mut bool,
heap_tid: pg_sys::ItemPointer,
_heap_relation: pg_sys::Relation,
_check_unique: pg_sys::IndexUniqueCheck,
_index_info: *mut pg_sys::IndexInfo,
heap_tid: pgrx::pg_sys::ItemPointer,
_heap_relation: pgrx::pg_sys::Relation,
_check_unique: pgrx::pg_sys::IndexUniqueCheck,
_index_info: *mut pgrx::pg_sys::IndexInfo,
) -> bool {
use pgrx::FromDatum;
let oid = (*index_relation).rd_id;
let id = Id::from_sys(oid);
let vector = VectorInput::from_datum(*values.add(0), *is_null.add(0)).unwrap();
let vector = vector.data().to_vec().into_boxed_slice();
let vector = vector.data().to_vec();
index_update::update_insert(id, vector, heap_tid);
true
}

View File

@@ -44,11 +44,11 @@ pub unsafe fn build(
}
#[cfg(feature = "pg12")]
#[pg_guard]
#[pgrx::pg_guard]
unsafe extern "C" fn callback(
index_relation: pg_sys::Relation,
htup: pg_sys::HeapTuple,
values: *mut pg_sys::Datum,
index_relation: pgrx::pg_sys::Relation,
htup: pgrx::pg_sys::HeapTuple,
values: *mut pgrx::pg_sys::Datum,
is_null: *mut bool,
_tuple_is_alive: bool,
state: *mut std::os::raw::c_void,

View File

@@ -114,6 +114,7 @@ pub unsafe fn start_scan(
}
}
#[allow(clippy::never_loop)]
pub unsafe fn next_scan(scan: pgrx::pg_sys::IndexScanDesc) -> bool {
let scanner = &mut *((*scan).opaque as *mut Scanner);
if matches!(scanner, Scanner::Initial { .. }) {

View File

@@ -6,7 +6,13 @@ use std::ops::Deref;
#[pgrx::opname(+)]
#[pgrx::commutator(+)]
fn operator_add(lhs: VectorInput<'_>, rhs: VectorInput<'_>) -> VectorOutput {
assert_eq!(lhs.len(), rhs.len(), "Invaild operation.");
if lhs.len() != rhs.len() {
FriendlyError::DifferentVectorDims {
left_dimensions: lhs.len() as _,
right_dimensions: rhs.len() as _,
}
.friendly();
}
let n = lhs.len();
let mut v = Vector::new_zeroed(n);
for i in 0..n {
@@ -18,7 +24,13 @@ fn operator_add(lhs: VectorInput<'_>, rhs: VectorInput<'_>) -> VectorOutput {
#[pgrx::pg_operator(immutable, parallel_safe, requires = ["vector"])]
#[pgrx::opname(-)]
fn operator_minus(lhs: VectorInput<'_>, rhs: VectorInput<'_>) -> VectorOutput {
assert_eq!(lhs.len(), rhs.len(), "Invaild operation.");
if lhs.len() != rhs.len() {
FriendlyError::DifferentVectorDims {
left_dimensions: lhs.len() as _,
right_dimensions: rhs.len() as _,
}
.friendly();
}
let n = lhs.len();
let mut v = Vector::new_zeroed(n);
for i in 0..n {
@@ -34,7 +46,13 @@ fn operator_minus(lhs: VectorInput<'_>, rhs: VectorInput<'_>) -> VectorOutput {
#[pgrx::restrict(scalarltsel)]
#[pgrx::join(scalarltjoinsel)]
fn operator_lt(lhs: VectorInput<'_>, rhs: VectorInput<'_>) -> bool {
assert_eq!(lhs.len(), rhs.len(), "Invaild operation.");
if lhs.len() != rhs.len() {
FriendlyError::DifferentVectorDims {
left_dimensions: lhs.len() as _,
right_dimensions: rhs.len() as _,
}
.friendly();
}
lhs.deref() < rhs.deref()
}
@@ -45,7 +63,13 @@ fn operator_lt(lhs: VectorInput<'_>, rhs: VectorInput<'_>) -> bool {
#[pgrx::restrict(scalarltsel)]
#[pgrx::join(scalarltjoinsel)]
fn operator_lte(lhs: VectorInput<'_>, rhs: VectorInput<'_>) -> bool {
assert_eq!(lhs.len(), rhs.len(), "Invaild operation.");
if lhs.len() != rhs.len() {
FriendlyError::DifferentVectorDims {
left_dimensions: lhs.len() as _,
right_dimensions: rhs.len() as _,
}
.friendly();
}
lhs.deref() <= rhs.deref()
}
@@ -56,7 +80,13 @@ fn operator_lte(lhs: VectorInput<'_>, rhs: VectorInput<'_>) -> bool {
#[pgrx::restrict(scalargtsel)]
#[pgrx::join(scalargtjoinsel)]
fn operator_gt(lhs: VectorInput<'_>, rhs: VectorInput<'_>) -> bool {
assert_eq!(lhs.len(), rhs.len(), "Invaild operation.");
if lhs.len() != rhs.len() {
FriendlyError::DifferentVectorDims {
left_dimensions: lhs.len() as _,
right_dimensions: rhs.len() as _,
}
.friendly();
}
lhs.deref() > rhs.deref()
}
@@ -67,7 +97,13 @@ fn operator_gt(lhs: VectorInput<'_>, rhs: VectorInput<'_>) -> bool {
#[pgrx::restrict(scalargtsel)]
#[pgrx::join(scalargtjoinsel)]
fn operator_gte(lhs: VectorInput<'_>, rhs: VectorInput<'_>) -> bool {
assert_eq!(lhs.len(), rhs.len(), "Invaild operation.");
if lhs.len() != rhs.len() {
FriendlyError::DifferentVectorDims {
left_dimensions: lhs.len() as _,
right_dimensions: rhs.len() as _,
}
.friendly();
}
lhs.deref() >= rhs.deref()
}
@@ -78,7 +114,13 @@ fn operator_gte(lhs: VectorInput<'_>, rhs: VectorInput<'_>) -> bool {
#[pgrx::restrict(eqsel)]
#[pgrx::join(eqjoinsel)]
fn operator_eq(lhs: VectorInput<'_>, rhs: VectorInput<'_>) -> bool {
assert_eq!(lhs.len(), rhs.len(), "Invaild operation.");
if lhs.len() != rhs.len() {
FriendlyError::DifferentVectorDims {
left_dimensions: lhs.len() as _,
right_dimensions: rhs.len() as _,
}
.friendly();
}
lhs.deref() == rhs.deref()
}
@@ -89,7 +131,13 @@ fn operator_eq(lhs: VectorInput<'_>, rhs: VectorInput<'_>) -> bool {
#[pgrx::restrict(eqsel)]
#[pgrx::join(eqjoinsel)]
fn operator_neq(lhs: VectorInput<'_>, rhs: VectorInput<'_>) -> bool {
assert_eq!(lhs.len(), rhs.len(), "Invaild operation.");
if lhs.len() != rhs.len() {
FriendlyError::DifferentVectorDims {
left_dimensions: lhs.len() as _,
right_dimensions: rhs.len() as _,
}
.friendly();
}
lhs.deref() != rhs.deref()
}
@@ -97,7 +145,13 @@ fn operator_neq(lhs: VectorInput<'_>, rhs: VectorInput<'_>) -> bool {
#[pgrx::opname(<=>)]
#[pgrx::commutator(<=>)]
fn operator_cosine(lhs: VectorInput<'_>, rhs: VectorInput<'_>) -> Scalar {
assert_eq!(lhs.len(), rhs.len(), "Invaild operation.");
if lhs.len() != rhs.len() {
FriendlyError::DifferentVectorDims {
left_dimensions: lhs.len() as _,
right_dimensions: rhs.len() as _,
}
.friendly();
}
Distance::Cosine.distance(&lhs, &rhs)
}
@@ -105,7 +159,13 @@ fn operator_cosine(lhs: VectorInput<'_>, rhs: VectorInput<'_>) -> Scalar {
#[pgrx::opname(<#>)]
#[pgrx::commutator(<#>)]
fn operator_dot(lhs: VectorInput<'_>, rhs: VectorInput<'_>) -> Scalar {
assert_eq!(lhs.len(), rhs.len(), "Invaild operation.");
if lhs.len() != rhs.len() {
FriendlyError::DifferentVectorDims {
left_dimensions: lhs.len() as _,
right_dimensions: rhs.len() as _,
}
.friendly();
}
Distance::Dot.distance(&lhs, &rhs)
}
@@ -113,6 +173,12 @@ fn operator_dot(lhs: VectorInput<'_>, rhs: VectorInput<'_>) -> Scalar {
#[pgrx::opname(<->)]
#[pgrx::commutator(<->)]
fn operator_l2(lhs: VectorInput<'_>, rhs: VectorInput<'_>) -> Scalar {
assert_eq!(lhs.len(), rhs.len(), "Invaild operation.");
if lhs.len() != rhs.len() {
FriendlyError::DifferentVectorDims {
left_dimensions: lhs.len() as _,
right_dimensions: rhs.len() as _,
}
.friendly();
}
Distance::L2.distance(&lhs, &rhs)
}

63
src/utils/cells.rs Normal file
View File

@@ -0,0 +1,63 @@
use std::cell::{Cell, RefCell, UnsafeCell};
pub struct PgCell<T>(Cell<T>);
unsafe impl<T: Send> Send for PgCell<T> {}
unsafe impl<T: Sync> Sync for PgCell<T> {}
impl<T> PgCell<T> {
pub const unsafe fn new(x: T) -> Self {
Self(Cell::new(x))
}
}
impl<T: Copy> PgCell<T> {
pub fn get(&self) -> T {
self.0.get()
}
pub fn set(&self, value: T) {
self.0.set(value);
}
}
pub struct PgRefCell<T>(RefCell<T>);
unsafe impl<T: Send> Send for PgRefCell<T> {}
unsafe impl<T: Sync> Sync for PgRefCell<T> {}
impl<T> PgRefCell<T> {
pub const unsafe fn new(x: T) -> Self {
Self(RefCell::new(x))
}
pub fn borrow_mut(&self) -> std::cell::RefMut<'_, T> {
self.0.borrow_mut()
}
pub fn borrow(&self) -> std::cell::Ref<'_, T> {
self.0.borrow()
}
}
#[repr(transparent)]
pub struct SyncUnsafeCell<T: ?Sized> {
value: UnsafeCell<T>,
}
unsafe impl<T: ?Sized + Sync> Sync for SyncUnsafeCell<T> {}
impl<T> SyncUnsafeCell<T> {
pub const fn new(value: T) -> Self {
Self {
value: UnsafeCell::new(value),
}
}
}
impl<T: ?Sized> SyncUnsafeCell<T> {
pub fn get(&self) -> *mut T {
self.value.get()
}
pub fn get_mut(&mut self) -> &mut T {
self.value.get_mut()
}
}

View File

@@ -1,5 +1,5 @@
use super::dir_ops::sync_dir;
use std::fs::{try_exists, File};
use std::fs::File;
use std::io::Write;
use std::path::{Path, PathBuf};
@@ -43,7 +43,7 @@ where
{
pub fn open(path: impl AsRef<Path>) -> Self {
let path = path.as_ref().to_owned();
if try_exists("1").unwrap() {
if path.join("1").try_exists().unwrap() {
std::fs::remove_file(path.join("1")).unwrap();
sync_dir(&path);
}

View File

@@ -1,8 +1,10 @@
pub mod cells;
pub mod clean;
pub mod dir_ops;
pub mod file_atomic;
pub mod file_socket;
pub mod file_wal;
pub mod mmap_array;
pub mod os;
pub mod semaphore;
pub mod vec2;

98
src/utils/os.rs Normal file
View File

@@ -0,0 +1,98 @@
use rustix::fd::{AsFd, OwnedFd};
use rustix::mm::{MapFlags, ProtFlags};
use std::sync::atomic::AtomicU32;
#[cfg(target_os = "linux")]
pub unsafe fn futex_wait(futex: &AtomicU32, value: u32) {
const FUTEX_TIMEOUT: libc::timespec = libc::timespec {
tv_sec: 15,
tv_nsec: 0,
};
libc::syscall(
libc::SYS_futex,
futex.as_ptr(),
libc::FUTEX_WAIT,
value,
&FUTEX_TIMEOUT,
);
}
#[cfg(target_os = "linux")]
pub unsafe fn futex_wake(futex: &AtomicU32) {
libc::syscall(libc::SYS_futex, futex.as_ptr(), libc::FUTEX_WAKE, i32::MAX);
}
#[cfg(target_os = "linux")]
pub fn memfd_create() -> std::io::Result<OwnedFd> {
use rustix::fs::MemfdFlags;
Ok(rustix::fs::memfd_create(
&format!(".memfd.VECTORS.{:x}", std::process::id()),
MemfdFlags::empty(),
)?)
}
#[cfg(target_os = "linux")]
pub unsafe fn mmap_populate(len: usize, fd: impl AsFd) -> std::io::Result<*mut libc::c_void> {
use std::ptr::null_mut;
Ok(rustix::mm::mmap(
null_mut(),
len,
ProtFlags::READ | ProtFlags::WRITE,
MapFlags::SHARED | MapFlags::POPULATE,
fd,
0,
)?)
}
#[cfg(target_os = "macos")]
pub unsafe fn futex_wait(futex: &AtomicU32, value: u32) {
const ULOCK_TIMEOUT: u32 = 15_000_000;
ulock_sys::__ulock_wait(
ulock_sys::darwin19::UL_COMPARE_AND_WAIT_SHARED,
futex.as_ptr().cast(),
value as _,
ULOCK_TIMEOUT,
);
}
#[cfg(target_os = "macos")]
pub unsafe fn futex_wake(futex: &AtomicU32) {
ulock_sys::__ulock_wake(
ulock_sys::darwin19::UL_COMPARE_AND_WAIT_SHARED,
futex.as_ptr().cast(),
0,
);
}
#[cfg(target_os = "macos")]
pub fn memfd_create() -> std::io::Result<OwnedFd> {
use rustix::fs::Mode;
use rustix::fs::OFlags;
// POSIX fcntl locking do not support shmem, so we use a regular file here.
// reference: https://man7.org/linux/man-pages/man3/fcntl.3p.html
let name = format!(
".shm.VECTORS.{:x}.{:x}",
std::process::id(),
rand::random::<u32>()
);
let fd = rustix::fs::open(
&name,
OFlags::RDWR | OFlags::CREATE | OFlags::EXCL,
Mode::RUSR | Mode::WUSR,
)?;
rustix::fs::unlink(&name)?;
Ok(fd)
}
#[cfg(target_os = "macos")]
pub unsafe fn mmap_populate(len: usize, fd: impl AsFd) -> std::io::Result<*mut libc::c_void> {
use std::ptr::null_mut;
Ok(rustix::mm::mmap(
null_mut(),
len,
ProtFlags::READ | ProtFlags::WRITE,
MapFlags::SHARED,
fd,
0,
)?)
}

View File

@@ -29,8 +29,7 @@ SELECT '[1,2]'::vector < '[1,3]';
----
t
# TODO: may need better error message
statement error assertion failed: `\(left == right\)`
statement error differs in dimensions
SELECT '[1,2]'::vector < '[1,2,3]';
query I