1
0
mirror of https://github.com/tensorchord/pgvecto.rs.git synced 2025-08-07 03:22:55 +03:00

feat: mmap transport for macos (#137)

* feat: mmap transport for macos

Signed-off-by: usamoi <usamoi@outlook.com>

* fix: building with feature pg12, pg13

Signed-off-by: usamoi <usamoi@outlook.com>

* fix: call unlink for shmem

Signed-off-by: usamoi <usamoi@outlook.com>

* fix: reduce shmem filename on macos

Signed-off-by: usamoi <usamoi@outlook.com>

* chore: enable testing on all Postgresql versions

Signed-off-by: usamoi <usamoi@outlook.com>

* fix: use file instead of shmem for macos

Signed-off-by: usamoi <usamoi@outlook.com>

* chore: select simpler matrix for pull requests in CI

Signed-off-by: usamoi <usamoi@outlook.com>

* fix: remove macos-latest-16

Signed-off-by: usamoi <usamoi@outlook.com>

* chore: reduce use of nightly features

Signed-off-by: usamoi <usamoi@outlook.com>

---------

Signed-off-by: usamoi <usamoi@outlook.com>
This commit is contained in:
Usamoi
2023-11-18 00:26:15 +08:00
committed by GitHub
parent f6e382d0fc
commit 39e8ee9797
26 changed files with 483 additions and 305 deletions

View File

@@ -1,7 +1,7 @@
[build] [build]
rustdocflags = ["--document-private-items"] rustdocflags = ["--document-private-items"]
[target.'cfg(target_arch="x86_64")'] [target.'cfg(all(target_os="linux", target_arch="x86_64"))']
rustflags = ["-Ctarget-cpu=x86-64-v3"] rustflags = ["-Ctarget-cpu=x86-64-v3"]
[target.'cfg(target_os="macos")'] [target.'cfg(target_os="macos")']

View File

@@ -2,23 +2,29 @@ name: Rust check
on: on:
push: push:
branches: [ "main" ] branches: ["main"]
paths: paths:
- '.github/workflows/check.yml' - ".cargo/**"
- 'src/**' - ".github/**"
- 'Cargo.toml' - "scripts/**"
- 'Cargo.lock' - "src/**"
- 'rust-toolchain.toml' - "tests/**"
- 'tests/**' - "Cargo.lock"
- "Cargo.toml"
- "rust-toolchain.toml"
- "vectors.control"
pull_request: pull_request:
branches: [ "main" ] branches: ["main"]
paths: paths:
- '.github/workflows/check.yml' - ".cargo/**"
- 'src/**' - ".github/**"
- 'Cargo.toml' - "scripts/**"
- 'Cargo.lock' - "src/**"
- 'rust-toolchain.toml' - "tests/**"
- 'tests/**' - "Cargo.lock"
- "Cargo.toml"
- "rust-toolchain.toml"
- "vectors.control"
merge_group: merge_group:
workflow_dispatch: workflow_dispatch:
@@ -28,116 +34,81 @@ concurrency:
env: env:
CARGO_TERM_COLOR: always CARGO_TERM_COLOR: always
RUST_BACKTRACE: 1
SCCACHE_GHA_ENABLED: true SCCACHE_GHA_ENABLED: true
RUSTC_WRAPPER: sccache RUSTC_WRAPPER: sccache
jobs: jobs:
lint: matrix:
runs-on: ubuntu-latest
outputs:
matrix: ${{ steps.main.outputs.matrix }}
steps:
- uses: actions/github-script@v7
id: main
with:
script: |
let matrix;
if("${{ github.event_name }}" == "push" || "${{ github.event_name }}" == "pull_request"){
matrix = [
{ version: 15, os: "ubuntu-latest" },
];
}
if("${{ github.event_name }}" == "merge_group" || "${{ github.event_name }}" == "workflow_dispatch"){
matrix = [
{ version: 12, os: "ubuntu-latest" },
{ version: 13, os: "ubuntu-latest" },
{ version: 14, os: "ubuntu-latest" },
{ version: 15, os: "ubuntu-latest" },
{ version: 16, os: "ubuntu-latest" },
{ version: 12, os: "macos-latest" },
{ version: 13, os: "macos-latest" },
{ version: 14, os: "macos-latest" },
{ version: 15, os: "macos-latest" },
];
}
core.setOutput('matrix', JSON.stringify(matrix));
check:
needs: matrix
strategy: strategy:
matrix: matrix:
version: [15] include: ${{ fromJson(needs.matrix.outputs.matrix) }}
runs-on: ubuntu-latest runs-on: ${{ matrix.os }}
env:
VERSION: ${{ matrix.version }}
OS: ${{ matrix.os }}
steps: steps:
- uses: actions/checkout@v3 - uses: actions/checkout@v3
- uses: actions/cache/restore@v3 - uses: actions/cache/restore@v3
with: with:
path: | path: |
~/.cargo/registry/index/ ~/.cargo/registry/index/
~/.cargo/registry/cache/ ~/.cargo/registry/cache/
~/.cargo/git/db/ ~/.cargo/git/db/
key: cargo-${{ runner.os }}-pg${{ matrix.version }}-${{ hashFiles('./Cargo.toml') }} key: cargo-${{ matrix.os }}-pg${{ matrix.version }}-${{ hashFiles('./Cargo.toml') }}
restore-keys: cargo-${{ runner.os }}-pg${{ matrix.version }} restore-keys: cargo-${{ matrix.os }}-pg${{ matrix.version }}
- uses: mozilla-actions/sccache-action@v0.0.3 - uses: mozilla-actions/sccache-action@v0.0.3
- name: Prepare - name: Setup
run: | shell: bash
sudo sh -c 'echo "deb http://apt.postgresql.org/pub/repos/apt $(lsb_release -cs)-pgdg main" > /etc/apt/sources.list.d/pgdg.list' run: ./scripts/ci_setup.sh
wget --quiet -O - https://www.postgresql.org/media/keys/ACCC4CF8.asc | sudo apt-key add - - name: Format check
sudo apt-get update run: cargo fmt --check
sudo apt-get -y install libpq-dev postgresql-${{ matrix.version }} postgresql-server-dev-${{ matrix.version }} - name: Semantic check
cargo install cargo-pgrx --version $(grep '^pgrx ' Cargo.toml | awk -F'\"' '{print $2}') run: cargo clippy --no-default-features --features "pg${{ matrix.version }} pg_test"
cargo pgrx init --pg${{ matrix.version }}=/usr/lib/postgresql/${{ matrix.version }}/bin/pg_config - name: Debug build
- name: Format check run: cargo build --no-default-features --features "pg${{ matrix.version }} pg_test"
run: cargo fmt --check - name: Test
- name: Semantic check run: cargo test --all --no-default-features --features "pg${{ matrix.version }} pg_test" -- --nocapture
run: cargo clippy - name: Install release
run: ./scripts/ci_install.sh
build: - name: Sqllogictest
strategy: run: |
matrix: psql -f ./tests/init.sql
version: [15] sqllogictest -u runner -d runner './tests/**/*.slt'
runs-on: ubuntu-latest - uses: actions/cache/save@v3
steps: with:
- uses: actions/checkout@v3 path: |
- uses: actions/cache/restore@v3 ~/.cargo/registry/index/
with: ~/.cargo/registry/cache/
path: | ~/.cargo/git/db/
~/.cargo/registry/index/ key: cargo-${{ matrix.os }}-pg${{ matrix.version }}-${{ hashFiles('./Cargo.toml') }}
~/.cargo/registry/cache/
~/.cargo/git/db/
key: cargo-${{ runner.os }}-pg${{ matrix.version }}-${{ hashFiles('./Cargo.toml') }}
restore-keys: cargo-${{ runner.os }}-pg${{ matrix.version }}
- uses: mozilla-actions/sccache-action@v0.0.3
- name: Prepare
run: |
sudo sh -c 'echo "deb http://apt.postgresql.org/pub/repos/apt $(lsb_release -cs)-pgdg main" > /etc/apt/sources.list.d/pgdg.list'
wget --quiet -O - https://www.postgresql.org/media/keys/ACCC4CF8.asc | sudo apt-key add -
sudo apt-get update
sudo apt-get -y install libpq-dev postgresql-${{ matrix.version }} postgresql-server-dev-${{ matrix.version }}
cargo install cargo-pgrx --version $(grep '^pgrx ' Cargo.toml | awk -F'\"' '{print $2}')
cargo pgrx init --pg${{ matrix.version }}=/usr/lib/postgresql/${{ matrix.version }}/bin/pg_config
- name: Build
run: cargo build --verbose
- name: Test
env:
RUST_BACKTRACE: 1
run: cargo test --all --no-default-features --features "pg${{ matrix.version }} pg_test" -- --nocapture
test:
strategy:
matrix:
version: [15]
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: actions/cache/restore@v3
with:
path: |
~/.cargo/registry/index/
~/.cargo/registry/cache/
~/.cargo/git/db/
key: cargo-${{ runner.os }}-pg${{ matrix.version }}-${{ hashFiles('./Cargo.toml') }}
restore-keys: cargo-${{ runner.os }}-pg${{ matrix.version }}
- uses: mozilla-actions/sccache-action@v0.0.3
- name: Prepare
run: |
sudo pg_dropcluster 14 main
sudo sh -c 'echo "deb http://apt.postgresql.org/pub/repos/apt $(lsb_release -cs)-pgdg main" > /etc/apt/sources.list.d/pgdg.list'
wget --quiet -O - https://www.postgresql.org/media/keys/ACCC4CF8.asc | sudo apt-key add -
sudo apt-get update
sudo apt-get -y install libpq-dev postgresql-${{ matrix.version }} postgresql-server-dev-${{ matrix.version }}
cargo install cargo-pgrx --version $(grep '^pgrx ' Cargo.toml | awk -F'\"' '{print $2}')
cargo pgrx init --pg${{ matrix.version }}=/usr/lib/postgresql/${{ matrix.version }}/bin/pg_config
cargo install sqllogictest-bin
- uses: actions/cache/save@v3
with:
path: |
~/.cargo/registry/index/
~/.cargo/registry/cache/
~/.cargo/git/db/
key: cargo-${{ runner.os }}-pg${{ matrix.version }}-${{ hashFiles('./Cargo.toml') }}
- name: Build
run: |
sudo chmod -R 777 /usr/share/postgresql/${{ matrix.version }}/extension
sudo chmod -R 777 /usr/lib/postgresql/${{ matrix.version }}/lib
cargo pgrx install --release
sudo systemctl start postgresql@${{ matrix.version }}-main
sudo -u postgres psql -c "CREATE USER $USER LOGIN SUPERUSER"
sudo -u postgres psql -c "CREATE DATABASE $USER OWNER $USER"
psql -c 'ALTER SYSTEM SET shared_preload_libraries = "vectors.so"'
sudo systemctl restart postgresql@${{ matrix.version }}-main
- name: Sqllogictest
run: |
export password=$(openssl rand -base64 32)
psql -c "ALTER USER $USER WITH PASSWORD '$password'"
psql -f ./tests/init.sql
sqllogictest -u "$USER" -w "$password" -d "$USER" './tests/**/*.slt'

View File

@@ -41,7 +41,7 @@ validator = { version = "0.16.1", features = ["derive"] }
toml = "0.7.6" toml = "0.7.6"
rayon = "1.6.1" rayon = "1.6.1"
uuid = { version = "1.4.1", features = ["serde"] } uuid = { version = "1.4.1", features = ["serde"] }
rustix = { version = "0.38.20", features = ["net", "mm"] } rustix = { version = "0.38.20", features = ["net", "mm", "shm"] }
arc-swap = "1.6.0" arc-swap = "1.6.0"
bytemuck = "1.14.0" bytemuck = "1.14.0"
serde_with = "3.4.0" serde_with = "3.4.0"
@@ -53,6 +53,9 @@ pgrx-tests = "0.11.0"
httpmock = "0.6" httpmock = "0.6"
mockall = "0.11.4" mockall = "0.11.4"
[target.'cfg(target_os = "macos")'.dependencies]
ulock-sys = "0.1.0"
[profile.dev] [profile.dev]
panic = "unwind" panic = "unwind"

View File

@@ -1,4 +1,9 @@
[toolchain] [toolchain]
channel = "nightly-2023-08-03" channel = "nightly-2023-11-15"
components = ["rustfmt", "clippy", "miri"] components = ["rustfmt", "clippy", "miri"]
targets = ["x86_64-unknown-linux-gnu"] targets = [
"x86_64-unknown-linux-gnu",
"x86_64-apple-darwin",
"aarch64-unknown-linux-gnu",
"aarch64-apple-darwin",
]

13
scripts/ci_install.sh Executable file
View File

@@ -0,0 +1,13 @@
#!/usr/bin/env bash
set -e
cargo pgrx install --no-default-features --features "pg$VERSION" --release
psql -c 'ALTER SYSTEM SET shared_preload_libraries = "vectors.so"'
if [ "$OS" == "ubuntu-latest" ]; then
sudo systemctl restart postgresql
pg_lsclusters
fi
if [ "$OS" == "macos-latest" ]; then
brew services restart postgresql@$VERSION
fi

39
scripts/ci_setup.sh Executable file
View File

@@ -0,0 +1,39 @@
#!/usr/bin/env bash
set -e
if [ "$OS" == "ubuntu-latest" ]; then
if [ $VERSION != 14 ]; then
sudo pg_dropcluster 14 main
fi
sudo apt-get remove -y '^postgres.*' '^libpq.*' '^clang.*' '^llvm.*' '^libclang.*' '^libllvm.*' '^mono-llvm.*'
sudo sh -c 'echo "deb http://apt.postgresql.org/pub/repos/apt $(lsb_release -cs)-pgdg main" > /etc/apt/sources.list.d/pgdg.list'
wget --quiet -O - https://www.postgresql.org/media/keys/ACCC4CF8.asc | sudo apt-key add -
sudo apt-get update
sudo apt-get -y install build-essential libpq-dev postgresql-$VERSION postgresql-server-dev-$VERSION
echo "local all all trust" | sudo tee /etc/postgresql/$VERSION/main/pg_hba.conf
echo "host all all 127.0.0.1/32 trust" | sudo tee -a /etc/postgresql/$VERSION/main/pg_hba.conf
echo "host all all ::1/128 trust" | sudo tee -a /etc/postgresql/$VERSION/main/pg_hba.conf
pg_lsclusters
sudo systemctl restart postgresql
pg_lsclusters
sudo -iu postgres createuser -s -r runner
createdb
fi
if [ "$OS" == "macos-latest" ]; then
brew uninstall postgresql
brew install postgresql@$VERSION
export PATH="$PATH:$(brew --prefix postgresql@$VERSION)/bin"
echo "$(brew --prefix postgresql@$VERSION)/bin" >> $GITHUB_PATH
brew services start postgresql@$VERSION
sleep 30
createdb
fi
sudo chmod -R 777 `pg_config --pkglibdir`
sudo chmod -R 777 `pg_config --sharedir`/extension
curl -L --proto '=https' --tlsv1.2 -sSf https://raw.githubusercontent.com/cargo-bins/cargo-binstall/main/install-from-binstall-release.sh | bash
cargo binstall sqllogictest-bin -y --force
cargo install cargo-pgrx --version $(grep '^pgrx ' Cargo.toml | awk -F'\"' '{print $2}') --debug
cargo pgrx init --pg$VERSION=$(which pg_config)

View File

@@ -7,13 +7,13 @@ use crate::index::segments::sealed::SealedSegment;
use crate::index::IndexOptions; use crate::index::IndexOptions;
use crate::index::VectorOptions; use crate::index::VectorOptions;
use crate::prelude::*; use crate::prelude::*;
use crate::utils::cells::SyncUnsafeCell;
use crate::utils::dir_ops::sync_dir; use crate::utils::dir_ops::sync_dir;
use crate::utils::mmap_array::MmapArray; use crate::utils::mmap_array::MmapArray;
use crate::utils::vec2::Vec2; use crate::utils::vec2::Vec2;
use rand::seq::index::sample; use rand::seq::index::sample;
use rand::thread_rng; use rand::thread_rng;
use rayon::prelude::{IntoParallelIterator, ParallelIterator}; use rayon::prelude::{IntoParallelIterator, ParallelIterator};
use std::cell::SyncUnsafeCell;
use std::fs::create_dir; use std::fs::create_dir;
use std::path::PathBuf; use std::path::PathBuf;
use std::sync::atomic::AtomicU32; use std::sync::atomic::AtomicU32;

View File

@@ -8,12 +8,12 @@ use crate::index::segments::sealed::SealedSegment;
use crate::index::IndexOptions; use crate::index::IndexOptions;
use crate::index::VectorOptions; use crate::index::VectorOptions;
use crate::prelude::*; use crate::prelude::*;
use crate::utils::cells::SyncUnsafeCell;
use crate::utils::dir_ops::sync_dir; use crate::utils::dir_ops::sync_dir;
use crate::utils::mmap_array::MmapArray; use crate::utils::mmap_array::MmapArray;
use crate::utils::vec2::Vec2; use crate::utils::vec2::Vec2;
use rand::seq::index::sample; use rand::seq::index::sample;
use rand::thread_rng; use rand::thread_rng;
use std::cell::SyncUnsafeCell;
use std::fs::create_dir; use std::fs::create_dir;
use std::path::PathBuf; use std::path::PathBuf;
use std::sync::atomic::AtomicU32; use std::sync::atomic::AtomicU32;

View File

@@ -4,7 +4,7 @@ use self::bgworker::Bgworker;
use crate::ipc::server::RpcHandler; use crate::ipc::server::RpcHandler;
use crate::ipc::IpcError; use crate::ipc::IpcError;
use std::fs::OpenOptions; use std::fs::OpenOptions;
use std::path::PathBuf; use std::path::{Path, PathBuf};
use std::sync::Arc; use std::sync::Arc;
pub fn main() { pub fn main() {
@@ -39,7 +39,7 @@ pub fn main() {
log::error!("Panickied. Info: {:?}. Backtrace: {}.", info, backtrace); log::error!("Panickied. Info: {:?}. Backtrace: {}.", info, backtrace);
})); }));
let bgworker; let bgworker;
if std::fs::try_exists("pg_vectors").unwrap() { if Path::new("pg_vectors").try_exists().unwrap() {
bgworker = Bgworker::open(PathBuf::from("pg_vectors")); bgworker = Bgworker::open(PathBuf::from("pg_vectors"));
} else { } else {
bgworker = Bgworker::create(PathBuf::from("pg_vectors")); bgworker = Bgworker::create(PathBuf::from("pg_vectors"));

View File

@@ -26,17 +26,10 @@ pub fn listen_unix() -> impl Iterator<Item = RpcHandler> {
} }
pub fn listen_mmap() -> impl Iterator<Item = RpcHandler> { pub fn listen_mmap() -> impl Iterator<Item = RpcHandler> {
#[cfg(target_os = "linux")] std::iter::from_fn(move || {
{ let socket = self::transport::Socket::Mmap(self::transport::mmap::accept());
std::iter::from_fn(move || { Some(self::server::RpcHandler::new(socket))
let socket = self::transport::Socket::Mmap(self::transport::mmap::accept()); })
Some(self::server::RpcHandler::new(socket))
})
}
#[cfg(not(target_os = "linux"))]
{
std::iter::empty()
}
} }
pub fn connect_unix() -> Rpc { pub fn connect_unix() -> Rpc {
@@ -45,14 +38,6 @@ pub fn connect_unix() -> Rpc {
} }
pub fn connect_mmap() -> Rpc { pub fn connect_mmap() -> Rpc {
#[cfg(target_os = "linux")] let socket = self::transport::Socket::Mmap(self::transport::mmap::connect());
{ self::client::Rpc::new(socket)
let socket = self::transport::Socket::Mmap(self::transport::mmap::connect());
self::client::Rpc::new(socket)
}
#[cfg(not(target_os = "linux"))]
{
use crate::prelude::FriendlyError;
FriendlyError::MmapTransportNotSupported.friendly();
}
} }

View File

@@ -1,21 +1,16 @@
use crate::ipc::IpcError; use crate::ipc::IpcError;
use crate::utils::file_socket::FileSocket; use crate::utils::file_socket::FileSocket;
use crate::utils::os::{futex_wait, futex_wake, memfd_create, mmap_populate};
use rustix::fd::{AsFd, OwnedFd}; use rustix::fd::{AsFd, OwnedFd};
use rustix::fs::{FlockOperation, MemfdFlags}; use rustix::fs::FlockOperation;
use rustix::mm::{MapFlags, ProtFlags};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::cell::UnsafeCell; use std::cell::UnsafeCell;
use std::io::ErrorKind; use std::io::ErrorKind;
use std::ptr::null_mut;
use std::sync::atomic::{AtomicU32, Ordering}; use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::OnceLock; use std::sync::OnceLock;
const BUFFER_SIZE: usize = 512 * 1024; const BUFFER_SIZE: usize = 512 * 1024;
const SPIN_LIMIT: usize = 8; const SPIN_LIMIT: usize = 8;
const FUTEX_TIMEOUT: libc::timespec = libc::timespec {
tv_sec: 15,
tv_nsec: 0,
};
static CHANNEL: OnceLock<FileSocket> = OnceLock::new(); static CHANNEL: OnceLock<FileSocket> = OnceLock::new();
@@ -26,18 +21,7 @@ pub fn init() {
pub fn accept() -> Socket { pub fn accept() -> Socket {
let memfd = CHANNEL.get().unwrap().recv().unwrap(); let memfd = CHANNEL.get().unwrap().recv().unwrap();
rustix::fs::fcntl_lock(&memfd, FlockOperation::NonBlockingLockShared).unwrap(); rustix::fs::fcntl_lock(&memfd, FlockOperation::NonBlockingLockShared).unwrap();
let addr; let addr = unsafe { mmap_populate(BUFFER_SIZE, &memfd).unwrap() };
unsafe {
addr = rustix::mm::mmap(
null_mut(),
BUFFER_SIZE,
ProtFlags::READ | ProtFlags::WRITE,
MapFlags::POPULATE | MapFlags::SHARED,
&memfd,
0,
)
.unwrap();
}
Socket { Socket {
is_server: true, is_server: true,
addr: addr as _, addr: addr as _,
@@ -46,22 +30,11 @@ pub fn accept() -> Socket {
} }
pub fn connect() -> Socket { pub fn connect() -> Socket {
let memfd = rustix::fs::memfd_create("transport", MemfdFlags::empty()).unwrap(); let memfd = memfd_create().unwrap();
rustix::fs::ftruncate(&memfd, BUFFER_SIZE as u64).unwrap(); rustix::fs::ftruncate(&memfd, BUFFER_SIZE as u64).unwrap();
rustix::fs::fcntl_lock(&memfd, FlockOperation::NonBlockingLockShared).unwrap(); rustix::fs::fcntl_lock(&memfd, FlockOperation::NonBlockingLockShared).unwrap();
CHANNEL.get().unwrap().send(memfd.as_fd()).unwrap(); CHANNEL.get().unwrap().send(memfd.as_fd()).unwrap();
let addr; let addr = unsafe { mmap_populate(BUFFER_SIZE, &memfd).unwrap() };
unsafe {
addr = rustix::mm::mmap(
null_mut(),
BUFFER_SIZE,
ProtFlags::READ | ProtFlags::WRITE,
MapFlags::POPULATE | MapFlags::SHARED,
&memfd,
0,
)
.unwrap();
}
Socket { Socket {
is_server: false, is_server: false,
addr: addr as _, addr: addr as _,
@@ -159,25 +132,13 @@ impl Channel {
{ {
break; break;
} }
libc::syscall( futex_wait(&self.futex, Y);
libc::SYS_futex,
self.futex.as_ptr(),
libc::FUTEX_WAIT,
Y,
&FUTEX_TIMEOUT,
);
} }
Y => { Y => {
if !test() { if !test() {
return Err(IpcError::Closed); return Err(IpcError::Closed);
} }
libc::syscall( futex_wait(&self.futex, Y);
libc::SYS_futex,
self.futex.as_ptr(),
libc::FUTEX_WAIT,
Y,
&FUTEX_TIMEOUT,
);
} }
_ => std::hint::unreachable_unchecked(), _ => std::hint::unreachable_unchecked(),
} }
@@ -193,17 +154,8 @@ impl Channel {
debug_assert!(matches!(self.futex.load(Ordering::Relaxed), S | X)); debug_assert!(matches!(self.futex.load(Ordering::Relaxed), S | X));
*self.len.get() = data.len() as u32; *self.len.get() = data.len() as u32;
(*self.bytes.get())[0..data.len()].copy_from_slice(data); (*self.bytes.get())[0..data.len()].copy_from_slice(data);
match self.futex.swap(T, Ordering::Release) { if X == self.futex.swap(T, Ordering::Release) {
S => (), futex_wake(&self.futex);
X => {
libc::syscall(
libc::SYS_futex,
self.futex.as_ptr(),
libc::FUTEX_WAKE,
i32::MAX,
);
}
_ => std::hint::unreachable_unchecked(),
} }
} }
unsafe fn server_recv(&self, test: impl Fn() -> bool) -> Result<Vec<u8>, IpcError> { unsafe fn server_recv(&self, test: impl Fn() -> bool) -> Result<Vec<u8>, IpcError> {
@@ -229,25 +181,13 @@ impl Channel {
{ {
break; break;
} }
libc::syscall( futex_wait(&self.futex, Y);
libc::SYS_futex,
self.futex.as_ptr(),
libc::FUTEX_WAIT,
Y,
&FUTEX_TIMEOUT,
);
} }
Y => { Y => {
if !test() { if !test() {
return Err(IpcError::Closed); return Err(IpcError::Closed);
} }
libc::syscall( futex_wait(&self.futex, Y);
libc::SYS_futex,
self.futex.as_ptr(),
libc::FUTEX_WAIT,
Y,
&FUTEX_TIMEOUT,
);
} }
_ => std::hint::unreachable_unchecked(), _ => std::hint::unreachable_unchecked(),
} }
@@ -263,17 +203,8 @@ impl Channel {
debug_assert!(matches!(self.futex.load(Ordering::Relaxed), S | X)); debug_assert!(matches!(self.futex.load(Ordering::Relaxed), S | X));
*self.len.get() = data.len() as u32; *self.len.get() = data.len() as u32;
(*self.bytes.get())[0..data.len()].copy_from_slice(data); (*self.bytes.get())[0..data.len()].copy_from_slice(data);
match self.futex.swap(T, Ordering::Release) { if X == self.futex.swap(T, Ordering::Release) {
S => (), futex_wake(&self.futex);
X => {
libc::syscall(
libc::SYS_futex,
self.futex.as_ptr(),
libc::FUTEX_WAKE,
i32::MAX,
);
}
_ => std::hint::unreachable_unchecked(),
} }
} }
} }

View File

@@ -1,4 +1,3 @@
#[cfg(target_os = "linux")]
pub mod mmap; pub mod mmap;
pub mod unix; pub mod unix;
@@ -7,7 +6,6 @@ use serde::{Deserialize, Serialize};
pub enum Socket { pub enum Socket {
Unix(unix::Socket), Unix(unix::Socket),
#[cfg(target_os = "linux")]
Mmap(mmap::Socket), Mmap(mmap::Socket),
} }
@@ -15,14 +13,12 @@ impl Socket {
pub fn send<T: Serialize>(&mut self, packet: T) -> Result<(), IpcError> { pub fn send<T: Serialize>(&mut self, packet: T) -> Result<(), IpcError> {
match self { match self {
Socket::Unix(x) => x.send(packet), Socket::Unix(x) => x.send(packet),
#[cfg(target_os = "linux")]
Socket::Mmap(x) => x.send(packet), Socket::Mmap(x) => x.send(packet),
} }
} }
pub fn recv<T: for<'a> Deserialize<'a>>(&mut self) -> Result<T, IpcError> { pub fn recv<T: for<'a> Deserialize<'a>>(&mut self) -> Result<T, IpcError> {
match self { match self {
Socket::Unix(x) => x.recv(), Socket::Unix(x) => x.recv(),
#[cfg(target_os = "linux")]
Socket::Mmap(x) => x.recv(), Socket::Mmap(x) => x.recv(),
} }
} }

View File

@@ -3,17 +3,7 @@
//! Provides an easy-to-use extension for vector similarity search. //! Provides an easy-to-use extension for vector similarity search.
#![feature(core_intrinsics)] #![feature(core_intrinsics)]
#![feature(allocator_api)] #![feature(allocator_api)]
#![feature(thread_local)]
#![feature(auto_traits)]
#![feature(negative_impls)]
#![feature(ptr_metadata)]
#![feature(new_uninit)] #![feature(new_uninit)]
#![feature(int_roundings)]
#![feature(never_type)]
#![feature(lazy_cell)]
#![feature(const_maybe_uninit_zeroed)]
#![feature(fs_try_exists)]
#![feature(sync_unsafe_cell)]
#![allow(clippy::complexity)] #![allow(clippy::complexity)]
#![allow(clippy::style)] #![allow(clippy::style)]
@@ -48,7 +38,6 @@ pub unsafe extern "C" fn _PG_init() {
.load(); .load();
self::postgres::init(); self::postgres::init();
self::ipc::transport::unix::init(); self::ipc::transport::unix::init();
#[cfg(target_os = "linux")]
self::ipc::transport::mmap::init(); self::ipc::transport::mmap::init();
} }

View File

@@ -10,14 +10,7 @@ pub enum Transport {
impl Transport { impl Transport {
pub const fn default() -> Transport { pub const fn default() -> Transport {
#[cfg(target_os = "linux")] Transport::mmap
{
Transport::mmap
}
#[cfg(not(target_os = "linux"))]
{
Transport::unix
}
} }
} }

View File

@@ -126,7 +126,7 @@ unsafe extern "C" fn rewrite_plan_state(
_ => (), _ => (),
} }
let walker = std::mem::transmute::<PlanstateTreeWalker, _>(rewrite_plan_state); let walker = std::mem::transmute::<PlanstateTreeWalker, _>(rewrite_plan_state);
#[cfg(not(feature = "pg16"))] #[cfg(any(feature = "pg12", feature = "pg13", feature = "pg14", feature = "pg15"))]
{ {
pgrx::pg_sys::planstate_tree_walker(node, Some(walker), context) pgrx::pg_sys::planstate_tree_walker(node, Some(walker), context)
} }

View File

@@ -3,17 +3,14 @@ use super::gucs::TRANSPORT;
use crate::ipc::client::Rpc; use crate::ipc::client::Rpc;
use crate::ipc::{connect_mmap, connect_unix}; use crate::ipc::{connect_mmap, connect_unix};
use crate::prelude::*; use crate::prelude::*;
use std::cell::RefCell; use crate::utils::cells::PgRefCell;
use std::collections::BTreeSet; use std::collections::BTreeSet;
#[thread_local] static FLUSH_IF_COMMIT: PgRefCell<BTreeSet<Id>> = unsafe { PgRefCell::new(BTreeSet::new()) };
static FLUSH_IF_COMMIT: RefCell<BTreeSet<Id>> = RefCell::new(BTreeSet::new());
#[thread_local] static DROP_IF_COMMIT: PgRefCell<BTreeSet<Id>> = unsafe { PgRefCell::new(BTreeSet::new()) };
static DROP_IF_COMMIT: RefCell<BTreeSet<Id>> = RefCell::new(BTreeSet::new());
#[thread_local] static CLIENT: PgRefCell<Option<Rpc>> = unsafe { PgRefCell::new(None) };
static CLIENT: RefCell<Option<Rpc>> = RefCell::new(None);
pub fn aborting() { pub fn aborting() {
*FLUSH_IF_COMMIT.borrow_mut() = BTreeSet::new(); *FLUSH_IF_COMMIT.borrow_mut() = BTreeSet::new();

View File

@@ -14,6 +14,34 @@ unsafe extern "C" fn vectors_executor_start(
super::hook_executor::post_executor_start(query_desc); super::hook_executor::post_executor_start(query_desc);
} }
#[cfg(any(feature = "pg12", feature = "pg13"))]
#[pgrx::pg_guard]
unsafe extern "C" fn vectors_process_utility(
pstmt: *mut pgrx::pg_sys::PlannedStmt,
query_string: *const ::std::os::raw::c_char,
context: pgrx::pg_sys::ProcessUtilityContext,
params: pgrx::pg_sys::ParamListInfo,
query_env: *mut pgrx::pg_sys::QueryEnvironment,
dest: *mut pgrx::pg_sys::DestReceiver,
qc: *mut pgrx::pg_sys::QueryCompletion,
) {
super::hook_executor::pre_process_utility(pstmt);
if let Some(prev_process_utility) = PREV_PROCESS_UTILITY {
prev_process_utility(pstmt, query_string, context, params, query_env, dest, qc);
} else {
pgrx::pg_sys::standard_ProcessUtility(
pstmt,
query_string,
context,
params,
query_env,
dest,
qc,
);
}
}
#[cfg(any(feature = "pg14", feature = "pg15", feature = "pg16"))]
#[pgrx::pg_guard] #[pgrx::pg_guard]
unsafe extern "C" fn vectors_process_utility( unsafe extern "C" fn vectors_process_utility(
pstmt: *mut pgrx::pg_sys::PlannedStmt, pstmt: *mut pgrx::pg_sys::PlannedStmt,

View File

@@ -5,10 +5,9 @@ use super::index_update;
use crate::postgres::datatype::VectorInput; use crate::postgres::datatype::VectorInput;
use crate::postgres::gucs::ENABLE_VECTOR_INDEX; use crate::postgres::gucs::ENABLE_VECTOR_INDEX;
use crate::prelude::*; use crate::prelude::*;
use std::cell::Cell; use crate::utils::cells::PgCell;
#[thread_local] static RELOPT_KIND: PgCell<pgrx::pg_sys::relopt_kind> = unsafe { PgCell::new(0) };
static RELOPT_KIND: Cell<pgrx::pg_sys::relopt_kind> = Cell::new(0);
pub unsafe fn init() { pub unsafe fn init() {
use pgrx::pg_sys::AsPgCStr; use pgrx::pg_sys::AsPgCStr;
@@ -50,7 +49,6 @@ const AM_HANDLER: pgrx::pg_sys::IndexAmRoutine = {
am_routine.amstrategies = 1; am_routine.amstrategies = 1;
am_routine.amsupport = 0; am_routine.amsupport = 0;
am_routine.amoptsprocnum = 0;
am_routine.amcanorder = false; am_routine.amcanorder = false;
am_routine.amcanorderbyop = true; am_routine.amcanorderbyop = true;
@@ -64,7 +62,6 @@ const AM_HANDLER: pgrx::pg_sys::IndexAmRoutine = {
am_routine.amclusterable = false; am_routine.amclusterable = false;
am_routine.ampredlocks = false; am_routine.ampredlocks = false;
am_routine.amcaninclude = false; am_routine.amcaninclude = false;
am_routine.amusemaintenanceworkmem = false;
am_routine.amkeytype = pgrx::pg_sys::InvalidOid; am_routine.amkeytype = pgrx::pg_sys::InvalidOid;
am_routine.amvalidate = Some(amvalidate); am_routine.amvalidate = Some(amvalidate);
@@ -95,25 +92,27 @@ pub unsafe extern "C" fn amvalidate(opclass_oid: pgrx::pg_sys::Oid) -> bool {
#[cfg(feature = "pg12")] #[cfg(feature = "pg12")]
#[pgrx::pg_guard] #[pgrx::pg_guard]
pub unsafe extern "C" fn amoptions( pub unsafe extern "C" fn amoptions(
reloptions: pg_sys::Datum, reloptions: pgrx::pg_sys::Datum,
validate: bool, validate: bool,
) -> *mut pg_sys::bytea { ) -> *mut pgrx::pg_sys::bytea {
use pg_sys::AsPgCStr; use pgrx::pg_sys::AsPgCStr;
let tab: &[pg_sys::relopt_parse_elt] = &[pg_sys::relopt_parse_elt { let tab: &[pgrx::pg_sys::relopt_parse_elt] = &[pgrx::pg_sys::relopt_parse_elt {
optname: "options".as_pg_cstr(), optname: "options".as_pg_cstr(),
opttype: pg_sys::relopt_type_RELOPT_TYPE_STRING, opttype: pgrx::pg_sys::relopt_type_RELOPT_TYPE_STRING,
offset: index_setup::helper_offset() as i32, offset: index_setup::helper_offset() as i32,
}]; }];
let mut noptions = 0; let mut noptions = 0;
let options = pg_sys::parseRelOptions(reloptions, validate, RELOPT_KIND.get(), &mut noptions); let options =
pgrx::pg_sys::parseRelOptions(reloptions, validate, RELOPT_KIND.get(), &mut noptions);
if noptions == 0 { if noptions == 0 {
return std::ptr::null_mut(); return std::ptr::null_mut();
} }
for relopt in std::slice::from_raw_parts_mut(options, noptions as usize) { for relopt in std::slice::from_raw_parts_mut(options, noptions as usize) {
relopt.gen.as_mut().unwrap().lockmode = pg_sys::AccessExclusiveLock as pg_sys::LOCKMODE; relopt.gen.as_mut().unwrap().lockmode =
pgrx::pg_sys::AccessExclusiveLock as pgrx::pg_sys::LOCKMODE;
} }
let rdopts = pg_sys::allocateReloptStruct(index_setup::helper_size(), options, noptions); let rdopts = pgrx::pg_sys::allocateReloptStruct(index_setup::helper_size(), options, noptions);
pg_sys::fillRelOptions( pgrx::pg_sys::fillRelOptions(
rdopts, rdopts,
index_setup::helper_size(), index_setup::helper_size(),
options, options,
@@ -122,8 +121,8 @@ pub unsafe extern "C" fn amoptions(
tab.as_ptr(), tab.as_ptr(),
tab.len() as i32, tab.len() as i32,
); );
pg_sys::pfree(options as pgrx::void_mut_ptr); pgrx::pg_sys::pfree(options as pgrx::void_mut_ptr);
rdopts as *mut pg_sys::bytea rdopts as *mut pgrx::pg_sys::bytea
} }
#[cfg(any(feature = "pg13", feature = "pg14", feature = "pg15", feature = "pg16"))] #[cfg(any(feature = "pg13", feature = "pg14", feature = "pg15", feature = "pg16"))]
@@ -196,21 +195,21 @@ pub unsafe extern "C" fn ambuildempty(index_relation: pgrx::pg_sys::Relation) {
} }
#[cfg(any(feature = "pg12", feature = "pg13"))] #[cfg(any(feature = "pg12", feature = "pg13"))]
#[pg_guard] #[pgrx::pg_guard]
pub unsafe extern "C" fn aminsert( pub unsafe extern "C" fn aminsert(
index_relation: pg_sys::Relation, index_relation: pgrx::pg_sys::Relation,
values: *mut pg_sys::Datum, values: *mut pgrx::pg_sys::Datum,
is_null: *mut bool, is_null: *mut bool,
heap_tid: pg_sys::ItemPointer, heap_tid: pgrx::pg_sys::ItemPointer,
_heap_relation: pg_sys::Relation, _heap_relation: pgrx::pg_sys::Relation,
_check_unique: pg_sys::IndexUniqueCheck, _check_unique: pgrx::pg_sys::IndexUniqueCheck,
_index_info: *mut pg_sys::IndexInfo, _index_info: *mut pgrx::pg_sys::IndexInfo,
) -> bool { ) -> bool {
use pgrx::FromDatum; use pgrx::FromDatum;
let oid = (*index_relation).rd_id; let oid = (*index_relation).rd_id;
let id = Id::from_sys(oid); let id = Id::from_sys(oid);
let vector = VectorInput::from_datum(*values.add(0), *is_null.add(0)).unwrap(); let vector = VectorInput::from_datum(*values.add(0), *is_null.add(0)).unwrap();
let vector = vector.data().to_vec().into_boxed_slice(); let vector = vector.data().to_vec();
index_update::update_insert(id, vector, heap_tid); index_update::update_insert(id, vector, heap_tid);
true true
} }

View File

@@ -44,11 +44,11 @@ pub unsafe fn build(
} }
#[cfg(feature = "pg12")] #[cfg(feature = "pg12")]
#[pg_guard] #[pgrx::pg_guard]
unsafe extern "C" fn callback( unsafe extern "C" fn callback(
index_relation: pg_sys::Relation, index_relation: pgrx::pg_sys::Relation,
htup: pg_sys::HeapTuple, htup: pgrx::pg_sys::HeapTuple,
values: *mut pg_sys::Datum, values: *mut pgrx::pg_sys::Datum,
is_null: *mut bool, is_null: *mut bool,
_tuple_is_alive: bool, _tuple_is_alive: bool,
state: *mut std::os::raw::c_void, state: *mut std::os::raw::c_void,

View File

@@ -114,6 +114,7 @@ pub unsafe fn start_scan(
} }
} }
#[allow(clippy::never_loop)]
pub unsafe fn next_scan(scan: pgrx::pg_sys::IndexScanDesc) -> bool { pub unsafe fn next_scan(scan: pgrx::pg_sys::IndexScanDesc) -> bool {
let scanner = &mut *((*scan).opaque as *mut Scanner); let scanner = &mut *((*scan).opaque as *mut Scanner);
if matches!(scanner, Scanner::Initial { .. }) { if matches!(scanner, Scanner::Initial { .. }) {

View File

@@ -6,7 +6,13 @@ use std::ops::Deref;
#[pgrx::opname(+)] #[pgrx::opname(+)]
#[pgrx::commutator(+)] #[pgrx::commutator(+)]
fn operator_add(lhs: VectorInput<'_>, rhs: VectorInput<'_>) -> VectorOutput { fn operator_add(lhs: VectorInput<'_>, rhs: VectorInput<'_>) -> VectorOutput {
assert_eq!(lhs.len(), rhs.len(), "Invaild operation."); if lhs.len() != rhs.len() {
FriendlyError::DifferentVectorDims {
left_dimensions: lhs.len() as _,
right_dimensions: rhs.len() as _,
}
.friendly();
}
let n = lhs.len(); let n = lhs.len();
let mut v = Vector::new_zeroed(n); let mut v = Vector::new_zeroed(n);
for i in 0..n { for i in 0..n {
@@ -18,7 +24,13 @@ fn operator_add(lhs: VectorInput<'_>, rhs: VectorInput<'_>) -> VectorOutput {
#[pgrx::pg_operator(immutable, parallel_safe, requires = ["vector"])] #[pgrx::pg_operator(immutable, parallel_safe, requires = ["vector"])]
#[pgrx::opname(-)] #[pgrx::opname(-)]
fn operator_minus(lhs: VectorInput<'_>, rhs: VectorInput<'_>) -> VectorOutput { fn operator_minus(lhs: VectorInput<'_>, rhs: VectorInput<'_>) -> VectorOutput {
assert_eq!(lhs.len(), rhs.len(), "Invaild operation."); if lhs.len() != rhs.len() {
FriendlyError::DifferentVectorDims {
left_dimensions: lhs.len() as _,
right_dimensions: rhs.len() as _,
}
.friendly();
}
let n = lhs.len(); let n = lhs.len();
let mut v = Vector::new_zeroed(n); let mut v = Vector::new_zeroed(n);
for i in 0..n { for i in 0..n {
@@ -34,7 +46,13 @@ fn operator_minus(lhs: VectorInput<'_>, rhs: VectorInput<'_>) -> VectorOutput {
#[pgrx::restrict(scalarltsel)] #[pgrx::restrict(scalarltsel)]
#[pgrx::join(scalarltjoinsel)] #[pgrx::join(scalarltjoinsel)]
fn operator_lt(lhs: VectorInput<'_>, rhs: VectorInput<'_>) -> bool { fn operator_lt(lhs: VectorInput<'_>, rhs: VectorInput<'_>) -> bool {
assert_eq!(lhs.len(), rhs.len(), "Invaild operation."); if lhs.len() != rhs.len() {
FriendlyError::DifferentVectorDims {
left_dimensions: lhs.len() as _,
right_dimensions: rhs.len() as _,
}
.friendly();
}
lhs.deref() < rhs.deref() lhs.deref() < rhs.deref()
} }
@@ -45,7 +63,13 @@ fn operator_lt(lhs: VectorInput<'_>, rhs: VectorInput<'_>) -> bool {
#[pgrx::restrict(scalarltsel)] #[pgrx::restrict(scalarltsel)]
#[pgrx::join(scalarltjoinsel)] #[pgrx::join(scalarltjoinsel)]
fn operator_lte(lhs: VectorInput<'_>, rhs: VectorInput<'_>) -> bool { fn operator_lte(lhs: VectorInput<'_>, rhs: VectorInput<'_>) -> bool {
assert_eq!(lhs.len(), rhs.len(), "Invaild operation."); if lhs.len() != rhs.len() {
FriendlyError::DifferentVectorDims {
left_dimensions: lhs.len() as _,
right_dimensions: rhs.len() as _,
}
.friendly();
}
lhs.deref() <= rhs.deref() lhs.deref() <= rhs.deref()
} }
@@ -56,7 +80,13 @@ fn operator_lte(lhs: VectorInput<'_>, rhs: VectorInput<'_>) -> bool {
#[pgrx::restrict(scalargtsel)] #[pgrx::restrict(scalargtsel)]
#[pgrx::join(scalargtjoinsel)] #[pgrx::join(scalargtjoinsel)]
fn operator_gt(lhs: VectorInput<'_>, rhs: VectorInput<'_>) -> bool { fn operator_gt(lhs: VectorInput<'_>, rhs: VectorInput<'_>) -> bool {
assert_eq!(lhs.len(), rhs.len(), "Invaild operation."); if lhs.len() != rhs.len() {
FriendlyError::DifferentVectorDims {
left_dimensions: lhs.len() as _,
right_dimensions: rhs.len() as _,
}
.friendly();
}
lhs.deref() > rhs.deref() lhs.deref() > rhs.deref()
} }
@@ -67,7 +97,13 @@ fn operator_gt(lhs: VectorInput<'_>, rhs: VectorInput<'_>) -> bool {
#[pgrx::restrict(scalargtsel)] #[pgrx::restrict(scalargtsel)]
#[pgrx::join(scalargtjoinsel)] #[pgrx::join(scalargtjoinsel)]
fn operator_gte(lhs: VectorInput<'_>, rhs: VectorInput<'_>) -> bool { fn operator_gte(lhs: VectorInput<'_>, rhs: VectorInput<'_>) -> bool {
assert_eq!(lhs.len(), rhs.len(), "Invaild operation."); if lhs.len() != rhs.len() {
FriendlyError::DifferentVectorDims {
left_dimensions: lhs.len() as _,
right_dimensions: rhs.len() as _,
}
.friendly();
}
lhs.deref() >= rhs.deref() lhs.deref() >= rhs.deref()
} }
@@ -78,7 +114,13 @@ fn operator_gte(lhs: VectorInput<'_>, rhs: VectorInput<'_>) -> bool {
#[pgrx::restrict(eqsel)] #[pgrx::restrict(eqsel)]
#[pgrx::join(eqjoinsel)] #[pgrx::join(eqjoinsel)]
fn operator_eq(lhs: VectorInput<'_>, rhs: VectorInput<'_>) -> bool { fn operator_eq(lhs: VectorInput<'_>, rhs: VectorInput<'_>) -> bool {
assert_eq!(lhs.len(), rhs.len(), "Invaild operation."); if lhs.len() != rhs.len() {
FriendlyError::DifferentVectorDims {
left_dimensions: lhs.len() as _,
right_dimensions: rhs.len() as _,
}
.friendly();
}
lhs.deref() == rhs.deref() lhs.deref() == rhs.deref()
} }
@@ -89,7 +131,13 @@ fn operator_eq(lhs: VectorInput<'_>, rhs: VectorInput<'_>) -> bool {
#[pgrx::restrict(eqsel)] #[pgrx::restrict(eqsel)]
#[pgrx::join(eqjoinsel)] #[pgrx::join(eqjoinsel)]
fn operator_neq(lhs: VectorInput<'_>, rhs: VectorInput<'_>) -> bool { fn operator_neq(lhs: VectorInput<'_>, rhs: VectorInput<'_>) -> bool {
assert_eq!(lhs.len(), rhs.len(), "Invaild operation."); if lhs.len() != rhs.len() {
FriendlyError::DifferentVectorDims {
left_dimensions: lhs.len() as _,
right_dimensions: rhs.len() as _,
}
.friendly();
}
lhs.deref() != rhs.deref() lhs.deref() != rhs.deref()
} }
@@ -97,7 +145,13 @@ fn operator_neq(lhs: VectorInput<'_>, rhs: VectorInput<'_>) -> bool {
#[pgrx::opname(<=>)] #[pgrx::opname(<=>)]
#[pgrx::commutator(<=>)] #[pgrx::commutator(<=>)]
fn operator_cosine(lhs: VectorInput<'_>, rhs: VectorInput<'_>) -> Scalar { fn operator_cosine(lhs: VectorInput<'_>, rhs: VectorInput<'_>) -> Scalar {
assert_eq!(lhs.len(), rhs.len(), "Invaild operation."); if lhs.len() != rhs.len() {
FriendlyError::DifferentVectorDims {
left_dimensions: lhs.len() as _,
right_dimensions: rhs.len() as _,
}
.friendly();
}
Distance::Cosine.distance(&lhs, &rhs) Distance::Cosine.distance(&lhs, &rhs)
} }
@@ -105,7 +159,13 @@ fn operator_cosine(lhs: VectorInput<'_>, rhs: VectorInput<'_>) -> Scalar {
#[pgrx::opname(<#>)] #[pgrx::opname(<#>)]
#[pgrx::commutator(<#>)] #[pgrx::commutator(<#>)]
fn operator_dot(lhs: VectorInput<'_>, rhs: VectorInput<'_>) -> Scalar { fn operator_dot(lhs: VectorInput<'_>, rhs: VectorInput<'_>) -> Scalar {
assert_eq!(lhs.len(), rhs.len(), "Invaild operation."); if lhs.len() != rhs.len() {
FriendlyError::DifferentVectorDims {
left_dimensions: lhs.len() as _,
right_dimensions: rhs.len() as _,
}
.friendly();
}
Distance::Dot.distance(&lhs, &rhs) Distance::Dot.distance(&lhs, &rhs)
} }
@@ -113,6 +173,12 @@ fn operator_dot(lhs: VectorInput<'_>, rhs: VectorInput<'_>) -> Scalar {
#[pgrx::opname(<->)] #[pgrx::opname(<->)]
#[pgrx::commutator(<->)] #[pgrx::commutator(<->)]
fn operator_l2(lhs: VectorInput<'_>, rhs: VectorInput<'_>) -> Scalar { fn operator_l2(lhs: VectorInput<'_>, rhs: VectorInput<'_>) -> Scalar {
assert_eq!(lhs.len(), rhs.len(), "Invaild operation."); if lhs.len() != rhs.len() {
FriendlyError::DifferentVectorDims {
left_dimensions: lhs.len() as _,
right_dimensions: rhs.len() as _,
}
.friendly();
}
Distance::L2.distance(&lhs, &rhs) Distance::L2.distance(&lhs, &rhs)
} }

63
src/utils/cells.rs Normal file
View File

@@ -0,0 +1,63 @@
use std::cell::{Cell, RefCell, UnsafeCell};
pub struct PgCell<T>(Cell<T>);
unsafe impl<T: Send> Send for PgCell<T> {}
unsafe impl<T: Sync> Sync for PgCell<T> {}
impl<T> PgCell<T> {
pub const unsafe fn new(x: T) -> Self {
Self(Cell::new(x))
}
}
impl<T: Copy> PgCell<T> {
pub fn get(&self) -> T {
self.0.get()
}
pub fn set(&self, value: T) {
self.0.set(value);
}
}
pub struct PgRefCell<T>(RefCell<T>);
unsafe impl<T: Send> Send for PgRefCell<T> {}
unsafe impl<T: Sync> Sync for PgRefCell<T> {}
impl<T> PgRefCell<T> {
pub const unsafe fn new(x: T) -> Self {
Self(RefCell::new(x))
}
pub fn borrow_mut(&self) -> std::cell::RefMut<'_, T> {
self.0.borrow_mut()
}
pub fn borrow(&self) -> std::cell::Ref<'_, T> {
self.0.borrow()
}
}
#[repr(transparent)]
pub struct SyncUnsafeCell<T: ?Sized> {
value: UnsafeCell<T>,
}
unsafe impl<T: ?Sized + Sync> Sync for SyncUnsafeCell<T> {}
impl<T> SyncUnsafeCell<T> {
pub const fn new(value: T) -> Self {
Self {
value: UnsafeCell::new(value),
}
}
}
impl<T: ?Sized> SyncUnsafeCell<T> {
pub fn get(&self) -> *mut T {
self.value.get()
}
pub fn get_mut(&mut self) -> &mut T {
self.value.get_mut()
}
}

View File

@@ -1,5 +1,5 @@
use super::dir_ops::sync_dir; use super::dir_ops::sync_dir;
use std::fs::{try_exists, File}; use std::fs::File;
use std::io::Write; use std::io::Write;
use std::path::{Path, PathBuf}; use std::path::{Path, PathBuf};
@@ -43,7 +43,7 @@ where
{ {
pub fn open(path: impl AsRef<Path>) -> Self { pub fn open(path: impl AsRef<Path>) -> Self {
let path = path.as_ref().to_owned(); let path = path.as_ref().to_owned();
if try_exists("1").unwrap() { if path.join("1").try_exists().unwrap() {
std::fs::remove_file(path.join("1")).unwrap(); std::fs::remove_file(path.join("1")).unwrap();
sync_dir(&path); sync_dir(&path);
} }

View File

@@ -1,8 +1,10 @@
pub mod cells;
pub mod clean; pub mod clean;
pub mod dir_ops; pub mod dir_ops;
pub mod file_atomic; pub mod file_atomic;
pub mod file_socket; pub mod file_socket;
pub mod file_wal; pub mod file_wal;
pub mod mmap_array; pub mod mmap_array;
pub mod os;
pub mod semaphore; pub mod semaphore;
pub mod vec2; pub mod vec2;

98
src/utils/os.rs Normal file
View File

@@ -0,0 +1,98 @@
use rustix::fd::{AsFd, OwnedFd};
use rustix::mm::{MapFlags, ProtFlags};
use std::sync::atomic::AtomicU32;
#[cfg(target_os = "linux")]
pub unsafe fn futex_wait(futex: &AtomicU32, value: u32) {
const FUTEX_TIMEOUT: libc::timespec = libc::timespec {
tv_sec: 15,
tv_nsec: 0,
};
libc::syscall(
libc::SYS_futex,
futex.as_ptr(),
libc::FUTEX_WAIT,
value,
&FUTEX_TIMEOUT,
);
}
#[cfg(target_os = "linux")]
pub unsafe fn futex_wake(futex: &AtomicU32) {
libc::syscall(libc::SYS_futex, futex.as_ptr(), libc::FUTEX_WAKE, i32::MAX);
}
#[cfg(target_os = "linux")]
pub fn memfd_create() -> std::io::Result<OwnedFd> {
use rustix::fs::MemfdFlags;
Ok(rustix::fs::memfd_create(
&format!(".memfd.VECTORS.{:x}", std::process::id()),
MemfdFlags::empty(),
)?)
}
#[cfg(target_os = "linux")]
pub unsafe fn mmap_populate(len: usize, fd: impl AsFd) -> std::io::Result<*mut libc::c_void> {
use std::ptr::null_mut;
Ok(rustix::mm::mmap(
null_mut(),
len,
ProtFlags::READ | ProtFlags::WRITE,
MapFlags::SHARED | MapFlags::POPULATE,
fd,
0,
)?)
}
#[cfg(target_os = "macos")]
pub unsafe fn futex_wait(futex: &AtomicU32, value: u32) {
const ULOCK_TIMEOUT: u32 = 15_000_000;
ulock_sys::__ulock_wait(
ulock_sys::darwin19::UL_COMPARE_AND_WAIT_SHARED,
futex.as_ptr().cast(),
value as _,
ULOCK_TIMEOUT,
);
}
#[cfg(target_os = "macos")]
pub unsafe fn futex_wake(futex: &AtomicU32) {
ulock_sys::__ulock_wake(
ulock_sys::darwin19::UL_COMPARE_AND_WAIT_SHARED,
futex.as_ptr().cast(),
0,
);
}
#[cfg(target_os = "macos")]
pub fn memfd_create() -> std::io::Result<OwnedFd> {
use rustix::fs::Mode;
use rustix::fs::OFlags;
// POSIX fcntl locking do not support shmem, so we use a regular file here.
// reference: https://man7.org/linux/man-pages/man3/fcntl.3p.html
let name = format!(
".shm.VECTORS.{:x}.{:x}",
std::process::id(),
rand::random::<u32>()
);
let fd = rustix::fs::open(
&name,
OFlags::RDWR | OFlags::CREATE | OFlags::EXCL,
Mode::RUSR | Mode::WUSR,
)?;
rustix::fs::unlink(&name)?;
Ok(fd)
}
#[cfg(target_os = "macos")]
pub unsafe fn mmap_populate(len: usize, fd: impl AsFd) -> std::io::Result<*mut libc::c_void> {
use std::ptr::null_mut;
Ok(rustix::mm::mmap(
null_mut(),
len,
ProtFlags::READ | ProtFlags::WRITE,
MapFlags::SHARED,
fd,
0,
)?)
}

View File

@@ -29,8 +29,7 @@ SELECT '[1,2]'::vector < '[1,3]';
---- ----
t t
# TODO: may need better error message statement error differs in dimensions
statement error assertion failed: `\(left == right\)`
SELECT '[1,2]'::vector < '[1,2,3]'; SELECT '[1,2]'::vector < '[1,2,3]';
query I query I