1
0
mirror of https://github.com/tensorchord/pgvecto.rs.git synced 2025-08-08 14:22:07 +03:00

feat: fp16 vector (#178)

* feat: fp16 vector

Signed-off-by: usamoi <usamoi@outlook.com>

* feat: detect avx512fp16

Signed-off-by: usamoi <usamoi@outlook.com>

* fix: install clang-16 for ci

Signed-off-by: usamoi <usamoi@outlook.com>

* fix: clippy

Signed-off-by: usamoi <usamoi@outlook.com>

* fix: rename c to pgvectorsc

Signed-off-by: usamoi <usamoi@outlook.com>

* feat: hand-writing avx512fp16

Signed-off-by: usamoi <usamoi@outlook.com>

* fix: index on fp16

Signed-off-by: usamoi <usamoi@outlook.com>

* feat: hand-writing avx2

Signed-off-by: usamoi <usamoi@outlook.com>

* fix: clippy

Signed-off-by: usamoi <usamoi@outlook.com>

* fix: add rerun in build script

Signed-off-by: usamoi <usamoi@outlook.com>

* fix: cross compilation

Signed-off-by: usamoi <usamoi@outlook.com>

* fix: do not leave uninitialized bytes in datatype input function

Signed-off-by: usamoi <usamoi@outlook.com>

* fix: compiler built-in function calling convention workaround

Signed-off-by: usamoi <usamoi@outlook.com>

* fix: cross compile on aarch64

Signed-off-by: usamoi <usamoi@outlook.com>

* fix: fix detect avx512fp16

Signed-off-by: usamoi <usamoi@outlook.com>

* fix: avx512 codegen by multiversion

Signed-off-by: usamoi <usamoi@outlook.com>

* fix: enable more target features for c

Signed-off-by: usamoi <usamoi@outlook.com>

* fix: use tensorchord/stdarch

Signed-off-by: usamoi <usamoi@outlook.com>

* fix: ci

Signed-off-by: usamoi <usamoi@outlook.com>

* fix: remove no-run cross test

Signed-off-by: usamoi <usamoi@outlook.com>

* fix: vbase

Signed-off-by: usamoi <usamoi@outlook.com>

* fix: error and document

Signed-off-by: usamoi <usamoi@outlook.com>

* [skip ci]

Signed-off-by: usamoi <usamoi@outlook.com>

---------

Signed-off-by: usamoi <usamoi@outlook.com>
This commit is contained in:
Usamoi
2023-12-14 17:50:52 +08:00
committed by GitHub
parent 2ab76118fc
commit 5c0450274d
146 changed files with 7436 additions and 4108 deletions

View File

@@ -4,3 +4,13 @@ rustdocflags = ["--document-private-items"]
[target.'cfg(target_os="macos")']
# Postgres symbols won't be available until runtime
rustflags = ["-Clink-arg=-Wl,-undefined,dynamic_lookup"]
[target.x86_64-unknown-linux-gnu]
linker = "x86_64-linux-gnu-gcc"
[target.aarch64-unknown-linux-gnu]
linker = "aarch64-linux-gnu-gcc"
[env]
BINDGEN_EXTRA_CLANG_ARGS_x86_64_unknown_linux_gnu = "-isystem /usr/x86_64-linux-gnu/include/ -ccc-gcc-name x86_64-linux-gnu-gcc"
BINDGEN_EXTRA_CLANG_ARGS_aarch64_unknown_linux_gnu = "-isystem /usr/aarch64-linux-gnu/include/ -ccc-gcc-name aarch64-linux-gnu-gcc"

View File

@@ -6,6 +6,7 @@ on:
paths:
- ".cargo/**"
- ".github/**"
- "crates/**"
- "scripts/**"
- "src/**"
- "tests/**"
@@ -18,6 +19,7 @@ on:
paths:
- ".cargo/**"
- ".github/**"
- "crates/**"
- "scripts/**"
- "src/**"
- "tests/**"
@@ -90,11 +92,16 @@ jobs:
- name: Format check
run: cargo fmt --check
- name: Semantic check
run: cargo clippy --no-default-features --features "pg${{ matrix.version }} pg_test"
run: |
cargo clippy --no-default-features --features "pg${{ matrix.version }} pg_test" --target x86_64-unknown-linux-gnu
cargo clippy --no-default-features --features "pg${{ matrix.version }} pg_test" --target aarch64-unknown-linux-gnu
- name: Debug build
run: cargo build --no-default-features --features "pg${{ matrix.version }} pg_test"
run: |
cargo build --no-default-features --features "pg${{ matrix.version }} pg_test" --target x86_64-unknown-linux-gnu
cargo build --no-default-features --features "pg${{ matrix.version }} pg_test" --target aarch64-unknown-linux-gnu
- name: Test
run: cargo test --all --no-default-features --features "pg${{ matrix.version }} pg_test" -- --nocapture
run: |
cargo test --all --no-fail-fast --no-default-features --features "pg${{ matrix.version }} pg_test" --target x86_64-unknown-linux-gnu
- name: Install release
run: ./scripts/ci_install.sh
- name: Sqllogictest

View File

@@ -112,15 +112,17 @@ jobs:
- uses: mozilla-actions/sccache-action@v0.0.3
- name: Prepare
run: |
sudo sh -c 'echo "deb http://apt.postgresql.org/pub/repos/apt $(lsb_release -cs)-pgdg main" > /etc/apt/sources.list.d/pgdg.list'
sudo sh -c 'echo "deb http://apt.postgresql.org/pub/repos/apt $(lsb_release -cs)-pgdg main" >> /etc/apt/sources.list.d/pgdg.list'
sudo sh -c 'echo "deb http://apt.llvm.org/$(lsb_release -cs)/ llvm-toolchain-$(lsb_release -cs)-16 main" >> /etc/apt/sources.list'
wget --quiet -O - https://www.postgresql.org/media/keys/ACCC4CF8.asc | sudo apt-key add -
wget --quiet -O - https://apt.llvm.org/llvm-snapshot.gpg.key | sudo apt-key add -
sudo apt-get update
sudo apt-get -y install libpq-dev postgresql-${{ matrix.version }} postgresql-server-dev-${{ matrix.version }}
sudo apt-get -y install clang-16
cargo install cargo-pgrx --git https://github.com/tensorchord/pgrx.git --rev $(cat Cargo.toml | grep "pgrx =" | awk -F'rev = "' '{print $2}' | cut -d'"' -f1)
cargo pgrx init --pg${{ matrix.version }}=/usr/lib/postgresql/${{ matrix.version }}/bin/pg_config
if [[ "${{ matrix.arch }}" == "arm64" ]]; then
sudo apt-get -y install crossbuild-essential-arm64
rustup target add aarch64-unknown-linux-gnu
fi
- name: Build Release
id: build_release
@@ -130,8 +132,6 @@ jobs:
mkdir ./artifacts
cargo pgrx package
if [[ "${{ matrix.arch }}" == "arm64" ]]; then
export CARGO_TARGET_AARCH64_UNKNOWN_LINUX_GNU_LINKER=aarch64-linux-gnu-gcc
export BINDGEN_EXTRA_CLANG_ARGS_aarch64_unknown_linux_gnu="-target aarch64-unknown-linux-gnu -isystem /usr/aarch64-linux-gnu/include/ -ccc-gcc-name aarch64-linux-gnu-gcc"
cargo build --target aarch64-unknown-linux-gnu --release --features "pg${{ matrix.version }}" --no-default-features
mv ./target/aarch64-unknown-linux-gnu/release/libvectors.so ./target/release/vectors-pg${{ matrix.version }}/usr/lib/postgresql/${{ matrix.version }}/lib/vectors.so
fi

3
.gitignore vendored
View File

@@ -6,4 +6,5 @@
.vscode
.ignore
__pycache__
.pytest_cache
.pytest_cache
rustc-ice-*.txt

576
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -1,7 +1,7 @@
[package]
name = "vectors"
version = "0.1.1"
edition = "2021"
version.workspace = true
edition.workspace = true
[lib]
crate-type = ["cdylib"]
@@ -16,45 +16,60 @@ pg16 = ["pgrx/pg16", "pgrx-tests/pg16"]
pg_test = []
[dependencies]
libc.workspace = true
log.workspace = true
serde.workspace = true
serde_json.workspace = true
validator.workspace = true
rustix.workspace = true
thiserror.workspace = true
byteorder.workspace = true
bincode.workspace = true
half.workspace = true
num-traits.workspace = true
service = { path = "crates/service" }
pgrx = { git = "https://github.com/tensorchord/pgrx.git", rev = "7c30e2023876c1efce613756f5ec81f3ab05696b", default-features = false, features = [
] }
openai_api_rust = { git = "https://github.com/tensorchord/openai-api.git", rev = "228d54b6002e98257b3c81501a054942342f585f" }
static_assertions = "1.1.0"
libc = "~0.2"
serde = "1.0.163"
bincode = "1.3.3"
rand = "0.8.5"
byteorder = "1.4.3"
crc32fast = "1.3.2"
log = "0.4.18"
env_logger = "0.10.0"
crossbeam = "0.8.2"
dashmap = "5.4.0"
parking_lot = "0.12.1"
memoffset = "0.9.0"
serde_json = "1"
thiserror = "1.0.40"
tempfile = "3.6.0"
cstr = "0.2.11"
arrayvec = { version = "0.7.3", features = ["serde"] }
memmap2 = "0.9.0"
validator = { version = "0.16.1", features = ["derive"] }
toml = "0.8.8"
rayon = "1.6.1"
uuid = { version = "1.4.1", features = ["serde"] }
rustix = { version = "0.38.20", features = ["net", "mm"] }
arc-swap = "1.6.0"
bytemuck = { version = "1.14.0", features = ["extern_crate_alloc"] }
serde_with = "3.4.0"
multiversion = "0.7.3"
[dev-dependencies]
pgrx-tests = { git = "https://github.com/tensorchord/pgrx.git", rev = "7c30e2023876c1efce613756f5ec81f3ab05696b" }
httpmock = "0.6"
mockall = "0.11.4"
[target.'cfg(target_os = "macos")'.dependencies]
ulock-sys = "0.1.0"
[lints]
clippy.too_many_arguments = "allow"
clippy.unnecessary_literal_unwrap = "allow"
clippy.unnecessary_unwrap = "allow"
rust.unsafe_op_in_unsafe_fn = "warn"
[workspace]
resolver = "2"
members = ["crates/*"]
[workspace.package]
version = "0.0.0"
edition = "2021"
[workspace.dependencies]
libc = "~0.2"
log = "~0.4"
serde = "~1.0"
serde_json = "1"
thiserror = "~1.0"
bincode = "~1.3"
byteorder = "~1.4"
half = { version = "~2.3", features = [
"bytemuck",
"num-traits",
"serde",
"use-intrinsics",
] }
num-traits = "~0.2"
validator = { version = "~0.16", features = ["derive"] }
rustix = { version = "~0.38", features = ["net", "mm"] }
[profile.dev]
panic = "unwind"
@@ -65,10 +80,3 @@ opt-level = 3
lto = "fat"
codegen-units = 1
debug = true
[lints.clippy]
needless_range_loop = "allow"
derivable_impls = "allow"
unnecessary_literal_unwrap = "allow"
too_many_arguments = "allow"
unnecessary_unwrap = "allow"

View File

@@ -21,13 +21,13 @@ pgvecto.rs is a Postgres extension that provides vector similarity search functi
## Comparison with pgvector
| | pgvecto.rs | pgvector |
| ------------------------------------------- | ------------------------------------------------------ | ------------------------ |
| Transaction support | ✅ | ⚠️ |
| Sufficient Result with Delete/Update/Filter | ✅ | ⚠️ |
| Vector Dimension Limit | 65535 | 2000 |
| Prefilter on HNSW | ✅ | ❌ |
| Parallel HNSW Index build | ⚡️ Linearly faster with more cores | 🐌 Only single core used |
| | pgvecto.rs | pgvector |
| ------------------------------------------- | ------------------------------------------------------ | ----------------------- |
| Transaction support | ✅ | ⚠️ |
| Sufficient Result with Delete/Update/Filter | ✅ | ⚠️ |
| Vector Dimension Limit | 65535 | 2000 |
| Prefilter on HNSW | ✅ | ❌ |
| Parallel HNSW Index build | ⚡️ Linearly faster with more cores | 🐌 Only single core used |
| Async Index build | Ready for queries anytime and do not block insertions. | ❌ |
| Quantization | Scalar/Product Quantization | ❌ |
@@ -45,7 +45,11 @@ More details at [./docs/comparison-pgvector.md](./docs/comparison-pgvector.md)
For users, we recommend you to try pgvecto.rs using our pre-built docker image, by running
```sh
docker run --name pgvecto-rs-demo -e POSTGRES_PASSWORD=mysecretpassword -p 5432:5432 -d tensorchord/pgvecto-rs:pg16-latest
docker run \
--name pgvecto-rs-demo \
-e POSTGRES_PASSWORD=mysecretpassword \
-p 5432:5432 \
-d tensorchord/pgvecto-rs:pg16-latest
```
## Development with envd

View File

@@ -19,14 +19,12 @@ URL = f"postgresql://{USER}:{PASS}@{HOST}:{PORT}/{DB_NAME}"
TOML_SETTINGS = {
"flat": toml.dumps(
{
"capacity": 2097152,
"algorithm": {"flat": {}},
"indexing": {"flat": {}},
},
),
"hnsw": toml.dumps(
{
"capacity": 2097152,
"algorithm": {"hnsw": {}},
"indexing": {"hnsw": {}},
},
),
}

View File

@@ -38,7 +38,7 @@ def conn():
@pytest.mark.parametrize(("index_name", "index_setting"), TOML_SETTINGS.items())
def test_create_index(conn: Connection, index_name: str, index_setting: str):
stat = sql.SQL(
"CREATE INDEX {} ON tb_test_item USING vectors (embedding l2_ops) WITH (options={});",
"CREATE INDEX {} ON tb_test_item USING vectors (embedding vector_l2_ops) WITH (options={});",
).format(sql.Identifier(index_name), index_setting)
conn.execute(stat)

View File

@@ -68,7 +68,7 @@ def test_create_index(session: Session, index_name: str, index_setting: str):
Document.embedding,
postgresql_using="vectors",
postgresql_with={"options": f"$${index_setting}$$"},
postgresql_ops={"embedding": "l2_ops"},
postgresql_ops={"embedding": "vector_l2_ops"},
)
index.create(session.bind)
session.commit()

3
crates/c/.gitignore vendored Normal file
View File

@@ -0,0 +1,3 @@
*.s
*.o
*.out

10
crates/c/Cargo.toml Normal file
View File

@@ -0,0 +1,10 @@
[package]
name = "c"
version.workspace = true
edition.workspace = true
[dependencies]
half = { version = "~2.3", features = ["use-intrinsics"] }
[build-dependencies]
cc = "1.0"

10
crates/c/build.rs Normal file
View File

@@ -0,0 +1,10 @@
fn main() {
println!("cargo:rerun-if-changed=src/c.h");
println!("cargo:rerun-if-changed=src/c.c");
cc::Build::new()
.compiler("/usr/bin/clang-16")
.file("./src/c.c")
.opt_level(3)
.debug(true)
.compile("pgvectorsc");
}

118
crates/c/src/c.c Normal file
View File

@@ -0,0 +1,118 @@
#include "c.h"
#include <math.h>
#if defined(__x86_64__)
#include <immintrin.h>
#endif
#if defined(__x86_64__)
__attribute__((target("arch=x86-64-v4,avx512fp16"))) extern float
v_f16_cosine_avx512fp16(_Float16 *a, _Float16 *b, size_t n) {
__m512h xy = _mm512_set1_ph(0);
__m512h xx = _mm512_set1_ph(0);
__m512h yy = _mm512_set1_ph(0);
while (n >= 32) {
__m512h x = _mm512_loadu_ph(a);
__m512h y = _mm512_loadu_ph(b);
a += 32, b += 32, n -= 32;
xy = _mm512_fmadd_ph(x, y, xy);
xx = _mm512_fmadd_ph(x, x, xx);
yy = _mm512_fmadd_ph(y, y, yy);
}
if (n > 0) {
__mmask32 mask = _bzhi_u32(0xFFFFFFFF, n);
__m512h x = _mm512_castsi512_ph(_mm512_maskz_loadu_epi16(mask, a));
__m512h y = _mm512_castsi512_ph(_mm512_maskz_loadu_epi16(mask, b));
xy = _mm512_fmadd_ph(x, y, xy);
xx = _mm512_fmadd_ph(x, x, xx);
yy = _mm512_fmadd_ph(y, y, yy);
}
return (float)(_mm512_reduce_add_ph(xy) /
sqrt(_mm512_reduce_add_ph(xx) * _mm512_reduce_add_ph(yy)));
}
__attribute__((target("arch=x86-64-v4,avx512fp16"))) extern float
v_f16_dot_avx512fp16(_Float16 *a, _Float16 *b, size_t n) {
__m512h xy = _mm512_set1_ph(0);
while (n >= 32) {
__m512h x = _mm512_loadu_ph(a);
__m512h y = _mm512_loadu_ph(b);
a += 32, b += 32, n -= 32;
xy = _mm512_fmadd_ph(x, y, xy);
}
if (n > 0) {
__mmask32 mask = _bzhi_u32(0xFFFFFFFF, n);
__m512h x = _mm512_castsi512_ph(_mm512_maskz_loadu_epi16(mask, a));
__m512h y = _mm512_castsi512_ph(_mm512_maskz_loadu_epi16(mask, b));
xy = _mm512_fmadd_ph(x, y, xy);
}
return (float)_mm512_reduce_add_ph(xy);
}
__attribute__((target("arch=x86-64-v4,avx512fp16"))) extern float
v_f16_sl2_avx512fp16(_Float16 *a, _Float16 *b, size_t n) {
__m512h dd = _mm512_set1_ph(0);
while (n >= 32) {
__m512h x = _mm512_loadu_ph(a);
__m512h y = _mm512_loadu_ph(b);
a += 32, b += 32, n -= 32;
__m512h d = _mm512_sub_ph(x, y);
dd = _mm512_fmadd_ph(d, d, dd);
}
if (n > 0) {
__mmask32 mask = _bzhi_u32(0xFFFFFFFF, n);
__m512h x = _mm512_castsi512_ph(_mm512_maskz_loadu_epi16(mask, a));
__m512h y = _mm512_castsi512_ph(_mm512_maskz_loadu_epi16(mask, b));
__m512h d = _mm512_sub_ph(x, y);
dd = _mm512_fmadd_ph(d, d, dd);
}
return (float)_mm512_reduce_add_ph(dd);
}
__attribute__((target("arch=x86-64-v3"))) extern float
v_f16_cosine_v3(_Float16 *a, _Float16 *b, size_t n) {
float xy = 0;
float xx = 0;
float yy = 0;
#pragma clang loop vectorize_width(8)
for (size_t i = 0; i < n; i++) {
float x = a[i];
float y = b[i];
xy += x * y;
xx += x * x;
yy += y * y;
}
return xy / sqrt(xx * yy);
}
__attribute__((target("arch=x86-64-v3"))) extern float
v_f16_dot_v3(_Float16 *a, _Float16 *b, size_t n) {
float xy = 0;
#pragma clang loop vectorize_width(8)
for (size_t i = 0; i < n; i++) {
float x = a[i];
float y = b[i];
xy += x * y;
}
return xy;
}
__attribute__((target("arch=x86-64-v3"))) extern float
v_f16_sl2_v3(_Float16 *a, _Float16 *b, size_t n) {
float dd = 0;
#pragma clang loop vectorize_width(8)
for (size_t i = 0; i < n; i++) {
float x = a[i];
float y = b[i];
float d = x - y;
dd += d * d;
}
return dd;
}
#endif

13
crates/c/src/c.h Normal file
View File

@@ -0,0 +1,13 @@
#include <stddef.h>
#include <stdint.h>
#if defined(__x86_64__)
extern float v_f16_cosine_avx512fp16(_Float16 *, _Float16 *, size_t n);
extern float v_f16_dot_avx512fp16(_Float16 *, _Float16 *, size_t n);
extern float v_f16_sl2_avx512fp16(_Float16 *, _Float16 *, size_t n);
extern float v_f16_cosine_v3(_Float16 *, _Float16 *, size_t n);
extern float v_f16_dot_v3(_Float16 *, _Float16 *, size_t n);
extern float v_f16_sl2_v3(_Float16 *, _Float16 *, size_t n);
#endif

24
crates/c/src/c.rs Normal file
View File

@@ -0,0 +1,24 @@
#[cfg(target_arch = "x86_64")]
#[link(name = "pgvectorsc", kind = "static")]
extern "C" {
pub fn v_f16_cosine_avx512fp16(a: *const u16, b: *const u16, n: usize) -> f32;
pub fn v_f16_dot_avx512fp16(a: *const u16, b: *const u16, n: usize) -> f32;
pub fn v_f16_sl2_avx512fp16(a: *const u16, b: *const u16, n: usize) -> f32;
pub fn v_f16_cosine_v3(a: *const u16, b: *const u16, n: usize) -> f32;
pub fn v_f16_dot_v3(a: *const u16, b: *const u16, n: usize) -> f32;
pub fn v_f16_sl2_v3(a: *const u16, b: *const u16, n: usize) -> f32;
}
// `compiler_builtin` defines `__extendhfsf2` with integer calling convention.
// However C compilers links `__extendhfsf2` with floating calling convention.
// The code should be removed once Rust offically supports `f16`.
#[cfg(target_arch = "x86_64")]
#[no_mangle]
#[linkage = "external"]
extern "C" fn __extendhfsf2(f: f64) -> f32 {
unsafe {
let f: half::f16 = std::mem::transmute_copy(&f);
f.to_f32()
}
}

6
crates/c/src/lib.rs Normal file
View File

@@ -0,0 +1,6 @@
#![feature(linkage)]
mod c;
#[allow(unused_imports)]
pub use self::c::*;

45
crates/service/Cargo.toml Normal file
View File

@@ -0,0 +1,45 @@
[package]
name = "service"
version.workspace = true
edition.workspace = true
[dependencies]
libc.workspace = true
log.workspace = true
serde.workspace = true
serde_json.workspace = true
validator.workspace = true
rustix.workspace = true
thiserror.workspace = true
byteorder.workspace = true
bincode.workspace = true
half.workspace = true
num-traits.workspace = true
c = { path = "../c" }
std_detect = { git = "https://github.com/tensorchord/stdarch.git", branch = "avx512fp16" }
rand = "0.8.5"
crc32fast = "1.3.2"
crossbeam = "0.8.2"
dashmap = "5.4.0"
parking_lot = "0.12.1"
memoffset = "0.9.0"
tempfile = "3.6.0"
arrayvec = { version = "0.7.3", features = ["serde"] }
memmap2 = "0.9.0"
rayon = "1.6.1"
uuid = { version = "1.6.1", features = ["serde"] }
arc-swap = "1.6.0"
bytemuck = { version = "1.14.0", features = ["extern_crate_alloc"] }
serde_with = "3.4.0"
multiversion = "0.7.3"
ctor = "0.2.6"
[target.'cfg(target_os = "macos")'.dependencies]
ulock-sys = "0.1.0"
[lints]
clippy.derivable_impls = "allow"
clippy.len_without_is_empty = "allow"
clippy.needless_range_loop = "allow"
clippy.too_many_arguments = "allow"
rust.unsafe_op_in_unsafe_fn = "warn"

View File

@@ -4,38 +4,37 @@ use rand::rngs::StdRng;
use rand::{Rng, SeedableRng};
use std::ops::{Index, IndexMut};
pub struct ElkanKMeans {
pub struct ElkanKMeans<S: G> {
dims: u16,
c: usize,
pub centroids: Vec2,
pub centroids: Vec2<S>,
lowerbound: Square,
upperbound: Vec<Scalar>,
upperbound: Vec<F32>,
assign: Vec<usize>,
rand: StdRng,
samples: Vec2,
d: Distance,
samples: Vec2<S>,
}
const DELTA: f32 = 1.0 / 1024.0;
impl ElkanKMeans {
pub fn new(c: usize, samples: Vec2, d: Distance) -> Self {
impl<S: G> ElkanKMeans<S> {
pub fn new(c: usize, samples: Vec2<S>) -> Self {
let n = samples.len();
let dims = samples.dims();
let mut rand = StdRng::from_entropy();
let mut centroids = Vec2::new(dims, c);
let mut lowerbound = Square::new(n, c);
let mut upperbound = vec![Scalar::Z; n];
let mut upperbound = vec![F32::zero(); n];
let mut assign = vec![0usize; n];
centroids[0].copy_from_slice(&samples[rand.gen_range(0..n)]);
let mut weight = vec![Scalar::INFINITY; n];
let mut weight = vec![F32::infinity(); n];
for i in 0..c {
let mut sum = Scalar::Z;
let mut sum = F32::zero();
for j in 0..n {
let dis = d.elkan_k_means_distance(&samples[j], &centroids[i]);
let dis = S::elkan_k_means_distance(&samples[j], &centroids[i]);
lowerbound[(j, i)] = dis;
if dis * dis < weight[j] {
weight[j] = dis * dis;
@@ -49,7 +48,7 @@ impl ElkanKMeans {
let mut choice = sum * rand.gen_range(0.0..1.0);
for j in 0..(n - 1) {
choice -= weight[j];
if choice <= Scalar::Z {
if choice <= F32::zero() {
break 'a j;
}
}
@@ -59,7 +58,7 @@ impl ElkanKMeans {
}
for i in 0..n {
let mut minimal = Scalar::INFINITY;
let mut minimal = F32::infinity();
let mut target = 0;
for j in 0..c {
let dis = lowerbound[(i, j)];
@@ -81,13 +80,11 @@ impl ElkanKMeans {
assign,
rand,
samples,
d,
}
}
pub fn iterate(&mut self) -> bool {
let c = self.c;
let f = |lhs: &[Scalar], rhs: &[Scalar]| self.d.elkan_k_means_distance(lhs, rhs);
let dims = self.dims;
let samples = &self.samples;
let rand = &mut self.rand;
@@ -100,16 +97,16 @@ impl ElkanKMeans {
// Step 1
let mut dist0 = Square::new(c, c);
let mut sp = vec![Scalar::Z; c];
let mut sp = vec![F32::zero(); c];
for i in 0..c {
for j in i + 1..c {
let dis = f(&centroids[i], &centroids[j]) * 0.5;
let dis = S::elkan_k_means_distance(&centroids[i], &centroids[j]) * 0.5;
dist0[(i, j)] = dis;
dist0[(j, i)] = dis;
}
}
for i in 0..c {
let mut minimal = Scalar::INFINITY;
let mut minimal = F32::infinity();
for j in 0..c {
if i == j {
continue;
@@ -127,7 +124,7 @@ impl ElkanKMeans {
if upperbound[i] <= sp[assign[i]] {
continue;
}
let mut minimal = f(&samples[i], &centroids[assign[i]]);
let mut minimal = S::elkan_k_means_distance(&samples[i], &centroids[assign[i]]);
lowerbound[(i, assign[i])] = minimal;
upperbound[i] = minimal;
// Step 3
@@ -142,7 +139,7 @@ impl ElkanKMeans {
continue;
}
if minimal > lowerbound[(i, j)] || minimal > dist0[(assign[i], j)] {
let dis = f(&samples[i], &centroids[j]);
let dis = S::elkan_k_means_distance(&samples[i], &centroids[j]);
lowerbound[(i, j)] = dis;
if dis < minimal {
minimal = dis;
@@ -156,8 +153,8 @@ impl ElkanKMeans {
// Step 4, 7
let old = std::mem::replace(centroids, Vec2::new(dims, c));
let mut count = vec![Scalar::Z; c];
centroids.fill(Scalar::Z);
let mut count = vec![F32::zero(); c];
centroids.fill(S::Scalar::zero());
for i in 0..n {
for j in 0..dims as usize {
centroids[assign[i]][j] += samples[i][j];
@@ -165,21 +162,21 @@ impl ElkanKMeans {
count[assign[i]] += 1.0;
}
for i in 0..c {
if count[i] == Scalar::Z {
if count[i] == F32::zero() {
continue;
}
for dim in 0..dims as usize {
centroids[i][dim] /= count[i];
centroids[i][dim] /= S::Scalar::from_f32(count[i].into());
}
}
for i in 0..c {
if count[i] != Scalar::Z {
if count[i] != F32::zero() {
continue;
}
let mut o = 0;
loop {
let alpha = Scalar(rand.gen_range(0.0..1.0));
let beta = (count[o] - 1.0) / (n - c) as Float;
let alpha = F32::from_f32(rand.gen_range(0.0..1.0f32));
let beta = (count[o] - 1.0) / (n - c) as f32;
if alpha < beta {
break;
}
@@ -188,28 +185,28 @@ impl ElkanKMeans {
centroids.copy_within(o, i);
for dim in 0..dims as usize {
if dim % 2 == 0 {
centroids[i][dim] *= 1.0 + DELTA;
centroids[o][dim] *= 1.0 - DELTA;
centroids[i][dim] *= S::Scalar::from_f32(1.0 + DELTA);
centroids[o][dim] *= S::Scalar::from_f32(1.0 - DELTA);
} else {
centroids[i][dim] *= 1.0 - DELTA;
centroids[o][dim] *= 1.0 + DELTA;
centroids[i][dim] *= S::Scalar::from_f32(1.0 - DELTA);
centroids[o][dim] *= S::Scalar::from_f32(1.0 + DELTA);
}
}
count[i] = count[o] / 2.0;
count[o] = count[o] - count[i];
}
for i in 0..c {
self.d.elkan_k_means_normalize(&mut centroids[i]);
S::elkan_k_means_normalize(&mut centroids[i]);
}
// Step 5, 6
let mut dist1 = vec![Scalar::Z; c];
let mut dist1 = vec![F32::zero(); c];
for i in 0..c {
dist1[i] = f(&old[i], &centroids[i]);
dist1[i] = S::elkan_k_means_distance(&old[i], &centroids[i]);
}
for i in 0..n {
for j in 0..c {
lowerbound[(i, j)] = (lowerbound[(i, j)] - dist1[j]).max(Scalar::Z);
lowerbound[(i, j)] = std::cmp::max(lowerbound[(i, j)] - dist1[j], F32::zero());
}
}
for i in 0..n {
@@ -219,7 +216,7 @@ impl ElkanKMeans {
change == 0
}
pub fn finish(self) -> Vec2 {
pub fn finish(self) -> Vec2<S> {
self.centroids
}
}
@@ -227,7 +224,7 @@ impl ElkanKMeans {
pub struct Square {
x: usize,
y: usize,
v: Box<[Scalar]>,
v: Vec<F32>,
}
impl Square {
@@ -235,13 +232,13 @@ impl Square {
Self {
x,
y,
v: bytemuck::zeroed_slice_box(x * y),
v: bytemuck::zeroed_vec(x * y),
}
}
}
impl Index<(usize, usize)> for Square {
type Output = Scalar;
type Output = F32;
fn index(&self, (x, y): (usize, usize)) -> &Self::Output {
debug_assert!(x < self.x);

View File

@@ -9,16 +9,16 @@ use std::fs::create_dir;
use std::path::PathBuf;
use std::sync::Arc;
pub struct Flat {
mmap: FlatMmap,
pub struct Flat<S: G> {
mmap: FlatMmap<S>,
}
impl Flat {
impl<S: G> Flat<S> {
pub fn create(
path: PathBuf,
options: IndexOptions,
sealed: Vec<Arc<SealedSegment>>,
growing: Vec<Arc<GrowingSegment>>,
sealed: Vec<Arc<SealedSegment<S>>>,
growing: Vec<Arc<GrowingSegment<S>>>,
) -> Self {
create_dir(&path).unwrap();
let ram = make(path.clone(), sealed, growing, options.clone());
@@ -35,7 +35,7 @@ impl Flat {
self.mmap.raw.len()
}
pub fn vector(&self, i: u32) -> &[Scalar] {
pub fn vector(&self, i: u32) -> &[S::Scalar] {
self.mmap.raw.vector(i)
}
@@ -43,35 +43,33 @@ impl Flat {
self.mmap.raw.payload(i)
}
pub fn search(&self, k: usize, vector: &[Scalar], filter: &mut impl Filter) -> Heap {
pub fn search(&self, k: usize, vector: &[S::Scalar], filter: &mut impl Filter) -> Heap {
search(&self.mmap, k, vector, filter)
}
}
unsafe impl Send for Flat {}
unsafe impl Sync for Flat {}
unsafe impl<S: G> Send for Flat<S> {}
unsafe impl<S: G> Sync for Flat<S> {}
pub struct FlatRam {
raw: Arc<Raw>,
quantization: Quantization,
d: Distance,
pub struct FlatRam<S: G> {
raw: Arc<Raw<S>>,
quantization: Quantization<S>,
}
pub struct FlatMmap {
raw: Arc<Raw>,
quantization: Quantization,
d: Distance,
pub struct FlatMmap<S: G> {
raw: Arc<Raw<S>>,
quantization: Quantization<S>,
}
unsafe impl Send for FlatMmap {}
unsafe impl Sync for FlatMmap {}
unsafe impl<S: G> Send for FlatMmap<S> {}
unsafe impl<S: G> Sync for FlatMmap<S> {}
pub fn make(
pub fn make<S: G>(
path: PathBuf,
sealed: Vec<Arc<SealedSegment>>,
growing: Vec<Arc<GrowingSegment>>,
sealed: Vec<Arc<SealedSegment<S>>>,
growing: Vec<Arc<GrowingSegment<S>>>,
options: IndexOptions,
) -> FlatRam {
) -> FlatRam<S> {
let idx_opts = options.indexing.clone().unwrap_flat();
let raw = Arc::new(Raw::create(
path.join("raw"),
@@ -85,22 +83,17 @@ pub fn make(
idx_opts.quantization,
&raw,
);
FlatRam {
raw,
quantization,
d: options.vector.d,
}
FlatRam { raw, quantization }
}
pub fn save(ram: FlatRam, _: PathBuf) -> FlatMmap {
pub fn save<S: G>(ram: FlatRam<S>, _: PathBuf) -> FlatMmap<S> {
FlatMmap {
raw: ram.raw,
quantization: ram.quantization,
d: ram.d,
}
}
pub fn load(path: PathBuf, options: IndexOptions) -> FlatMmap {
pub fn load<S: G>(path: PathBuf, options: IndexOptions) -> FlatMmap<S> {
let idx_opts = options.indexing.clone().unwrap_flat();
let raw = Arc::new(Raw::open(path.join("raw"), options.clone()));
let quantization = Quantization::open(
@@ -109,17 +102,18 @@ pub fn load(path: PathBuf, options: IndexOptions) -> FlatMmap {
idx_opts.quantization,
&raw,
);
FlatMmap {
raw,
quantization,
d: options.vector.d,
}
FlatMmap { raw, quantization }
}
pub fn search(mmap: &FlatMmap, k: usize, vector: &[Scalar], filter: &mut impl Filter) -> Heap {
pub fn search<S: G>(
mmap: &FlatMmap<S>,
k: usize,
vector: &[S::Scalar],
filter: &mut impl Filter,
) -> Heap {
let mut result = Heap::new(k);
for i in 0..mmap.raw.len() {
let distance = mmap.quantization.distance(mmap.d, vector, i);
let distance = mmap.quantization.distance(vector, i);
let payload = mmap.raw.payload(i);
if filter.check(payload) {
result.push(HeapElement { distance, payload });

View File

@@ -3,7 +3,7 @@ use super::raw::Raw;
use crate::index::indexing::hnsw::HnswIndexingOptions;
use crate::index::segments::growing::GrowingSegment;
use crate::index::segments::sealed::SealedSegment;
use crate::index::{IndexOptions, VectorOptions};
use crate::index::IndexOptions;
use crate::prelude::*;
use crate::utils::dir_ops::sync_dir;
use crate::utils::mmap_array::MmapArray;
@@ -17,16 +17,16 @@ use std::ops::RangeInclusive;
use std::path::PathBuf;
use std::sync::Arc;
pub struct Hnsw {
mmap: HnswMmap,
pub struct Hnsw<S: G> {
mmap: HnswMmap<S>,
}
impl Hnsw {
impl<S: G> Hnsw<S> {
pub fn create(
path: PathBuf,
options: IndexOptions,
sealed: Vec<Arc<SealedSegment>>,
growing: Vec<Arc<GrowingSegment>>,
sealed: Vec<Arc<SealedSegment<S>>>,
growing: Vec<Arc<GrowingSegment<S>>>,
) -> Self {
create_dir(&path).unwrap();
let ram = make(path.clone(), sealed, growing, options.clone());
@@ -43,7 +43,7 @@ impl Hnsw {
self.mmap.raw.len()
}
pub fn vector(&self, i: u32) -> &[Scalar] {
pub fn vector(&self, i: u32) -> &[S::Scalar] {
self.mmap.raw.vector(i)
}
@@ -51,27 +51,21 @@ impl Hnsw {
self.mmap.raw.payload(i)
}
pub fn search(&self, k: usize, vector: &[Scalar], filter: &mut impl Filter) -> Heap {
pub fn search(&self, k: usize, vector: &[S::Scalar], filter: &mut impl Filter) -> Heap {
search(&self.mmap, k, vector, filter)
}
pub fn search_vbase<'index, 'vector>(
&'index self,
range: usize,
vector: &'vector [Scalar],
) -> HnswIndexIter<'index, 'vector> {
pub fn search_vbase(&self, range: usize, vector: &[S::Scalar]) -> HnswIndexIter<'_, S> {
search_vbase(&self.mmap, range, vector)
}
}
unsafe impl Send for Hnsw {}
unsafe impl Sync for Hnsw {}
unsafe impl<S: G> Send for Hnsw<S> {}
unsafe impl<S: G> Sync for Hnsw<S> {}
pub struct HnswRam {
raw: Arc<Raw>,
quantization: Quantization,
// ----------------------
d: Distance,
pub struct HnswRam<S: G> {
raw: Arc<Raw<S>>,
quantization: Quantization<S>,
// ----------------------
m: u32,
// ----------------------
@@ -95,14 +89,12 @@ impl HnswRamVertex {
}
struct HnswRamLayer {
edges: Vec<(Scalar, u32)>,
edges: Vec<(F32, u32)>,
}
pub struct HnswMmap {
raw: Arc<Raw>,
quantization: Quantization,
// ----------------------
d: Distance,
pub struct HnswMmap<S: G> {
raw: Arc<Raw<S>>,
quantization: Quantization<S>,
// ----------------------
m: u32,
// ----------------------
@@ -114,20 +106,19 @@ pub struct HnswMmap {
}
#[derive(Debug, Clone, Copy, Default)]
struct HnswMmapEdge(Scalar, u32);
struct HnswMmapEdge(F32, u32);
unsafe impl Send for HnswMmap {}
unsafe impl Sync for HnswMmap {}
unsafe impl<S: G> Send for HnswMmap<S> {}
unsafe impl<S: G> Sync for HnswMmap<S> {}
unsafe impl Pod for HnswMmapEdge {}
unsafe impl Zeroable for HnswMmapEdge {}
pub fn make(
pub fn make<S: G>(
path: PathBuf,
sealed: Vec<Arc<SealedSegment>>,
growing: Vec<Arc<GrowingSegment>>,
sealed: Vec<Arc<SealedSegment<S>>>,
growing: Vec<Arc<GrowingSegment<S>>>,
options: IndexOptions,
) -> HnswRam {
let VectorOptions { d, .. } = options.vector;
) -> HnswRam<S> {
let HnswIndexingOptions {
m,
ef_construction,
@@ -159,23 +150,22 @@ pub fn make(
let entry = RwLock::<Option<u32>>::new(None);
let visited = VisitedPool::new(raw.len());
(0..n).into_par_iter().for_each(|i| {
fn fast_search(
quantization: &Quantization,
fn fast_search<S: G>(
quantization: &Quantization<S>,
graph: &HnswRamGraph,
d: Distance,
levels: RangeInclusive<u8>,
u: u32,
target: &[Scalar],
target: &[S::Scalar],
) -> u32 {
let mut u = u;
let mut u_dis = quantization.distance(d, target, u);
let mut u_dis = quantization.distance(target, u);
for i in levels.rev() {
let mut changed = true;
while changed {
changed = false;
let guard = graph.vertexs[u as usize].layers[i as usize].read();
for &(_, v) in guard.edges.iter() {
let v_dis = quantization.distance(d, target, v);
let v_dis = quantization.distance(target, v);
if v_dis < u_dis {
u = v;
u_dis = v_dis;
@@ -186,21 +176,20 @@ pub fn make(
}
u
}
fn local_search(
quantization: &Quantization,
fn local_search<S: G>(
quantization: &Quantization<S>,
graph: &HnswRamGraph,
d: Distance,
visited: &mut VisitedGuard,
vector: &[Scalar],
vector: &[S::Scalar],
s: u32,
k: usize,
i: u8,
) -> Vec<(Scalar, u32)> {
) -> Vec<(F32, u32)> {
assert!(k > 0);
let mut visited = visited.fetch();
let mut candidates = BinaryHeap::<Reverse<(Scalar, u32)>>::new();
let mut candidates = BinaryHeap::<Reverse<(F32, u32)>>::new();
let mut results = BinaryHeap::new();
let s_dis = quantization.distance(d, vector, s);
let s_dis = quantization.distance(vector, s);
visited.mark(s);
candidates.push(Reverse((s_dis, s)));
results.push((s_dis, s));
@@ -217,7 +206,7 @@ pub fn make(
continue;
}
visited.mark(v);
let v_dis = quantization.distance(d, vector, v);
let v_dis = quantization.distance(vector, v);
if results.len() < k || v_dis < results.peek().unwrap().0 {
candidates.push(Reverse((v_dis, v)));
results.push((v_dis, v));
@@ -229,12 +218,7 @@ pub fn make(
}
results.into_sorted_vec()
}
fn select(
quantization: &Quantization,
d: Distance,
input: &mut Vec<(Scalar, u32)>,
size: u32,
) {
fn select<S: G>(quantization: &Quantization<S>, input: &mut Vec<(F32, u32)>, size: u32) {
if input.len() <= size as usize {
return;
}
@@ -245,7 +229,7 @@ pub fn make(
}
let check = res
.iter()
.map(|&(_, v)| quantization.distance2(d, u, v))
.map(|&(_, v)| quantization.distance2(u, v))
.all(|dist| dist > u_dis);
if check {
res.push((u_dis, u));
@@ -290,14 +274,13 @@ pub fn make(
};
let top = graph.vertexs[u as usize].levels();
if top > levels {
u = fast_search(&quantization, &graph, d, levels + 1..=top, u, target);
u = fast_search(&quantization, &graph, levels + 1..=top, u, target);
}
let mut result = Vec::with_capacity(1 + std::cmp::min(levels, top) as usize);
for j in (0..=std::cmp::min(levels, top)).rev() {
let mut edges = local_search(
&quantization,
&graph,
d,
&mut visited,
target,
u,
@@ -305,12 +288,7 @@ pub fn make(
j,
);
edges.sort();
select(
&quantization,
d,
&mut edges,
count_max_edges_of_a_layer(m, j),
);
select(&quantization, &mut edges, count_max_edges_of_a_layer(m, j));
u = edges.first().unwrap().1;
result.push(edges);
}
@@ -325,7 +303,6 @@ pub fn make(
write.edges.insert(index, element);
select(
&quantization,
d,
&mut write.edges,
count_max_edges_of_a_layer(m, j),
);
@@ -338,14 +315,13 @@ pub fn make(
HnswRam {
raw,
quantization,
d,
m,
graph,
visited,
}
}
pub fn save(mut ram: HnswRam, path: PathBuf) -> HnswMmap {
pub fn save<S: G>(mut ram: HnswRam<S>, path: PathBuf) -> HnswMmap<S> {
let edges = MmapArray::create(
path.join("edges"),
ram.graph
@@ -369,7 +345,6 @@ pub fn save(mut ram: HnswRam, path: PathBuf) -> HnswMmap {
HnswMmap {
raw: ram.raw,
quantization: ram.quantization,
d: ram.d,
m: ram.m,
edges,
by_layer_id,
@@ -378,7 +353,7 @@ pub fn save(mut ram: HnswRam, path: PathBuf) -> HnswMmap {
}
}
pub fn load(path: PathBuf, options: IndexOptions) -> HnswMmap {
pub fn load<S: G>(path: PathBuf, options: IndexOptions) -> HnswMmap<S> {
let idx_opts = options.indexing.clone().unwrap_hnsw();
let raw = Arc::new(Raw::open(path.join("raw"), options.clone()));
let quantization = Quantization::open(
@@ -395,7 +370,6 @@ pub fn load(path: PathBuf, options: IndexOptions) -> HnswMmap {
HnswMmap {
raw,
quantization,
d: options.vector.d,
m: idx_opts.m,
edges,
by_layer_id,
@@ -404,7 +378,12 @@ pub fn load(path: PathBuf, options: IndexOptions) -> HnswMmap {
}
}
pub fn search(mmap: &HnswMmap, k: usize, vector: &[Scalar], filter: &mut impl Filter) -> Heap {
pub fn search<S: G>(
mmap: &HnswMmap<S>,
k: usize,
vector: &[S::Scalar],
filter: &mut impl Filter,
) -> Heap {
let Some(s) = entry(mmap, filter) else {
return Heap::new(k);
};
@@ -413,11 +392,11 @@ pub fn search(mmap: &HnswMmap, k: usize, vector: &[Scalar], filter: &mut impl Fi
local_search(mmap, k, u, vector, filter)
}
pub fn search_vbase<'index, 'vector>(
mmap: &'index HnswMmap,
pub fn search_vbase<'a, S: G>(
mmap: &'a HnswMmap<S>,
range: usize,
vector: &'vector [Scalar],
) -> HnswIndexIter<'index, 'vector> {
vector: &[S::Scalar],
) -> HnswIndexIter<'a, S> {
let filter_fn = &mut |_| true;
let Some(s) = entry(mmap, filter_fn) else {
return HnswIndexIter(None);
@@ -427,7 +406,7 @@ pub fn search_vbase<'index, 'vector>(
local_search_vbase(mmap, range, u, vector)
}
pub fn entry(mmap: &HnswMmap, filter: &mut impl Filter) -> Option<u32> {
pub fn entry<S: G>(mmap: &HnswMmap<S>, filter: &mut impl Filter) -> Option<u32> {
let m = mmap.m;
let n = mmap.raw.len();
let mut shift = 1u64;
@@ -455,15 +434,15 @@ pub fn entry(mmap: &HnswMmap, filter: &mut impl Filter) -> Option<u32> {
None
}
pub fn fast_search(
mmap: &HnswMmap,
pub fn fast_search<S: G>(
mmap: &HnswMmap<S>,
levels: RangeInclusive<u8>,
u: u32,
vector: &[Scalar],
vector: &[S::Scalar],
filter: &mut impl Filter,
) -> u32 {
let mut u = u;
let mut u_dis = mmap.quantization.distance(mmap.d, vector, u);
let mut u_dis = mmap.quantization.distance(vector, u);
for i in levels.rev() {
let mut changed = true;
while changed {
@@ -473,7 +452,7 @@ pub fn fast_search(
if !filter.check(mmap.raw.payload(v)) {
continue;
}
let v_dis = mmap.quantization.distance(mmap.d, vector, v);
let v_dis = mmap.quantization.distance(vector, v);
if v_dis < u_dis {
u = v;
u_dis = v_dis;
@@ -485,20 +464,20 @@ pub fn fast_search(
u
}
pub fn local_search(
mmap: &HnswMmap,
pub fn local_search<S: G>(
mmap: &HnswMmap<S>,
k: usize,
s: u32,
vector: &[Scalar],
vector: &[S::Scalar],
filter: &mut impl Filter,
) -> Heap {
assert!(k > 0);
let mut visited = mmap.visited.fetch();
let mut visited = visited.fetch();
let mut candidates = BinaryHeap::<Reverse<(Scalar, u32)>>::new();
let mut candidates = BinaryHeap::<Reverse<(F32, u32)>>::new();
let mut results = Heap::new(k);
visited.mark(s);
let s_dis = mmap.quantization.distance(mmap.d, vector, s);
let s_dis = mmap.quantization.distance(vector, s);
candidates.push(Reverse((s_dis, s)));
results.push(HeapElement {
distance: s_dis,
@@ -517,7 +496,7 @@ pub fn local_search(
if !filter.check(mmap.raw.payload(v)) {
continue;
}
let v_dis = mmap.quantization.distance(mmap.d, vector, v);
let v_dis = mmap.quantization.distance(vector, v);
if !results.check(v_dis) {
continue;
}
@@ -531,20 +510,20 @@ pub fn local_search(
results
}
fn local_search_vbase<'mmap, 'vector>(
mmap: &'mmap HnswMmap,
fn local_search_vbase<'a, S: G>(
mmap: &'a HnswMmap<S>,
range: usize,
s: u32,
vector: &'vector [Scalar],
) -> HnswIndexIter<'mmap, 'vector> {
vector: &[S::Scalar],
) -> HnswIndexIter<'a, S> {
assert!(range > 0);
let mut visited_guard = mmap.visited.fetch();
let mut visited = visited_guard.fetch();
let mut candidates = BinaryHeap::<Reverse<(Scalar, u32)>>::new();
let mut candidates = BinaryHeap::<Reverse<(F32, u32)>>::new();
let mut results = Heap::new(range);
let mut lost = Vec::<Reverse<HeapElement>>::new();
visited.mark(s);
let s_dis = mmap.quantization.distance(mmap.d, vector, s);
let s_dis = mmap.quantization.distance(vector, s);
candidates.push(Reverse((s_dis, s)));
results.push(HeapElement {
distance: s_dis,
@@ -561,7 +540,7 @@ fn local_search_vbase<'mmap, 'vector>(
continue;
}
visited.mark(v);
let v_dis = mmap.quantization.distance(mmap.d, vector, v);
let v_dis = mmap.quantization.distance(vector, v);
if !results.check(v_dis) {
continue;
}
@@ -582,7 +561,7 @@ fn local_search_vbase<'mmap, 'vector>(
results: results.into_reversed_heap(),
lost,
visited: visited_guard,
vector,
vector: vector.to_vec(),
}))
}
@@ -614,7 +593,7 @@ fn caluate_offsets(iter: impl Iterator<Item = usize>) -> impl Iterator<Item = us
})
}
fn find_edges(mmap: &HnswMmap, u: u32, level: u8) -> &[HnswMmapEdge] {
fn find_edges<S: G>(mmap: &HnswMmap<S>, u: u32, level: u8) -> &[HnswMmapEdge] {
let offset = u as usize;
let index = mmap.by_vertex_id[offset]..mmap.by_vertex_id[offset + 1];
let offset = index.start + level as usize;
@@ -670,7 +649,7 @@ impl<'a> Drop for VisitedGuard<'a> {
fn drop(&mut self) {
let src = VisitedBuffer {
version: 0,
data: Box::new([]),
data: Vec::new(),
};
let buffer = std::mem::replace(&mut self.buffer, src);
self.pool.locked_buffers.lock().push(buffer);
@@ -692,39 +671,39 @@ impl<'a> VisitedChecker<'a> {
struct VisitedBuffer {
version: usize,
data: Box<[usize]>,
data: Vec<usize>,
}
impl VisitedBuffer {
fn new(capacity: usize) -> Self {
Self {
version: 0,
data: bytemuck::zeroed_slice_box(capacity),
data: bytemuck::zeroed_vec(capacity),
}
}
}
pub struct HnswIndexIter<'mmap, 'vector>(Option<HnswIndexIterInner<'mmap, 'vector>>);
pub struct HnswIndexIter<'mmap, S: G>(Option<HnswIndexIterInner<'mmap, S>>);
pub struct HnswIndexIterInner<'mmap, 'vector> {
mmap: &'mmap HnswMmap,
pub struct HnswIndexIterInner<'mmap, S: G> {
mmap: &'mmap HnswMmap<S>,
range: usize,
candidates: BinaryHeap<Reverse<(Scalar, u32)>>,
candidates: BinaryHeap<Reverse<(F32, u32)>>,
results: BinaryHeap<Reverse<HeapElement>>,
// The points lost in the first stage, we should keep it to the second stage.
lost: Vec<Reverse<HeapElement>>,
visited: VisitedGuard<'mmap>,
vector: &'vector [Scalar],
vector: Vec<S::Scalar>,
}
impl Iterator for HnswIndexIter<'_, '_> {
impl<S: G> Iterator for HnswIndexIter<'_, S> {
type Item = HeapElement;
fn next(&mut self) -> Option<Self::Item> {
self.0.as_mut()?.next()
}
}
impl Iterator for HnswIndexIterInner<'_, '_> {
impl<S: G> Iterator for HnswIndexIterInner<'_, S> {
type Item = HeapElement;
fn next(&mut self) -> Option<Self::Item> {
if self.results.len() > self.range {
@@ -739,7 +718,7 @@ impl Iterator for HnswIndexIterInner<'_, '_> {
continue;
}
visited.mark(v);
let v_dis = self.mmap.quantization.distance(self.mmap.d, self.vector, v);
let v_dis = self.mmap.quantization.distance(&self.vector, v);
self.candidates.push(Reverse((v_dis, v)));
self.results.push(Reverse(HeapElement {
distance: v_dis,
@@ -755,7 +734,7 @@ impl Iterator for HnswIndexIterInner<'_, '_> {
}
}
impl HnswIndexIterInner<'_, '_> {
impl<S: G> HnswIndexIterInner<'_, S> {
fn pop(&mut self) -> Option<HeapElement> {
if self.results.peek() > self.lost.last() {
self.results.pop().map(|x| x.0)

View File

@@ -20,16 +20,16 @@ use std::sync::atomic::AtomicU32;
use std::sync::atomic::Ordering::{Acquire, Relaxed, Release};
use std::sync::Arc;
pub struct IvfNaive {
mmap: IvfMmap,
pub struct IvfNaive<S: G> {
mmap: IvfMmap<S>,
}
impl IvfNaive {
impl<S: G> IvfNaive<S> {
pub fn create(
path: PathBuf,
options: IndexOptions,
sealed: Vec<Arc<SealedSegment>>,
growing: Vec<Arc<GrowingSegment>>,
sealed: Vec<Arc<SealedSegment<S>>>,
growing: Vec<Arc<GrowingSegment<S>>>,
) -> Self {
create_dir(&path).unwrap();
let ram = make(path.clone(), sealed, growing, options);
@@ -47,7 +47,7 @@ impl IvfNaive {
self.mmap.raw.len()
}
pub fn vector(&self, i: u32) -> &[Scalar] {
pub fn vector(&self, i: u32) -> &[S::Scalar] {
self.mmap.raw.vector(i)
}
@@ -55,65 +55,63 @@ impl IvfNaive {
self.mmap.raw.payload(i)
}
pub fn search(&self, k: usize, vector: &[Scalar], filter: &mut impl Filter) -> Heap {
pub fn search(&self, k: usize, vector: &[S::Scalar], filter: &mut impl Filter) -> Heap {
search(&self.mmap, k, vector, filter)
}
}
unsafe impl Send for IvfNaive {}
unsafe impl Sync for IvfNaive {}
unsafe impl<S: G> Send for IvfNaive<S> {}
unsafe impl<S: G> Sync for IvfNaive<S> {}
pub struct IvfRam {
raw: Arc<Raw>,
quantization: Quantization,
pub struct IvfRam<S: G> {
raw: Arc<Raw<S>>,
quantization: Quantization<S>,
// ----------------------
dims: u16,
d: Distance,
// ----------------------
nlist: u32,
nprobe: u32,
// ----------------------
centroids: Vec2,
centroids: Vec2<S>,
heads: Vec<AtomicU32>,
nexts: Vec<SyncUnsafeCell<u32>>,
}
unsafe impl Send for IvfRam {}
unsafe impl Sync for IvfRam {}
unsafe impl<S: G> Send for IvfRam<S> {}
unsafe impl<S: G> Sync for IvfRam<S> {}
pub struct IvfMmap {
raw: Arc<Raw>,
quantization: Quantization,
pub struct IvfMmap<S: G> {
raw: Arc<Raw<S>>,
quantization: Quantization<S>,
// ----------------------
dims: u16,
d: Distance,
// ----------------------
nlist: u32,
nprobe: u32,
// ----------------------
centroids: MmapArray<Scalar>,
centroids: MmapArray<S::Scalar>,
heads: MmapArray<u32>,
nexts: MmapArray<u32>,
}
unsafe impl Send for IvfMmap {}
unsafe impl Sync for IvfMmap {}
unsafe impl<S: G> Send for IvfMmap<S> {}
unsafe impl<S: G> Sync for IvfMmap<S> {}
impl IvfMmap {
fn centroids(&self, i: u32) -> &[Scalar] {
impl<S: G> IvfMmap<S> {
fn centroids(&self, i: u32) -> &[S::Scalar] {
let s = i as usize * self.dims as usize;
let e = (i + 1) as usize * self.dims as usize;
&self.centroids[s..e]
}
}
pub fn make(
pub fn make<S: G>(
path: PathBuf,
sealed: Vec<Arc<SealedSegment>>,
growing: Vec<Arc<GrowingSegment>>,
sealed: Vec<Arc<SealedSegment<S>>>,
growing: Vec<Arc<GrowingSegment<S>>>,
options: IndexOptions,
) -> IvfRam {
let VectorOptions { dims, d } = options.vector;
) -> IvfRam<S> {
let VectorOptions { dims, .. } = options.vector;
let IvfIndexingOptions {
least_iterations,
iterations,
@@ -140,9 +138,9 @@ pub fn make(
let mut samples = Vec2::new(dims, m as usize);
for i in 0..m {
samples[i as usize].copy_from_slice(raw.vector(f[i as usize] as u32));
d.elkan_k_means_normalize(&mut samples[i as usize]);
S::elkan_k_means_normalize(&mut samples[i as usize]);
}
let mut k_means = ElkanKMeans::new(nlist as usize, samples, d);
let mut k_means = ElkanKMeans::new(nlist as usize, samples);
for _ in 0..least_iterations {
k_means.iterate();
}
@@ -164,10 +162,10 @@ pub fn make(
};
(0..n).into_par_iter().for_each(|i| {
let mut vector = raw.vector(i).to_vec();
d.elkan_k_means_normalize(&mut vector);
let mut result = (Scalar::INFINITY, 0);
S::elkan_k_means_normalize(&mut vector);
let mut result = (F32::infinity(), 0);
for i in 0..nlist {
let dis = d.elkan_k_means_distance(&vector, &centroids[i as usize]);
let dis = S::elkan_k_means_distance(&vector, &centroids[i as usize]);
result = std::cmp::min(result, (dis, i));
}
let centroid_id = result.1;
@@ -191,11 +189,10 @@ pub fn make(
nprobe,
nlist,
dims,
d,
}
}
pub fn save(mut ram: IvfRam, path: PathBuf) -> IvfMmap {
pub fn save<S: G>(mut ram: IvfRam<S>, path: PathBuf) -> IvfMmap<S> {
let centroids = MmapArray::create(
path.join("centroids"),
(0..ram.nlist)
@@ -214,7 +211,6 @@ pub fn save(mut ram: IvfRam, path: PathBuf) -> IvfMmap {
raw: ram.raw,
quantization: ram.quantization,
dims: ram.dims,
d: ram.d,
nlist: ram.nlist,
nprobe: ram.nprobe,
centroids,
@@ -223,7 +219,7 @@ pub fn save(mut ram: IvfRam, path: PathBuf) -> IvfMmap {
}
}
pub fn load(path: PathBuf, options: IndexOptions) -> IvfMmap {
pub fn load<S: G>(path: PathBuf, options: IndexOptions) -> IvfMmap<S> {
let raw = Arc::new(Raw::open(path.join("raw"), options.clone()));
let quantization = Quantization::open(
path.join("quantization"),
@@ -239,7 +235,6 @@ pub fn load(path: PathBuf, options: IndexOptions) -> IvfMmap {
raw,
quantization,
dims: options.vector.dims,
d: options.vector.d,
nlist,
nprobe,
centroids,
@@ -248,13 +243,18 @@ pub fn load(path: PathBuf, options: IndexOptions) -> IvfMmap {
}
}
pub fn search(mmap: &IvfMmap, k: usize, vector: &[Scalar], filter: &mut impl Filter) -> Heap {
pub fn search<S: G>(
mmap: &IvfMmap<S>,
k: usize,
vector: &[S::Scalar],
filter: &mut impl Filter,
) -> Heap {
let mut target = vector.to_vec();
mmap.d.elkan_k_means_normalize(&mut target);
S::elkan_k_means_normalize(&mut target);
let mut lists = Heap::new(mmap.nprobe as usize);
for i in 0..mmap.nlist {
let centroid = mmap.centroids(i);
let distance = mmap.d.elkan_k_means_distance(&target, centroid);
let distance = S::elkan_k_means_distance(&target, centroid);
if lists.check(distance) {
lists.push(HeapElement {
distance,
@@ -267,7 +267,7 @@ pub fn search(mmap: &IvfMmap, k: usize, vector: &[Scalar], filter: &mut impl Fil
for i in lists.iter().map(|e| e.payload as usize) {
let mut j = mmap.heads[i];
while u32::MAX != j {
let distance = mmap.quantization.distance(mmap.d, vector, j);
let distance = mmap.quantization.distance(vector, j);
let payload = mmap.raw.payload(j);
if result.check(distance) && filter.check(payload) {
result.push(HeapElement { distance, payload });

View File

@@ -20,16 +20,16 @@ use std::sync::atomic::AtomicU32;
use std::sync::atomic::Ordering::{Acquire, Relaxed, Release};
use std::sync::Arc;
pub struct IvfPq {
mmap: IvfMmap,
pub struct IvfPq<S: G> {
mmap: IvfMmap<S>,
}
impl IvfPq {
impl<S: G> IvfPq<S> {
pub fn create(
path: PathBuf,
options: IndexOptions,
sealed: Vec<Arc<SealedSegment>>,
growing: Vec<Arc<GrowingSegment>>,
sealed: Vec<Arc<SealedSegment<S>>>,
growing: Vec<Arc<GrowingSegment<S>>>,
) -> Self {
create_dir(&path).unwrap();
let ram = make(path.clone(), sealed, growing, options);
@@ -47,7 +47,7 @@ impl IvfPq {
self.mmap.raw.len()
}
pub fn vector(&self, i: u32) -> &[Scalar] {
pub fn vector(&self, i: u32) -> &[S::Scalar] {
self.mmap.raw.vector(i)
}
@@ -55,65 +55,63 @@ impl IvfPq {
self.mmap.raw.payload(i)
}
pub fn search(&self, k: usize, vector: &[Scalar], filter: &mut impl Filter) -> Heap {
pub fn search(&self, k: usize, vector: &[S::Scalar], filter: &mut impl Filter) -> Heap {
search(&self.mmap, k, vector, filter)
}
}
unsafe impl Send for IvfPq {}
unsafe impl Sync for IvfPq {}
unsafe impl<S: G> Send for IvfPq<S> {}
unsafe impl<S: G> Sync for IvfPq<S> {}
pub struct IvfRam {
raw: Arc<Raw>,
quantization: ProductQuantization,
pub struct IvfRam<S: G> {
raw: Arc<Raw<S>>,
quantization: ProductQuantization<S>,
// ----------------------
dims: u16,
d: Distance,
// ----------------------
nlist: u32,
nprobe: u32,
// ----------------------
centroids: Vec2,
centroids: Vec2<S>,
heads: Vec<AtomicU32>,
nexts: Vec<SyncUnsafeCell<u32>>,
}
unsafe impl Send for IvfRam {}
unsafe impl Sync for IvfRam {}
unsafe impl<S: G> Send for IvfRam<S> {}
unsafe impl<S: G> Sync for IvfRam<S> {}
pub struct IvfMmap {
raw: Arc<Raw>,
quantization: ProductQuantization,
pub struct IvfMmap<S: G> {
raw: Arc<Raw<S>>,
quantization: ProductQuantization<S>,
// ----------------------
dims: u16,
d: Distance,
// ----------------------
nlist: u32,
nprobe: u32,
// ----------------------
centroids: MmapArray<Scalar>,
centroids: MmapArray<S::Scalar>,
heads: MmapArray<u32>,
nexts: MmapArray<u32>,
}
unsafe impl Send for IvfMmap {}
unsafe impl Sync for IvfMmap {}
unsafe impl<S: G> Send for IvfMmap<S> {}
unsafe impl<S: G> Sync for IvfMmap<S> {}
impl IvfMmap {
fn centroids(&self, i: u32) -> &[Scalar] {
impl<S: G> IvfMmap<S> {
fn centroids(&self, i: u32) -> &[S::Scalar] {
let s = i as usize * self.dims as usize;
let e = (i + 1) as usize * self.dims as usize;
&self.centroids[s..e]
}
}
pub fn make(
pub fn make<S: G>(
path: PathBuf,
sealed: Vec<Arc<SealedSegment>>,
growing: Vec<Arc<GrowingSegment>>,
sealed: Vec<Arc<SealedSegment<S>>>,
growing: Vec<Arc<GrowingSegment<S>>>,
options: IndexOptions,
) -> IvfRam {
let VectorOptions { dims, d } = options.vector;
) -> IvfRam<S> {
let VectorOptions { dims, .. } = options.vector;
let IvfIndexingOptions {
least_iterations,
iterations,
@@ -134,9 +132,9 @@ pub fn make(
let mut samples = Vec2::new(dims, m as usize);
for i in 0..m {
samples[i as usize].copy_from_slice(raw.vector(f[i as usize] as u32));
d.elkan_k_means_normalize(&mut samples[i as usize]);
S::elkan_k_means_normalize(&mut samples[i as usize]);
}
let mut k_means = ElkanKMeans::new(nlist as usize, samples, d);
let mut k_means = ElkanKMeans::new(nlist as usize, samples);
for _ in 0..least_iterations {
k_means.iterate();
}
@@ -163,10 +161,10 @@ pub fn make(
&raw,
|i, target| {
let mut vector = target.to_vec();
d.elkan_k_means_normalize(&mut vector);
let mut result = (Scalar::INFINITY, 0);
S::elkan_k_means_normalize(&mut vector);
let mut result = (F32::infinity(), 0);
for i in 0..nlist {
let dis = d.elkan_k_means_distance(&vector, &centroids[i as usize]);
let dis = S::elkan_k_means_distance(&vector, &centroids[i as usize]);
result = std::cmp::min(result, (dis, i));
}
let centroid_id = result.1;
@@ -194,11 +192,10 @@ pub fn make(
nprobe,
nlist,
dims,
d,
}
}
pub fn save(mut ram: IvfRam, path: PathBuf) -> IvfMmap {
pub fn save<S: G>(mut ram: IvfRam<S>, path: PathBuf) -> IvfMmap<S> {
let centroids = MmapArray::create(
path.join("centroids"),
(0..ram.nlist)
@@ -217,7 +214,6 @@ pub fn save(mut ram: IvfRam, path: PathBuf) -> IvfMmap {
raw: ram.raw,
quantization: ram.quantization,
dims: ram.dims,
d: ram.d,
nlist: ram.nlist,
nprobe: ram.nprobe,
centroids,
@@ -226,7 +222,7 @@ pub fn save(mut ram: IvfRam, path: PathBuf) -> IvfMmap {
}
}
pub fn load(path: PathBuf, options: IndexOptions) -> IvfMmap {
pub fn load<S: G>(path: PathBuf, options: IndexOptions) -> IvfMmap<S> {
let raw = Arc::new(Raw::open(path.join("raw"), options.clone()));
let quantization = ProductQuantization::open(
path.join("quantization"),
@@ -242,7 +238,6 @@ pub fn load(path: PathBuf, options: IndexOptions) -> IvfMmap {
raw,
quantization,
dims: options.vector.dims,
d: options.vector.d,
nlist,
nprobe,
centroids,
@@ -251,13 +246,18 @@ pub fn load(path: PathBuf, options: IndexOptions) -> IvfMmap {
}
}
pub fn search(mmap: &IvfMmap, k: usize, vector: &[Scalar], filter: &mut impl Filter) -> Heap {
pub fn search<S: G>(
mmap: &IvfMmap<S>,
k: usize,
vector: &[S::Scalar],
filter: &mut impl Filter,
) -> Heap {
let mut target = vector.to_vec();
mmap.d.elkan_k_means_normalize(&mut target);
S::elkan_k_means_normalize(&mut target);
let mut lists = Heap::new(mmap.nprobe as usize);
for i in 0..mmap.nlist {
let centroid = mmap.centroids(i);
let distance = mmap.d.elkan_k_means_distance(&target, centroid);
let distance = S::elkan_k_means_distance(&target, centroid);
if lists.check(distance) {
lists.push(HeapElement {
distance,
@@ -270,9 +270,9 @@ pub fn search(mmap: &IvfMmap, k: usize, vector: &[Scalar], filter: &mut impl Fil
for i in lists.iter().map(|e| e.payload as u32) {
let mut j = mmap.heads[i as usize];
while u32::MAX != j {
let distance =
mmap.quantization
.distance_with_delta(mmap.d, vector, j, mmap.centroids(i));
let distance = mmap
.quantization
.distance_with_delta(vector, j, mmap.centroids(i));
let payload = mmap.raw.payload(j);
if result.check(distance) && filter.check(payload) {
result.push(HeapElement { distance, payload });

View File

@@ -10,17 +10,17 @@ use crate::prelude::*;
use std::path::PathBuf;
use std::sync::Arc;
pub enum Ivf {
Naive(IvfNaive),
Pq(IvfPq),
pub enum Ivf<S: G> {
Naive(IvfNaive<S>),
Pq(IvfPq<S>),
}
impl Ivf {
impl<S: G> Ivf<S> {
pub fn create(
path: PathBuf,
options: IndexOptions,
sealed: Vec<Arc<SealedSegment>>,
growing: Vec<Arc<GrowingSegment>>,
sealed: Vec<Arc<SealedSegment<S>>>,
growing: Vec<Arc<GrowingSegment<S>>>,
) -> Self {
if options
.indexing
@@ -56,7 +56,7 @@ impl Ivf {
}
}
pub fn vector(&self, i: u32) -> &[Scalar] {
pub fn vector(&self, i: u32) -> &[S::Scalar] {
match self {
Ivf::Naive(x) => x.vector(i),
Ivf::Pq(x) => x.vector(i),
@@ -70,7 +70,7 @@ impl Ivf {
}
}
pub fn search(&self, k: usize, vector: &[Scalar], filter: &mut impl Filter) -> Heap {
pub fn search(&self, k: usize, vector: &[S::Scalar], filter: &mut impl Filter) -> Heap {
match self {
Ivf::Naive(x) => x.search(k, vector, filter),
Ivf::Pq(x) => x.search(k, vector, filter),

View File

@@ -1,5 +1,4 @@
pub mod clustering;
pub mod diskann;
pub mod flat;
pub mod hnsw;
pub mod ivf;

View File

@@ -56,35 +56,35 @@ impl QuantizationOptions {
}
}
pub trait Quan {
pub trait Quan<S: G> {
fn create(
path: PathBuf,
options: IndexOptions,
quantization_options: QuantizationOptions,
raw: &Arc<Raw>,
raw: &Arc<Raw<S>>,
) -> Self;
fn open(
path: PathBuf,
options: IndexOptions,
quantization_options: QuantizationOptions,
raw: &Arc<Raw>,
raw: &Arc<Raw<S>>,
) -> Self;
fn distance(&self, d: Distance, lhs: &[Scalar], rhs: u32) -> Scalar;
fn distance2(&self, d: Distance, lhs: u32, rhs: u32) -> Scalar;
fn distance(&self, lhs: &[S::Scalar], rhs: u32) -> F32;
fn distance2(&self, lhs: u32, rhs: u32) -> F32;
}
pub enum Quantization {
Trivial(TrivialQuantization),
Scalar(ScalarQuantization),
Product(ProductQuantization),
pub enum Quantization<S: G> {
Trivial(TrivialQuantization<S>),
Scalar(ScalarQuantization<S>),
Product(ProductQuantization<S>),
}
impl Quantization {
impl<S: G> Quantization<S> {
pub fn create(
path: PathBuf,
options: IndexOptions,
quantization_options: QuantizationOptions,
raw: &Arc<Raw>,
raw: &Arc<Raw<S>>,
) -> Self {
match quantization_options {
QuantizationOptions::Trivial(_) => Self::Trivial(TrivialQuantization::create(
@@ -112,7 +112,7 @@ impl Quantization {
path: PathBuf,
options: IndexOptions,
quantization_options: QuantizationOptions,
raw: &Arc<Raw>,
raw: &Arc<Raw<S>>,
) -> Self {
match quantization_options {
QuantizationOptions::Trivial(_) => Self::Trivial(TrivialQuantization::open(
@@ -136,21 +136,21 @@ impl Quantization {
}
}
pub fn distance(&self, d: Distance, lhs: &[Scalar], rhs: u32) -> Scalar {
pub fn distance(&self, lhs: &[S::Scalar], rhs: u32) -> F32 {
use Quantization::*;
match self {
Trivial(x) => x.distance(d, lhs, rhs),
Scalar(x) => x.distance(d, lhs, rhs),
Product(x) => x.distance(d, lhs, rhs),
Trivial(x) => x.distance(lhs, rhs),
Scalar(x) => x.distance(lhs, rhs),
Product(x) => x.distance(lhs, rhs),
}
}
pub fn distance2(&self, d: Distance, lhs: u32, rhs: u32) -> Scalar {
pub fn distance2(&self, lhs: u32, rhs: u32) -> F32 {
use Quantization::*;
match self {
Trivial(x) => x.distance2(d, lhs, rhs),
Scalar(x) => x.distance2(d, lhs, rhs),
Product(x) => x.distance2(d, lhs, rhs),
Trivial(x) => x.distance2(lhs, rhs),
Scalar(x) => x.distance2(lhs, rhs),
Product(x) => x.distance2(lhs, rhs),
}
}
}

View File

@@ -55,17 +55,17 @@ impl Default for ProductQuantizationOptionsRatio {
}
}
pub struct ProductQuantization {
pub struct ProductQuantization<S: G> {
dims: u16,
ratio: u16,
centroids: Vec<Scalar>,
centroids: Vec<S::Scalar>,
codes: MmapArray<u8>,
}
unsafe impl Send for ProductQuantization {}
unsafe impl Sync for ProductQuantization {}
unsafe impl<S: G> Send for ProductQuantization<S> {}
unsafe impl<S: G> Sync for ProductQuantization<S> {}
impl ProductQuantization {
impl<S: G> ProductQuantization<S> {
fn codes(&self, i: u32) -> &[u8] {
let width = self.dims.div_ceil(self.ratio);
let s = i as usize * width as usize;
@@ -74,12 +74,12 @@ impl ProductQuantization {
}
}
impl Quan for ProductQuantization {
impl<S: G> Quan<S> for ProductQuantization<S> {
fn create(
path: PathBuf,
options: IndexOptions,
quantization_options: QuantizationOptions,
raw: &Arc<Raw>,
raw: &Arc<Raw<S>>,
) -> Self {
Self::with_normalizer(path, options, quantization_options, raw, |_, _| ())
}
@@ -88,7 +88,7 @@ impl Quan for ProductQuantization {
path: PathBuf,
options: IndexOptions,
quantization_options: QuantizationOptions,
_: &Arc<Raw>,
_: &Arc<Raw<S>>,
) -> Self {
let centroids =
serde_json::from_slice(&std::fs::read(path.join("centroids")).unwrap()).unwrap();
@@ -101,32 +101,32 @@ impl Quan for ProductQuantization {
}
}
fn distance(&self, d: Distance, lhs: &[Scalar], rhs: u32) -> Scalar {
fn distance(&self, lhs: &[S::Scalar], rhs: u32) -> F32 {
let dims = self.dims;
let ratio = self.ratio;
let rhs = self.codes(rhs);
d.product_quantization_distance(dims, ratio, &self.centroids, lhs, rhs)
S::product_quantization_distance(dims, ratio, &self.centroids, lhs, rhs)
}
fn distance2(&self, d: Distance, lhs: u32, rhs: u32) -> Scalar {
fn distance2(&self, lhs: u32, rhs: u32) -> F32 {
let dims = self.dims;
let ratio = self.ratio;
let lhs = self.codes(lhs);
let rhs = self.codes(rhs);
d.product_quantization_distance2(dims, ratio, &self.centroids, lhs, rhs)
S::product_quantization_distance2(dims, ratio, &self.centroids, lhs, rhs)
}
}
impl ProductQuantization {
impl<S: G> ProductQuantization<S> {
pub fn with_normalizer<F>(
path: PathBuf,
options: IndexOptions,
quantization_options: QuantizationOptions,
raw: &Raw,
raw: &Raw<S>,
normalizer: F,
) -> Self
where
F: Fn(u32, &mut [Scalar]),
F: Fn(u32, &mut [S::Scalar]),
{
std::fs::create_dir(&path).unwrap();
let quantization_options = quantization_options.unwrap_product_quantization();
@@ -136,22 +136,22 @@ impl ProductQuantization {
let m = std::cmp::min(n, quantization_options.sample);
let samples = {
let f = sample(&mut thread_rng(), n as usize, m as usize).into_vec();
let mut samples = Vec2::new(options.vector.dims, m as usize);
let mut samples = Vec2::<S>::new(options.vector.dims, m as usize);
for i in 0..m {
samples[i as usize].copy_from_slice(raw.vector(f[i as usize] as u32));
}
samples
};
let width = dims.div_ceil(ratio);
let mut centroids = vec![Scalar::Z; 256 * dims as usize];
let mut centroids = vec![S::Scalar::zero(); 256 * dims as usize];
for i in 0..width {
let subdims = std::cmp::min(ratio, dims - ratio * i);
let mut subsamples = Vec2::new(subdims, m as usize);
let mut subsamples = Vec2::<S::L2>::new(subdims, m as usize);
for j in 0..m {
let src = &samples[j as usize][(i * ratio) as usize..][..subdims as usize];
subsamples[j as usize].copy_from_slice(src);
}
let mut k_means = ElkanKMeans::new(256, subsamples, Distance::L2);
let mut k_means = ElkanKMeans::<S::L2>::new(256, subsamples);
for _ in 0..25 {
if k_means.iterate() {
break;
@@ -170,13 +170,13 @@ impl ProductQuantization {
let mut result = Vec::with_capacity(width as usize);
for i in 0..width {
let subdims = std::cmp::min(ratio, dims - ratio * i);
let mut minimal = Scalar::INFINITY;
let mut minimal = F32::infinity();
let mut target = 0u8;
let left = &vector[(i * ratio) as usize..][..subdims as usize];
for j in 0u8..=255 {
let right = &centroids[j as usize * dims as usize..][(i * ratio) as usize..]
[..subdims as usize];
let dis = Distance::L2.distance(left, right);
let dis = S::L2::distance(left, right);
if dis < minimal {
minimal = dis;
target = j;
@@ -201,16 +201,10 @@ impl ProductQuantization {
}
}
pub fn distance_with_delta(
&self,
d: Distance,
lhs: &[Scalar],
rhs: u32,
delta: &[Scalar],
) -> Scalar {
pub fn distance_with_delta(&self, lhs: &[S::Scalar], rhs: u32, delta: &[S::Scalar]) -> F32 {
let dims = self.dims;
let ratio = self.ratio;
let rhs = self.codes(rhs);
d.product_quantization_distance_with_delta(dims, ratio, &self.centroids, lhs, rhs, delta)
S::product_quantization_distance_with_delta(dims, ratio, &self.centroids, lhs, rhs, delta)
}
}

View File

@@ -19,17 +19,17 @@ impl Default for ScalarQuantizationOptions {
}
}
pub struct ScalarQuantization {
pub struct ScalarQuantization<S: G> {
dims: u16,
max: Vec<Scalar>,
min: Vec<Scalar>,
max: Vec<S::Scalar>,
min: Vec<S::Scalar>,
codes: MmapArray<u8>,
}
unsafe impl Send for ScalarQuantization {}
unsafe impl Sync for ScalarQuantization {}
unsafe impl<S: G> Send for ScalarQuantization<S> {}
unsafe impl<S: G> Sync for ScalarQuantization<S> {}
impl ScalarQuantization {
impl<S: G> ScalarQuantization<S> {
fn codes(&self, i: u32) -> &[u8] {
let s = i as usize * self.dims as usize;
let e = (i + 1) as usize * self.dims as usize;
@@ -37,17 +37,17 @@ impl ScalarQuantization {
}
}
impl Quan for ScalarQuantization {
impl<S: G> Quan<S> for ScalarQuantization<S> {
fn create(
path: PathBuf,
options: IndexOptions,
_: QuantizationOptions,
raw: &Arc<Raw>,
raw: &Arc<Raw<S>>,
) -> Self {
std::fs::create_dir(&path).unwrap();
let dims = options.vector.dims;
let mut max = vec![Scalar::NEG_INFINITY; dims as usize];
let mut min = vec![Scalar::INFINITY; dims as usize];
let mut max = vec![S::Scalar::neg_infinity(); dims as usize];
let mut min = vec![S::Scalar::infinity(); dims as usize];
let n = raw.len();
for i in 0..n {
let vector = raw.vector(i);
@@ -62,7 +62,7 @@ impl Quan for ScalarQuantization {
let vector = raw.vector(i);
let mut result = vec![0u8; dims as usize];
for i in 0..dims as usize {
let w = ((vector[i] - min[i]) / (max[i] - min[i]) * 256.0).0 as u32;
let w = (((vector[i] - min[i]) / (max[i] - min[i])).to_f32() * 256.0) as u32;
result[i] = w.clamp(0, 255) as u8;
}
result.into_iter()
@@ -77,7 +77,7 @@ impl Quan for ScalarQuantization {
}
}
fn open(path: PathBuf, options: IndexOptions, _: QuantizationOptions, _: &Arc<Raw>) -> Self {
fn open(path: PathBuf, options: IndexOptions, _: QuantizationOptions, _: &Arc<Raw<S>>) -> Self {
let dims = options.vector.dims;
let max = serde_json::from_slice(&std::fs::read("max").unwrap()).unwrap();
let min = serde_json::from_slice(&std::fs::read("min").unwrap()).unwrap();
@@ -90,16 +90,16 @@ impl Quan for ScalarQuantization {
}
}
fn distance(&self, d: Distance, lhs: &[Scalar], rhs: u32) -> Scalar {
fn distance(&self, lhs: &[S::Scalar], rhs: u32) -> F32 {
let dims = self.dims;
let rhs = self.codes(rhs);
d.scalar_quantization_distance(dims, &self.max, &self.min, lhs, rhs)
S::scalar_quantization_distance(dims, &self.max, &self.min, lhs, rhs)
}
fn distance2(&self, d: Distance, lhs: u32, rhs: u32) -> Scalar {
fn distance2(&self, lhs: u32, rhs: u32) -> F32 {
let dims = self.dims;
let lhs = self.codes(lhs);
let rhs = self.codes(rhs);
d.scalar_quantization_distance2(dims, &self.max, &self.min, lhs, rhs)
S::scalar_quantization_distance2(dims, &self.max, &self.min, lhs, rhs)
}
}

View File

@@ -17,24 +17,24 @@ impl Default for TrivialQuantizationOptions {
}
}
pub struct TrivialQuantization {
raw: Arc<Raw>,
pub struct TrivialQuantization<S: G> {
raw: Arc<Raw<S>>,
}
impl Quan for TrivialQuantization {
fn create(_: PathBuf, _: IndexOptions, _: QuantizationOptions, raw: &Arc<Raw>) -> Self {
impl<S: G> Quan<S> for TrivialQuantization<S> {
fn create(_: PathBuf, _: IndexOptions, _: QuantizationOptions, raw: &Arc<Raw<S>>) -> Self {
Self { raw: raw.clone() }
}
fn open(_: PathBuf, _: IndexOptions, _: QuantizationOptions, raw: &Arc<Raw>) -> Self {
fn open(_: PathBuf, _: IndexOptions, _: QuantizationOptions, raw: &Arc<Raw<S>>) -> Self {
Self { raw: raw.clone() }
}
fn distance(&self, d: Distance, lhs: &[Scalar], rhs: u32) -> Scalar {
d.distance(lhs, self.raw.vector(rhs))
fn distance(&self, lhs: &[S::Scalar], rhs: u32) -> F32 {
S::distance(lhs, self.raw.vector(rhs))
}
fn distance2(&self, d: Distance, lhs: u32, rhs: u32) -> Scalar {
d.distance(self.raw.vector(lhs), self.raw.vector(rhs))
fn distance2(&self, lhs: u32, rhs: u32) -> F32 {
S::distance(self.raw.vector(lhs), self.raw.vector(rhs))
}
}

View File

@@ -6,16 +6,16 @@ use crate::utils::mmap_array::MmapArray;
use std::path::PathBuf;
use std::sync::Arc;
pub struct Raw {
mmap: RawMmap,
pub struct Raw<S: G> {
mmap: RawMmap<S>,
}
impl Raw {
impl<S: G> Raw<S> {
pub fn create(
path: PathBuf,
options: IndexOptions,
sealed: Vec<Arc<SealedSegment>>,
growing: Vec<Arc<GrowingSegment>>,
sealed: Vec<Arc<SealedSegment<S>>>,
growing: Vec<Arc<GrowingSegment<S>>>,
) -> Self {
std::fs::create_dir(&path).unwrap();
let ram = make(sealed, growing, options);
@@ -33,7 +33,7 @@ impl Raw {
self.mmap.len()
}
pub fn vector(&self, i: u32) -> &[Scalar] {
pub fn vector(&self, i: u32) -> &[S::Scalar] {
self.mmap.vector(i)
}
@@ -42,21 +42,21 @@ impl Raw {
}
}
unsafe impl Send for Raw {}
unsafe impl Sync for Raw {}
unsafe impl<S: G> Send for Raw<S> {}
unsafe impl<S: G> Sync for Raw<S> {}
struct RawRam {
sealed: Vec<Arc<SealedSegment>>,
growing: Vec<Arc<GrowingSegment>>,
struct RawRam<S: G> {
sealed: Vec<Arc<SealedSegment<S>>>,
growing: Vec<Arc<GrowingSegment<S>>>,
dims: u16,
}
impl RawRam {
impl<S: G> RawRam<S> {
fn len(&self) -> u32 {
self.sealed.iter().map(|x| x.len()).sum::<u32>()
+ self.growing.iter().map(|x| x.len()).sum::<u32>()
}
fn vector(&self, mut index: u32) -> &[Scalar] {
fn vector(&self, mut index: u32) -> &[S::Scalar] {
for x in self.sealed.iter() {
if index < x.len() {
return x.vector(index);
@@ -88,18 +88,18 @@ impl RawRam {
}
}
struct RawMmap {
vectors: MmapArray<Scalar>,
struct RawMmap<S: G> {
vectors: MmapArray<S::Scalar>,
payload: MmapArray<Payload>,
dims: u16,
}
impl RawMmap {
impl<S: G> RawMmap<S> {
fn len(&self) -> u32 {
self.payload.len() as u32
}
fn vector(&self, i: u32) -> &[Scalar] {
fn vector(&self, i: u32) -> &[S::Scalar] {
let s = i as usize * self.dims as usize;
let e = (i + 1) as usize * self.dims as usize;
&self.vectors[s..e]
@@ -110,14 +110,14 @@ impl RawMmap {
}
}
unsafe impl Send for RawMmap {}
unsafe impl Sync for RawMmap {}
unsafe impl<S: G> Send for RawMmap<S> {}
unsafe impl<S: G> Sync for RawMmap<S> {}
fn make(
sealed: Vec<Arc<SealedSegment>>,
growing: Vec<Arc<GrowingSegment>>,
fn make<S: G>(
sealed: Vec<Arc<SealedSegment<S>>>,
growing: Vec<Arc<GrowingSegment<S>>>,
options: IndexOptions,
) -> RawRam {
) -> RawRam<S> {
RawRam {
sealed,
growing,
@@ -125,7 +125,7 @@ fn make(
}
}
fn save(ram: RawRam, path: PathBuf) -> RawMmap {
fn save<S: G>(ram: RawRam<S>, path: PathBuf) -> RawMmap<S> {
let n = ram.len();
let vectors_iter = (0..n).flat_map(|i| ram.vector(i)).copied();
let payload_iter = (0..n).map(|i| ram.payload(i));
@@ -138,8 +138,8 @@ fn save(ram: RawRam, path: PathBuf) -> RawMmap {
}
}
fn load(path: PathBuf, options: IndexOptions) -> RawMmap {
let vectors: MmapArray<Scalar> = MmapArray::open(path.join("vectors"));
fn load<S: G>(path: PathBuf, options: IndexOptions) -> RawMmap<S> {
let vectors = MmapArray::open(path.join("vectors"));
let payload = MmapArray::open(path.join("payload"));
RawMmap {
vectors,

View File

@@ -24,16 +24,16 @@ impl Default for FlatIndexingOptions {
}
}
pub struct FlatIndexing {
raw: crate::algorithms::flat::Flat,
pub struct FlatIndexing<S: G> {
raw: crate::algorithms::flat::Flat<S>,
}
impl AbstractIndexing for FlatIndexing {
impl<S: G> AbstractIndexing<S> for FlatIndexing<S> {
fn create(
path: PathBuf,
options: IndexOptions,
sealed: Vec<Arc<SealedSegment>>,
growing: Vec<Arc<GrowingSegment>>,
sealed: Vec<Arc<SealedSegment<S>>>,
growing: Vec<Arc<GrowingSegment<S>>>,
) -> Self {
let raw = Flat::create(path, options, sealed, growing);
Self { raw }
@@ -48,7 +48,7 @@ impl AbstractIndexing for FlatIndexing {
self.raw.len()
}
fn vector(&self, i: u32) -> &[Scalar] {
fn vector(&self, i: u32) -> &[S::Scalar] {
self.raw.vector(i)
}
@@ -56,7 +56,7 @@ impl AbstractIndexing for FlatIndexing {
self.raw.payload(i)
}
fn search(&self, k: usize, vector: &[Scalar], filter: &mut impl Filter) -> Heap {
fn search(&self, k: usize, vector: &[S::Scalar], filter: &mut impl Filter) -> Heap {
self.raw.search(k, vector, filter)
}
}

View File

@@ -41,16 +41,16 @@ impl Default for HnswIndexingOptions {
}
}
pub struct HnswIndexing {
raw: Hnsw,
pub struct HnswIndexing<S: G> {
raw: Hnsw<S>,
}
impl AbstractIndexing for HnswIndexing {
impl<S: G> AbstractIndexing<S> for HnswIndexing<S> {
fn create(
path: PathBuf,
options: IndexOptions,
sealed: Vec<Arc<SealedSegment>>,
growing: Vec<Arc<GrowingSegment>>,
sealed: Vec<Arc<SealedSegment<S>>>,
growing: Vec<Arc<GrowingSegment<S>>>,
) -> Self {
let raw = Hnsw::create(path, options, sealed, growing);
Self { raw }
@@ -65,7 +65,7 @@ impl AbstractIndexing for HnswIndexing {
self.raw.len()
}
fn vector(&self, i: u32) -> &[Scalar] {
fn vector(&self, i: u32) -> &[S::Scalar] {
self.raw.vector(i)
}
@@ -73,17 +73,13 @@ impl AbstractIndexing for HnswIndexing {
self.raw.payload(i)
}
fn search(&self, k: usize, vector: &[Scalar], filter: &mut impl Filter) -> Heap {
fn search(&self, k: usize, vector: &[S::Scalar], filter: &mut impl Filter) -> Heap {
self.raw.search(k, vector, filter)
}
}
impl HnswIndexing {
pub fn search_vbase<'index, 'vector>(
&'index self,
range: usize,
vector: &'vector [Scalar],
) -> HnswIndexIter<'index, 'vector> {
impl<S: G> HnswIndexing<S> {
pub fn search_vbase(&self, range: usize, vector: &[S::Scalar]) -> HnswIndexIter<'_, S> {
self.raw.search_vbase(range, vector)
}
}

View File

@@ -4,7 +4,6 @@ use crate::algorithms::quantization::QuantizationOptions;
use crate::index::segments::growing::GrowingSegment;
use crate::index::segments::sealed::SealedSegment;
use crate::index::IndexOptions;
use crate::prelude::Scalar;
use crate::prelude::*;
use serde::{Deserialize, Serialize};
use std::path::PathBuf;
@@ -64,16 +63,16 @@ impl Default for IvfIndexingOptions {
}
}
pub struct IvfIndexing {
raw: Ivf,
pub struct IvfIndexing<S: G> {
raw: Ivf<S>,
}
impl AbstractIndexing for IvfIndexing {
impl<S: G> AbstractIndexing<S> for IvfIndexing<S> {
fn create(
path: PathBuf,
options: IndexOptions,
sealed: Vec<Arc<SealedSegment>>,
growing: Vec<Arc<GrowingSegment>>,
sealed: Vec<Arc<SealedSegment<S>>>,
growing: Vec<Arc<GrowingSegment<S>>>,
) -> Self {
let raw = Ivf::create(path, options, sealed, growing);
Self { raw }
@@ -88,7 +87,7 @@ impl AbstractIndexing for IvfIndexing {
self.raw.len()
}
fn vector(&self, i: u32) -> &[Scalar] {
fn vector(&self, i: u32) -> &[S::Scalar] {
self.raw.vector(i)
}
@@ -96,7 +95,7 @@ impl AbstractIndexing for IvfIndexing {
self.raw.payload(i)
}
fn search(&self, k: usize, vector: &[Scalar], filter: &mut impl Filter) -> Heap {
fn search(&self, k: usize, vector: &[S::Scalar], filter: &mut impl Filter) -> Heap {
self.raw.search(k, vector, filter)
}
}

View File

@@ -60,36 +60,36 @@ impl Validate for IndexingOptions {
}
}
pub trait AbstractIndexing: Sized {
pub trait AbstractIndexing<S: G>: Sized {
fn create(
path: PathBuf,
options: IndexOptions,
sealed: Vec<Arc<SealedSegment>>,
growing: Vec<Arc<GrowingSegment>>,
sealed: Vec<Arc<SealedSegment<S>>>,
growing: Vec<Arc<GrowingSegment<S>>>,
) -> Self;
fn open(path: PathBuf, options: IndexOptions) -> Self;
fn len(&self) -> u32;
fn vector(&self, i: u32) -> &[Scalar];
fn vector(&self, i: u32) -> &[S::Scalar];
fn payload(&self, i: u32) -> Payload;
fn search(&self, k: usize, vector: &[Scalar], filter: &mut impl Filter) -> Heap;
fn search(&self, k: usize, vector: &[S::Scalar], filter: &mut impl Filter) -> Heap;
}
pub enum DynamicIndexing {
Flat(FlatIndexing),
Ivf(IvfIndexing),
Hnsw(HnswIndexing),
pub enum DynamicIndexing<S: G> {
Flat(FlatIndexing<S>),
Ivf(IvfIndexing<S>),
Hnsw(HnswIndexing<S>),
}
pub enum DynamicIndexIter<'index, 'vector> {
Hnsw(HnswIndexIter<'index, 'vector>),
pub enum DynamicIndexIter<'a, S: G> {
Hnsw(HnswIndexIter<'a, S>),
}
impl DynamicIndexing {
impl<S: G> DynamicIndexing<S> {
pub fn create(
path: PathBuf,
options: IndexOptions,
sealed: Vec<Arc<SealedSegment>>,
growing: Vec<Arc<GrowingSegment>>,
sealed: Vec<Arc<SealedSegment<S>>>,
growing: Vec<Arc<GrowingSegment<S>>>,
) -> Self {
match options.indexing {
IndexingOptions::Flat(_) => {
@@ -120,7 +120,7 @@ impl DynamicIndexing {
}
}
pub fn vector(&self, i: u32) -> &[Scalar] {
pub fn vector(&self, i: u32) -> &[S::Scalar] {
match self {
DynamicIndexing::Flat(x) => x.vector(i),
DynamicIndexing::Ivf(x) => x.vector(i),
@@ -136,7 +136,7 @@ impl DynamicIndexing {
}
}
pub fn search(&self, k: usize, vector: &[Scalar], filter: &mut impl Filter) -> Heap {
pub fn search(&self, k: usize, vector: &[S::Scalar], filter: &mut impl Filter) -> Heap {
match self {
DynamicIndexing::Flat(x) => x.search(k, vector, filter),
DynamicIndexing::Ivf(x) => x.search(k, vector, filter),
@@ -144,11 +144,7 @@ impl DynamicIndexing {
}
}
pub fn search_vbase<'index, 'vector>(
&'index self,
range: usize,
vector: &'vector [Scalar],
) -> DynamicIndexIter<'index, 'vector> {
pub fn vbase(&self, range: usize, vector: &[S::Scalar]) -> DynamicIndexIter<'_, S> {
use DynamicIndexIter::*;
match self {
DynamicIndexing::Hnsw(x) => Hnsw(x.search_vbase(range, vector)),
@@ -157,7 +153,7 @@ impl DynamicIndexing {
}
}
impl Iterator for DynamicIndexIter<'_, '_> {
impl<S: G> Iterator for DynamicIndexIter<'_, S> {
type Item = HeapElement;
fn next(&mut self) -> Option<Self::Item> {
use DynamicIndexIter::*;

View File

@@ -0,0 +1,506 @@
pub mod delete;
pub mod indexing;
pub mod optimizing;
pub mod segments;
use self::delete::Delete;
use self::indexing::IndexingOptions;
use self::optimizing::OptimizingOptions;
use self::segments::growing::GrowingSegment;
use self::segments::growing::GrowingSegmentInsertError;
use self::segments::sealed::SealedSegment;
use self::segments::SegmentsOptions;
use crate::index::indexing::DynamicIndexIter;
use crate::index::optimizing::indexing::OptimizerIndexing;
use crate::index::optimizing::sealing::OptimizerSealing;
use crate::prelude::*;
use crate::utils::clean::clean;
use crate::utils::dir_ops::sync_dir;
use crate::utils::file_atomic::FileAtomic;
use arc_swap::ArcSwap;
use crossbeam::atomic::AtomicCell;
use parking_lot::Mutex;
use serde::{Deserialize, Serialize};
use std::cmp::Reverse;
use std::collections::BinaryHeap;
use std::collections::HashMap;
use std::collections::HashSet;
use std::path::PathBuf;
use std::sync::Arc;
use std::time::Instant;
use thiserror::Error;
use uuid::Uuid;
use validator::Validate;
#[derive(Debug, Error)]
#[error("The index view is outdated.")]
pub struct OutdatedError(#[from] pub Option<GrowingSegmentInsertError>);
#[derive(Debug, Clone, Serialize, Deserialize, Validate)]
pub struct VectorOptions {
#[validate(range(min = 1, max = 65535))]
#[serde(rename = "dimensions")]
pub dims: u16,
#[serde(rename = "distance")]
pub d: Distance,
#[serde(rename = "kind")]
pub k: Kind,
}
#[derive(Debug, Clone, Serialize, Deserialize, Validate)]
pub struct IndexOptions {
#[validate]
pub vector: VectorOptions,
#[validate]
pub segment: SegmentsOptions,
#[validate]
pub optimizing: OptimizingOptions,
#[validate]
pub indexing: IndexingOptions,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct IndexStat {
pub indexing: bool,
pub sealed: Vec<u32>,
pub growing: Vec<u32>,
pub write: u32,
pub options: IndexOptions,
}
pub struct Index<S: G> {
path: PathBuf,
options: IndexOptions,
delete: Arc<Delete>,
protect: Mutex<IndexProtect<S>>,
view: ArcSwap<IndexView<S>>,
instant_index: AtomicCell<Instant>,
instant_write: AtomicCell<Instant>,
_tracker: Arc<IndexTracker>,
}
impl<S: G> Index<S> {
pub fn create(path: PathBuf, options: IndexOptions) -> Arc<Self> {
assert!(options.validate().is_ok());
std::fs::create_dir(&path).unwrap();
std::fs::create_dir(path.join("segments")).unwrap();
let startup = FileAtomic::create(
path.join("startup"),
IndexStartup {
sealeds: HashSet::new(),
growings: HashSet::new(),
},
);
let delete = Delete::create(path.join("delete"));
sync_dir(&path);
let index = Arc::new(Index {
path: path.clone(),
options: options.clone(),
delete: delete.clone(),
protect: Mutex::new(IndexProtect {
startup,
sealed: HashMap::new(),
growing: HashMap::new(),
write: None,
}),
view: ArcSwap::new(Arc::new(IndexView {
options: options.clone(),
sealed: HashMap::new(),
growing: HashMap::new(),
delete: delete.clone(),
write: None,
})),
instant_index: AtomicCell::new(Instant::now()),
instant_write: AtomicCell::new(Instant::now()),
_tracker: Arc::new(IndexTracker { path }),
});
OptimizerIndexing::new(index.clone()).spawn();
OptimizerSealing::new(index.clone()).spawn();
index
}
pub fn open(path: PathBuf, options: IndexOptions) -> Arc<Self> {
let tracker = Arc::new(IndexTracker { path: path.clone() });
let startup = FileAtomic::<IndexStartup>::open(path.join("startup"));
clean(
path.join("segments"),
startup
.get()
.sealeds
.iter()
.map(|s| s.to_string())
.chain(startup.get().growings.iter().map(|s| s.to_string())),
);
let sealed = startup
.get()
.sealeds
.iter()
.map(|&uuid| {
(
uuid,
SealedSegment::open(
tracker.clone(),
path.join("segments").join(uuid.to_string()),
uuid,
options.clone(),
),
)
})
.collect::<HashMap<_, _>>();
let growing = startup
.get()
.growings
.iter()
.map(|&uuid| {
(
uuid,
GrowingSegment::open(
tracker.clone(),
path.join("segments").join(uuid.to_string()),
uuid,
),
)
})
.collect::<HashMap<_, _>>();
let delete = Delete::open(path.join("delete"));
let index = Arc::new(Index {
path: path.clone(),
options: options.clone(),
delete: delete.clone(),
protect: Mutex::new(IndexProtect {
startup,
sealed: sealed.clone(),
growing: growing.clone(),
write: None,
}),
view: ArcSwap::new(Arc::new(IndexView {
options: options.clone(),
delete: delete.clone(),
sealed,
growing,
write: None,
})),
instant_index: AtomicCell::new(Instant::now()),
instant_write: AtomicCell::new(Instant::now()),
_tracker: tracker,
});
OptimizerIndexing::new(index.clone()).spawn();
OptimizerSealing::new(index.clone()).spawn();
index
}
pub fn options(&self) -> &IndexOptions {
&self.options
}
pub fn view(&self) -> Arc<IndexView<S>> {
self.view.load_full()
}
pub fn refresh(&self) {
let mut protect = self.protect.lock();
if let Some((uuid, write)) = protect.write.clone() {
if !write.is_full() {
return;
}
write.seal();
protect.growing.insert(uuid, write);
}
let write_segment_uuid = Uuid::new_v4();
let write_segment = GrowingSegment::create(
self._tracker.clone(),
self.path
.join("segments")
.join(write_segment_uuid.to_string()),
write_segment_uuid,
self.options.clone(),
);
protect.write = Some((write_segment_uuid, write_segment));
protect.maintain(self.options.clone(), self.delete.clone(), &self.view);
self.instant_write.store(Instant::now());
}
pub fn seal(&self, check: Uuid) {
let mut protect = self.protect.lock();
if let Some((uuid, write)) = protect.write.clone() {
if check != uuid {
return;
}
write.seal();
protect.growing.insert(uuid, write);
}
protect.write = None;
protect.maintain(self.options.clone(), self.delete.clone(), &self.view);
self.instant_write.store(Instant::now());
}
pub fn stat(&self) -> IndexStat {
let view = self.view();
IndexStat {
indexing: self.instant_index.load() < self.instant_write.load(),
sealed: view.sealed.values().map(|x| x.len()).collect(),
growing: view.growing.values().map(|x| x.len()).collect(),
write: view.write.as_ref().map(|(_, x)| x.len()).unwrap_or(0),
options: self.options().clone(),
}
}
}
impl<S: G> Drop for Index<S> {
fn drop(&mut self) {}
}
#[derive(Debug, Clone)]
pub struct IndexTracker {
path: PathBuf,
}
impl Drop for IndexTracker {
fn drop(&mut self) {
std::fs::remove_dir_all(&self.path).unwrap();
}
}
pub struct IndexView<S: G> {
pub options: IndexOptions,
pub delete: Arc<Delete>,
pub sealed: HashMap<Uuid, Arc<SealedSegment<S>>>,
pub growing: HashMap<Uuid, Arc<GrowingSegment<S>>>,
pub write: Option<(Uuid, Arc<GrowingSegment<S>>)>,
}
impl<S: G> IndexView<S> {
pub fn search<F: FnMut(Pointer) -> bool>(
&self,
k: usize,
vector: &[S::Scalar],
mut filter: F,
) -> Vec<Pointer> {
assert_eq!(self.options.vector.dims as usize, vector.len());
struct Comparer(BinaryHeap<Reverse<HeapElement>>);
impl PartialEq for Comparer {
fn eq(&self, other: &Self) -> bool {
self.cmp(other).is_eq()
}
}
impl Eq for Comparer {}
impl PartialOrd for Comparer {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl Ord for Comparer {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.0.peek().cmp(&other.0.peek()).reverse()
}
}
let mut filter = |payload| {
if let Some(p) = self.delete.check(payload) {
filter(p)
} else {
false
}
};
let n = self.sealed.len() + self.growing.len() + 1;
let mut result = Heap::new(k);
let mut heaps = BinaryHeap::with_capacity(1 + n);
for (_, sealed) in self.sealed.iter() {
let p = sealed.search(k, vector, &mut filter).into_reversed_heap();
heaps.push(Comparer(p));
}
for (_, growing) in self.growing.iter() {
let p = growing.search(k, vector, &mut filter).into_reversed_heap();
heaps.push(Comparer(p));
}
if let Some((_, write)) = &self.write {
let p = write.search(k, vector, &mut filter).into_reversed_heap();
heaps.push(Comparer(p));
}
while let Some(Comparer(mut heap)) = heaps.pop() {
if let Some(Reverse(x)) = heap.pop() {
result.push(x);
heaps.push(Comparer(heap));
}
}
result
.into_sorted_vec()
.iter()
.map(|x| Pointer::from_u48(x.payload >> 16))
.collect()
}
pub fn vbase(&self, vector: &[S::Scalar]) -> impl Iterator<Item = Pointer> + '_ {
assert_eq!(self.options.vector.dims as usize, vector.len());
let range = 86;
struct Comparer<'a, S: G> {
iter: ComparerIter<'a, S>,
item: Option<HeapElement>,
}
enum ComparerIter<'a, S: G> {
Sealed(DynamicIndexIter<'a, S>),
Growing(std::vec::IntoIter<HeapElement>),
}
impl<S: G> PartialEq for Comparer<'_, S> {
fn eq(&self, other: &Self) -> bool {
self.cmp(other).is_eq()
}
}
impl<S: G> Eq for Comparer<'_, S> {}
impl<S: G> PartialOrd for Comparer<'_, S> {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl<S: G> Ord for Comparer<'_, S> {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.item.cmp(&other.item).reverse()
}
}
impl<S: G> Iterator for ComparerIter<'_, S> {
type Item = HeapElement;
fn next(&mut self) -> Option<Self::Item> {
match self {
Self::Sealed(iter) => iter.next(),
Self::Growing(iter) => iter.next(),
}
}
}
impl<S: G> Iterator for Comparer<'_, S> {
type Item = HeapElement;
fn next(&mut self) -> Option<Self::Item> {
let item = self.item.take();
self.item = self.iter.next();
item
}
}
fn from_iter<S: G>(mut iter: ComparerIter<'_, S>) -> Comparer<'_, S> {
let item = iter.next();
Comparer { iter, item }
}
use ComparerIter::*;
let filter = |payload| self.delete.check(payload).is_some();
let n = self.sealed.len() + self.growing.len() + 1;
let mut heaps: BinaryHeap<Comparer<S>> = BinaryHeap::with_capacity(1 + n);
for (_, sealed) in self.sealed.iter() {
let res = sealed.vbase(range, vector);
heaps.push(from_iter(Sealed(res)));
}
for (_, growing) in self.growing.iter() {
let mut res = growing.vbase(vector);
res.sort_unstable();
heaps.push(from_iter(Growing(res.into_iter())));
}
if let Some((_, write)) = &self.write {
let mut res = write.vbase(vector);
res.sort_unstable();
heaps.push(from_iter(Growing(res.into_iter())));
}
std::iter::from_fn(move || {
while let Some(mut iter) = heaps.pop() {
if let Some(x) = iter.next() {
if !filter(x.payload) {
continue;
}
heaps.push(iter);
return Some(Pointer::from_u48(x.payload >> 16));
}
}
None
})
}
pub fn insert(&self, vector: Vec<S::Scalar>, pointer: Pointer) -> Result<(), OutdatedError> {
assert_eq!(self.options.vector.dims as usize, vector.len());
let payload = (pointer.as_u48() << 16) | self.delete.version(pointer) as Payload;
if let Some((_, growing)) = self.write.as_ref() {
Ok(growing.insert(vector, payload)?)
} else {
Err(OutdatedError(None))
}
}
pub fn delete<F: FnMut(Pointer) -> bool>(&self, mut f: F) {
for (_, sealed) in self.sealed.iter() {
let n = sealed.len();
for i in 0..n {
if let Some(p) = self.delete.check(sealed.payload(i)) {
if f(p) {
self.delete.delete(p);
}
}
}
}
for (_, growing) in self.growing.iter() {
let n = growing.len();
for i in 0..n {
if let Some(p) = self.delete.check(growing.payload(i)) {
if f(p) {
self.delete.delete(p);
}
}
}
}
if let Some((_, write)) = &self.write {
let n = write.len();
for i in 0..n {
if let Some(p) = self.delete.check(write.payload(i)) {
if f(p) {
self.delete.delete(p);
}
}
}
}
}
pub fn flush(&self) {
self.delete.flush();
if let Some((_, write)) = &self.write {
write.flush();
}
}
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
struct IndexStartup {
sealeds: HashSet<Uuid>,
growings: HashSet<Uuid>,
}
struct IndexProtect<S: G> {
startup: FileAtomic<IndexStartup>,
sealed: HashMap<Uuid, Arc<SealedSegment<S>>>,
growing: HashMap<Uuid, Arc<GrowingSegment<S>>>,
write: Option<(Uuid, Arc<GrowingSegment<S>>)>,
}
impl<S: G> IndexProtect<S> {
fn maintain(
&mut self,
options: IndexOptions,
delete: Arc<Delete>,
swap: &ArcSwap<IndexView<S>>,
) {
let view = Arc::new(IndexView {
options,
delete,
sealed: self.sealed.clone(),
growing: self.growing.clone(),
write: self.write.clone(),
});
let startup_write = self.write.as_ref().map(|(uuid, _)| *uuid);
let startup_sealeds = self.sealed.keys().copied().collect();
let startup_growings = self.growing.keys().copied().chain(startup_write).collect();
self.startup.set(IndexStartup {
sealeds: startup_sealeds,
growings: startup_growings,
});
swap.swap(view);
}
}

View File

@@ -0,0 +1,148 @@
use crate::index::GrowingSegment;
use crate::index::Index;
use crate::index::SealedSegment;
use crate::prelude::*;
use std::cmp::Reverse;
use std::sync::Arc;
use std::time::Instant;
use uuid::Uuid;
pub struct OptimizerIndexing<S: G> {
index: Arc<Index<S>>,
}
impl<S: G> OptimizerIndexing<S> {
pub fn new(index: Arc<Index<S>>) -> Self {
Self { index }
}
pub fn spawn(self) {
std::thread::spawn(move || {
self.main();
});
}
pub fn main(self) {
let index = self.index;
let pool = rayon::ThreadPoolBuilder::new()
.num_threads(index.options.optimizing.optimizing_threads)
.build()
.unwrap();
let weak_index = Arc::downgrade(&index);
std::mem::drop(index);
loop {
{
let Some(index) = weak_index.upgrade() else {
return;
};
if let Ok(()) = pool.install(|| optimizing_indexing(index.clone())) {
continue;
}
}
std::thread::sleep(std::time::Duration::from_secs(60));
}
}
}
enum Seg<S: G> {
Sealed(Arc<SealedSegment<S>>),
Growing(Arc<GrowingSegment<S>>),
}
impl<S: G> Seg<S> {
fn uuid(&self) -> Uuid {
use Seg::*;
match self {
Sealed(x) => x.uuid(),
Growing(x) => x.uuid(),
}
}
fn len(&self) -> u32 {
use Seg::*;
match self {
Sealed(x) => x.len(),
Growing(x) => x.len(),
}
}
fn get_sealed(&self) -> Option<Arc<SealedSegment<S>>> {
match self {
Seg::Sealed(x) => Some(x.clone()),
_ => None,
}
}
fn get_growing(&self) -> Option<Arc<GrowingSegment<S>>> {
match self {
Seg::Growing(x) => Some(x.clone()),
_ => None,
}
}
}
#[derive(Debug, thiserror::Error)]
#[error("Interrupted, retry again.")]
pub struct RetryError;
pub fn optimizing_indexing<S: G>(index: Arc<Index<S>>) -> Result<(), RetryError> {
use Seg::*;
let segs = {
let protect = index.protect.lock();
let mut segs_0 = Vec::new();
segs_0.extend(protect.growing.values().map(|x| Growing(x.clone())));
segs_0.extend(protect.sealed.values().map(|x| Sealed(x.clone())));
segs_0.sort_by_key(|case| Reverse(case.len()));
let mut segs_1 = Vec::new();
let mut total = 0u64;
let mut count = 0;
while let Some(seg) = segs_0.pop() {
if total + seg.len() as u64 <= index.options.segment.max_sealed_segment_size as u64 {
total += seg.len() as u64;
if let Growing(_) = seg {
count += 1;
}
segs_1.push(seg);
} else {
break;
}
}
if segs_1.is_empty() || (segs_1.len() == 1 && count == 0) {
index.instant_index.store(Instant::now());
return Err(RetryError);
}
segs_1
};
let sealed_segment = merge(&index, &segs);
{
let mut protect = index.protect.lock();
for seg in segs.iter() {
if protect.sealed.contains_key(&seg.uuid()) {
continue;
}
if protect.growing.contains_key(&seg.uuid()) {
continue;
}
return Ok(());
}
for seg in segs.iter() {
protect.sealed.remove(&seg.uuid());
protect.growing.remove(&seg.uuid());
}
protect.sealed.insert(sealed_segment.uuid(), sealed_segment);
protect.maintain(index.options.clone(), index.delete.clone(), &index.view);
}
Ok(())
}
fn merge<S: G>(index: &Arc<Index<S>>, segs: &[Seg<S>]) -> Arc<SealedSegment<S>> {
let sealed = segs.iter().filter_map(|x| x.get_sealed()).collect();
let growing = segs.iter().filter_map(|x| x.get_growing()).collect();
let sealed_segment_uuid = Uuid::new_v4();
SealedSegment::create(
index._tracker.clone(),
index
.path
.join("segments")
.join(sealed_segment_uuid.to_string()),
sealed_segment_uuid,
index.options.clone(),
sealed,
growing,
)
}

View File

@@ -1,4 +1,5 @@
pub mod indexing;
pub mod sealing;
pub mod vacuum;
use serde::{Deserialize, Serialize};
@@ -6,9 +7,12 @@ use validator::Validate;
#[derive(Debug, Clone, Serialize, Deserialize, Validate)]
pub struct OptimizingOptions {
#[serde(default = "OptimizingOptions::default_waiting_secs", skip)]
#[validate(range(min = 0, max = 600))]
pub waiting_secs: u64,
#[serde(default = "OptimizingOptions::default_sealing_secs")]
#[validate(range(min = 0, max = 60))]
pub sealing_secs: u64,
#[serde(default = "OptimizingOptions::default_sealing_size")]
#[validate(range(min = 1, max = 4_000_000_000))]
pub sealing_size: u32,
#[serde(default = "OptimizingOptions::default_deleted_threshold", skip)]
#[validate(range(min = 0.01, max = 1.00))]
pub deleted_threshold: f64,
@@ -18,9 +22,12 @@ pub struct OptimizingOptions {
}
impl OptimizingOptions {
fn default_waiting_secs() -> u64 {
fn default_sealing_secs() -> u64 {
60
}
fn default_sealing_size() -> u32 {
1
}
fn default_deleted_threshold() -> f64 {
0.2
}
@@ -35,7 +42,8 @@ impl OptimizingOptions {
impl Default for OptimizingOptions {
fn default() -> Self {
Self {
waiting_secs: Self::default_waiting_secs(),
sealing_secs: Self::default_sealing_secs(),
sealing_size: Self::default_sealing_size(),
deleted_threshold: Self::default_deleted_threshold(),
optimizing_threads: Self::default_optimizing_threads(),
}

View File

@@ -0,0 +1,49 @@
use crate::index::Index;
use crate::prelude::*;
use std::sync::Arc;
use std::time::Duration;
pub struct OptimizerSealing<S: G> {
index: Arc<Index<S>>,
}
impl<S: G> OptimizerSealing<S> {
pub fn new(index: Arc<Index<S>>) -> Self {
Self { index }
}
pub fn spawn(self) {
std::thread::spawn(move || {
self.main();
});
}
pub fn main(self) {
let index = self.index;
let dur = Duration::from_secs(index.options.optimizing.sealing_secs);
let least = index.options.optimizing.sealing_size;
let weak_index = Arc::downgrade(&index);
std::mem::drop(index);
let mut check = None;
loop {
{
let Some(index) = weak_index.upgrade() else {
return;
};
let view = index.view();
let stamp = view
.write
.as_ref()
.map(|(uuid, segment)| (*uuid, segment.len()));
if stamp == check {
if let Some((uuid, len)) = stamp {
if len >= least {
index.seal(uuid);
}
}
} else {
check = stamp;
}
}
std::thread::sleep(dur);
}
}
}

View File

@@ -1,7 +1,8 @@
#![allow(clippy::all)] // Clippy bug.
use super::SegmentTracker;
use crate::index::IndexOptions;
use crate::index::IndexTracker;
use crate::index::VectorOptions;
use crate::prelude::*;
use crate::utils::dir_ops::sync_dir;
use crate::utils::file_wal::FileWal;
@@ -19,17 +20,16 @@ use uuid::Uuid;
#[error("`GrowingSegment` stopped growing.")]
pub struct GrowingSegmentInsertError;
pub struct GrowingSegment {
pub struct GrowingSegment<S: G> {
uuid: Uuid,
options: VectorOptions,
vec: Vec<UnsafeCell<MaybeUninit<Log>>>,
vec: Vec<UnsafeCell<MaybeUninit<Log<S>>>>,
wal: Mutex<FileWal>,
len: AtomicUsize,
pro: Mutex<Protect>,
_tracker: Arc<SegmentTracker>,
}
impl GrowingSegment {
impl<S: G> GrowingSegment<S> {
pub fn create(
_tracker: Arc<IndexTracker>,
path: PathBuf,
@@ -42,7 +42,6 @@ impl GrowingSegment {
sync_dir(&path);
Arc::new(Self {
uuid,
options: options.vector,
vec: unsafe {
let mut vec = Vec::with_capacity(capacity as usize);
vec.set_len(capacity as usize);
@@ -57,23 +56,17 @@ impl GrowingSegment {
_tracker: Arc::new(SegmentTracker { path, _tracker }),
})
}
pub fn open(
_tracker: Arc<IndexTracker>,
path: PathBuf,
uuid: Uuid,
options: IndexOptions,
) -> Arc<Self> {
pub fn open(_tracker: Arc<IndexTracker>, path: PathBuf, uuid: Uuid) -> Arc<Self> {
let mut wal = FileWal::open(path.join("wal"));
let mut vec = Vec::new();
while let Some(log) = wal.read() {
let log = bincode::deserialize::<Log>(&log).unwrap();
let log = bincode::deserialize::<Log<S>>(&log).unwrap();
vec.push(UnsafeCell::new(MaybeUninit::new(log)));
}
wal.truncate();
let n = vec.len();
Arc::new(Self {
uuid,
options: options.vector,
vec,
wal: { Mutex::new(wal) },
len: AtomicUsize::new(n),
@@ -87,6 +80,20 @@ impl GrowingSegment {
pub fn uuid(&self) -> Uuid {
self.uuid
}
pub fn is_full(&self) -> bool {
let n;
{
let pro = self.pro.lock();
if pro.inflight < pro.capacity {
return false;
}
n = pro.inflight;
}
while self.len.load(Ordering::Acquire) != n {
std::hint::spin_loop();
}
true
}
pub fn seal(&self) {
let n;
{
@@ -104,7 +111,7 @@ impl GrowingSegment {
}
pub fn insert(
&self,
vector: Vec<Scalar>,
vector: Vec<S::Scalar>,
payload: Payload,
) -> Result<(), GrowingSegmentInsertError> {
let log = Log { vector, payload };
@@ -126,13 +133,13 @@ impl GrowingSegment {
self.len.store(1 + i, Ordering::Release);
self.wal
.lock()
.write(&bincode::serialize::<Log>(&log).unwrap());
.write(&bincode::serialize::<Log<S>>(&log).unwrap());
Ok(())
}
pub fn len(&self) -> u32 {
self.len.load(Ordering::Acquire) as u32
}
pub fn vector(&self, i: u32) -> &[Scalar] {
pub fn vector(&self, i: u32) -> &[S::Scalar] {
let i = i as usize;
if i >= self.len.load(Ordering::Acquire) {
panic!("Out of bound.");
@@ -148,12 +155,12 @@ impl GrowingSegment {
let log = unsafe { (*self.vec[i].get()).assume_init_ref() };
log.payload
}
pub fn search(&self, k: usize, vector: &[Scalar], filter: &mut impl Filter) -> Heap {
pub fn search(&self, k: usize, vector: &[S::Scalar], filter: &mut impl Filter) -> Heap {
let n = self.len.load(Ordering::Acquire);
let mut heap = Heap::new(k);
for i in 0..n {
let log = unsafe { (*self.vec[i].get()).assume_init_ref() };
let distance = self.options.d.distance(vector, &log.vector);
let distance = S::distance(vector, &log.vector);
if heap.check(distance) && filter.check(log.payload) {
heap.push(HeapElement {
distance,
@@ -163,12 +170,12 @@ impl GrowingSegment {
}
heap
}
pub fn search_all(&self, vector: &[Scalar]) -> Vec<HeapElement> {
pub fn vbase(&self, vector: &[S::Scalar]) -> Vec<HeapElement> {
let n = self.len.load(Ordering::Acquire);
let mut result = Vec::new();
for i in 0..n {
let log = unsafe { (*self.vec[i].get()).assume_init_ref() };
let distance = self.options.d.distance(vector, &log.vector);
let distance = S::distance(vector, &log.vector);
result.push(HeapElement {
distance,
payload: log.payload,
@@ -178,10 +185,10 @@ impl GrowingSegment {
}
}
unsafe impl Send for GrowingSegment {}
unsafe impl Sync for GrowingSegment {}
unsafe impl<S: G> Send for GrowingSegment<S> {}
unsafe impl<S: G> Sync for GrowingSegment<S> {}
impl Drop for GrowingSegment {
impl<S: G> Drop for GrowingSegment<S> {
fn drop(&mut self) {
let n = *self.len.get_mut();
for i in 0..n {
@@ -193,8 +200,8 @@ impl Drop for GrowingSegment {
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
struct Log {
vector: Vec<Scalar>,
struct Log<S: G> {
vector: Vec<S::Scalar>,
payload: Payload,
}

View File

@@ -10,14 +10,10 @@ use validator::ValidationError;
#[derive(Debug, Clone, Serialize, Deserialize, Validate)]
#[validate(schema(function = "Self::validate_0"))]
#[validate(schema(function = "Self::validate_1"))]
pub struct SegmentsOptions {
#[serde(default = "SegmentsOptions::default_max_growing_segment_size")]
#[validate(range(min = 1, max = 4_000_000_000))]
pub max_growing_segment_size: u32,
#[serde(default = "SegmentsOptions::default_min_sealed_segment_size")]
#[validate(range(min = 1, max = 4_000_000_000))]
pub min_sealed_segment_size: u32,
#[serde(default = "SegmentsOptions::default_max_sealed_segment_size")]
#[validate(range(min = 1, max = 4_000_000_000))]
pub max_sealed_segment_size: u32,
@@ -27,22 +23,11 @@ impl SegmentsOptions {
fn default_max_growing_segment_size() -> u32 {
20_000
}
fn default_min_sealed_segment_size() -> u32 {
1_000
}
fn default_max_sealed_segment_size() -> u32 {
1_000_000
}
// min_sealed_segment_size <= max_growing_segment_size <= max_sealed_segment_size
// max_growing_segment_size <= max_sealed_segment_size
fn validate_0(&self) -> Result<(), ValidationError> {
if self.min_sealed_segment_size > self.max_growing_segment_size {
return Err(ValidationError::new(
"`min_sealed_segment_size` must be less than or equal to `max_growing_segment_size`",
));
}
Ok(())
}
fn validate_1(&self) -> Result<(), ValidationError> {
if self.max_growing_segment_size > self.max_sealed_segment_size {
return Err(ValidationError::new(
"`max_growing_segment_size` must be less than or equal to `max_sealed_segment_size`",
@@ -56,7 +41,6 @@ impl Default for SegmentsOptions {
fn default() -> Self {
Self {
max_growing_segment_size: Self::default_max_growing_segment_size(),
min_sealed_segment_size: Self::default_min_sealed_segment_size(),
max_sealed_segment_size: Self::default_max_sealed_segment_size(),
}
}

View File

@@ -8,20 +8,20 @@ use std::path::PathBuf;
use std::sync::Arc;
use uuid::Uuid;
pub struct SealedSegment {
pub struct SealedSegment<S: G> {
uuid: Uuid,
indexing: DynamicIndexing,
indexing: DynamicIndexing<S>,
_tracker: Arc<SegmentTracker>,
}
impl SealedSegment {
impl<S: G> SealedSegment<S> {
pub fn create(
_tracker: Arc<IndexTracker>,
path: PathBuf,
uuid: Uuid,
options: IndexOptions,
sealed: Vec<Arc<SealedSegment>>,
growing: Vec<Arc<GrowingSegment>>,
sealed: Vec<Arc<SealedSegment<S>>>,
growing: Vec<Arc<GrowingSegment<S>>>,
) -> Arc<Self> {
std::fs::create_dir(&path).unwrap();
let indexing = DynamicIndexing::create(path.join("indexing"), options, sealed, growing);
@@ -51,20 +51,16 @@ impl SealedSegment {
pub fn len(&self) -> u32 {
self.indexing.len()
}
pub fn vector(&self, i: u32) -> &[Scalar] {
pub fn vector(&self, i: u32) -> &[S::Scalar] {
self.indexing.vector(i)
}
pub fn payload(&self, i: u32) -> Payload {
self.indexing.payload(i)
}
pub fn search(&self, k: usize, vector: &[Scalar], filter: &mut impl Filter) -> Heap {
pub fn search(&self, k: usize, vector: &[S::Scalar], filter: &mut impl Filter) -> Heap {
self.indexing.search(k, vector, filter)
}
pub fn search_vbase<'index, 'vector>(
&'index self,
range: usize,
vector: &'vector [Scalar],
) -> DynamicIndexIter<'index, 'vector> {
self.indexing.search_vbase(range, vector)
pub fn vbase(&self, range: usize, vector: &[S::Scalar]) -> DynamicIndexIter<'_, S> {
self.indexing.vbase(range, vector)
}
}

View File

@@ -0,0 +1,9 @@
#![feature(core_intrinsics)]
#![feature(avx512_target_feature)]
pub mod algorithms;
pub mod index;
pub mod prelude;
pub mod worker;
mod utils;

View File

@@ -1,5 +1,3 @@
use crate::ipc::IpcError;
use crate::prelude::*;
use serde::{Deserialize, Serialize};
use thiserror::Error;
@@ -15,100 +13,77 @@ or simply run the command `psql -U postgres -c 'ALTER SYSTEM SET shared_preload_
")]
BadInit,
#[error("\
The given index option is invalid.
INFORMATION: reason = {0:?}\
Bad literal.
INFORMATION: hint = {hint}\
")]
BadOption(String),
#[error("\
The given vector is invalid for input.
INFORMATION: vector = {0:?}
ADVICE: Check if dimensions of the vector is matched with the index.\
")]
BadVector(Vec<Scalar>),
BadLiteral {
hint: String,
},
#[error("\
Modifier of the type is invalid.
ADVICE: Check if modifier of the type is an integer among 1 and 65535.\
")]
BadTypmod,
BadTypeDimensions,
#[error("\
Dimensions of the vector is invalid.
ADVICE: Check if dimensions of the vector are among 1 and 65535.\
")]
BadVecForDims,
BadValueDimensions,
#[error("\
Dimensions of the vector is unmatched with the type modifier.
INFORMATION: type_dimensions = {type_dimensions}, value_dimensions = {value_dimensions}\
The given index option is invalid.
INFORMATION: reason = {validation:?}\
")]
BadVecForUnmatchedDims {
value_dimensions: u16,
type_dimensions: u16,
},
BadOption { validation: String },
#[error("\
Operands of the operator differs in dimensions.
INFORMATION: left_dimensions = {left_dimensions}, right_dimensions = {right_dimensions}\
Dimensions type modifier of a vector column is needed for building the index.\
")]
DifferentVectorDims {
left_dimensions: u16,
right_dimensions: u16,
},
BadOption2,
#[error("\
Indexes can only be built on built-in distance functions.
ADVICE: If you want pgvecto.rs to support more distance functions, \
visit `https://github.com/tensorchord/pgvecto.rs/issues` and contribute your ideas.\
")]
UnsupportedOperator,
BadOptions3,
#[error("\
The index is not existing in the background worker.
ADVICE: Drop or rebuild the index.\
")]
Index404,
UnknownIndex,
#[error("\
Dimensions type modifier of a vector column is needed for building the index.\
Operands of the operator differs in dimensions or scalar type.
INFORMATION: left_dimensions = {left_dimensions}, right_dimensions = {right_dimensions}\
")]
DimsIsNeeded,
#[error("\
Bad vector string.
INFORMATION: hint = {hint}\
")]
BadVectorString {
hint: String,
Unmatched {
left_dimensions: u16,
right_dimensions: u16,
},
#[error("\
`mmap` transport is not supported by MacOS.\
The given vector is invalid for input.
ADVICE: Check if dimensions and scalar type of the vector is matched with the index.\
")]
MmapTransportNotSupported,
Unmatched2,
}
impl FriendlyError {
pub fn friendly(self) -> ! {
pub trait FriendlyErrorLike {
fn friendly(self) -> !;
}
impl FriendlyErrorLike for FriendlyError {
fn friendly(self) -> ! {
panic!("pgvecto.rs: {}", self);
}
}
impl IpcError {
pub fn friendly(self) -> ! {
panic!("pgvecto.rs: {}", self);
}
}
pub trait Friendly {
pub trait FriendlyResult {
type Output;
fn friendly(self) -> Self::Output;
}
impl<T> Friendly for Result<T, FriendlyError> {
type Output = T;
fn friendly(self) -> T {
match self {
Ok(x) => x,
Err(e) => e.friendly(),
}
}
}
impl<T> Friendly for Result<T, IpcError> {
impl<T, E> FriendlyResult for Result<T, E>
where
E: FriendlyErrorLike,
{
type Output = T;
fn friendly(self) -> T {

View File

@@ -0,0 +1,114 @@
use crate::prelude::*;
pub fn cosine(lhs: &[F16], rhs: &[F16]) -> F32 {
#[inline(always)]
#[multiversion::multiversion(targets(
"x86_64/x86-64-v4",
"x86_64/x86-64-v3",
"x86_64/x86-64-v2",
"aarch64+neon"
))]
pub fn cosine(lhs: &[F16], rhs: &[F16]) -> F32 {
assert!(lhs.len() == rhs.len());
let n = lhs.len();
let mut xy = F32::zero();
let mut x2 = F32::zero();
let mut y2 = F32::zero();
for i in 0..n {
xy += lhs[i].to_f() * rhs[i].to_f();
x2 += lhs[i].to_f() * lhs[i].to_f();
y2 += rhs[i].to_f() * rhs[i].to_f();
}
xy / (x2 * y2).sqrt()
}
#[cfg(target_arch = "x86_64")]
if crate::utils::detect::x86_64::detect_avx512fp16() {
assert!(lhs.len() == rhs.len());
let n = lhs.len();
unsafe {
return c::v_f16_cosine_avx512fp16(lhs.as_ptr().cast(), rhs.as_ptr().cast(), n).into();
}
}
#[cfg(target_arch = "x86_64")]
if crate::utils::detect::x86_64::detect_v3() {
assert!(lhs.len() == rhs.len());
let n = lhs.len();
unsafe {
return c::v_f16_cosine_v3(lhs.as_ptr().cast(), rhs.as_ptr().cast(), n).into();
}
}
cosine(lhs, rhs)
}
pub fn dot(lhs: &[F16], rhs: &[F16]) -> F32 {
#[inline(always)]
#[multiversion::multiversion(targets(
"x86_64/x86-64-v4",
"x86_64/x86-64-v3",
"x86_64/x86-64-v2",
"aarch64+neon"
))]
pub fn dot(lhs: &[F16], rhs: &[F16]) -> F32 {
assert!(lhs.len() == rhs.len());
let n = lhs.len();
let mut xy = F32::zero();
for i in 0..n {
xy += lhs[i].to_f() * rhs[i].to_f();
}
xy
}
#[cfg(target_arch = "x86_64")]
if crate::utils::detect::x86_64::detect_avx512fp16() {
assert!(lhs.len() == rhs.len());
let n = lhs.len();
unsafe {
return c::v_f16_dot_avx512fp16(lhs.as_ptr().cast(), rhs.as_ptr().cast(), n).into();
}
}
#[cfg(target_arch = "x86_64")]
if crate::utils::detect::x86_64::detect_v3() {
assert!(lhs.len() == rhs.len());
let n = lhs.len();
unsafe {
return c::v_f16_dot_v3(lhs.as_ptr().cast(), rhs.as_ptr().cast(), n).into();
}
}
cosine(lhs, rhs)
}
pub fn sl2(lhs: &[F16], rhs: &[F16]) -> F32 {
#[inline(always)]
#[multiversion::multiversion(targets(
"x86_64/x86-64-v4",
"x86_64/x86-64-v3",
"x86_64/x86-64-v2",
"aarch64+neon"
))]
pub fn sl2(lhs: &[F16], rhs: &[F16]) -> F32 {
assert!(lhs.len() == rhs.len());
let n = lhs.len();
let mut d2 = F32::zero();
for i in 0..n {
let d = lhs[i].to_f() - rhs[i].to_f();
d2 += d * d;
}
d2
}
#[cfg(target_arch = "x86_64")]
if crate::utils::detect::x86_64::detect_avx512fp16() {
assert!(lhs.len() == rhs.len());
let n = lhs.len();
unsafe {
return c::v_f16_sl2_avx512fp16(lhs.as_ptr().cast(), rhs.as_ptr().cast(), n).into();
}
}
#[cfg(target_arch = "x86_64")]
if crate::utils::detect::x86_64::detect_v3() {
assert!(lhs.len() == rhs.len());
let n = lhs.len();
unsafe {
return c::v_f16_sl2_v3(lhs.as_ptr().cast(), rhs.as_ptr().cast(), n).into();
}
}
sl2(lhs, rhs)
}

View File

@@ -0,0 +1,244 @@
use super::G;
use crate::prelude::scalar::F32;
use crate::prelude::*;
#[derive(Debug, Clone, Copy)]
pub enum F16Cos {}
impl G for F16Cos {
type Scalar = F16;
const DISTANCE: Distance = Distance::Cos;
type L2 = F16L2;
fn distance(lhs: &[F16], rhs: &[F16]) -> F32 {
super::f16::cosine(lhs, rhs) * (-1.0)
}
fn elkan_k_means_normalize(vector: &mut [F16]) {
l2_normalize(vector)
}
fn elkan_k_means_distance(lhs: &[F16], rhs: &[F16]) -> F32 {
super::f16::dot(lhs, rhs).acos()
}
#[multiversion::multiversion(targets(
"x86_64/x86-64-v4",
"x86_64/x86-64-v3",
"x86_64/x86-64-v2",
"aarch64+neon"
))]
fn scalar_quantization_distance(
dims: u16,
max: &[F16],
min: &[F16],
lhs: &[F16],
rhs: &[u8],
) -> F32 {
let mut xy = F32::zero();
let mut x2 = F32::zero();
let mut y2 = F32::zero();
for i in 0..dims as usize {
let _x = lhs[i].to_f();
let _y = F32(rhs[i] as f32 / 256.0) * (max[i].to_f() - min[i].to_f()) + min[i].to_f();
xy += _x * _y;
x2 += _x * _x;
y2 += _y * _y;
}
xy / (x2 * y2).sqrt() * (-1.0)
}
#[multiversion::multiversion(targets(
"x86_64/x86-64-v4",
"x86_64/x86-64-v3",
"x86_64/x86-64-v2",
"aarch64+neon"
))]
fn scalar_quantization_distance2(
dims: u16,
max: &[F16],
min: &[F16],
lhs: &[u8],
rhs: &[u8],
) -> F32 {
let mut xy = F32::zero();
let mut x2 = F32::zero();
let mut y2 = F32::zero();
for i in 0..dims as usize {
let _x = F32(lhs[i] as f32 / 256.0) * (max[i].to_f() - min[i].to_f()) + min[i].to_f();
let _y = F32(rhs[i] as f32 / 256.0) * (max[i].to_f() - min[i].to_f()) + min[i].to_f();
xy += _x * _y;
x2 += _x * _x;
y2 += _y * _y;
}
xy / (x2 * y2).sqrt() * (-1.0)
}
#[multiversion::multiversion(targets(
"x86_64/x86-64-v4",
"x86_64/x86-64-v3",
"x86_64/x86-64-v2",
"aarch64+neon"
))]
fn product_quantization_distance(
dims: u16,
ratio: u16,
centroids: &[F16],
lhs: &[F16],
rhs: &[u8],
) -> F32 {
let width = dims.div_ceil(ratio);
let mut xy = F32::zero();
let mut x2 = F32::zero();
let mut y2 = F32::zero();
for i in 0..width {
let k = std::cmp::min(ratio, dims - ratio * i);
let lhs = &lhs[(i * ratio) as usize..][..k as usize];
let rhsp = rhs[i as usize] as usize * dims as usize;
let rhs = &centroids[rhsp..][(i * ratio) as usize..][..k as usize];
let (_xy, _x2, _y2) = xy_x2_y2(lhs, rhs);
xy += _xy;
x2 += _x2;
y2 += _y2;
}
xy / (x2 * y2).sqrt() * (-1.0)
}
#[multiversion::multiversion(targets(
"x86_64/x86-64-v4",
"x86_64/x86-64-v3",
"x86_64/x86-64-v2",
"aarch64+neon"
))]
fn product_quantization_distance2(
dims: u16,
ratio: u16,
centroids: &[F16],
lhs: &[u8],
rhs: &[u8],
) -> F32 {
let width = dims.div_ceil(ratio);
let mut xy = F32::zero();
let mut x2 = F32::zero();
let mut y2 = F32::zero();
for i in 0..width {
let k = std::cmp::min(ratio, dims - ratio * i);
let lhsp = lhs[i as usize] as usize * dims as usize;
let lhs = &centroids[lhsp..][(i * ratio) as usize..][..k as usize];
let rhsp = rhs[i as usize] as usize * dims as usize;
let rhs = &centroids[rhsp..][(i * ratio) as usize..][..k as usize];
let (_xy, _x2, _y2) = xy_x2_y2(lhs, rhs);
xy += _xy;
x2 += _x2;
y2 += _y2;
}
xy / (x2 * y2).sqrt() * (-1.0)
}
#[multiversion::multiversion(targets(
"x86_64/x86-64-v4",
"x86_64/x86-64-v3",
"x86_64/x86-64-v2",
"aarch64+neon"
))]
fn product_quantization_distance_with_delta(
dims: u16,
ratio: u16,
centroids: &[F16],
lhs: &[F16],
rhs: &[u8],
delta: &[F16],
) -> F32 {
let width = dims.div_ceil(ratio);
let mut xy = F32::zero();
let mut x2 = F32::zero();
let mut y2 = F32::zero();
for i in 0..width {
let k = std::cmp::min(ratio, dims - ratio * i);
let lhs = &lhs[(i * ratio) as usize..][..k as usize];
let rhsp = rhs[i as usize] as usize * dims as usize;
let rhs = &centroids[rhsp..][(i * ratio) as usize..][..k as usize];
let del = &delta[(i * ratio) as usize..][..k as usize];
let (_xy, _x2, _y2) = xy_x2_y2_delta(lhs, rhs, del);
xy += _xy;
x2 += _x2;
y2 += _y2;
}
xy / (x2 * y2).sqrt() * (-1.0)
}
}
#[inline(always)]
#[multiversion::multiversion(targets(
"x86_64/x86-64-v4",
"x86_64/x86-64-v3",
"x86_64/x86-64-v2",
"aarch64+neon"
))]
fn length(vector: &[F16]) -> F16 {
let n = vector.len();
let mut dot = F16::zero();
for i in 0..n {
dot += vector[i] * vector[i];
}
dot.sqrt()
}
#[inline(always)]
#[multiversion::multiversion(targets(
"x86_64/x86-64-v4",
"x86_64/x86-64-v3",
"x86_64/x86-64-v2",
"aarch64+neon"
))]
fn l2_normalize(vector: &mut [F16]) {
let n = vector.len();
let l = length(vector);
for i in 0..n {
vector[i] /= l;
}
}
#[inline(always)]
#[multiversion::multiversion(targets(
"x86_64/x86-64-v4",
"x86_64/x86-64-v3",
"x86_64/x86-64-v2",
"aarch64+neon"
))]
fn xy_x2_y2(lhs: &[F16], rhs: &[F16]) -> (F32, F32, F32) {
assert!(lhs.len() == rhs.len());
let n = lhs.len();
let mut xy = F32::zero();
let mut x2 = F32::zero();
let mut y2 = F32::zero();
for i in 0..n {
xy += lhs[i].to_f() * rhs[i].to_f();
x2 += lhs[i].to_f() * lhs[i].to_f();
y2 += rhs[i].to_f() * rhs[i].to_f();
}
(xy, x2, y2)
}
#[inline(always)]
#[multiversion::multiversion(targets(
"x86_64/x86-64-v4",
"x86_64/x86-64-v3",
"x86_64/x86-64-v2",
"aarch64+neon"
))]
fn xy_x2_y2_delta(lhs: &[F16], rhs: &[F16], del: &[F16]) -> (F32, F32, F32) {
assert!(lhs.len() == rhs.len());
let n = lhs.len();
let mut xy = F32::zero();
let mut x2 = F32::zero();
let mut y2 = F32::zero();
for i in 0..n {
xy += lhs[i].to_f() * (rhs[i].to_f() + del[i].to_f());
x2 += lhs[i].to_f() * lhs[i].to_f();
y2 += (rhs[i].to_f() + del[i].to_f()) * (rhs[i].to_f() + del[i].to_f());
}
(xy, x2, y2)
}

View File

@@ -0,0 +1,199 @@
use super::G;
use crate::prelude::scalar::F32;
use crate::prelude::*;
#[derive(Debug, Clone, Copy)]
pub enum F16Dot {}
impl G for F16Dot {
type Scalar = F16;
const DISTANCE: Distance = Distance::Dot;
type L2 = F16L2;
fn distance(lhs: &[F16], rhs: &[F16]) -> F32 {
super::f16::dot(lhs, rhs) * (-1.0)
}
fn elkan_k_means_normalize(vector: &mut [F16]) {
l2_normalize(vector)
}
fn elkan_k_means_distance(lhs: &[F16], rhs: &[F16]) -> F32 {
super::f16::dot(lhs, rhs).acos()
}
#[multiversion::multiversion(targets(
"x86_64/x86-64-v4",
"x86_64/x86-64-v3",
"x86_64/x86-64-v2",
"aarch64+neon"
))]
fn scalar_quantization_distance(
dims: u16,
max: &[F16],
min: &[F16],
lhs: &[F16],
rhs: &[u8],
) -> F32 {
let mut xy = F32::zero();
for i in 0..dims as usize {
let _x = lhs[i].to_f();
let _y = F32(rhs[i] as f32 / 256.0) * (max[i].to_f() - min[i].to_f()) + min[i].to_f();
xy += _x * _y;
}
xy * (-1.0)
}
#[multiversion::multiversion(targets(
"x86_64/x86-64-v4",
"x86_64/x86-64-v3",
"x86_64/x86-64-v2",
"aarch64+neon"
))]
fn scalar_quantization_distance2(
dims: u16,
max: &[F16],
min: &[F16],
lhs: &[u8],
rhs: &[u8],
) -> F32 {
let mut xy = F32::zero();
for i in 0..dims as usize {
let _x = F32(lhs[i] as f32 / 256.0) * (max[i].to_f() - min[i].to_f()) + min[i].to_f();
let _y = F32(rhs[i] as f32 / 256.0) * (max[i].to_f() - min[i].to_f()) + min[i].to_f();
xy += _x * _y;
}
xy * (-1.0)
}
#[multiversion::multiversion(targets(
"x86_64/x86-64-v4",
"x86_64/x86-64-v3",
"x86_64/x86-64-v2",
"aarch64+neon"
))]
fn product_quantization_distance(
dims: u16,
ratio: u16,
centroids: &[F16],
lhs: &[F16],
rhs: &[u8],
) -> F32 {
let width = dims.div_ceil(ratio);
let mut xy = F32::zero();
for i in 0..width {
let k = std::cmp::min(ratio, dims - ratio * i);
let lhs = &lhs[(i * ratio) as usize..][..k as usize];
let rhsp = rhs[i as usize] as usize * dims as usize;
let rhs = &centroids[rhsp..][(i * ratio) as usize..][..k as usize];
let _xy = super::f16::dot(lhs, rhs);
xy += _xy;
}
xy * (-1.0)
}
#[multiversion::multiversion(targets(
"x86_64/x86-64-v4",
"x86_64/x86-64-v3",
"x86_64/x86-64-v2",
"aarch64+neon"
))]
fn product_quantization_distance2(
dims: u16,
ratio: u16,
centroids: &[F16],
lhs: &[u8],
rhs: &[u8],
) -> F32 {
let width = dims.div_ceil(ratio);
let mut xy = F32::zero();
for i in 0..width {
let k = std::cmp::min(ratio, dims - ratio * i);
let lhsp = lhs[i as usize] as usize * dims as usize;
let lhs = &centroids[lhsp..][(i * ratio) as usize..][..k as usize];
let rhsp = rhs[i as usize] as usize * dims as usize;
let rhs = &centroids[rhsp..][(i * ratio) as usize..][..k as usize];
let _xy = super::f16::dot(lhs, rhs);
xy += _xy;
}
xy * (-1.0)
}
#[multiversion::multiversion(targets(
"x86_64/x86-64-v4",
"x86_64/x86-64-v3",
"x86_64/x86-64-v2",
"aarch64+neon"
))]
fn product_quantization_distance_with_delta(
dims: u16,
ratio: u16,
centroids: &[F16],
lhs: &[F16],
rhs: &[u8],
delta: &[F16],
) -> F32 {
let width = dims.div_ceil(ratio);
let mut xy = F32::zero();
for i in 0..width {
let k = std::cmp::min(ratio, dims - ratio * i);
let lhs = &lhs[(i * ratio) as usize..][..k as usize];
let rhsp = rhs[i as usize] as usize * dims as usize;
let rhs = &centroids[rhsp..][(i * ratio) as usize..][..k as usize];
let del = &delta[(i * ratio) as usize..][..k as usize];
let _xy = dot_delta(lhs, rhs, del);
xy += _xy;
}
xy * (-1.0)
}
}
#[inline(always)]
#[multiversion::multiversion(targets(
"x86_64/x86-64-v4",
"x86_64/x86-64-v3",
"x86_64/x86-64-v2",
"aarch64+neon"
))]
fn length(vector: &[F16]) -> F16 {
let n = vector.len();
let mut dot = F16::zero();
for i in 0..n {
dot += vector[i] * vector[i];
}
dot.sqrt()
}
#[inline(always)]
#[multiversion::multiversion(targets(
"x86_64/x86-64-v4",
"x86_64/x86-64-v3",
"x86_64/x86-64-v2",
"aarch64+neon"
))]
fn l2_normalize(vector: &mut [F16]) {
let n = vector.len();
let l = length(vector);
for i in 0..n {
vector[i] /= l;
}
}
#[inline(always)]
#[multiversion::multiversion(targets(
"x86_64/x86-64-v4",
"x86_64/x86-64-v3",
"x86_64/x86-64-v2",
"aarch64+neon"
))]
fn dot_delta(lhs: &[F16], rhs: &[F16], del: &[F16]) -> F32 {
assert!(lhs.len() == rhs.len());
let n: usize = lhs.len();
let mut xy = F32::zero();
for i in 0..n {
xy += lhs[i].to_f() * (rhs[i].to_f() + del[i].to_f());
}
xy
}

View File

@@ -0,0 +1,165 @@
use super::G;
use crate::prelude::scalar::F16;
use crate::prelude::scalar::F32;
use crate::prelude::*;
#[derive(Debug, Clone, Copy)]
pub enum F16L2 {}
impl G for F16L2 {
type Scalar = F16;
const DISTANCE: Distance = Distance::L2;
type L2 = F16L2;
fn distance(lhs: &[F16], rhs: &[F16]) -> F32 {
super::f16::sl2(lhs, rhs)
}
fn elkan_k_means_normalize(_: &mut [F16]) {}
fn elkan_k_means_distance(lhs: &[F16], rhs: &[F16]) -> F32 {
super::f16::sl2(lhs, rhs).sqrt()
}
#[multiversion::multiversion(targets(
"x86_64/x86-64-v4",
"x86_64/x86-64-v3",
"x86_64/x86-64-v2",
"aarch64+neon"
))]
fn scalar_quantization_distance(
dims: u16,
max: &[F16],
min: &[F16],
lhs: &[F16],
rhs: &[u8],
) -> F32 {
let mut result = F32::zero();
for i in 0..dims as usize {
let _x = lhs[i].to_f();
let _y = (F32(rhs[i] as f32) / 256.0) * (max[i].to_f() - min[i].to_f()) + min[i].to_f();
result += (_x - _y) * (_x - _y);
}
result
}
#[multiversion::multiversion(targets(
"x86_64/x86-64-v4",
"x86_64/x86-64-v3",
"x86_64/x86-64-v2",
"aarch64+neon"
))]
fn scalar_quantization_distance2(
dims: u16,
max: &[F16],
min: &[F16],
lhs: &[u8],
rhs: &[u8],
) -> F32 {
let mut result = F32::zero();
for i in 0..dims as usize {
let _x = F32(lhs[i] as f32 / 256.0) * (max[i].to_f() - min[i].to_f()) + min[i].to_f();
let _y = F32(rhs[i] as f32 / 256.0) * (max[i].to_f() - min[i].to_f()) + min[i].to_f();
result += (_x - _y) * (_x - _y);
}
result
}
#[multiversion::multiversion(targets(
"x86_64/x86-64-v4",
"x86_64/x86-64-v3",
"x86_64/x86-64-v2",
"aarch64+neon"
))]
fn product_quantization_distance(
dims: u16,
ratio: u16,
centroids: &[F16],
lhs: &[F16],
rhs: &[u8],
) -> F32 {
let width = dims.div_ceil(ratio);
let mut result = F32::zero();
for i in 0..width {
let k = std::cmp::min(ratio, dims - ratio * i);
let lhs = &lhs[(i * ratio) as usize..][..k as usize];
let rhsp = rhs[i as usize] as usize * dims as usize;
let rhs = &centroids[rhsp..][(i * ratio) as usize..][..k as usize];
result += super::f16::sl2(lhs, rhs);
}
result
}
#[multiversion::multiversion(targets(
"x86_64/x86-64-v4",
"x86_64/x86-64-v3",
"x86_64/x86-64-v2",
"aarch64+neon"
))]
fn product_quantization_distance2(
dims: u16,
ratio: u16,
centroids: &[F16],
lhs: &[u8],
rhs: &[u8],
) -> F32 {
let width = dims.div_ceil(ratio);
let mut result = F32::zero();
for i in 0..width {
let k = std::cmp::min(ratio, dims - ratio * i);
let lhsp = lhs[i as usize] as usize * dims as usize;
let lhs = &centroids[lhsp..][(i * ratio) as usize..][..k as usize];
let rhsp = rhs[i as usize] as usize * dims as usize;
let rhs = &centroids[rhsp..][(i * ratio) as usize..][..k as usize];
result += super::f16::sl2(lhs, rhs);
}
result
}
#[multiversion::multiversion(targets(
"x86_64/x86-64-v4",
"x86_64/x86-64-v3",
"x86_64/x86-64-v2",
"aarch64+neon"
))]
fn product_quantization_distance_with_delta(
dims: u16,
ratio: u16,
centroids: &[F16],
lhs: &[F16],
rhs: &[u8],
delta: &[F16],
) -> F32 {
let width = dims.div_ceil(ratio);
let mut result = F32::zero();
for i in 0..width {
let k = std::cmp::min(ratio, dims - ratio * i);
let lhs = &lhs[(i * ratio) as usize..][..k as usize];
let rhsp = rhs[i as usize] as usize * dims as usize;
let rhs = &centroids[rhsp..][(i * ratio) as usize..][..k as usize];
let del = &delta[(i * ratio) as usize..][..k as usize];
result += distance_squared_l2_delta(lhs, rhs, del);
}
result
}
}
#[inline(always)]
#[multiversion::multiversion(targets(
"x86_64/x86-64-v4",
"x86_64/x86-64-v3",
"x86_64/x86-64-v2",
"aarch64+neon"
))]
fn distance_squared_l2_delta(lhs: &[F16], rhs: &[F16], del: &[F16]) -> F32 {
assert!(lhs.len() == rhs.len());
let n = lhs.len();
let mut d2 = F32::zero();
for i in 0..n {
let d = lhs[i].to_f() - (rhs[i].to_f() + del[i].to_f());
d2 += d * d;
}
d2
}

View File

@@ -0,0 +1,265 @@
use super::G;
use crate::prelude::scalar::F32;
use crate::prelude::*;
#[derive(Debug, Clone, Copy)]
pub enum F32Cos {}
impl G for F32Cos {
type Scalar = F32;
const DISTANCE: Distance = Distance::Cos;
type L2 = F32L2;
fn distance(lhs: &[F32], rhs: &[F32]) -> F32 {
cosine(lhs, rhs) * (-1.0)
}
fn elkan_k_means_normalize(vector: &mut [F32]) {
l2_normalize(vector)
}
fn elkan_k_means_distance(lhs: &[F32], rhs: &[F32]) -> F32 {
super::f32_dot::dot(lhs, rhs).acos()
}
#[multiversion::multiversion(targets(
"x86_64/x86-64-v4",
"x86_64/x86-64-v3",
"x86_64/x86-64-v2",
"aarch64+neon"
))]
fn scalar_quantization_distance(
dims: u16,
max: &[F32],
min: &[F32],
lhs: &[F32],
rhs: &[u8],
) -> F32 {
let mut xy = F32::zero();
let mut x2 = F32::zero();
let mut y2 = F32::zero();
for i in 0..dims as usize {
let _x = lhs[i];
let _y = F32(rhs[i] as f32 / 256.0) * (max[i] - min[i]) + min[i];
xy += _x * _y;
x2 += _x * _x;
y2 += _y * _y;
}
xy / (x2 * y2).sqrt() * (-1.0)
}
#[multiversion::multiversion(targets(
"x86_64/x86-64-v4",
"x86_64/x86-64-v3",
"x86_64/x86-64-v2",
"aarch64+neon"
))]
fn scalar_quantization_distance2(
dims: u16,
max: &[F32],
min: &[F32],
lhs: &[u8],
rhs: &[u8],
) -> F32 {
let mut xy = F32::zero();
let mut x2 = F32::zero();
let mut y2 = F32::zero();
for i in 0..dims as usize {
let _x = F32(lhs[i] as f32 / 256.0) * (max[i] - min[i]) + min[i];
let _y = F32(rhs[i] as f32 / 256.0) * (max[i] - min[i]) + min[i];
xy += _x * _y;
x2 += _x * _x;
y2 += _y * _y;
}
xy / (x2 * y2).sqrt() * (-1.0)
}
#[multiversion::multiversion(targets(
"x86_64/x86-64-v4",
"x86_64/x86-64-v3",
"x86_64/x86-64-v2",
"aarch64+neon"
))]
fn product_quantization_distance(
dims: u16,
ratio: u16,
centroids: &[F32],
lhs: &[F32],
rhs: &[u8],
) -> F32 {
let width = dims.div_ceil(ratio);
let mut xy = F32::zero();
let mut x2 = F32::zero();
let mut y2 = F32::zero();
for i in 0..width {
let k = std::cmp::min(ratio, dims - ratio * i);
let lhs = &lhs[(i * ratio) as usize..][..k as usize];
let rhsp = rhs[i as usize] as usize * dims as usize;
let rhs = &centroids[rhsp..][(i * ratio) as usize..][..k as usize];
let (_xy, _x2, _y2) = xy_x2_y2(lhs, rhs);
xy += _xy;
x2 += _x2;
y2 += _y2;
}
xy / (x2 * y2).sqrt() * (-1.0)
}
#[multiversion::multiversion(targets(
"x86_64/x86-64-v4",
"x86_64/x86-64-v3",
"x86_64/x86-64-v2",
"aarch64+neon"
))]
fn product_quantization_distance2(
dims: u16,
ratio: u16,
centroids: &[F32],
lhs: &[u8],
rhs: &[u8],
) -> F32 {
let width = dims.div_ceil(ratio);
let mut xy = F32::zero();
let mut x2 = F32::zero();
let mut y2 = F32::zero();
for i in 0..width {
let k = std::cmp::min(ratio, dims - ratio * i);
let lhsp = lhs[i as usize] as usize * dims as usize;
let lhs = &centroids[lhsp..][(i * ratio) as usize..][..k as usize];
let rhsp = rhs[i as usize] as usize * dims as usize;
let rhs = &centroids[rhsp..][(i * ratio) as usize..][..k as usize];
let (_xy, _x2, _y2) = xy_x2_y2(lhs, rhs);
xy += _xy;
x2 += _x2;
y2 += _y2;
}
xy / (x2 * y2).sqrt() * (-1.0)
}
#[multiversion::multiversion(targets(
"x86_64/x86-64-v4",
"x86_64/x86-64-v3",
"x86_64/x86-64-v2",
"aarch64+neon"
))]
fn product_quantization_distance_with_delta(
dims: u16,
ratio: u16,
centroids: &[F32],
lhs: &[F32],
rhs: &[u8],
delta: &[F32],
) -> F32 {
let width = dims.div_ceil(ratio);
let mut xy = F32::zero();
let mut x2 = F32::zero();
let mut y2 = F32::zero();
for i in 0..width {
let k = std::cmp::min(ratio, dims - ratio * i);
let lhs = &lhs[(i * ratio) as usize..][..k as usize];
let rhsp = rhs[i as usize] as usize * dims as usize;
let rhs = &centroids[rhsp..][(i * ratio) as usize..][..k as usize];
let del = &delta[(i * ratio) as usize..][..k as usize];
let (_xy, _x2, _y2) = xy_x2_y2_delta(lhs, rhs, del);
xy += _xy;
x2 += _x2;
y2 += _y2;
}
xy / (x2 * y2).sqrt() * (-1.0)
}
}
#[inline(always)]
#[multiversion::multiversion(targets(
"x86_64/x86-64-v4",
"x86_64/x86-64-v3",
"x86_64/x86-64-v2",
"aarch64+neon"
))]
fn length(vector: &[F32]) -> F32 {
let n = vector.len();
let mut dot = F32::zero();
for i in 0..n {
dot += vector[i] * vector[i];
}
dot.sqrt()
}
#[inline(always)]
#[multiversion::multiversion(targets(
"x86_64/x86-64-v4",
"x86_64/x86-64-v3",
"x86_64/x86-64-v2",
"aarch64+neon"
))]
fn l2_normalize(vector: &mut [F32]) {
let n = vector.len();
let l = length(vector);
for i in 0..n {
vector[i] /= l;
}
}
#[inline(always)]
#[multiversion::multiversion(targets(
"x86_64/x86-64-v4",
"x86_64/x86-64-v3",
"x86_64/x86-64-v2",
"aarch64+neon"
))]
fn cosine(lhs: &[F32], rhs: &[F32]) -> F32 {
assert!(lhs.len() == rhs.len());
let n = lhs.len();
let mut xy = F32::zero();
let mut x2 = F32::zero();
let mut y2 = F32::zero();
for i in 0..n {
xy += lhs[i] * rhs[i];
x2 += lhs[i] * lhs[i];
y2 += rhs[i] * rhs[i];
}
xy / (x2 * y2).sqrt()
}
#[inline(always)]
#[multiversion::multiversion(targets(
"x86_64/x86-64-v4",
"x86_64/x86-64-v3",
"x86_64/x86-64-v2",
"aarch64+neon"
))]
fn xy_x2_y2(lhs: &[F32], rhs: &[F32]) -> (F32, F32, F32) {
assert!(lhs.len() == rhs.len());
let n = lhs.len();
let mut xy = F32::zero();
let mut x2 = F32::zero();
let mut y2 = F32::zero();
for i in 0..n {
xy += lhs[i] * rhs[i];
x2 += lhs[i] * lhs[i];
y2 += rhs[i] * rhs[i];
}
(xy, x2, y2)
}
#[inline(always)]
#[multiversion::multiversion(targets(
"x86_64/x86-64-v4",
"x86_64/x86-64-v3",
"x86_64/x86-64-v2",
"aarch64+neon"
))]
fn xy_x2_y2_delta(lhs: &[F32], rhs: &[F32], del: &[F32]) -> (F32, F32, F32) {
assert!(lhs.len() == rhs.len());
let n = lhs.len();
let mut xy = F32::zero();
let mut x2 = F32::zero();
let mut y2 = F32::zero();
for i in 0..n {
xy += lhs[i] * (rhs[i] + del[i]);
x2 += lhs[i] * lhs[i];
y2 += (rhs[i] + del[i]) * (rhs[i] + del[i]);
}
(xy, x2, y2)
}

View File

@@ -0,0 +1,237 @@
use super::G;
use crate::prelude::scalar::F32;
use crate::prelude::*;
#[derive(Debug, Clone, Copy)]
pub enum F32Dot {}
impl G for F32Dot {
type Scalar = F32;
const DISTANCE: Distance = Distance::Dot;
type L2 = F32L2;
fn distance(lhs: &[F32], rhs: &[F32]) -> F32 {
dot(lhs, rhs) * (-1.0)
}
fn elkan_k_means_normalize(vector: &mut [F32]) {
l2_normalize(vector)
}
fn elkan_k_means_distance(lhs: &[F32], rhs: &[F32]) -> F32 {
super::f32_dot::dot(lhs, rhs).acos()
}
#[multiversion::multiversion(targets(
"x86_64/x86-64-v4",
"x86_64/x86-64-v3",
"x86_64/x86-64-v2",
"aarch64+neon"
))]
fn scalar_quantization_distance(
dims: u16,
max: &[F32],
min: &[F32],
lhs: &[F32],
rhs: &[u8],
) -> F32 {
let mut xy = F32::zero();
for i in 0..dims as usize {
let _x = lhs[i];
let _y = F32(rhs[i] as f32 / 256.0) * (max[i] - min[i]) + min[i];
xy += _x * _y;
}
xy * (-1.0)
}
#[multiversion::multiversion(targets(
"x86_64/x86-64-v4",
"x86_64/x86-64-v3",
"x86_64/x86-64-v2",
"aarch64+neon"
))]
fn scalar_quantization_distance2(
dims: u16,
max: &[F32],
min: &[F32],
lhs: &[u8],
rhs: &[u8],
) -> F32 {
let mut xy = F32::zero();
for i in 0..dims as usize {
let _x = F32(lhs[i] as f32 / 256.0) * (max[i] - min[i]) + min[i];
let _y = F32(rhs[i] as f32 / 256.0) * (max[i] - min[i]) + min[i];
xy += _x * _y;
}
xy * (-1.0)
}
#[multiversion::multiversion(targets(
"x86_64/x86-64-v4",
"x86_64/x86-64-v3",
"x86_64/x86-64-v2",
"aarch64+neon"
))]
fn product_quantization_distance(
dims: u16,
ratio: u16,
centroids: &[F32],
lhs: &[F32],
rhs: &[u8],
) -> F32 {
let width = dims.div_ceil(ratio);
let mut xy = F32::zero();
for i in 0..width {
let k = std::cmp::min(ratio, dims - ratio * i);
let lhs = &lhs[(i * ratio) as usize..][..k as usize];
let rhsp = rhs[i as usize] as usize * dims as usize;
let rhs = &centroids[rhsp..][(i * ratio) as usize..][..k as usize];
let _xy = dot(lhs, rhs);
xy += _xy;
}
xy * (-1.0)
}
#[multiversion::multiversion(targets(
"x86_64/x86-64-v4",
"x86_64/x86-64-v3",
"x86_64/x86-64-v2",
"aarch64+neon"
))]
fn product_quantization_distance2(
dims: u16,
ratio: u16,
centroids: &[F32],
lhs: &[u8],
rhs: &[u8],
) -> F32 {
let width = dims.div_ceil(ratio);
let mut xy = F32::zero();
for i in 0..width {
let k = std::cmp::min(ratio, dims - ratio * i);
let lhsp = lhs[i as usize] as usize * dims as usize;
let lhs = &centroids[lhsp..][(i * ratio) as usize..][..k as usize];
let rhsp = rhs[i as usize] as usize * dims as usize;
let rhs = &centroids[rhsp..][(i * ratio) as usize..][..k as usize];
let _xy = dot(lhs, rhs);
xy += _xy;
}
xy * (-1.0)
}
#[multiversion::multiversion(targets(
"x86_64/x86-64-v4",
"x86_64/x86-64-v3",
"x86_64/x86-64-v2",
"aarch64+neon"
))]
fn product_quantization_distance_with_delta(
dims: u16,
ratio: u16,
centroids: &[F32],
lhs: &[F32],
rhs: &[u8],
delta: &[F32],
) -> F32 {
let width = dims.div_ceil(ratio);
let mut xy = F32::zero();
for i in 0..width {
let k = std::cmp::min(ratio, dims - ratio * i);
let lhs = &lhs[(i * ratio) as usize..][..k as usize];
let rhsp = rhs[i as usize] as usize * dims as usize;
let rhs = &centroids[rhsp..][(i * ratio) as usize..][..k as usize];
let del = &delta[(i * ratio) as usize..][..k as usize];
let _xy = dot_delta(lhs, rhs, del);
xy += _xy;
}
xy * (-1.0)
}
}
#[inline(always)]
#[multiversion::multiversion(targets(
"x86_64/x86-64-v4",
"x86_64/x86-64-v3",
"x86_64/x86-64-v2",
"aarch64+neon"
))]
fn length(vector: &[F32]) -> F32 {
let n = vector.len();
let mut dot = F32::zero();
for i in 0..n {
dot += vector[i] * vector[i];
}
dot.sqrt()
}
#[inline(always)]
#[multiversion::multiversion(targets(
"x86_64/x86-64-v4",
"x86_64/x86-64-v3",
"x86_64/x86-64-v2",
"aarch64+neon"
))]
fn l2_normalize(vector: &mut [F32]) {
let n = vector.len();
let l = length(vector);
for i in 0..n {
vector[i] /= l;
}
}
#[inline(always)]
#[multiversion::multiversion(targets(
"x86_64/x86-64-v4",
"x86_64/x86-64-v3",
"x86_64/x86-64-v2",
"aarch64+neon"
))]
fn cosine(lhs: &[F32], rhs: &[F32]) -> F32 {
assert!(lhs.len() == rhs.len());
let n = lhs.len();
let mut xy = F32::zero();
let mut x2 = F32::zero();
let mut y2 = F32::zero();
for i in 0..n {
xy += lhs[i] * rhs[i];
x2 += lhs[i] * lhs[i];
y2 += rhs[i] * rhs[i];
}
xy / (x2 * y2).sqrt()
}
#[inline(always)]
#[multiversion::multiversion(targets(
"x86_64/x86-64-v4",
"x86_64/x86-64-v3",
"x86_64/x86-64-v2",
"aarch64+neon"
))]
pub fn dot(lhs: &[F32], rhs: &[F32]) -> F32 {
assert!(lhs.len() == rhs.len());
let n = lhs.len();
let mut xy = F32::zero();
for i in 0..n {
xy += lhs[i] * rhs[i];
}
xy
}
#[inline(always)]
#[multiversion::multiversion(targets(
"x86_64/x86-64-v4",
"x86_64/x86-64-v3",
"x86_64/x86-64-v2",
"aarch64+neon"
))]
fn dot_delta(lhs: &[F32], rhs: &[F32], del: &[F32]) -> F32 {
assert!(lhs.len() == rhs.len());
let n: usize = lhs.len();
let mut xy = F32::zero();
for i in 0..n {
xy += lhs[i] * (rhs[i] + del[i]);
}
xy
}

View File

@@ -0,0 +1,182 @@
use super::G;
use crate::prelude::scalar::F32;
use crate::prelude::*;
#[derive(Debug, Clone, Copy)]
pub enum F32L2 {}
impl G for F32L2 {
type Scalar = F32;
const DISTANCE: Distance = Distance::L2;
type L2 = F32L2;
fn distance(lhs: &[F32], rhs: &[F32]) -> F32 {
distance_squared_l2(lhs, rhs)
}
fn elkan_k_means_normalize(_: &mut [F32]) {}
fn elkan_k_means_distance(lhs: &[F32], rhs: &[F32]) -> F32 {
distance_squared_l2(lhs, rhs).sqrt()
}
#[multiversion::multiversion(targets(
"x86_64/x86-64-v4",
"x86_64/x86-64-v3",
"x86_64/x86-64-v2",
"aarch64+neon"
))]
fn scalar_quantization_distance(
dims: u16,
max: &[F32],
min: &[F32],
lhs: &[F32],
rhs: &[u8],
) -> F32 {
let mut result = F32::zero();
for i in 0..dims as usize {
let _x = lhs[i];
let _y = F32(rhs[i] as f32 / 256.0) * (max[i] - min[i]) + min[i];
result += (_x - _y) * (_x - _y);
}
result
}
#[multiversion::multiversion(targets(
"x86_64/x86-64-v4",
"x86_64/x86-64-v3",
"x86_64/x86-64-v2",
"aarch64+neon"
))]
fn scalar_quantization_distance2(
dims: u16,
max: &[F32],
min: &[F32],
lhs: &[u8],
rhs: &[u8],
) -> F32 {
let mut result = F32::zero();
for i in 0..dims as usize {
let _x = F32(lhs[i] as f32 / 256.0) * (max[i] - min[i]) + min[i];
let _y = F32(rhs[i] as f32 / 256.0) * (max[i] - min[i]) + min[i];
result += (_x - _y) * (_x - _y);
}
result
}
#[multiversion::multiversion(targets(
"x86_64/x86-64-v4",
"x86_64/x86-64-v3",
"x86_64/x86-64-v2",
"aarch64+neon"
))]
fn product_quantization_distance(
dims: u16,
ratio: u16,
centroids: &[F32],
lhs: &[F32],
rhs: &[u8],
) -> F32 {
let width = dims.div_ceil(ratio);
let mut result = F32::zero();
for i in 0..width {
let k = std::cmp::min(ratio, dims - ratio * i);
let lhs = &lhs[(i * ratio) as usize..][..k as usize];
let rhsp = rhs[i as usize] as usize * dims as usize;
let rhs = &centroids[rhsp..][(i * ratio) as usize..][..k as usize];
result += distance_squared_l2(lhs, rhs);
}
result
}
#[multiversion::multiversion(targets(
"x86_64/x86-64-v4",
"x86_64/x86-64-v3",
"x86_64/x86-64-v2",
"aarch64+neon"
))]
fn product_quantization_distance2(
dims: u16,
ratio: u16,
centroids: &[F32],
lhs: &[u8],
rhs: &[u8],
) -> F32 {
let width = dims.div_ceil(ratio);
let mut result = F32::zero();
for i in 0..width {
let k = std::cmp::min(ratio, dims - ratio * i);
let lhsp = lhs[i as usize] as usize * dims as usize;
let lhs = &centroids[lhsp..][(i * ratio) as usize..][..k as usize];
let rhsp = rhs[i as usize] as usize * dims as usize;
let rhs = &centroids[rhsp..][(i * ratio) as usize..][..k as usize];
result += distance_squared_l2(lhs, rhs);
}
result
}
#[multiversion::multiversion(targets(
"x86_64/x86-64-v4",
"x86_64/x86-64-v3",
"x86_64/x86-64-v2",
"aarch64+neon"
))]
fn product_quantization_distance_with_delta(
dims: u16,
ratio: u16,
centroids: &[F32],
lhs: &[F32],
rhs: &[u8],
delta: &[F32],
) -> F32 {
let width = dims.div_ceil(ratio);
let mut result = F32::zero();
for i in 0..width {
let k = std::cmp::min(ratio, dims - ratio * i);
let lhs = &lhs[(i * ratio) as usize..][..k as usize];
let rhsp = rhs[i as usize] as usize * dims as usize;
let rhs = &centroids[rhsp..][(i * ratio) as usize..][..k as usize];
let del = &delta[(i * ratio) as usize..][..k as usize];
result += distance_squared_l2_delta(lhs, rhs, del);
}
result
}
}
#[inline(always)]
#[multiversion::multiversion(targets(
"x86_64/x86-64-v4",
"x86_64/x86-64-v3",
"x86_64/x86-64-v2",
"aarch64+neon"
))]
pub fn distance_squared_l2(lhs: &[F32], rhs: &[F32]) -> F32 {
assert!(lhs.len() == rhs.len());
let n = lhs.len();
let mut d2 = F32::zero();
for i in 0..n {
let d = lhs[i] - rhs[i];
d2 += d * d;
}
d2
}
#[inline(always)]
#[multiversion::multiversion(targets(
"x86_64/x86-64-v4",
"x86_64/x86-64-v3",
"x86_64/x86-64-v2",
"aarch64+neon"
))]
fn distance_squared_l2_delta(lhs: &[F32], rhs: &[F32], del: &[F32]) -> F32 {
assert!(lhs.len() == rhs.len());
let n = lhs.len();
let mut d2 = F32::zero();
for i in 0..n {
let d = lhs[i] - (rhs[i] + del[i]);
d2 += d * d;
}
d2
}

View File

@@ -0,0 +1,121 @@
mod f16;
mod f16_cos;
mod f16_dot;
mod f16_l2;
mod f32_cos;
mod f32_dot;
mod f32_l2;
pub use f16_cos::F16Cos;
pub use f16_dot::F16Dot;
pub use f16_l2::F16L2;
pub use f32_cos::F32Cos;
pub use f32_dot::F32Dot;
pub use f32_l2::F32L2;
use crate::prelude::*;
use serde::{Deserialize, Serialize};
use std::fmt::Debug;
pub trait G: Copy + std::fmt::Debug + 'static {
type Scalar: Copy
+ Send
+ Sync
+ std::fmt::Debug
+ std::fmt::Display
+ serde::Serialize
+ for<'a> serde::Deserialize<'a>
+ Ord
+ bytemuck::Zeroable
+ bytemuck::Pod
+ num_traits::Float
+ num_traits::NumOps
+ num_traits::NumAssignOps
+ FloatCast;
const DISTANCE: Distance;
type L2: G<Scalar = Self::Scalar>;
fn distance(lhs: &[Self::Scalar], rhs: &[Self::Scalar]) -> F32;
fn elkan_k_means_normalize(vector: &mut [Self::Scalar]);
fn elkan_k_means_distance(lhs: &[Self::Scalar], rhs: &[Self::Scalar]) -> F32;
fn scalar_quantization_distance(
dims: u16,
max: &[Self::Scalar],
min: &[Self::Scalar],
lhs: &[Self::Scalar],
rhs: &[u8],
) -> F32;
fn scalar_quantization_distance2(
dims: u16,
max: &[Self::Scalar],
min: &[Self::Scalar],
lhs: &[u8],
rhs: &[u8],
) -> F32;
fn product_quantization_distance(
dims: u16,
ratio: u16,
centroids: &[Self::Scalar],
lhs: &[Self::Scalar],
rhs: &[u8],
) -> F32;
fn product_quantization_distance2(
dims: u16,
ratio: u16,
centroids: &[Self::Scalar],
lhs: &[u8],
rhs: &[u8],
) -> F32;
fn product_quantization_distance_with_delta(
dims: u16,
ratio: u16,
centroids: &[Self::Scalar],
lhs: &[Self::Scalar],
rhs: &[u8],
delta: &[Self::Scalar],
) -> F32;
}
pub trait FloatCast: Sized {
fn from_f32(x: f32) -> Self;
fn to_f32(self) -> f32;
fn from_f(x: F32) -> Self {
Self::from_f32(x.0)
}
fn to_f(self) -> F32 {
F32(Self::to_f32(self))
}
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub enum DynamicVector {
F32(Vec<F32>),
F16(Vec<F16>),
}
impl From<Vec<F32>> for DynamicVector {
fn from(value: Vec<F32>) -> Self {
Self::F32(value)
}
}
impl From<Vec<F16>> for DynamicVector {
fn from(value: Vec<F16>) -> Self {
Self::F16(value)
}
}
#[repr(u8)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
pub enum Distance {
L2,
Cos,
Dot,
}
#[repr(u8)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
pub enum Kind {
F32,
F16,
}

View File

@@ -1,9 +1,9 @@
use crate::prelude::{Payload, Scalar};
use crate::prelude::{Payload, F32};
use std::{cmp::Reverse, collections::BinaryHeap};
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub struct HeapElement {
pub distance: Scalar,
pub distance: F32,
pub payload: Payload,
}
@@ -20,7 +20,7 @@ impl Heap {
k,
}
}
pub fn check(&self, distance: Scalar) -> bool {
pub fn check(&self, distance: F32) -> bool {
self.binary_heap.len() < self.k || distance < self.binary_heap.peek().unwrap().distance
}
pub fn push(&mut self, element: HeapElement) -> Option<HeapElement> {

View File

@@ -0,0 +1,16 @@
mod error;
mod filter;
mod global;
mod heap;
mod scalar;
mod sys;
pub use self::error::{FriendlyError, FriendlyErrorLike, FriendlyResult};
pub use self::global::*;
pub use self::scalar::{F16, F32};
pub use self::filter::{Filter, Payload};
pub use self::heap::{Heap, HeapElement};
pub use self::sys::{Id, Pointer};
pub use num_traits::{Float, Zero};

View File

@@ -0,0 +1,653 @@
use crate::prelude::global::FloatCast;
use half::f16;
use serde::{Deserialize, Serialize};
use std::cmp::Ordering;
use std::fmt::{Debug, Display};
use std::num::ParseFloatError;
use std::ops::*;
use std::str::FromStr;
#[derive(Clone, Copy, Default, Serialize, Deserialize)]
#[repr(transparent)]
#[serde(transparent)]
pub struct F16(pub f16);
impl Debug for F16 {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
Debug::fmt(&self.0, f)
}
}
impl Display for F16 {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
Display::fmt(&self.0, f)
}
}
impl PartialEq for F16 {
fn eq(&self, other: &Self) -> bool {
self.0.total_cmp(&other.0) == Ordering::Equal
}
}
impl Eq for F16 {}
impl PartialOrd for F16 {
#[inline(always)]
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(Ord::cmp(self, other))
}
}
impl Ord for F16 {
#[inline(always)]
fn cmp(&self, other: &Self) -> Ordering {
self.0.total_cmp(&other.0)
}
}
unsafe impl bytemuck::Zeroable for F16 {}
unsafe impl bytemuck::Pod for F16 {}
impl num_traits::Zero for F16 {
fn zero() -> Self {
Self(f16::zero())
}
fn is_zero(&self) -> bool {
self.0.is_zero()
}
}
impl num_traits::One for F16 {
fn one() -> Self {
Self(f16::one())
}
}
impl num_traits::FromPrimitive for F16 {
fn from_i64(n: i64) -> Option<Self> {
f16::from_i64(n).map(Self)
}
fn from_u64(n: u64) -> Option<Self> {
f16::from_u64(n).map(Self)
}
fn from_isize(n: isize) -> Option<Self> {
f16::from_isize(n).map(Self)
}
fn from_i8(n: i8) -> Option<Self> {
f16::from_i8(n).map(Self)
}
fn from_i16(n: i16) -> Option<Self> {
f16::from_i16(n).map(Self)
}
fn from_i32(n: i32) -> Option<Self> {
f16::from_i32(n).map(Self)
}
fn from_i128(n: i128) -> Option<Self> {
f16::from_i128(n).map(Self)
}
fn from_usize(n: usize) -> Option<Self> {
f16::from_usize(n).map(Self)
}
fn from_u8(n: u8) -> Option<Self> {
f16::from_u8(n).map(Self)
}
fn from_u16(n: u16) -> Option<Self> {
f16::from_u16(n).map(Self)
}
fn from_u32(n: u32) -> Option<Self> {
f16::from_u32(n).map(Self)
}
fn from_u128(n: u128) -> Option<Self> {
f16::from_u128(n).map(Self)
}
fn from_f32(n: f32) -> Option<Self> {
Some(Self(f16::from_f32(n)))
}
fn from_f64(n: f64) -> Option<Self> {
Some(Self(f16::from_f64(n)))
}
}
impl num_traits::ToPrimitive for F16 {
fn to_i64(&self) -> Option<i64> {
self.0.to_i64()
}
fn to_u64(&self) -> Option<u64> {
self.0.to_u64()
}
fn to_isize(&self) -> Option<isize> {
self.0.to_isize()
}
fn to_i8(&self) -> Option<i8> {
self.0.to_i8()
}
fn to_i16(&self) -> Option<i16> {
self.0.to_i16()
}
fn to_i32(&self) -> Option<i32> {
self.0.to_i32()
}
fn to_i128(&self) -> Option<i128> {
self.0.to_i128()
}
fn to_usize(&self) -> Option<usize> {
self.0.to_usize()
}
fn to_u8(&self) -> Option<u8> {
self.0.to_u8()
}
fn to_u16(&self) -> Option<u16> {
self.0.to_u16()
}
fn to_u32(&self) -> Option<u32> {
self.0.to_u32()
}
fn to_u128(&self) -> Option<u128> {
self.0.to_u128()
}
fn to_f32(&self) -> Option<f32> {
Some(self.0.to_f32())
}
fn to_f64(&self) -> Option<f64> {
Some(self.0.to_f64())
}
}
impl num_traits::NumCast for F16 {
fn from<T: num_traits::ToPrimitive>(n: T) -> Option<Self> {
num_traits::NumCast::from(n).map(Self)
}
}
impl num_traits::Num for F16 {
type FromStrRadixErr = <f16 as num_traits::Num>::FromStrRadixErr;
fn from_str_radix(str: &str, radix: u32) -> Result<Self, Self::FromStrRadixErr> {
f16::from_str_radix(str, radix).map(Self)
}
}
impl num_traits::Float for F16 {
fn nan() -> Self {
Self(f16::nan())
}
fn infinity() -> Self {
Self(f16::infinity())
}
fn neg_infinity() -> Self {
Self(f16::neg_infinity())
}
fn neg_zero() -> Self {
Self(f16::neg_zero())
}
fn min_value() -> Self {
Self(f16::min_value())
}
fn min_positive_value() -> Self {
Self(f16::min_positive_value())
}
fn max_value() -> Self {
Self(f16::max_value())
}
fn is_nan(self) -> bool {
self.0.is_nan()
}
fn is_infinite(self) -> bool {
self.0.is_infinite()
}
fn is_finite(self) -> bool {
self.0.is_finite()
}
fn is_normal(self) -> bool {
self.0.is_normal()
}
fn classify(self) -> std::num::FpCategory {
self.0.classify()
}
fn floor(self) -> Self {
Self(self.0.floor())
}
fn ceil(self) -> Self {
Self(self.0.ceil())
}
fn round(self) -> Self {
Self(self.0.round())
}
fn trunc(self) -> Self {
Self(self.0.trunc())
}
fn fract(self) -> Self {
Self(self.0.fract())
}
fn abs(self) -> Self {
Self(self.0.abs())
}
fn signum(self) -> Self {
Self(self.0.signum())
}
fn is_sign_positive(self) -> bool {
self.0.is_sign_positive()
}
fn is_sign_negative(self) -> bool {
self.0.is_sign_negative()
}
fn mul_add(self, a: Self, b: Self) -> Self {
Self(self.0.mul_add(a.0, b.0))
}
fn recip(self) -> Self {
Self(self.0.recip())
}
fn powi(self, n: i32) -> Self {
Self(self.0.powi(n))
}
fn powf(self, n: Self) -> Self {
Self(self.0.powf(n.0))
}
fn sqrt(self) -> Self {
Self(self.0.sqrt())
}
fn exp(self) -> Self {
Self(self.0.exp())
}
fn exp2(self) -> Self {
Self(self.0.exp2())
}
fn ln(self) -> Self {
Self(self.0.ln())
}
fn log(self, base: Self) -> Self {
Self(self.0.log(base.0))
}
fn log2(self) -> Self {
Self(self.0.log2())
}
fn log10(self) -> Self {
Self(self.0.log10())
}
fn max(self, other: Self) -> Self {
Self(self.0.max(other.0))
}
fn min(self, other: Self) -> Self {
Self(self.0.min(other.0))
}
fn abs_sub(self, _: Self) -> Self {
unimplemented!()
}
fn cbrt(self) -> Self {
Self(self.0.cbrt())
}
fn hypot(self, other: Self) -> Self {
Self(self.0.hypot(other.0))
}
fn sin(self) -> Self {
Self(self.0.sin())
}
fn cos(self) -> Self {
Self(self.0.cos())
}
fn tan(self) -> Self {
Self(self.0.tan())
}
fn asin(self) -> Self {
Self(self.0.asin())
}
fn acos(self) -> Self {
Self(self.0.acos())
}
fn atan(self) -> Self {
Self(self.0.atan())
}
fn atan2(self, other: Self) -> Self {
Self(self.0.atan2(other.0))
}
fn sin_cos(self) -> (Self, Self) {
let (_x, _y) = self.0.sin_cos();
(Self(_x), Self(_y))
}
fn exp_m1(self) -> Self {
Self(self.0.exp_m1())
}
fn ln_1p(self) -> Self {
Self(self.0.ln_1p())
}
fn sinh(self) -> Self {
Self(self.0.sinh())
}
fn cosh(self) -> Self {
Self(self.0.cosh())
}
fn tanh(self) -> Self {
Self(self.0.tanh())
}
fn asinh(self) -> Self {
Self(self.0.asinh())
}
fn acosh(self) -> Self {
Self(self.0.acosh())
}
fn atanh(self) -> Self {
Self(self.0.atanh())
}
fn integer_decode(self) -> (u64, i16, i8) {
self.0.integer_decode()
}
fn epsilon() -> Self {
Self(f16::EPSILON)
}
fn is_subnormal(self) -> bool {
self.0.classify() == std::num::FpCategory::Subnormal
}
fn to_degrees(self) -> Self {
Self(self.0.to_degrees())
}
fn to_radians(self) -> Self {
Self(self.0.to_radians())
}
fn copysign(self, sign: Self) -> Self {
Self(self.0.copysign(sign.0))
}
}
impl Add<F16> for F16 {
type Output = F16;
#[inline(always)]
fn add(self, rhs: F16) -> F16 {
unsafe { self::intrinsics::fadd_fast(self.0, rhs.0).into() }
}
}
impl AddAssign<F16> for F16 {
#[inline(always)]
fn add_assign(&mut self, rhs: F16) {
unsafe { self.0 = self::intrinsics::fadd_fast(self.0, rhs.0) }
}
}
impl Sub<F16> for F16 {
type Output = F16;
#[inline(always)]
fn sub(self, rhs: F16) -> F16 {
unsafe { self::intrinsics::fsub_fast(self.0, rhs.0).into() }
}
}
impl SubAssign<F16> for F16 {
#[inline(always)]
fn sub_assign(&mut self, rhs: F16) {
unsafe { self.0 = self::intrinsics::fsub_fast(self.0, rhs.0) }
}
}
impl Mul<F16> for F16 {
type Output = F16;
#[inline(always)]
fn mul(self, rhs: F16) -> F16 {
unsafe { self::intrinsics::fmul_fast(self.0, rhs.0).into() }
}
}
impl MulAssign<F16> for F16 {
#[inline(always)]
fn mul_assign(&mut self, rhs: F16) {
unsafe { self.0 = self::intrinsics::fmul_fast(self.0, rhs.0) }
}
}
impl Div<F16> for F16 {
type Output = F16;
#[inline(always)]
fn div(self, rhs: F16) -> F16 {
unsafe { self::intrinsics::fdiv_fast(self.0, rhs.0).into() }
}
}
impl DivAssign<F16> for F16 {
#[inline(always)]
fn div_assign(&mut self, rhs: F16) {
unsafe { self.0 = self::intrinsics::fdiv_fast(self.0, rhs.0) }
}
}
impl Rem<F16> for F16 {
type Output = F16;
#[inline(always)]
fn rem(self, rhs: F16) -> F16 {
unsafe { self::intrinsics::frem_fast(self.0, rhs.0).into() }
}
}
impl RemAssign<F16> for F16 {
#[inline(always)]
fn rem_assign(&mut self, rhs: F16) {
unsafe { self.0 = self::intrinsics::frem_fast(self.0, rhs.0) }
}
}
impl Neg for F16 {
type Output = Self;
fn neg(self) -> Self::Output {
Self(self.0.neg())
}
}
impl FromStr for F16 {
type Err = ParseFloatError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
f16::from_str(s).map(|x| x.into())
}
}
impl FloatCast for F16 {
fn from_f32(x: f32) -> Self {
Self(f16::from_f32(x))
}
fn to_f32(self) -> f32 {
f16::to_f32(self.0)
}
}
impl From<f16> for F16 {
fn from(value: f16) -> Self {
Self(value)
}
}
impl From<F16> for f16 {
fn from(F16(float): F16) -> Self {
float
}
}
impl Add<f16> for F16 {
type Output = F16;
#[inline(always)]
fn add(self, rhs: f16) -> F16 {
unsafe { self::intrinsics::fadd_fast(self.0, rhs).into() }
}
}
impl AddAssign<f16> for F16 {
fn add_assign(&mut self, rhs: f16) {
unsafe { self.0 = self::intrinsics::fadd_fast(self.0, rhs) }
}
}
impl Sub<f16> for F16 {
type Output = F16;
#[inline(always)]
fn sub(self, rhs: f16) -> F16 {
unsafe { self::intrinsics::fsub_fast(self.0, rhs).into() }
}
}
impl SubAssign<f16> for F16 {
#[inline(always)]
fn sub_assign(&mut self, rhs: f16) {
unsafe { self.0 = self::intrinsics::fsub_fast(self.0, rhs) }
}
}
impl Mul<f16> for F16 {
type Output = F16;
#[inline(always)]
fn mul(self, rhs: f16) -> F16 {
unsafe { self::intrinsics::fmul_fast(self.0, rhs).into() }
}
}
impl MulAssign<f16> for F16 {
#[inline(always)]
fn mul_assign(&mut self, rhs: f16) {
unsafe { self.0 = self::intrinsics::fmul_fast(self.0, rhs) }
}
}
impl Div<f16> for F16 {
type Output = F16;
#[inline(always)]
fn div(self, rhs: f16) -> F16 {
unsafe { self::intrinsics::fdiv_fast(self.0, rhs).into() }
}
}
impl DivAssign<f16> for F16 {
#[inline(always)]
fn div_assign(&mut self, rhs: f16) {
unsafe { self.0 = self::intrinsics::fdiv_fast(self.0, rhs) }
}
}
impl Rem<f16> for F16 {
type Output = F16;
#[inline(always)]
fn rem(self, rhs: f16) -> F16 {
unsafe { self::intrinsics::frem_fast(self.0, rhs).into() }
}
}
impl RemAssign<f16> for F16 {
#[inline(always)]
fn rem_assign(&mut self, rhs: f16) {
unsafe { self.0 = self::intrinsics::frem_fast(self.0, rhs) }
}
}
mod intrinsics {
use half::f16;
pub unsafe fn fadd_fast(lhs: f16, rhs: f16) -> f16 {
lhs + rhs
}
pub unsafe fn fsub_fast(lhs: f16, rhs: f16) -> f16 {
lhs - rhs
}
pub unsafe fn fmul_fast(lhs: f16, rhs: f16) -> f16 {
lhs * rhs
}
pub unsafe fn fdiv_fast(lhs: f16, rhs: f16) -> f16 {
lhs / rhs
}
pub unsafe fn frem_fast(lhs: f16, rhs: f16) -> f16 {
lhs % rhs
}
}

View File

@@ -0,0 +1,632 @@
use crate::prelude::global::FloatCast;
use serde::{Deserialize, Serialize};
use std::cmp::Ordering;
use std::fmt::{Debug, Display};
use std::num::ParseFloatError;
use std::ops::*;
use std::str::FromStr;
#[derive(Clone, Copy, Default, Serialize, Deserialize)]
#[repr(transparent)]
#[serde(transparent)]
pub struct F32(pub f32);
impl Debug for F32 {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
Debug::fmt(&self.0, f)
}
}
impl Display for F32 {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
Display::fmt(&self.0, f)
}
}
impl PartialEq for F32 {
fn eq(&self, other: &Self) -> bool {
self.0.total_cmp(&other.0) == Ordering::Equal
}
}
impl Eq for F32 {}
impl PartialOrd for F32 {
#[inline(always)]
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(Ord::cmp(self, other))
}
}
impl Ord for F32 {
#[inline(always)]
fn cmp(&self, other: &Self) -> Ordering {
self.0.total_cmp(&other.0)
}
}
unsafe impl bytemuck::Zeroable for F32 {}
unsafe impl bytemuck::Pod for F32 {}
impl num_traits::Zero for F32 {
fn zero() -> Self {
Self(f32::zero())
}
fn is_zero(&self) -> bool {
self.0.is_zero()
}
}
impl num_traits::One for F32 {
fn one() -> Self {
Self(f32::one())
}
}
impl num_traits::FromPrimitive for F32 {
fn from_i64(n: i64) -> Option<Self> {
f32::from_i64(n).map(Self)
}
fn from_u64(n: u64) -> Option<Self> {
f32::from_u64(n).map(Self)
}
fn from_isize(n: isize) -> Option<Self> {
f32::from_isize(n).map(Self)
}
fn from_i8(n: i8) -> Option<Self> {
f32::from_i8(n).map(Self)
}
fn from_i16(n: i16) -> Option<Self> {
f32::from_i16(n).map(Self)
}
fn from_i32(n: i32) -> Option<Self> {
f32::from_i32(n).map(Self)
}
fn from_i128(n: i128) -> Option<Self> {
f32::from_i128(n).map(Self)
}
fn from_usize(n: usize) -> Option<Self> {
f32::from_usize(n).map(Self)
}
fn from_u8(n: u8) -> Option<Self> {
f32::from_u8(n).map(Self)
}
fn from_u16(n: u16) -> Option<Self> {
f32::from_u16(n).map(Self)
}
fn from_u32(n: u32) -> Option<Self> {
f32::from_u32(n).map(Self)
}
fn from_u128(n: u128) -> Option<Self> {
f32::from_u128(n).map(Self)
}
fn from_f32(n: f32) -> Option<Self> {
f32::from_f32(n).map(Self)
}
fn from_f64(n: f64) -> Option<Self> {
f32::from_f64(n).map(Self)
}
}
impl num_traits::ToPrimitive for F32 {
fn to_i64(&self) -> Option<i64> {
self.0.to_i64()
}
fn to_u64(&self) -> Option<u64> {
self.0.to_u64()
}
fn to_isize(&self) -> Option<isize> {
self.0.to_isize()
}
fn to_i8(&self) -> Option<i8> {
self.0.to_i8()
}
fn to_i16(&self) -> Option<i16> {
self.0.to_i16()
}
fn to_i32(&self) -> Option<i32> {
self.0.to_i32()
}
fn to_i128(&self) -> Option<i128> {
self.0.to_i128()
}
fn to_usize(&self) -> Option<usize> {
self.0.to_usize()
}
fn to_u8(&self) -> Option<u8> {
self.0.to_u8()
}
fn to_u16(&self) -> Option<u16> {
self.0.to_u16()
}
fn to_u32(&self) -> Option<u32> {
self.0.to_u32()
}
fn to_u128(&self) -> Option<u128> {
self.0.to_u128()
}
fn to_f32(&self) -> Option<f32> {
self.0.to_f32()
}
fn to_f64(&self) -> Option<f64> {
self.0.to_f64()
}
}
impl num_traits::NumCast for F32 {
fn from<T: num_traits::ToPrimitive>(n: T) -> Option<Self> {
num_traits::NumCast::from(n).map(Self)
}
}
impl num_traits::Num for F32 {
type FromStrRadixErr = <f32 as num_traits::Num>::FromStrRadixErr;
fn from_str_radix(str: &str, radix: u32) -> Result<Self, Self::FromStrRadixErr> {
f32::from_str_radix(str, radix).map(Self)
}
}
impl num_traits::Float for F32 {
fn nan() -> Self {
Self(f32::nan())
}
fn infinity() -> Self {
Self(f32::infinity())
}
fn neg_infinity() -> Self {
Self(f32::neg_infinity())
}
fn neg_zero() -> Self {
Self(f32::neg_zero())
}
fn min_value() -> Self {
Self(f32::min_value())
}
fn min_positive_value() -> Self {
Self(f32::min_positive_value())
}
fn max_value() -> Self {
Self(f32::max_value())
}
fn is_nan(self) -> bool {
self.0.is_nan()
}
fn is_infinite(self) -> bool {
self.0.is_infinite()
}
fn is_finite(self) -> bool {
self.0.is_finite()
}
fn is_normal(self) -> bool {
self.0.is_normal()
}
fn classify(self) -> std::num::FpCategory {
self.0.classify()
}
fn floor(self) -> Self {
Self(self.0.floor())
}
fn ceil(self) -> Self {
Self(self.0.ceil())
}
fn round(self) -> Self {
Self(self.0.round())
}
fn trunc(self) -> Self {
Self(self.0.trunc())
}
fn fract(self) -> Self {
Self(self.0.fract())
}
fn abs(self) -> Self {
Self(self.0.abs())
}
fn signum(self) -> Self {
Self(self.0.signum())
}
fn is_sign_positive(self) -> bool {
self.0.is_sign_positive()
}
fn is_sign_negative(self) -> bool {
self.0.is_sign_negative()
}
fn mul_add(self, a: Self, b: Self) -> Self {
Self(self.0.mul_add(a.0, b.0))
}
fn recip(self) -> Self {
Self(self.0.recip())
}
fn powi(self, n: i32) -> Self {
Self(self.0.powi(n))
}
fn powf(self, n: Self) -> Self {
Self(self.0.powf(n.0))
}
fn sqrt(self) -> Self {
Self(self.0.sqrt())
}
fn exp(self) -> Self {
Self(self.0.exp())
}
fn exp2(self) -> Self {
Self(self.0.exp2())
}
fn ln(self) -> Self {
Self(self.0.ln())
}
fn log(self, base: Self) -> Self {
Self(self.0.log(base.0))
}
fn log2(self) -> Self {
Self(self.0.log2())
}
fn log10(self) -> Self {
Self(self.0.log10())
}
fn max(self, other: Self) -> Self {
Self(self.0.max(other.0))
}
fn min(self, other: Self) -> Self {
Self(self.0.min(other.0))
}
fn abs_sub(self, _: Self) -> Self {
unimplemented!()
}
fn cbrt(self) -> Self {
Self(self.0.cbrt())
}
fn hypot(self, other: Self) -> Self {
Self(self.0.hypot(other.0))
}
fn sin(self) -> Self {
Self(self.0.sin())
}
fn cos(self) -> Self {
Self(self.0.cos())
}
fn tan(self) -> Self {
Self(self.0.tan())
}
fn asin(self) -> Self {
Self(self.0.asin())
}
fn acos(self) -> Self {
Self(self.0.acos())
}
fn atan(self) -> Self {
Self(self.0.atan())
}
fn atan2(self, other: Self) -> Self {
Self(self.0.atan2(other.0))
}
fn sin_cos(self) -> (Self, Self) {
let (_x, _y) = self.0.sin_cos();
(Self(_x), Self(_y))
}
fn exp_m1(self) -> Self {
Self(self.0.exp_m1())
}
fn ln_1p(self) -> Self {
Self(self.0.ln_1p())
}
fn sinh(self) -> Self {
Self(self.0.sinh())
}
fn cosh(self) -> Self {
Self(self.0.cosh())
}
fn tanh(self) -> Self {
Self(self.0.tanh())
}
fn asinh(self) -> Self {
Self(self.0.asinh())
}
fn acosh(self) -> Self {
Self(self.0.acosh())
}
fn atanh(self) -> Self {
Self(self.0.atanh())
}
fn integer_decode(self) -> (u64, i16, i8) {
self.0.integer_decode()
}
fn epsilon() -> Self {
Self(f32::EPSILON)
}
fn is_subnormal(self) -> bool {
self.0.classify() == std::num::FpCategory::Subnormal
}
fn to_degrees(self) -> Self {
Self(self.0.to_degrees())
}
fn to_radians(self) -> Self {
Self(self.0.to_radians())
}
fn copysign(self, sign: Self) -> Self {
Self(self.0.copysign(sign.0))
}
}
impl Add<F32> for F32 {
type Output = F32;
#[inline(always)]
fn add(self, rhs: F32) -> F32 {
unsafe { std::intrinsics::fadd_fast(self.0, rhs.0).into() }
}
}
impl AddAssign<F32> for F32 {
#[inline(always)]
fn add_assign(&mut self, rhs: F32) {
unsafe { self.0 = std::intrinsics::fadd_fast(self.0, rhs.0) }
}
}
impl Sub<F32> for F32 {
type Output = F32;
#[inline(always)]
fn sub(self, rhs: F32) -> F32 {
unsafe { std::intrinsics::fsub_fast(self.0, rhs.0).into() }
}
}
impl SubAssign<F32> for F32 {
#[inline(always)]
fn sub_assign(&mut self, rhs: F32) {
unsafe { self.0 = std::intrinsics::fsub_fast(self.0, rhs.0) }
}
}
impl Mul<F32> for F32 {
type Output = F32;
#[inline(always)]
fn mul(self, rhs: F32) -> F32 {
unsafe { std::intrinsics::fmul_fast(self.0, rhs.0).into() }
}
}
impl MulAssign<F32> for F32 {
#[inline(always)]
fn mul_assign(&mut self, rhs: F32) {
unsafe { self.0 = std::intrinsics::fmul_fast(self.0, rhs.0) }
}
}
impl Div<F32> for F32 {
type Output = F32;
#[inline(always)]
fn div(self, rhs: F32) -> F32 {
unsafe { std::intrinsics::fdiv_fast(self.0, rhs.0).into() }
}
}
impl DivAssign<F32> for F32 {
#[inline(always)]
fn div_assign(&mut self, rhs: F32) {
unsafe { self.0 = std::intrinsics::fdiv_fast(self.0, rhs.0) }
}
}
impl Rem<F32> for F32 {
type Output = F32;
#[inline(always)]
fn rem(self, rhs: F32) -> F32 {
unsafe { std::intrinsics::frem_fast(self.0, rhs.0).into() }
}
}
impl RemAssign<F32> for F32 {
#[inline(always)]
fn rem_assign(&mut self, rhs: F32) {
unsafe { self.0 = std::intrinsics::frem_fast(self.0, rhs.0) }
}
}
impl Neg for F32 {
type Output = Self;
fn neg(self) -> Self::Output {
Self(self.0.neg())
}
}
impl FromStr for F32 {
type Err = ParseFloatError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
f32::from_str(s).map(|x| x.into())
}
}
impl FloatCast for F32 {
fn from_f32(x: f32) -> Self {
Self(x)
}
fn to_f32(self) -> f32 {
self.0
}
}
impl From<f32> for F32 {
fn from(value: f32) -> Self {
Self(value)
}
}
impl From<F32> for f32 {
fn from(F32(float): F32) -> Self {
float
}
}
impl Add<f32> for F32 {
type Output = F32;
#[inline(always)]
fn add(self, rhs: f32) -> F32 {
unsafe { std::intrinsics::fadd_fast(self.0, rhs).into() }
}
}
impl AddAssign<f32> for F32 {
fn add_assign(&mut self, rhs: f32) {
unsafe { self.0 = std::intrinsics::fadd_fast(self.0, rhs) }
}
}
impl Sub<f32> for F32 {
type Output = F32;
#[inline(always)]
fn sub(self, rhs: f32) -> F32 {
unsafe { std::intrinsics::fsub_fast(self.0, rhs).into() }
}
}
impl SubAssign<f32> for F32 {
#[inline(always)]
fn sub_assign(&mut self, rhs: f32) {
unsafe { self.0 = std::intrinsics::fsub_fast(self.0, rhs) }
}
}
impl Mul<f32> for F32 {
type Output = F32;
#[inline(always)]
fn mul(self, rhs: f32) -> F32 {
unsafe { std::intrinsics::fmul_fast(self.0, rhs).into() }
}
}
impl MulAssign<f32> for F32 {
#[inline(always)]
fn mul_assign(&mut self, rhs: f32) {
unsafe { self.0 = std::intrinsics::fmul_fast(self.0, rhs) }
}
}
impl Div<f32> for F32 {
type Output = F32;
#[inline(always)]
fn div(self, rhs: f32) -> F32 {
unsafe { std::intrinsics::fdiv_fast(self.0, rhs).into() }
}
}
impl DivAssign<f32> for F32 {
#[inline(always)]
fn div_assign(&mut self, rhs: f32) {
unsafe { self.0 = std::intrinsics::fdiv_fast(self.0, rhs) }
}
}
impl Rem<f32> for F32 {
type Output = F32;
#[inline(always)]
fn rem(self, rhs: f32) -> F32 {
unsafe { std::intrinsics::frem_fast(self.0, rhs).into() }
}
}
impl RemAssign<f32> for F32 {
#[inline(always)]
fn rem_assign(&mut self, rhs: f32) {
unsafe { self.0 = std::intrinsics::frem_fast(self.0, rhs) }
}
}

View File

@@ -0,0 +1,5 @@
mod f16;
mod f32;
pub use f16::F16;
pub use f32::F32;

View File

@@ -3,15 +3,10 @@ use std::{fmt::Display, num::ParseIntError, str::FromStr};
#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
pub struct Id {
newtype: u32,
pub newtype: u32,
}
impl Id {
pub fn from_sys(sys: pgrx::pg_sys::Oid) -> Self {
Self {
newtype: sys.as_u32(),
}
}
pub fn as_u32(self) -> u32 {
self.newtype
}
@@ -35,26 +30,10 @@ impl FromStr for Id {
#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq, Serialize, Deserialize)]
pub struct Pointer {
newtype: u64,
pub newtype: u64,
}
impl Pointer {
pub fn from_sys(sys: pgrx::pg_sys::ItemPointerData) -> Self {
let mut newtype = 0;
newtype |= (sys.ip_blkid.bi_hi as u64) << 32;
newtype |= (sys.ip_blkid.bi_lo as u64) << 16;
newtype |= sys.ip_posid as u64;
Self { newtype }
}
pub fn into_sys(self) -> pgrx::pg_sys::ItemPointerData {
pgrx::pg_sys::ItemPointerData {
ip_blkid: pgrx::pg_sys::BlockIdData {
bi_hi: ((self.newtype >> 32) & 0xffff) as u16,
bi_lo: ((self.newtype >> 16) & 0xffff) as u16,
},
ip_posid: (self.newtype & 0xffff) as u16,
}
}
pub fn from_u48(value: u64) -> Self {
assert!(value < (1u64 << 48));
Self { newtype: value }

View File

@@ -0,0 +1,26 @@
use std::cell::UnsafeCell;
#[repr(transparent)]
pub struct SyncUnsafeCell<T: ?Sized> {
value: UnsafeCell<T>,
}
unsafe impl<T: ?Sized + Sync> Sync for SyncUnsafeCell<T> {}
impl<T> SyncUnsafeCell<T> {
pub const fn new(value: T) -> Self {
Self {
value: UnsafeCell::new(value),
}
}
}
impl<T: ?Sized> SyncUnsafeCell<T> {
pub fn get(&self) -> *mut T {
self.value.get()
}
pub fn get_mut(&mut self) -> &mut T {
self.value.get_mut()
}
}

View File

@@ -0,0 +1 @@
pub mod x86_64;

View File

@@ -0,0 +1,85 @@
#![cfg(target_arch = "x86_64")]
use std::sync::atomic::{AtomicBool, Ordering};
static ATOMIC_AVX512FP16: AtomicBool = AtomicBool::new(false);
pub fn test_avx512fp16() -> bool {
std_detect::is_x86_feature_detected!("avx512fp16") && test_v4()
}
#[ctor::ctor]
fn ctor_avx512fp16() {
ATOMIC_AVX512FP16.store(test_avx512fp16(), Ordering::Relaxed);
}
pub fn detect_avx512fp16() -> bool {
ATOMIC_AVX512FP16.load(Ordering::Relaxed)
}
static ATOMIC_V4: AtomicBool = AtomicBool::new(false);
pub fn test_v4() -> bool {
std_detect::is_x86_feature_detected!("avx512bw")
&& std_detect::is_x86_feature_detected!("avx512cd")
&& std_detect::is_x86_feature_detected!("avx512dq")
&& std_detect::is_x86_feature_detected!("avx512f")
&& std_detect::is_x86_feature_detected!("avx512vl")
&& test_v3()
}
#[ctor::ctor]
fn ctor_v4() {
ATOMIC_V4.store(test_v4(), Ordering::Relaxed);
}
pub fn _detect_v4() -> bool {
ATOMIC_V4.load(Ordering::Relaxed)
}
static ATOMIC_V3: AtomicBool = AtomicBool::new(false);
pub fn test_v3() -> bool {
std_detect::is_x86_feature_detected!("avx")
&& std_detect::is_x86_feature_detected!("avx2")
&& std_detect::is_x86_feature_detected!("bmi1")
&& std_detect::is_x86_feature_detected!("bmi2")
&& std_detect::is_x86_feature_detected!("f16c")
&& std_detect::is_x86_feature_detected!("fma")
&& std_detect::is_x86_feature_detected!("lzcnt")
&& std_detect::is_x86_feature_detected!("movbe")
&& std_detect::is_x86_feature_detected!("xsave")
&& test_v2()
}
#[ctor::ctor]
fn ctor_v3() {
ATOMIC_V3.store(test_v3(), Ordering::Relaxed);
}
pub fn detect_v3() -> bool {
ATOMIC_V3.load(Ordering::Relaxed)
}
static ATOMIC_V2: AtomicBool = AtomicBool::new(false);
pub fn test_v2() -> bool {
std_detect::is_x86_feature_detected!("cmpxchg16b")
&& std_detect::is_x86_feature_detected!("fxsr")
&& std_detect::is_x86_feature_detected!("popcnt")
&& std_detect::is_x86_feature_detected!("sse")
&& std_detect::is_x86_feature_detected!("sse2")
&& std_detect::is_x86_feature_detected!("sse3")
&& std_detect::is_x86_feature_detected!("sse4.1")
&& std_detect::is_x86_feature_detected!("sse4.2")
&& std_detect::is_x86_feature_detected!("ssse3")
}
#[ctor::ctor]
fn ctor_v2() {
ATOMIC_V2.store(test_v2(), Ordering::Relaxed);
}
pub fn _detect_v2() -> bool {
ATOMIC_V2.load(Ordering::Relaxed)
}

View File

@@ -111,9 +111,11 @@ fn read_information(mut file: &File) -> Information {
unsafe fn read_mmap(file: &File, len: usize) -> memmap2::Mmap {
let len = len.next_multiple_of(4096);
memmap2::MmapOptions::new()
.populate()
.len(len)
.map(file)
.unwrap()
unsafe {
memmap2::MmapOptions::new()
.populate()
.len(len)
.map(file)
.unwrap()
}
}

View File

@@ -0,0 +1,8 @@
pub mod cells;
pub mod clean;
pub mod detect;
pub mod dir_ops;
pub mod file_atomic;
pub mod file_wal;
pub mod mmap_array;
pub mod vec2;

View File

@@ -2,16 +2,16 @@ use crate::prelude::*;
use std::ops::{Deref, DerefMut, Index, IndexMut};
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct Vec2 {
pub struct Vec2<S: G> {
dims: u16,
v: Box<[Scalar]>,
v: Vec<S::Scalar>,
}
impl Vec2 {
impl<S: G> Vec2<S> {
pub fn new(dims: u16, n: usize) -> Self {
Self {
dims,
v: bytemuck::zeroed_slice_box(dims as usize * n),
v: bytemuck::zeroed_vec(dims as usize * n),
}
}
pub fn dims(&self) -> u16 {
@@ -32,29 +32,29 @@ impl Vec2 {
}
}
impl Index<usize> for Vec2 {
type Output = [Scalar];
impl<S: G> Index<usize> for Vec2<S> {
type Output = [S::Scalar];
fn index(&self, index: usize) -> &Self::Output {
&self.v[self.dims as usize * index..][..self.dims as usize]
}
}
impl IndexMut<usize> for Vec2 {
impl<S: G> IndexMut<usize> for Vec2<S> {
fn index_mut(&mut self, index: usize) -> &mut Self::Output {
&mut self.v[self.dims as usize * index..][..self.dims as usize]
}
}
impl Deref for Vec2 {
type Target = [Scalar];
impl<S: G> Deref for Vec2<S> {
type Target = [S::Scalar];
fn deref(&self) -> &Self::Target {
self.v.deref()
}
}
impl DerefMut for Vec2 {
impl<S: G> DerefMut for Vec2<S> {
fn deref_mut(&mut self) -> &mut Self::Target {
self.v.deref_mut()
}

View File

@@ -0,0 +1,248 @@
use crate::index::Index;
use crate::index::IndexOptions;
use crate::index::IndexStat;
use crate::index::IndexView;
use crate::index::OutdatedError;
use crate::prelude::*;
use std::path::PathBuf;
use std::sync::Arc;
#[derive(Clone)]
pub enum Instance {
F32Cos(Arc<Index<F32Cos>>),
F32Dot(Arc<Index<F32Dot>>),
F32L2(Arc<Index<F32L2>>),
F16Cos(Arc<Index<F16Cos>>),
F16Dot(Arc<Index<F16Dot>>),
F16L2(Arc<Index<F16L2>>),
}
impl Instance {
pub fn create(path: PathBuf, options: IndexOptions) -> Self {
match (options.vector.d, options.vector.k) {
(Distance::Cos, Kind::F32) => Self::F32Cos(Index::create(path, options)),
(Distance::Dot, Kind::F32) => Self::F32Dot(Index::create(path, options)),
(Distance::L2, Kind::F32) => Self::F32L2(Index::create(path, options)),
(Distance::Cos, Kind::F16) => Self::F16Cos(Index::create(path, options)),
(Distance::Dot, Kind::F16) => Self::F16Dot(Index::create(path, options)),
(Distance::L2, Kind::F16) => Self::F16L2(Index::create(path, options)),
}
}
pub fn open(path: PathBuf, options: IndexOptions) -> Self {
match (options.vector.d, options.vector.k) {
(Distance::Cos, Kind::F32) => Self::F32Cos(Index::open(path, options)),
(Distance::Dot, Kind::F32) => Self::F32Dot(Index::open(path, options)),
(Distance::L2, Kind::F32) => Self::F32L2(Index::open(path, options)),
(Distance::Cos, Kind::F16) => Self::F16Cos(Index::open(path, options)),
(Distance::Dot, Kind::F16) => Self::F16Dot(Index::open(path, options)),
(Distance::L2, Kind::F16) => Self::F16L2(Index::open(path, options)),
}
}
pub fn options(&self) -> &IndexOptions {
match self {
Instance::F32Cos(x) => x.options(),
Instance::F32Dot(x) => x.options(),
Instance::F32L2(x) => x.options(),
Instance::F16Cos(x) => x.options(),
Instance::F16Dot(x) => x.options(),
Instance::F16L2(x) => x.options(),
}
}
pub fn refresh(&self) {
match self {
Instance::F32Cos(x) => x.refresh(),
Instance::F32Dot(x) => x.refresh(),
Instance::F32L2(x) => x.refresh(),
Instance::F16Cos(x) => x.refresh(),
Instance::F16Dot(x) => x.refresh(),
Instance::F16L2(x) => x.refresh(),
}
}
pub fn view(&self) -> InstanceView {
match self {
Instance::F32Cos(x) => InstanceView::F32Cos(x.view()),
Instance::F32Dot(x) => InstanceView::F32Dot(x.view()),
Instance::F32L2(x) => InstanceView::F32L2(x.view()),
Instance::F16Cos(x) => InstanceView::F16Cos(x.view()),
Instance::F16Dot(x) => InstanceView::F16Dot(x.view()),
Instance::F16L2(x) => InstanceView::F16L2(x.view()),
}
}
pub fn stat(&self) -> IndexStat {
match self {
Instance::F32Cos(x) => x.stat(),
Instance::F32Dot(x) => x.stat(),
Instance::F32L2(x) => x.stat(),
Instance::F16Cos(x) => x.stat(),
Instance::F16Dot(x) => x.stat(),
Instance::F16L2(x) => x.stat(),
}
}
}
pub enum InstanceView {
F32Cos(Arc<IndexView<F32Cos>>),
F32Dot(Arc<IndexView<F32Dot>>),
F32L2(Arc<IndexView<F32L2>>),
F16Cos(Arc<IndexView<F16Cos>>),
F16Dot(Arc<IndexView<F16Dot>>),
F16L2(Arc<IndexView<F16L2>>),
}
impl InstanceView {
pub fn search<F: FnMut(Pointer) -> bool>(
&self,
k: usize,
vector: DynamicVector,
filter: F,
) -> Result<Vec<Pointer>, FriendlyError> {
match (self, vector) {
(InstanceView::F32Cos(x), DynamicVector::F32(vector)) => {
if x.options.vector.dims as usize != vector.len() {
return Err(FriendlyError::Unmatched2);
}
Ok(x.search(k, &vector, filter))
}
(InstanceView::F32Dot(x), DynamicVector::F32(vector)) => {
if x.options.vector.dims as usize != vector.len() {
return Err(FriendlyError::Unmatched2);
}
Ok(x.search(k, &vector, filter))
}
(InstanceView::F32L2(x), DynamicVector::F32(vector)) => {
if x.options.vector.dims as usize != vector.len() {
return Err(FriendlyError::Unmatched2);
}
Ok(x.search(k, &vector, filter))
}
(InstanceView::F16Cos(x), DynamicVector::F16(vector)) => {
if x.options.vector.dims as usize != vector.len() {
return Err(FriendlyError::Unmatched2);
}
Ok(x.search(k, &vector, filter))
}
(InstanceView::F16Dot(x), DynamicVector::F16(vector)) => {
if x.options.vector.dims as usize != vector.len() {
return Err(FriendlyError::Unmatched2);
}
Ok(x.search(k, &vector, filter))
}
(InstanceView::F16L2(x), DynamicVector::F16(vector)) => {
if x.options.vector.dims as usize != vector.len() {
return Err(FriendlyError::Unmatched2);
}
Ok(x.search(k, &vector, filter))
}
_ => Err(FriendlyError::Unmatched2),
}
}
pub fn vbase(
&self,
vector: DynamicVector,
) -> Result<impl Iterator<Item = Pointer> + '_, FriendlyError> {
match (self, vector) {
(InstanceView::F32Cos(x), DynamicVector::F32(vector)) => {
if x.options.vector.dims as usize != vector.len() {
return Err(FriendlyError::Unmatched2);
}
Ok(Box::new(x.vbase(&vector)) as Box<dyn Iterator<Item = Pointer>>)
}
(InstanceView::F32Dot(x), DynamicVector::F32(vector)) => {
if x.options.vector.dims as usize != vector.len() {
return Err(FriendlyError::Unmatched2);
}
Ok(Box::new(x.vbase(&vector)))
}
(InstanceView::F32L2(x), DynamicVector::F32(vector)) => {
if x.options.vector.dims as usize != vector.len() {
return Err(FriendlyError::Unmatched2);
}
Ok(Box::new(x.vbase(&vector)))
}
(InstanceView::F16Cos(x), DynamicVector::F16(vector)) => {
if x.options.vector.dims as usize != vector.len() {
return Err(FriendlyError::Unmatched2);
}
Ok(Box::new(x.vbase(&vector)))
}
(InstanceView::F16Dot(x), DynamicVector::F16(vector)) => {
if x.options.vector.dims as usize != vector.len() {
return Err(FriendlyError::Unmatched2);
}
Ok(Box::new(x.vbase(&vector)))
}
(InstanceView::F16L2(x), DynamicVector::F16(vector)) => {
if x.options.vector.dims as usize != vector.len() {
return Err(FriendlyError::Unmatched2);
}
Ok(Box::new(x.vbase(&vector)))
}
_ => Err(FriendlyError::Unmatched2),
}
}
pub fn insert(
&self,
vector: DynamicVector,
pointer: Pointer,
) -> Result<Result<(), OutdatedError>, FriendlyError> {
match (self, vector) {
(InstanceView::F32Cos(x), DynamicVector::F32(vector)) => {
if x.options.vector.dims as usize != vector.len() {
return Err(FriendlyError::Unmatched2);
}
Ok(x.insert(vector, pointer))
}
(InstanceView::F32Dot(x), DynamicVector::F32(vector)) => {
if x.options.vector.dims as usize != vector.len() {
return Err(FriendlyError::Unmatched2);
}
Ok(x.insert(vector, pointer))
}
(InstanceView::F32L2(x), DynamicVector::F32(vector)) => {
if x.options.vector.dims as usize != vector.len() {
return Err(FriendlyError::Unmatched2);
}
Ok(x.insert(vector, pointer))
}
(InstanceView::F16Cos(x), DynamicVector::F16(vector)) => {
if x.options.vector.dims as usize != vector.len() {
return Err(FriendlyError::Unmatched2);
}
Ok(x.insert(vector, pointer))
}
(InstanceView::F16Dot(x), DynamicVector::F16(vector)) => {
if x.options.vector.dims as usize != vector.len() {
return Err(FriendlyError::Unmatched2);
}
Ok(x.insert(vector, pointer))
}
(InstanceView::F16L2(x), DynamicVector::F16(vector)) => {
if x.options.vector.dims as usize != vector.len() {
return Err(FriendlyError::Unmatched2);
}
Ok(x.insert(vector, pointer))
}
_ => Err(FriendlyError::Unmatched2),
}
}
pub fn delete<F: FnMut(Pointer) -> bool>(&self, f: F) {
match self {
InstanceView::F32Cos(x) => x.delete(f),
InstanceView::F32Dot(x) => x.delete(f),
InstanceView::F32L2(x) => x.delete(f),
InstanceView::F16Cos(x) => x.delete(f),
InstanceView::F16Dot(x) => x.delete(f),
InstanceView::F16L2(x) => x.delete(f),
}
}
pub fn flush(&self) {
match self {
InstanceView::F32Cos(x) => x.flush(),
InstanceView::F32Dot(x) => x.flush(),
InstanceView::F32L2(x) => x.flush(),
InstanceView::F16Cos(x) => x.flush(),
InstanceView::F16Dot(x) => x.flush(),
InstanceView::F16L2(x) => x.flush(),
}
}
}

View File

@@ -1,7 +1,9 @@
use crate::index::Index;
use crate::index::IndexInsertError;
pub mod instance;
use self::instance::Instance;
use crate::index::IndexOptions;
use crate::index::IndexSearchError;
use crate::index::IndexStat;
use crate::index::OutdatedError;
use crate::prelude::*;
use crate::utils::clean::clean;
use crate::utils::dir_ops::sync_dir;
@@ -57,7 +59,7 @@ impl Worker {
let mut indexes = HashMap::new();
for (&id, options) in startup.get().indexes.iter() {
let path = path.join("indexes").join(id.to_string());
let index = Index::open(path, options.clone());
let index = Instance::open(path, options.clone());
indexes.insert(id, index);
}
let view = Arc::new(WorkerView {
@@ -72,7 +74,7 @@ impl Worker {
}
pub fn call_create(&self, id: Id, options: IndexOptions) {
let mut protect = self.protect.lock();
let index = Index::create(self.path.join("indexes").join(id.to_string()), options);
let index = Instance::create(self.path.join("indexes").join(id.to_string()), options);
if protect.indexes.insert(id, index).is_some() {
panic!("index {} already exists", id)
}
@@ -81,44 +83,29 @@ impl Worker {
pub fn call_search<F>(
&self,
id: Id,
search: (Vec<Scalar>, usize),
search: (DynamicVector, usize),
filter: F,
) -> Result<Vec<Pointer>, FriendlyError>
where
F: FnMut(Pointer) -> bool,
{
let view = self.view.load_full();
let index = view.indexes.get(&id).ok_or(FriendlyError::Index404)?;
let index = view.indexes.get(&id).ok_or(FriendlyError::UnknownIndex)?;
let view = index.view();
match view.search(search.1, &search.0, filter) {
Ok(x) => Ok(x),
Err(IndexSearchError::InvalidVector(x)) => Err(FriendlyError::BadVector(x)),
}
view.search(search.1, search.0, filter)
}
pub fn call_search_vbase<F>(
pub fn call_insert(
&self,
id: Id,
search: (Vec<Scalar>, usize),
next: F,
) -> Result<(), FriendlyError>
where
F: FnMut(Pointer) -> bool,
{
insert: (DynamicVector, Pointer),
) -> Result<(), FriendlyError> {
let view = self.view.load_full();
let index = view.indexes.get(&id).ok_or(FriendlyError::Index404)?;
let view = index.view();
view.search_vbase(search.1, &search.0, next)
.map_err(|IndexSearchError::InvalidVector(x)| FriendlyError::BadVector(x))
}
pub fn call_insert(&self, id: Id, insert: (Vec<Scalar>, Pointer)) -> Result<(), FriendlyError> {
let view = self.view.load_full();
let index = view.indexes.get(&id).ok_or(FriendlyError::Index404)?;
let index = view.indexes.get(&id).ok_or(FriendlyError::UnknownIndex)?;
loop {
let view = index.view();
match view.insert(insert.0.clone(), insert.1) {
match view.insert(insert.0.clone(), insert.1)? {
Ok(()) => break Ok(()),
Err(IndexInsertError::InvalidVector(x)) => break Err(FriendlyError::BadVector(x)),
Err(IndexInsertError::OutdatedView(_)) => index.refresh(),
Err(OutdatedError(_)) => index.refresh(),
}
}
}
@@ -127,16 +114,16 @@ impl Worker {
F: FnMut(Pointer) -> bool,
{
let view = self.view.load_full();
let index = view.indexes.get(&id).ok_or(FriendlyError::Index404)?;
let index = view.indexes.get(&id).ok_or(FriendlyError::UnknownIndex)?;
let view = index.view();
view.delete(f);
Ok(())
}
pub fn call_flush(&self, id: Id) -> Result<(), FriendlyError> {
let view = self.view.load_full();
let index = view.indexes.get(&id).ok_or(FriendlyError::Index404)?;
let index = view.indexes.get(&id).ok_or(FriendlyError::UnknownIndex)?;
let view = index.view();
view.flush().unwrap();
view.flush();
Ok(())
}
pub fn call_destory(&self, ids: Vec<Id>) {
@@ -149,44 +136,25 @@ impl Worker {
protect.maintain(&self.view);
}
}
pub fn call_stat(&self, id: Id) -> Result<VectorIndexInfo, FriendlyError> {
pub fn call_stat(&self, id: Id) -> Result<IndexStat, FriendlyError> {
let view = self.view.load_full();
let index = view.indexes.get(&id).ok_or(FriendlyError::Index404)?;
let view = index.view();
let idx_sealed_len = view.sealed_len();
let idx_growing_len = view.growing_len();
let idx_write = view.write_len();
let res = VectorIndexInfo {
indexing: index.indexing(),
idx_tuples: (idx_write + idx_sealed_len + idx_growing_len)
.try_into()
.unwrap(),
idx_sealed_len: idx_sealed_len.try_into().unwrap(),
idx_growing_len: idx_growing_len.try_into().unwrap(),
idx_write: idx_write.try_into().unwrap(),
idx_sealed: view
.sealed_len_vec()
.into_iter()
.map(|x| x.try_into().unwrap())
.collect(),
idx_growing: view
.growing_len_vec()
.into_iter()
.map(|x| x.try_into().unwrap())
.collect(),
idx_config: serde_json::to_string(index.options()).unwrap(),
};
Ok(res)
let index = view.indexes.get(&id).ok_or(FriendlyError::UnknownIndex)?;
Ok(index.stat())
}
pub fn get_instance(&self, id: Id) -> Result<Instance, FriendlyError> {
let view = self.view.load_full();
let index = view.indexes.get(&id).ok_or(FriendlyError::UnknownIndex)?;
Ok(index.clone())
}
}
struct WorkerView {
indexes: HashMap<Id, Arc<Index>>,
indexes: HashMap<Id, Instance>,
}
struct WorkerProtect {
startup: FileAtomic<WorkerStartup>,
indexes: HashMap<Id, Arc<Index>>,
indexes: HashMap<Id, Instance>,
}
impl WorkerProtect {

View File

@@ -11,7 +11,7 @@ Why not just use Postgres to do the vector similarity search? This is the reason
UPDATE documents SET embedding = ai_embedding_vector(content) WHERE length(embedding) = 0;
-- Create an index on the embedding column
CREATE INDEX ON documents USING vectors (embedding l2_ops);
CREATE INDEX ON documents USING vectors (embedding vector_l2_ops);
-- Query the similar embeddings
SELECT * FROM documents ORDER BY embedding <-> ai_embedding_vector('hello world') LIMIT 5;

View File

@@ -46,9 +46,9 @@ We support three operators to calculate the distance between two vectors.
-- squared Euclidean distance
SELECT '[1, 2, 3]'::vector <-> '[3, 2, 1]'::vector;
-- negative dot product
SELECT '[1, 2, 3]' <#> '[3, 2, 1]';
SELECT '[1, 2, 3]'::vector <#> '[3, 2, 1]'::vector;
-- negative cosine similarity
SELECT '[1, 2, 3]' <=> '[3, 2, 1]';
SELECT '[1, 2, 3]'::vector <=> '[3, 2, 1]'::vector;
```
You can search for a vector simply like this.
@@ -58,6 +58,10 @@ You can search for a vector simply like this.
SELECT * FROM items ORDER BY embedding <-> '[3,2,1]' LIMIT 5;
```
## Half-precision floating-point
`vecf16` type is the same with `vector` in anything but the scalar type. It stores 16-bit floating point numbers. If you want to reduce the memory usage to get better performace, you can try to replace `vector` type with `vecf16` type.
## Things You Need to Know
`vector(n)` is a valid data type only if $1 \leq n \leq 65535$. Due to limits of PostgreSQL, it's possible to create a value of type `vector(3)` of $5$ dimensions and `vector` is also a valid data. However, you cannot still put $0$ scalar or more than $65535$ scalars to a vector. If you use `vector` for a column or there is some values mismatched with dimension denoted by the column, you won't able to create an index on it.

View File

@@ -5,11 +5,19 @@ Indexing is the core ability of pgvecto.rs.
Assuming there is a table `items` and there is a column named `embedding` of type `vector(n)`, you can create a vector index for squared Euclidean distance with the following SQL.
```sql
CREATE INDEX ON items USING vectors (embedding l2_ops);
CREATE INDEX ON items USING vectors (embedding vector_l2_ops);
```
For negative dot product, replace `l2_ops` with `dot_ops`.
For negative cosine similarity, replace `l2_ops` with `cosine_ops`.
There is a table for you to choose a proper operator class for creating indexes.
| Vector type | Distance type | Operator class |
| ----------- | -------------------------- | -------------- |
| vector | squared Euclidean distance | vector_l2_ops |
| vector | negative dot product | vector_dot_ops |
| vector | negative cosine similarity | vector_cos_ops |
| vecf16 | squared Euclidean distance | vecf16_l2_ops |
| vecf16 | negative dot product | vecf16_dot_ops |
| vecf16 | negative cosine similarity | vecf16_cos_ops |
Now you can perform a KNN search with the following SQL again, but this time the vector index is used for searching.
@@ -36,14 +44,15 @@ Options for table `segment`.
| Key | Type | Description |
| ------------------------ | ------- | ------------------------------------------------------------------- |
| max_growing_segment_size | integer | Maximum size of unindexed vectors. Default value is `20_000`. |
| min_sealed_segment_size | integer | Minimum size of vectors for indexing. Default value is `1_000`. |
| max_sealed_segment_size | integer | Maximum size of vectors for indexing. Default value is `1_000_000`. |
Options for table `optimizing`.
| Key | Type | Description |
| ------------------ | ------- | --------------------------------------------------------------------------- |
| optimizing_threads | integer | Maximum threads for indexing. Default value is the sqrt of number of cores. |
| Key | Type | Description |
| ------------------ | ------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------- |
| optimizing_threads | integer | Maximum threads for indexing. Default value is the sqrt of number of cores. |
| sealing_secs | integer | If a writing segment larger than `sealing_size` do not accept new data for `sealing_secs` seconds, the writing segment will be turned to a sealed segment. |
| sealing_size | integer | See above. |
Options for table `indexing`.
@@ -99,23 +108,19 @@ Options for table `product`.
## Progress View
We also provide a view `pg_vector_index_info` to monitor the progress of indexing.
Note that whether idx_sealed_len is equal to idx_tuples doesn't relate to the completion of indexing.
It may do further optimization after indexing. It may also stop indexing because there are too few tuples left.
| Column | Type | Description |
| --------------- | ------ | --------------------------------------------- |
| tablerelid | oid | The oid of the table. |
| indexrelid | oid | The oid of the index. |
| tablename | name | The name of the table. |
| indexname | name | The name of the index. |
| indexing | bool | Whether the background thread is indexing. |
| idx_tuples | int4 | The number of tuples. |
| idx_sealed_len | int4 | The number of tuples in sealed segments. |
| idx_growing_len | int4 | The number of tuples in growing segments. |
| idx_write | int4 | The number of tuples in write buffer. |
| idx_sealed | int4[] | The number of tuples in each sealed segment. |
| idx_growing | int4[] | The number of tuples in each growing segment. |
| idx_config | text | The configuration of the index. |
| Column | Type | Description |
| ------------ | ------ | --------------------------------------------- |
| tablerelid | oid | The oid of the table. |
| indexrelid | oid | The oid of the index. |
| tablename | name | The name of the table. |
| indexname | name | The name of the index. |
| idx_indexing | bool | Whether the background thread is indexing. |
| idx_tuples | int8 | The number of tuples. |
| idx_sealed | int8[] | The number of tuples in each sealed segment. |
| idx_growing | int8[] | The number of tuples in each growing segment. |
| idx_write | int8 | The number of tuples in write buffer. |
| idx_config | text | The configuration of the index. |
## Examples
@@ -124,11 +129,11 @@ There are some examples.
```sql
-- HNSW algorithm, default settings.
CREATE INDEX ON items USING vectors (embedding l2_ops);
CREATE INDEX ON items USING vectors (embedding vector_l2_ops);
--- Or using bruteforce with PQ.
CREATE INDEX ON items USING vectors (embedding l2_ops)
CREATE INDEX ON items USING vectors (embedding vector_l2_ops)
WITH (options = $$
[indexing.flat]
quantization.product.ratio = "x16"
@@ -136,7 +141,7 @@ $$);
--- Or using IVFPQ algorithm.
CREATE INDEX ON items USING vectors (embedding l2_ops)
CREATE INDEX ON items USING vectors (embedding vector_l2_ops)
WITH (options = $$
[indexing.ivf]
quantization.product.ratio = "x16"
@@ -144,14 +149,14 @@ $$);
-- Use more threads for background building the index.
CREATE INDEX ON items USING vectors (embedding l2_ops)
CREATE INDEX ON items USING vectors (embedding vector_l2_ops)
WITH (options = $$
optimizing.optimizing_threads = 16
$$);
-- Prefer smaller HNSW graph.
CREATE INDEX ON items USING vectors (embedding l2_ops)
CREATE INDEX ON items USING vectors (embedding vector_l2_ops)
WITH (options = $$
segment.max_growing_segment_size = 200000
$$);

View File

@@ -19,24 +19,49 @@ To acheive full performance, please mount the volume to pg data directory by add
You can configure PostgreSQL by the reference of the parent image in https://hub.docker.com/_/postgres/.
## Build from source
## Install from source
Install Rust and base dependency.
```sh
sudo apt install -y build-essential libpq-dev libssl-dev pkg-config gcc libreadline-dev flex bison libxml2-dev libxslt-dev libxml2-utils xsltproc zlib1g-dev ccache clang git
sudo apt install -y \
build-essential \
libpq-dev \
libssl-dev \
pkg-config \
gcc \
libreadline-dev \
flex \
bison \
libxml2-dev \
libxslt-dev \
libxml2-utils \
xsltproc \
zlib1g-dev \
ccache \
clang \
git
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh
```
Install PostgreSQL.
```sh
sudo sh -c 'echo "deb http://apt.postgresql.org/pub/repos/apt $(lsb_release -cs)-pgdg main" > /etc/apt/sources.list.d/pgdg.list'
sudo sh -c 'echo "deb http://apt.postgresql.org/pub/repos/apt $(lsb_release -cs)-pgdg main" >> /etc/apt/sources.list.d/pgdg.list'
wget --quiet -O - https://www.postgresql.org/media/keys/ACCC4CF8.asc | sudo apt-key add -
sudo apt-get update
sudo apt-get -y install libpq-dev postgresql-15 postgresql-server-dev-15
```
Install clang-16.
```sh
sudo sh -c 'echo "deb http://apt.llvm.org/$(lsb_release -cs)/ llvm-toolchain-$(lsb_release -cs)-16 main" >> /etc/apt/sources.list'
wget --quiet -O - https://apt.llvm.org/llvm-snapshot.gpg.key | sudo apt-key add -
sudo apt-get update
sudo apt-get -y install clang-16
```
Clone the Repository.
```sh
@@ -54,7 +79,7 @@ cargo pgrx init --pg15=/usr/lib/postgresql/15/bin/pg_config
Install pgvecto.rs.
```sh
cargo pgrx install --release
cargo pgrx install --sudo --release
```
Configure your PostgreSQL by modifying the `shared_preload_libraries` to include `vectors.so`.

View File

@@ -15,11 +15,12 @@ If `vectors.k` is set to `64`, but your SQL returned less than `64` rows, for ex
* The vector index returned `64` rows, but `32` of which are invisble to the transaction so PostgreSQL decided to hide these rows for you.
* The vector index returned `64` rows, but `32` of which are satifying the condition `id % 2 = 0` in `WHERE` clause.
There are three ways to solve the problem:
There are four ways to solve the problem:
* Set `vectors.k` larger. If you estimate that 20% of rows will satisfy the condition in `WHERE`, just set `vectors.k` to be 5 times than before.
* Set `vectors.enable_vector_index` to `off`. If you estimate that 0.0001% of rows will satisfy the condition in `WHERE`, just do not use vector index. No alogrithms will be faster than brute force by PostgreSQL.
* Set `vectors.enable_prefilter` to `on`. If you cannot estimate how many rows will satisfy the condition in `WHERE`, leave the job for the index. The index will check if the returned row can be accepted by PostgreSQL. However, it will make queries slower so the default value for this option is `off`.
* Set `vectors.enable_vbase` to `on`. It will use vbase optimization, so that the index will pull rows as many as you need. It only works for HNSW algorithm.
## Options
@@ -30,6 +31,6 @@ Search options are specified by PostgreSQL GUC. You can use `SET` command to app
| vectors.k | integer | Expected number of candidates returned by index. The parameter will influence the recall if you use HNSW or quantization for indexing. Default value is `64`. |
| vectors.enable_prefilter | boolean | Enable prefiltering or not. Default value is `off`. |
| vectors.enable_vector_index | boolean | Enable vector indexes or not. This option is for debugging. Default value is `on`. |
| vectors.vbase_range | int4 | The range size when using vbase optimization. When it is set to `0`, vbase optimization will be disabled. A recommended value is `86`. Default value is `0`. |
| vectors.enable_vbase | boolean | Enable vbase optimization. Default value is `off`. |
Note: When `vectors.vbase_range` is enabled, it will ignore `vectors.enable_prefilter`.
Note: When `vectors.enable_vbase` is enabled, prefilter does not work.

View File

@@ -6,10 +6,14 @@ if [ "$OS" == "ubuntu-latest" ]; then
sudo pg_dropcluster 14 main
fi
sudo apt-get remove -y '^postgres.*' '^libpq.*' '^clang.*' '^llvm.*' '^libclang.*' '^libllvm.*' '^mono-llvm.*'
sudo sh -c 'echo "deb http://apt.postgresql.org/pub/repos/apt $(lsb_release -cs)-pgdg main" > /etc/apt/sources.list.d/pgdg.list'
sudo sh -c 'echo "deb http://apt.postgresql.org/pub/repos/apt $(lsb_release -cs)-pgdg main" >> /etc/apt/sources.list.d/pgdg.list'
sudo sh -c 'echo "deb http://apt.llvm.org/$(lsb_release -cs)/ llvm-toolchain-$(lsb_release -cs)-16 main" >> /etc/apt/sources.list'
wget --quiet -O - https://www.postgresql.org/media/keys/ACCC4CF8.asc | sudo apt-key add -
wget --quiet -O - https://apt.llvm.org/llvm-snapshot.gpg.key | sudo apt-key add -
sudo apt-get update
sudo apt-get -y install build-essential libpq-dev postgresql-$VERSION postgresql-server-dev-$VERSION
sudo apt-get -y install clang-16
sudo apt-get -y install crossbuild-essential-arm64
echo "local all all trust" | sudo tee /etc/postgresql/$VERSION/main/pg_hba.conf
echo "host all all 127.0.0.1/32 trust" | sudo tee -a /etc/postgresql/$VERSION/main/pg_hba.conf
echo "host all all ::1/128 trust" | sudo tee -a /etc/postgresql/$VERSION/main/pg_hba.conf

View File

@@ -1 +0,0 @@
pub mod vamana;

View File

@@ -1,11 +1,26 @@
pub mod worker;
use self::worker::Worker;
use crate::ipc::server::RpcHandler;
use crate::ipc::IpcError;
use service::worker::Worker;
use std::path::{Path, PathBuf};
use std::sync::Arc;
pub unsafe fn init() {
use pgrx::bgworkers::BackgroundWorkerBuilder;
use pgrx::bgworkers::BgWorkerStartTime;
BackgroundWorkerBuilder::new("vectors")
.set_function("vectors_main")
.set_library("vectors")
.set_argument(None)
.enable_shmem_access(None)
.set_start_time(BgWorkerStartTime::PostmasterStart)
.load();
}
#[no_mangle]
extern "C" fn vectors_main(_arg: pgrx::pg_sys::Datum) {
let _ = std::panic::catch_unwind(crate::bgworker::main);
}
pub fn main() {
{
let mut builder = env_logger::builder();
@@ -109,10 +124,6 @@ fn session(worker: Arc<Worker>, mut handler: RpcHandler) -> Result<(), IpcError>
handler = x.leave(res)?;
}
}
RpcHandle::SearchVbase { id, search, mut x } => {
let res = worker.call_search_vbase(id, search, |p| x.next(p).unwrap());
handler = x.leave(res)?;
}
RpcHandle::Flush { id, x } => {
let result = worker.call_flush(id);
handler = x.leave(result)?;
@@ -125,11 +136,36 @@ fn session(worker: Arc<Worker>, mut handler: RpcHandler) -> Result<(), IpcError>
let result = worker.call_stat(id);
handler = x.leave(result)?;
}
RpcHandle::Leave {} => {
log::debug!("Handle leave rpc.");
break;
RpcHandle::Vbase { id, vector, x } => {
use crate::ipc::server::VbaseHandle::*;
let instance = match worker.get_instance(id) {
Ok(x) => x,
Err(e) => {
x.error(Err(e))?;
break Ok(());
}
};
let view = instance.view();
let mut it = match view.vbase(vector) {
Ok(x) => x,
Err(e) => {
x.error(Err(e))?;
break Ok(());
}
};
let mut x = x.error(Ok(()))?;
loop {
match x.handle()? {
Next { x: y } => {
x = y.leave(it.next())?;
}
Leave { x } => {
handler = x;
break;
}
}
}
}
}
}
Ok(())
}

26
src/datatype/casts_f32.rs Normal file
View File

@@ -0,0 +1,26 @@
use crate::datatype::typmod::Typmod;
use crate::datatype::vecf32::{Vecf32, Vecf32Input, Vecf32Output};
use service::prelude::*;
#[pgrx::pg_extern(immutable, parallel_safe, strict)]
fn vecf32_cast_array_to_vector(
array: pgrx::Array<f32>,
typmod: i32,
_explicit: bool,
) -> Vecf32Output {
assert!(!array.is_empty());
assert!(array.len() <= 65535);
assert!(!array.contains_nulls());
let typmod = Typmod::parse_from_i32(typmod).unwrap();
let len = typmod.dims().unwrap_or(array.len() as u16);
let mut data = vec![F32::zero(); len as usize];
for (i, x) in array.iter().enumerate() {
data[i] = F32(x.unwrap_or(f32::NAN));
}
Vecf32::new_in_postgres(&data)
}
#[pgrx::pg_extern(immutable, parallel_safe, strict)]
fn vecf32_cast_vector_to_array(vector: Vecf32Input<'_>, _typmod: i32, _explicit: bool) -> Vec<f32> {
vector.data().iter().map(|x| x.to_f32()).collect()
}

6
src/datatype/mod.rs Normal file
View File

@@ -0,0 +1,6 @@
pub mod casts_f32;
pub mod operators_f16;
pub mod operators_f32;
pub mod typmod;
pub mod vecf16;
pub mod vecf32;

View File

@@ -1,53 +1,53 @@
use crate::postgres::datatype::{Vector, VectorInput, VectorOutput};
use crate::prelude::*;
use crate::datatype::vecf16::{Vecf16, Vecf16Input, Vecf16Output};
use service::prelude::*;
use std::ops::Deref;
#[pgrx::pg_operator(immutable, parallel_safe, requires = ["vector"])]
#[pgrx::pg_operator(immutable, parallel_safe, requires = ["vecf16"])]
#[pgrx::opname(+)]
#[pgrx::commutator(+)]
fn operator_add(lhs: VectorInput<'_>, rhs: VectorInput<'_>) -> VectorOutput {
fn vecf16_operator_add(lhs: Vecf16Input<'_>, rhs: Vecf16Input<'_>) -> Vecf16Output {
if lhs.len() != rhs.len() {
FriendlyError::DifferentVectorDims {
FriendlyError::Unmatched {
left_dimensions: lhs.len() as _,
right_dimensions: rhs.len() as _,
}
.friendly();
}
let n = lhs.len();
let mut v = Vector::new_zeroed(n);
let mut v = vec![F16::zero(); n];
for i in 0..n {
v[i] = lhs[i] + rhs[i];
}
v.copy_into_postgres()
Vecf16::new_in_postgres(&v)
}
#[pgrx::pg_operator(immutable, parallel_safe, requires = ["vector"])]
#[pgrx::pg_operator(immutable, parallel_safe, requires = ["vecf16"])]
#[pgrx::opname(-)]
fn operator_minus(lhs: VectorInput<'_>, rhs: VectorInput<'_>) -> VectorOutput {
fn vecf16_operator_minus(lhs: Vecf16Input<'_>, rhs: Vecf16Input<'_>) -> Vecf16Output {
if lhs.len() != rhs.len() {
FriendlyError::DifferentVectorDims {
FriendlyError::Unmatched {
left_dimensions: lhs.len() as _,
right_dimensions: rhs.len() as _,
}
.friendly();
}
let n = lhs.len();
let mut v = Vector::new_zeroed(n);
let mut v = vec![F16::zero(); n];
for i in 0..n {
v[i] = lhs[i] - rhs[i];
}
v.copy_into_postgres()
Vecf16::new_in_postgres(&v)
}
#[pgrx::pg_operator(immutable, parallel_safe, requires = ["vector"])]
#[pgrx::pg_operator(immutable, parallel_safe, requires = ["vecf16"])]
#[pgrx::opname(<)]
#[pgrx::negator(>=)]
#[pgrx::commutator(>)]
#[pgrx::restrict(scalarltsel)]
#[pgrx::join(scalarltjoinsel)]
fn operator_lt(lhs: VectorInput<'_>, rhs: VectorInput<'_>) -> bool {
fn vecf16_operator_lt(lhs: Vecf16Input<'_>, rhs: Vecf16Input<'_>) -> bool {
if lhs.len() != rhs.len() {
FriendlyError::DifferentVectorDims {
FriendlyError::Unmatched {
left_dimensions: lhs.len() as _,
right_dimensions: rhs.len() as _,
}
@@ -56,15 +56,15 @@ fn operator_lt(lhs: VectorInput<'_>, rhs: VectorInput<'_>) -> bool {
lhs.deref() < rhs.deref()
}
#[pgrx::pg_operator(immutable, parallel_safe, requires = ["vector"])]
#[pgrx::pg_operator(immutable, parallel_safe, requires = ["vecf16"])]
#[pgrx::opname(<=)]
#[pgrx::negator(>)]
#[pgrx::commutator(>=)]
#[pgrx::restrict(scalarltsel)]
#[pgrx::join(scalarltjoinsel)]
fn operator_lte(lhs: VectorInput<'_>, rhs: VectorInput<'_>) -> bool {
fn vecf16_operator_lte(lhs: Vecf16Input<'_>, rhs: Vecf16Input<'_>) -> bool {
if lhs.len() != rhs.len() {
FriendlyError::DifferentVectorDims {
FriendlyError::Unmatched {
left_dimensions: lhs.len() as _,
right_dimensions: rhs.len() as _,
}
@@ -73,15 +73,15 @@ fn operator_lte(lhs: VectorInput<'_>, rhs: VectorInput<'_>) -> bool {
lhs.deref() <= rhs.deref()
}
#[pgrx::pg_operator(immutable, parallel_safe, requires = ["vector"])]
#[pgrx::pg_operator(immutable, parallel_safe, requires = ["vecf16"])]
#[pgrx::opname(>)]
#[pgrx::negator(<=)]
#[pgrx::commutator(<)]
#[pgrx::restrict(scalargtsel)]
#[pgrx::join(scalargtjoinsel)]
fn operator_gt(lhs: VectorInput<'_>, rhs: VectorInput<'_>) -> bool {
fn vecf16_operator_gt(lhs: Vecf16Input<'_>, rhs: Vecf16Input<'_>) -> bool {
if lhs.len() != rhs.len() {
FriendlyError::DifferentVectorDims {
FriendlyError::Unmatched {
left_dimensions: lhs.len() as _,
right_dimensions: rhs.len() as _,
}
@@ -90,15 +90,15 @@ fn operator_gt(lhs: VectorInput<'_>, rhs: VectorInput<'_>) -> bool {
lhs.deref() > rhs.deref()
}
#[pgrx::pg_operator(immutable, parallel_safe, requires = ["vector"])]
#[pgrx::pg_operator(immutable, parallel_safe, requires = ["vecf16"])]
#[pgrx::opname(>=)]
#[pgrx::negator(<)]
#[pgrx::commutator(<=)]
#[pgrx::restrict(scalargtsel)]
#[pgrx::join(scalargtjoinsel)]
fn operator_gte(lhs: VectorInput<'_>, rhs: VectorInput<'_>) -> bool {
fn vecf16_operator_gte(lhs: Vecf16Input<'_>, rhs: Vecf16Input<'_>) -> bool {
if lhs.len() != rhs.len() {
FriendlyError::DifferentVectorDims {
FriendlyError::Unmatched {
left_dimensions: lhs.len() as _,
right_dimensions: rhs.len() as _,
}
@@ -107,15 +107,15 @@ fn operator_gte(lhs: VectorInput<'_>, rhs: VectorInput<'_>) -> bool {
lhs.deref() >= rhs.deref()
}
#[pgrx::pg_operator(immutable, parallel_safe, requires = ["vector"])]
#[pgrx::pg_operator(immutable, parallel_safe, requires = ["vecf16"])]
#[pgrx::opname(=)]
#[pgrx::negator(<>)]
#[pgrx::commutator(=)]
#[pgrx::restrict(eqsel)]
#[pgrx::join(eqjoinsel)]
fn operator_eq(lhs: VectorInput<'_>, rhs: VectorInput<'_>) -> bool {
fn vecf16_operator_eq(lhs: Vecf16Input<'_>, rhs: Vecf16Input<'_>) -> bool {
if lhs.len() != rhs.len() {
FriendlyError::DifferentVectorDims {
FriendlyError::Unmatched {
left_dimensions: lhs.len() as _,
right_dimensions: rhs.len() as _,
}
@@ -124,15 +124,15 @@ fn operator_eq(lhs: VectorInput<'_>, rhs: VectorInput<'_>) -> bool {
lhs.deref() == rhs.deref()
}
#[pgrx::pg_operator(immutable, parallel_safe, requires = ["vector"])]
#[pgrx::pg_operator(immutable, parallel_safe, requires = ["vecf16"])]
#[pgrx::opname(<>)]
#[pgrx::negator(=)]
#[pgrx::commutator(<>)]
#[pgrx::restrict(eqsel)]
#[pgrx::join(eqjoinsel)]
fn operator_neq(lhs: VectorInput<'_>, rhs: VectorInput<'_>) -> bool {
fn vecf16_operator_neq(lhs: Vecf16Input<'_>, rhs: Vecf16Input<'_>) -> bool {
if lhs.len() != rhs.len() {
FriendlyError::DifferentVectorDims {
FriendlyError::Unmatched {
left_dimensions: lhs.len() as _,
right_dimensions: rhs.len() as _,
}
@@ -141,44 +141,44 @@ fn operator_neq(lhs: VectorInput<'_>, rhs: VectorInput<'_>) -> bool {
lhs.deref() != rhs.deref()
}
#[pgrx::pg_operator(immutable, parallel_safe, requires = ["vector"])]
#[pgrx::pg_operator(immutable, parallel_safe, requires = ["vecf16"])]
#[pgrx::opname(<=>)]
#[pgrx::commutator(<=>)]
fn operator_cosine(lhs: VectorInput<'_>, rhs: VectorInput<'_>) -> Scalar {
fn vecf16_operator_cosine(lhs: Vecf16Input<'_>, rhs: Vecf16Input<'_>) -> f32 {
if lhs.len() != rhs.len() {
FriendlyError::DifferentVectorDims {
FriendlyError::Unmatched {
left_dimensions: lhs.len() as _,
right_dimensions: rhs.len() as _,
}
.friendly();
}
Distance::Cosine.distance(&lhs, &rhs)
F16Cos::distance(&lhs, &rhs).to_f32()
}
#[pgrx::pg_operator(immutable, parallel_safe, requires = ["vector"])]
#[pgrx::pg_operator(immutable, parallel_safe, requires = ["vecf16"])]
#[pgrx::opname(<#>)]
#[pgrx::commutator(<#>)]
fn operator_dot(lhs: VectorInput<'_>, rhs: VectorInput<'_>) -> Scalar {
fn vecf16_operator_dot(lhs: Vecf16Input<'_>, rhs: Vecf16Input<'_>) -> f32 {
if lhs.len() != rhs.len() {
FriendlyError::DifferentVectorDims {
FriendlyError::Unmatched {
left_dimensions: lhs.len() as _,
right_dimensions: rhs.len() as _,
}
.friendly();
}
Distance::Dot.distance(&lhs, &rhs)
F16Dot::distance(&lhs, &rhs).to_f32()
}
#[pgrx::pg_operator(immutable, parallel_safe, requires = ["vector"])]
#[pgrx::pg_operator(immutable, parallel_safe, requires = ["vecf16"])]
#[pgrx::opname(<->)]
#[pgrx::commutator(<->)]
fn operator_l2(lhs: VectorInput<'_>, rhs: VectorInput<'_>) -> Scalar {
fn vecf16_operator_l2(lhs: Vecf16Input<'_>, rhs: Vecf16Input<'_>) -> f32 {
if lhs.len() != rhs.len() {
FriendlyError::DifferentVectorDims {
FriendlyError::Unmatched {
left_dimensions: lhs.len() as _,
right_dimensions: rhs.len() as _,
}
.friendly();
}
Distance::L2.distance(&lhs, &rhs)
F16L2::distance(&lhs, &rhs).to_f32()
}

View File

@@ -0,0 +1,184 @@
use crate::datatype::vecf32::{Vecf32, Vecf32Input, Vecf32Output};
use service::prelude::*;
use std::ops::Deref;
#[pgrx::pg_operator(immutable, parallel_safe, requires = ["vecf32"])]
#[pgrx::opname(+)]
#[pgrx::commutator(+)]
fn vecf32_operator_add(lhs: Vecf32Input<'_>, rhs: Vecf32Input<'_>) -> Vecf32Output {
if lhs.len() != rhs.len() {
FriendlyError::Unmatched {
left_dimensions: lhs.len() as _,
right_dimensions: rhs.len() as _,
}
.friendly();
}
let n = lhs.len();
let mut v = vec![F32::zero(); n];
for i in 0..n {
v[i] = lhs[i] + rhs[i];
}
Vecf32::new_in_postgres(&v)
}
#[pgrx::pg_operator(immutable, parallel_safe, requires = ["vecf32"])]
#[pgrx::opname(-)]
fn vecf32_operator_minus(lhs: Vecf32Input<'_>, rhs: Vecf32Input<'_>) -> Vecf32Output {
if lhs.len() != rhs.len() {
FriendlyError::Unmatched {
left_dimensions: lhs.len() as _,
right_dimensions: rhs.len() as _,
}
.friendly();
}
let n = lhs.len();
let mut v = vec![F32::zero(); n];
for i in 0..n {
v[i] = lhs[i] - rhs[i];
}
Vecf32::new_in_postgres(&v)
}
#[pgrx::pg_operator(immutable, parallel_safe, requires = ["vecf32"])]
#[pgrx::opname(<)]
#[pgrx::negator(>=)]
#[pgrx::commutator(>)]
#[pgrx::restrict(scalarltsel)]
#[pgrx::join(scalarltjoinsel)]
fn vecf32_operator_lt(lhs: Vecf32Input<'_>, rhs: Vecf32Input<'_>) -> bool {
if lhs.len() != rhs.len() {
FriendlyError::Unmatched {
left_dimensions: lhs.len() as _,
right_dimensions: rhs.len() as _,
}
.friendly();
}
lhs.deref() < rhs.deref()
}
#[pgrx::pg_operator(immutable, parallel_safe, requires = ["vecf32"])]
#[pgrx::opname(<=)]
#[pgrx::negator(>)]
#[pgrx::commutator(>=)]
#[pgrx::restrict(scalarltsel)]
#[pgrx::join(scalarltjoinsel)]
fn vecf32_operator_lte(lhs: Vecf32Input<'_>, rhs: Vecf32Input<'_>) -> bool {
if lhs.len() != rhs.len() {
FriendlyError::Unmatched {
left_dimensions: lhs.len() as _,
right_dimensions: rhs.len() as _,
}
.friendly();
}
lhs.deref() <= rhs.deref()
}
#[pgrx::pg_operator(immutable, parallel_safe, requires = ["vecf32"])]
#[pgrx::opname(>)]
#[pgrx::negator(<=)]
#[pgrx::commutator(<)]
#[pgrx::restrict(scalargtsel)]
#[pgrx::join(scalargtjoinsel)]
fn vecf32_operator_gt(lhs: Vecf32Input<'_>, rhs: Vecf32Input<'_>) -> bool {
if lhs.len() != rhs.len() {
FriendlyError::Unmatched {
left_dimensions: lhs.len() as _,
right_dimensions: rhs.len() as _,
}
.friendly();
}
lhs.deref() > rhs.deref()
}
#[pgrx::pg_operator(immutable, parallel_safe, requires = ["vecf32"])]
#[pgrx::opname(>=)]
#[pgrx::negator(<)]
#[pgrx::commutator(<=)]
#[pgrx::restrict(scalargtsel)]
#[pgrx::join(scalargtjoinsel)]
fn vecf32_operator_gte(lhs: Vecf32Input<'_>, rhs: Vecf32Input<'_>) -> bool {
if lhs.len() != rhs.len() {
FriendlyError::Unmatched {
left_dimensions: lhs.len() as _,
right_dimensions: rhs.len() as _,
}
.friendly();
}
lhs.deref() >= rhs.deref()
}
#[pgrx::pg_operator(immutable, parallel_safe, requires = ["vecf32"])]
#[pgrx::opname(=)]
#[pgrx::negator(<>)]
#[pgrx::commutator(=)]
#[pgrx::restrict(eqsel)]
#[pgrx::join(eqjoinsel)]
fn vecf32_operator_eq(lhs: Vecf32Input<'_>, rhs: Vecf32Input<'_>) -> bool {
if lhs.len() != rhs.len() {
FriendlyError::Unmatched {
left_dimensions: lhs.len() as _,
right_dimensions: rhs.len() as _,
}
.friendly();
}
lhs.deref() == rhs.deref()
}
#[pgrx::pg_operator(immutable, parallel_safe, requires = ["vecf32"])]
#[pgrx::opname(<>)]
#[pgrx::negator(=)]
#[pgrx::commutator(<>)]
#[pgrx::restrict(eqsel)]
#[pgrx::join(eqjoinsel)]
fn vecf32_operator_neq(lhs: Vecf32Input<'_>, rhs: Vecf32Input<'_>) -> bool {
if lhs.len() != rhs.len() {
FriendlyError::Unmatched {
left_dimensions: lhs.len() as _,
right_dimensions: rhs.len() as _,
}
.friendly();
}
lhs.deref() != rhs.deref()
}
#[pgrx::pg_operator(immutable, parallel_safe, requires = ["vecf32"])]
#[pgrx::opname(<=>)]
#[pgrx::commutator(<=>)]
fn vecf32_operator_cosine(lhs: Vecf32Input<'_>, rhs: Vecf32Input<'_>) -> f32 {
if lhs.len() != rhs.len() {
FriendlyError::Unmatched {
left_dimensions: lhs.len() as _,
right_dimensions: rhs.len() as _,
}
.friendly();
}
F32Cos::distance(&lhs, &rhs).to_f32()
}
#[pgrx::pg_operator(immutable, parallel_safe, requires = ["vecf32"])]
#[pgrx::opname(<#>)]
#[pgrx::commutator(<#>)]
fn vecf32_operator_dot(lhs: Vecf32Input<'_>, rhs: Vecf32Input<'_>) -> f32 {
if lhs.len() != rhs.len() {
FriendlyError::Unmatched {
left_dimensions: lhs.len() as _,
right_dimensions: rhs.len() as _,
}
.friendly();
}
F32Dot::distance(&lhs, &rhs).to_f32()
}
#[pgrx::pg_operator(immutable, parallel_safe, requires = ["vecf32"])]
#[pgrx::opname(<->)]
#[pgrx::commutator(<->)]
fn vecf32_operator_l2(lhs: Vecf32Input<'_>, rhs: Vecf32Input<'_>) -> f32 {
if lhs.len() != rhs.len() {
FriendlyError::Unmatched {
left_dimensions: lhs.len() as _,
right_dimensions: rhs.len() as _,
}
.friendly();
}
F32L2::distance(&lhs, &rhs).to_f32()
}

77
src/datatype/typmod.rs Normal file
View File

@@ -0,0 +1,77 @@
use pgrx::Array;
use serde::{Deserialize, Serialize};
use service::prelude::*;
use std::ffi::{CStr, CString};
use std::num::NonZeroU16;
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub enum Typmod {
Any,
Dims(NonZeroU16),
}
impl Typmod {
pub fn parse_from_str(s: &str) -> Option<Self> {
use Typmod::*;
if let Ok(x) = s.parse::<NonZeroU16>() {
Some(Dims(x))
} else {
None
}
}
pub fn parse_from_i32(x: i32) -> Option<Self> {
use Typmod::*;
if x == -1 {
Some(Any)
} else if 1 <= x && x <= u16::MAX as i32 {
Some(Dims(NonZeroU16::new(x as u16).unwrap()))
} else {
None
}
}
pub fn into_option_string(self) -> Option<String> {
use Typmod::*;
match self {
Any => None,
Dims(x) => Some(i32::from(x.get()).to_string()),
}
}
pub fn into_i32(self) -> i32 {
use Typmod::*;
match self {
Any => -1,
Dims(x) => i32::from(x.get()),
}
}
pub fn dims(self) -> Option<u16> {
use Typmod::*;
match self {
Any => None,
Dims(dims) => Some(dims.get()),
}
}
}
#[pgrx::pg_extern(immutable, parallel_safe, strict)]
fn typmod_in(list: Array<&CStr>) -> i32 {
if list.is_empty() {
-1
} else if list.len() == 1 {
let s = list.get(0).unwrap().unwrap().to_str().unwrap();
let typmod = Typmod::parse_from_str(s)
.ok_or(FriendlyError::BadTypeDimensions)
.friendly();
typmod.into_i32()
} else {
FriendlyError::BadTypeDimensions.friendly();
}
}
#[pgrx::pg_extern(immutable, parallel_safe, strict)]
fn typmod_out(typmod: i32) -> CString {
let typmod = Typmod::parse_from_i32(typmod).unwrap();
match typmod.into_option_string() {
Some(s) => CString::new(format!("({})", s)).unwrap(),
None => CString::new("()").unwrap(),
}
}

343
src/datatype/vecf16.rs Normal file
View File

@@ -0,0 +1,343 @@
use crate::datatype::typmod::Typmod;
use pgrx::pg_sys::Datum;
use pgrx::pg_sys::Oid;
use pgrx::pgrx_sql_entity_graph::metadata::ArgumentError;
use pgrx::pgrx_sql_entity_graph::metadata::Returns;
use pgrx::pgrx_sql_entity_graph::metadata::ReturnsError;
use pgrx::pgrx_sql_entity_graph::metadata::SqlMapping;
use pgrx::pgrx_sql_entity_graph::metadata::SqlTranslatable;
use pgrx::FromDatum;
use pgrx::IntoDatum;
use service::prelude::*;
use std::alloc::Layout;
use std::cmp::Ordering;
use std::ffi::CStr;
use std::ffi::CString;
use std::ops::Deref;
use std::ops::DerefMut;
use std::ops::Index;
use std::ops::IndexMut;
use std::ptr::NonNull;
pgrx::extension_sql!(
r#"
CREATE TYPE vecf16 (
INPUT = vecf16_in,
OUTPUT = vecf16_out,
TYPMOD_IN = typmod_in,
TYPMOD_OUT = typmod_out,
STORAGE = EXTENDED,
INTERNALLENGTH = VARIABLE,
ALIGNMENT = double
);
"#,
name = "vecf16",
creates = [Type(Vecf16)],
requires = [vecf16_in, vecf16_out, typmod_in, typmod_out],
);
#[repr(C, align(8))]
pub struct Vecf16 {
varlena: u32,
len: u16,
kind: u8,
reserved: u8,
phantom: [F16; 0],
}
impl Vecf16 {
fn varlena(size: usize) -> u32 {
(size << 2) as u32
}
fn layout(len: usize) -> Layout {
u16::try_from(len).expect("Vector is too large.");
let layout_alpha = Layout::new::<Vecf16>();
let layout_beta = Layout::array::<F16>(len).unwrap();
let layout = layout_alpha.extend(layout_beta).unwrap().0;
layout.pad_to_align()
}
pub fn new_in_postgres(slice: &[F16]) -> Vecf16Output {
unsafe {
assert!(u16::try_from(slice.len()).is_ok());
let layout = Vecf16::layout(slice.len());
let ptr = pgrx::pg_sys::palloc(layout.size()) as *mut Vecf16;
ptr.cast::<u8>().add(layout.size() - 8).write_bytes(0, 8);
std::ptr::addr_of_mut!((*ptr).varlena).write(Vecf16::varlena(layout.size()));
std::ptr::addr_of_mut!((*ptr).kind).write(1);
std::ptr::addr_of_mut!((*ptr).reserved).write(0);
std::ptr::addr_of_mut!((*ptr).len).write(slice.len() as u16);
std::ptr::copy_nonoverlapping(slice.as_ptr(), (*ptr).phantom.as_mut_ptr(), slice.len());
Vecf16Output(NonNull::new(ptr).unwrap())
}
}
pub fn len(&self) -> usize {
self.len as usize
}
pub fn data(&self) -> &[F16] {
debug_assert_eq!(self.varlena & 3, 0);
debug_assert_eq!(self.kind, 1);
unsafe { std::slice::from_raw_parts(self.phantom.as_ptr(), self.len as usize) }
}
pub fn data_mut(&mut self) -> &mut [F16] {
debug_assert_eq!(self.varlena & 3, 0);
debug_assert_eq!(self.kind, 1);
unsafe { std::slice::from_raw_parts_mut(self.phantom.as_mut_ptr(), self.len as usize) }
}
}
impl Deref for Vecf16 {
type Target = [F16];
fn deref(&self) -> &Self::Target {
self.data()
}
}
impl DerefMut for Vecf16 {
fn deref_mut(&mut self) -> &mut Self::Target {
self.data_mut()
}
}
impl AsRef<[F16]> for Vecf16 {
fn as_ref(&self) -> &[F16] {
self.data()
}
}
impl AsMut<[F16]> for Vecf16 {
fn as_mut(&mut self) -> &mut [F16] {
self.data_mut()
}
}
impl Index<usize> for Vecf16 {
type Output = F16;
fn index(&self, index: usize) -> &Self::Output {
self.data().index(index)
}
}
impl IndexMut<usize> for Vecf16 {
fn index_mut(&mut self, index: usize) -> &mut Self::Output {
self.data_mut().index_mut(index)
}
}
impl PartialEq for Vecf16 {
fn eq(&self, other: &Self) -> bool {
if self.len() != other.len() {
return false;
}
let n = self.len();
for i in 0..n {
if self[i] != other[i] {
return false;
}
}
true
}
}
impl Eq for Vecf16 {}
impl PartialOrd for Vecf16 {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for Vecf16 {
fn cmp(&self, other: &Self) -> Ordering {
use Ordering::*;
if let x @ Less | x @ Greater = self.len().cmp(&other.len()) {
return x;
}
let n = self.len();
for i in 0..n {
if let x @ Less | x @ Greater = self[i].cmp(&other[i]) {
return x;
}
}
Equal
}
}
pub enum Vecf16Input<'a> {
Owned(Vecf16Output),
Borrowed(&'a Vecf16),
}
impl<'a> Vecf16Input<'a> {
pub unsafe fn new(p: NonNull<Vecf16>) -> Self {
let q = unsafe {
NonNull::new(pgrx::pg_sys::pg_detoast_datum(p.cast().as_ptr()).cast()).unwrap()
};
if p != q {
Vecf16Input::Owned(Vecf16Output(q))
} else {
unsafe { Vecf16Input::Borrowed(p.as_ref()) }
}
}
}
impl Deref for Vecf16Input<'_> {
type Target = Vecf16;
fn deref(&self) -> &Self::Target {
match self {
Vecf16Input::Owned(x) => x,
Vecf16Input::Borrowed(x) => x,
}
}
}
pub struct Vecf16Output(NonNull<Vecf16>);
impl Vecf16Output {
pub fn into_raw(self) -> *mut Vecf16 {
let result = self.0.as_ptr();
std::mem::forget(self);
result
}
}
impl Deref for Vecf16Output {
type Target = Vecf16;
fn deref(&self) -> &Self::Target {
unsafe { self.0.as_ref() }
}
}
impl DerefMut for Vecf16Output {
fn deref_mut(&mut self) -> &mut Self::Target {
unsafe { self.0.as_mut() }
}
}
impl Drop for Vecf16Output {
fn drop(&mut self) {
unsafe {
pgrx::pg_sys::pfree(self.0.as_ptr() as _);
}
}
}
impl<'a> FromDatum for Vecf16Input<'a> {
unsafe fn from_polymorphic_datum(datum: Datum, is_null: bool, _typoid: Oid) -> Option<Self> {
if is_null {
None
} else {
let ptr = NonNull::new(datum.cast_mut_ptr::<Vecf16>()).unwrap();
unsafe { Some(Vecf16Input::new(ptr)) }
}
}
}
impl IntoDatum for Vecf16Output {
fn into_datum(self) -> Option<Datum> {
Some(Datum::from(self.into_raw() as *mut ()))
}
fn type_oid() -> Oid {
pgrx::wrappers::regtypein("vecf16")
}
}
unsafe impl SqlTranslatable for Vecf16Input<'_> {
fn argument_sql() -> Result<SqlMapping, ArgumentError> {
Ok(SqlMapping::As(String::from("vecf16")))
}
fn return_sql() -> Result<Returns, ReturnsError> {
Ok(Returns::One(SqlMapping::As(String::from("vecf16"))))
}
}
unsafe impl SqlTranslatable for Vecf16Output {
fn argument_sql() -> Result<SqlMapping, ArgumentError> {
Ok(SqlMapping::As(String::from("vecf16")))
}
fn return_sql() -> Result<Returns, ReturnsError> {
Ok(Returns::One(SqlMapping::As(String::from("vecf16"))))
}
}
#[pgrx::pg_extern(immutable, parallel_safe, strict)]
fn vecf16_in(input: &CStr, _oid: Oid, typmod: i32) -> Vecf16Output {
fn solve<T>(option: Option<T>, hint: &str) -> T {
if let Some(x) = option {
x
} else {
FriendlyError::BadLiteral {
hint: hint.to_string(),
}
.friendly()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum State {
MatchingLeft,
Reading,
MatchedRight,
}
use State::*;
let input = input.to_bytes();
let typmod = Typmod::parse_from_i32(typmod).unwrap();
let mut vector = Vec::<F16>::with_capacity(typmod.dims().unwrap_or(0) as usize);
let mut state = MatchingLeft;
let mut token: Option<String> = None;
for &c in input {
match (state, c) {
(MatchingLeft, b'[') => {
state = Reading;
}
(Reading, b'0'..=b'9' | b'.' | b'e' | b'+' | b'-') => {
let token = token.get_or_insert(String::new());
token.push(char::from_u32(c as u32).unwrap());
}
(Reading, b',') => {
let token = solve(token.take(), "Expect a number.");
vector.push(solve(token.parse().ok(), "Bad number."));
}
(Reading, b']') => {
if let Some(token) = token.take() {
vector.push(solve(token.parse().ok(), "Bad number."));
}
state = MatchedRight;
}
(_, b' ') => {}
_ => {
FriendlyError::BadLiteral {
hint: format!("Bad charactor with ascii {:#x}.", c),
}
.friendly();
}
}
}
if state != MatchedRight {
FriendlyError::BadLiteral {
hint: "Bad sequence.".to_string(),
}
.friendly();
}
if vector.is_empty() || vector.len() > 65535 {
FriendlyError::BadValueDimensions.friendly();
}
Vecf16::new_in_postgres(&vector)
}
#[pgrx::pg_extern(immutable, parallel_safe, strict)]
fn vecf16_out(vector: Vecf16Input<'_>) -> CString {
let mut buffer = String::new();
buffer.push('[');
if let Some(&x) = vector.data().first() {
buffer.push_str(format!("{}", x).as_str());
}
for &x in vector.data().iter().skip(1) {
buffer.push_str(format!(", {}", x).as_str());
}
buffer.push(']');
CString::new(buffer).unwrap()
}

343
src/datatype/vecf32.rs Normal file
View File

@@ -0,0 +1,343 @@
use crate::datatype::typmod::Typmod;
use pgrx::pg_sys::Datum;
use pgrx::pg_sys::Oid;
use pgrx::pgrx_sql_entity_graph::metadata::ArgumentError;
use pgrx::pgrx_sql_entity_graph::metadata::Returns;
use pgrx::pgrx_sql_entity_graph::metadata::ReturnsError;
use pgrx::pgrx_sql_entity_graph::metadata::SqlMapping;
use pgrx::pgrx_sql_entity_graph::metadata::SqlTranslatable;
use pgrx::FromDatum;
use pgrx::IntoDatum;
use service::prelude::*;
use std::alloc::Layout;
use std::cmp::Ordering;
use std::ffi::CStr;
use std::ffi::CString;
use std::ops::Deref;
use std::ops::DerefMut;
use std::ops::Index;
use std::ops::IndexMut;
use std::ptr::NonNull;
pgrx::extension_sql!(
r#"
CREATE TYPE vector (
INPUT = vecf32_in,
OUTPUT = vecf32_out,
TYPMOD_IN = typmod_in,
TYPMOD_OUT = typmod_out,
STORAGE = EXTENDED,
INTERNALLENGTH = VARIABLE,
ALIGNMENT = double
);
"#,
name = "vecf32",
creates = [Type(Vecf32)],
requires = [vecf32_in, vecf32_out, typmod_in, typmod_out],
);
#[repr(C, align(8))]
pub struct Vecf32 {
varlena: u32,
len: u16,
kind: u8,
reserved: u8,
phantom: [F32; 0],
}
impl Vecf32 {
fn varlena(size: usize) -> u32 {
(size << 2) as u32
}
fn layout(len: usize) -> Layout {
u16::try_from(len).expect("Vector is too large.");
let layout_alpha = Layout::new::<Vecf32>();
let layout_beta = Layout::array::<F32>(len).unwrap();
let layout = layout_alpha.extend(layout_beta).unwrap().0;
layout.pad_to_align()
}
pub fn new_in_postgres(slice: &[F32]) -> Vecf32Output {
unsafe {
assert!(u16::try_from(slice.len()).is_ok());
let layout = Vecf32::layout(slice.len());
let ptr = pgrx::pg_sys::palloc(layout.size()) as *mut Vecf32;
ptr.cast::<u8>().add(layout.size() - 8).write_bytes(0, 8);
std::ptr::addr_of_mut!((*ptr).varlena).write(Vecf32::varlena(layout.size()));
std::ptr::addr_of_mut!((*ptr).kind).write(0);
std::ptr::addr_of_mut!((*ptr).reserved).write(0);
std::ptr::addr_of_mut!((*ptr).len).write(slice.len() as u16);
std::ptr::copy_nonoverlapping(slice.as_ptr(), (*ptr).phantom.as_mut_ptr(), slice.len());
Vecf32Output(NonNull::new(ptr).unwrap())
}
}
pub fn len(&self) -> usize {
self.len as usize
}
pub fn data(&self) -> &[F32] {
debug_assert_eq!(self.varlena & 3, 0);
debug_assert_eq!(self.kind, 0);
unsafe { std::slice::from_raw_parts(self.phantom.as_ptr(), self.len as usize) }
}
pub fn data_mut(&mut self) -> &mut [F32] {
debug_assert_eq!(self.varlena & 3, 0);
debug_assert_eq!(self.kind, 0);
unsafe { std::slice::from_raw_parts_mut(self.phantom.as_mut_ptr(), self.len as usize) }
}
}
impl Deref for Vecf32 {
type Target = [F32];
fn deref(&self) -> &Self::Target {
self.data()
}
}
impl DerefMut for Vecf32 {
fn deref_mut(&mut self) -> &mut Self::Target {
self.data_mut()
}
}
impl AsRef<[F32]> for Vecf32 {
fn as_ref(&self) -> &[F32] {
self.data()
}
}
impl AsMut<[F32]> for Vecf32 {
fn as_mut(&mut self) -> &mut [F32] {
self.data_mut()
}
}
impl Index<usize> for Vecf32 {
type Output = F32;
fn index(&self, index: usize) -> &Self::Output {
self.data().index(index)
}
}
impl IndexMut<usize> for Vecf32 {
fn index_mut(&mut self, index: usize) -> &mut Self::Output {
self.data_mut().index_mut(index)
}
}
impl PartialEq for Vecf32 {
fn eq(&self, other: &Self) -> bool {
if self.len() != other.len() {
return false;
}
let n = self.len();
for i in 0..n {
if self[i] != other[i] {
return false;
}
}
true
}
}
impl Eq for Vecf32 {}
impl PartialOrd for Vecf32 {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for Vecf32 {
fn cmp(&self, other: &Self) -> Ordering {
use Ordering::*;
if let x @ Less | x @ Greater = self.len().cmp(&other.len()) {
return x;
}
let n = self.len();
for i in 0..n {
if let x @ Less | x @ Greater = self[i].cmp(&other[i]) {
return x;
}
}
Equal
}
}
pub enum Vecf32Input<'a> {
Owned(Vecf32Output),
Borrowed(&'a Vecf32),
}
impl<'a> Vecf32Input<'a> {
pub unsafe fn new(p: NonNull<Vecf32>) -> Self {
let q = unsafe {
NonNull::new(pgrx::pg_sys::pg_detoast_datum(p.cast().as_ptr()).cast()).unwrap()
};
if p != q {
Vecf32Input::Owned(Vecf32Output(q))
} else {
unsafe { Vecf32Input::Borrowed(p.as_ref()) }
}
}
}
impl Deref for Vecf32Input<'_> {
type Target = Vecf32;
fn deref(&self) -> &Self::Target {
match self {
Vecf32Input::Owned(x) => x,
Vecf32Input::Borrowed(x) => x,
}
}
}
pub struct Vecf32Output(NonNull<Vecf32>);
impl Vecf32Output {
pub fn into_raw(self) -> *mut Vecf32 {
let result = self.0.as_ptr();
std::mem::forget(self);
result
}
}
impl Deref for Vecf32Output {
type Target = Vecf32;
fn deref(&self) -> &Self::Target {
unsafe { self.0.as_ref() }
}
}
impl DerefMut for Vecf32Output {
fn deref_mut(&mut self) -> &mut Self::Target {
unsafe { self.0.as_mut() }
}
}
impl Drop for Vecf32Output {
fn drop(&mut self) {
unsafe {
pgrx::pg_sys::pfree(self.0.as_ptr() as _);
}
}
}
impl<'a> FromDatum for Vecf32Input<'a> {
unsafe fn from_polymorphic_datum(datum: Datum, is_null: bool, _typoid: Oid) -> Option<Self> {
if is_null {
None
} else {
let ptr = NonNull::new(datum.cast_mut_ptr::<Vecf32>()).unwrap();
unsafe { Some(Vecf32Input::new(ptr)) }
}
}
}
impl IntoDatum for Vecf32Output {
fn into_datum(self) -> Option<Datum> {
Some(Datum::from(self.into_raw() as *mut ()))
}
fn type_oid() -> Oid {
pgrx::wrappers::regtypein("vector")
}
}
unsafe impl SqlTranslatable for Vecf32Input<'_> {
fn argument_sql() -> Result<SqlMapping, ArgumentError> {
Ok(SqlMapping::As(String::from("vector")))
}
fn return_sql() -> Result<Returns, ReturnsError> {
Ok(Returns::One(SqlMapping::As(String::from("vector"))))
}
}
unsafe impl SqlTranslatable for Vecf32Output {
fn argument_sql() -> Result<SqlMapping, ArgumentError> {
Ok(SqlMapping::As(String::from("vector")))
}
fn return_sql() -> Result<Returns, ReturnsError> {
Ok(Returns::One(SqlMapping::As(String::from("vector"))))
}
}
#[pgrx::pg_extern(immutable, parallel_safe, strict)]
fn vecf32_in(input: &CStr, _oid: Oid, typmod: i32) -> Vecf32Output {
fn solve<T>(option: Option<T>, hint: &str) -> T {
if let Some(x) = option {
x
} else {
FriendlyError::BadLiteral {
hint: hint.to_string(),
}
.friendly()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum State {
MatchingLeft,
Reading,
MatchedRight,
}
use State::*;
let input = input.to_bytes();
let typmod = Typmod::parse_from_i32(typmod).unwrap();
let mut vector = Vec::<F32>::with_capacity(typmod.dims().unwrap_or(0) as usize);
let mut state = MatchingLeft;
let mut token: Option<String> = None;
for &c in input {
match (state, c) {
(MatchingLeft, b'[') => {
state = Reading;
}
(Reading, b'0'..=b'9' | b'.' | b'e' | b'+' | b'-') => {
let token = token.get_or_insert(String::new());
token.push(char::from_u32(c as u32).unwrap());
}
(Reading, b',') => {
let token = solve(token.take(), "Expect a number.");
vector.push(solve(token.parse().ok(), "Bad number."));
}
(Reading, b']') => {
if let Some(token) = token.take() {
vector.push(solve(token.parse().ok(), "Bad number."));
}
state = MatchedRight;
}
(_, b' ') => {}
_ => {
FriendlyError::BadLiteral {
hint: format!("Bad charactor with ascii {:#x}.", c),
}
.friendly();
}
}
}
if state != MatchedRight {
FriendlyError::BadLiteral {
hint: "Bad sequence.".to_string(),
}
.friendly();
}
if vector.is_empty() || vector.len() > 65535 {
FriendlyError::BadValueDimensions.friendly();
}
Vecf32::new_in_postgres(&vector)
}
#[pgrx::pg_extern(immutable, parallel_safe, strict)]
fn vecf32_out(vector: Vecf32Input<'_>) -> CString {
let mut buffer = String::new();
buffer.push('[');
if let Some(&x) = vector.data().first() {
buffer.push_str(format!("{}", x).as_str());
}
for &x in vector.data().iter().skip(1) {
buffer.push_str(format!(", {}", x).as_str());
}
buffer.push(']');
CString::new(buffer).unwrap()
}

View File

@@ -1,14 +1,12 @@
use super::openai::{EmbeddingCreator, OpenAIEmbedding};
use super::Embedding;
use crate::postgres::datatype::Vector;
use crate::postgres::datatype::VectorOutput;
use crate::postgres::gucs::OPENAI_API_KEY_GUC;
use crate::prelude::Float;
use crate::prelude::Scalar;
use crate::datatype::vecf32::{Vecf32, Vecf32Output};
use crate::gucs::OPENAI_API_KEY_GUC;
use pgrx::prelude::*;
use service::prelude::F32;
#[pg_extern]
fn ai_embedding_vector(input: String) -> VectorOutput {
fn ai_embedding_vector(input: String) -> Vecf32Output {
let api_key = match OPENAI_API_KEY_GUC.get() {
Some(key) => key
.to_str()
@@ -26,9 +24,9 @@ fn ai_embedding_vector(input: String) -> VectorOutput {
Ok(embedding) => {
let embedding = embedding
.into_iter()
.map(|x| Scalar(x as Float))
.map(|x| F32(x as f32))
.collect::<Vec<_>>();
Vector::new_in_postgres(&embedding)
Vecf32::new_in_postgres(&embedding)
}
Err(e) => {
error!("{}", e)

View File

@@ -23,7 +23,7 @@ pub static ENABLE_VECTOR_INDEX: GucSetting<bool> = GucSetting::<bool>::new(true)
pub static ENABLE_PREFILTER: GucSetting<bool> = GucSetting::<bool>::new(false);
pub static VBASE_RANGE: GucSetting<i32> = GucSetting::<i32>::new(0);
pub static ENABLE_VBASE: GucSetting<bool> = GucSetting::<bool>::new(false);
pub static TRANSPORT: GucSetting<Transport> = GucSetting::<Transport>::new(Transport::default());
@@ -62,13 +62,11 @@ pub unsafe fn init() {
GucContext::Userset,
GucFlags::default(),
);
GucRegistry::define_int_guc(
"vectors.vbase_range",
"The range of vbase.",
"The range of vbase.",
&VBASE_RANGE,
0,
u16::MAX as _,
GucRegistry::define_bool_guc(
"vectors.enable_vbase",
"Whether to enable vbase.",
"When enabled, it will use vbase for filtering.",
&ENABLE_VBASE,
GucContext::Userset,
GucFlags::default(),
);

View File

@@ -1,11 +1,12 @@
use super::index_build;
use super::index_scan;
use super::index_setup;
use super::index_update;
use crate::postgres::datatype::VectorInput;
use crate::postgres::gucs::ENABLE_VECTOR_INDEX;
use super::am_build;
use super::am_scan;
use super::am_setup;
use super::am_update;
use crate::gucs::ENABLE_VECTOR_INDEX;
use crate::index::utils::from_datum;
use crate::prelude::*;
use crate::utils::cells::PgCell;
use service::prelude::*;
static RELOPT_KIND: PgCell<pgrx::pg_sys::relopt_kind> = unsafe { PgCell::new(0) };
@@ -28,9 +29,7 @@ pub unsafe fn init() {
#[pgrx::pg_extern(sql = "
CREATE OR REPLACE FUNCTION vectors_amhandler(internal) RETURNS index_am_handler
PARALLEL SAFE IMMUTABLE STRICT LANGUAGE c AS 'MODULE_PATHNAME', '@FUNCTION_NAME@';
CREATE ACCESS METHOD vectors TYPE INDEX HANDLER vectors_amhandler;
COMMENT ON ACCESS METHOD vectors IS 'pgvecto.rs index access method';
", requires = ["vector"])]
", requires = ["vecf32"])]
fn vectors_amhandler(
_fcinfo: pgrx::pg_sys::FunctionCallInfo,
) -> pgrx::PgBox<pgrx::pg_sys::IndexAmRoutine> {
@@ -85,7 +84,7 @@ const AM_HANDLER: pgrx::pg_sys::IndexAmRoutine = {
#[pgrx::pg_guard]
pub unsafe extern "C" fn amvalidate(opclass_oid: pgrx::pg_sys::Oid) -> bool {
index_setup::convert_opclass_to_distance(opclass_oid);
am_setup::convert_opclass_to_distance(opclass_oid);
true
}
@@ -99,7 +98,7 @@ pub unsafe extern "C" fn amoptions(
let tab: &[pgrx::pg_sys::relopt_parse_elt] = &[pgrx::pg_sys::relopt_parse_elt {
optname: "options".as_pg_cstr(),
opttype: pgrx::pg_sys::relopt_type_RELOPT_TYPE_STRING,
offset: index_setup::helper_offset() as i32,
offset: am_setup::helper_offset() as i32,
}];
let mut noptions = 0;
let options =
@@ -111,10 +110,10 @@ pub unsafe extern "C" fn amoptions(
relopt.gen.as_mut().unwrap().lockmode =
pgrx::pg_sys::AccessExclusiveLock as pgrx::pg_sys::LOCKMODE;
}
let rdopts = pgrx::pg_sys::allocateReloptStruct(index_setup::helper_size(), options, noptions);
let rdopts = pgrx::pg_sys::allocateReloptStruct(am_setup::helper_size(), options, noptions);
pgrx::pg_sys::fillRelOptions(
rdopts,
index_setup::helper_size(),
am_setup::helper_size(),
options,
noptions,
validate,
@@ -136,13 +135,13 @@ pub unsafe extern "C" fn amoptions(
let tab: &[pgrx::pg_sys::relopt_parse_elt] = &[pgrx::pg_sys::relopt_parse_elt {
optname: "options".as_pg_cstr(),
opttype: pgrx::pg_sys::relopt_type_RELOPT_TYPE_STRING,
offset: index_setup::helper_offset() as i32,
offset: am_setup::helper_offset() as i32,
}];
let rdopts = pgrx::pg_sys::build_reloptions(
reloptions,
validate,
RELOPT_KIND.get(),
index_setup::helper_size(),
am_setup::helper_size(),
tab.as_ptr(),
tab.len() as _,
);
@@ -182,7 +181,7 @@ pub unsafe extern "C" fn ambuild(
index_info: *mut pgrx::pg_sys::IndexInfo,
) -> *mut pgrx::pg_sys::IndexBuildResult {
let result = pgrx::PgBox::<pgrx::pg_sys::IndexBuildResult>::alloc0();
index_build::build(
am_build::build(
index_relation,
Some((heap_relation, index_info, result.as_ptr())),
);
@@ -191,7 +190,7 @@ pub unsafe extern "C" fn ambuild(
#[pgrx::pg_guard]
pub unsafe extern "C" fn ambuildempty(index_relation: pgrx::pg_sys::Relation) {
index_build::build(index_relation, None);
am_build::build(index_relation, None);
}
#[cfg(any(feature = "pg12", feature = "pg13"))]
@@ -199,18 +198,16 @@ pub unsafe extern "C" fn ambuildempty(index_relation: pgrx::pg_sys::Relation) {
pub unsafe extern "C" fn aminsert(
index_relation: pgrx::pg_sys::Relation,
values: *mut pgrx::pg_sys::Datum,
is_null: *mut bool,
_is_null: *mut bool,
heap_tid: pgrx::pg_sys::ItemPointer,
_heap_relation: pgrx::pg_sys::Relation,
_check_unique: pgrx::pg_sys::IndexUniqueCheck,
_index_info: *mut pgrx::pg_sys::IndexInfo,
) -> bool {
use pgrx::FromDatum;
let oid = (*index_relation).rd_node.relNode;
let id = Id::from_sys(oid);
let vector = VectorInput::from_datum(*values.add(0), *is_null.add(0)).unwrap();
let vector = vector.data().to_vec();
index_update::update_insert(id, vector, *heap_tid);
let vector = from_datum(*values.add(0));
am_update::update_insert(id, vector, *heap_tid);
true
}
@@ -219,22 +216,20 @@ pub unsafe extern "C" fn aminsert(
pub unsafe extern "C" fn aminsert(
index_relation: pgrx::pg_sys::Relation,
values: *mut pgrx::pg_sys::Datum,
is_null: *mut bool,
_is_null: *mut bool,
heap_tid: pgrx::pg_sys::ItemPointer,
_heap_relation: pgrx::pg_sys::Relation,
_check_unique: pgrx::pg_sys::IndexUniqueCheck,
_index_unchanged: bool,
_index_info: *mut pgrx::pg_sys::IndexInfo,
) -> bool {
use pgrx::FromDatum;
#[cfg(any(feature = "pg14", feature = "pg15"))]
let oid = (*index_relation).rd_node.relNode;
#[cfg(feature = "pg16")]
let oid = (*index_relation).rd_locator.relNumber;
let id = Id::from_sys(oid);
let vector = VectorInput::from_datum(*values.add(0), *is_null.add(0)).unwrap();
let vector = vector.data().to_vec();
index_update::update_insert(id, vector, *heap_tid);
let vector = from_datum(*values.add(0));
am_update::update_insert(id, vector, *heap_tid);
true
}
@@ -242,20 +237,26 @@ pub unsafe extern "C" fn aminsert(
pub unsafe extern "C" fn ambeginscan(
index_relation: pgrx::pg_sys::Relation,
n_keys: std::os::raw::c_int,
n_order_bys: std::os::raw::c_int,
n_orderbys: std::os::raw::c_int,
) -> pgrx::pg_sys::IndexScanDesc {
index_scan::make_scan(index_relation, n_keys, n_order_bys)
assert!(n_keys == 0);
assert!(n_orderbys == 1);
am_scan::make_scan(index_relation)
}
#[pgrx::pg_guard]
pub unsafe extern "C" fn amrescan(
scan: pgrx::pg_sys::IndexScanDesc,
keys: pgrx::pg_sys::ScanKey,
_keys: pgrx::pg_sys::ScanKey,
n_keys: std::os::raw::c_int,
orderbys: pgrx::pg_sys::ScanKey,
n_orderbys: std::os::raw::c_int,
) {
index_scan::start_scan(scan, keys, n_keys, orderbys, n_orderbys);
assert!((*scan).numberOfKeys == n_keys);
assert!((*scan).numberOfOrderBys == n_orderbys);
assert!(n_keys == 0);
assert!(n_orderbys == 1);
am_scan::start_scan(scan, orderbys);
}
#[pgrx::pg_guard]
@@ -264,12 +265,12 @@ pub unsafe extern "C" fn amgettuple(
direction: pgrx::pg_sys::ScanDirection,
) -> bool {
assert!(direction == pgrx::pg_sys::ScanDirection_ForwardScanDirection);
index_scan::next_scan(scan)
am_scan::next_scan(scan)
}
#[pgrx::pg_guard]
pub unsafe extern "C" fn amendscan(scan: pgrx::pg_sys::IndexScanDesc) {
index_scan::end_scan(scan);
am_scan::end_scan(scan);
}
#[pgrx::pg_guard]
@@ -285,7 +286,7 @@ pub unsafe extern "C" fn ambulkdelete(
let oid = (*(*info).index).rd_locator.relNumber;
let id = Id::from_sys(oid);
if let Some(callback) = callback {
index_update::update_delete(id, |pointer| {
am_update::update_delete(id, |pointer| {
callback(
&mut pointer.into_sys() as *mut pgrx::pg_sys::ItemPointerData,
callback_state,

View File

@@ -1,11 +1,13 @@
use super::hook_transaction::{client, flush_if_commit};
use crate::ipc::client::Rpc;
use crate::postgres::index_setup::options;
use super::hook_transaction::flush_if_commit;
use crate::index::utils::from_datum;
use crate::ipc::client::ClientGuard;
use crate::prelude::*;
use crate::{index::am_setup::options, ipc::client::Rpc};
use pgrx::pg_sys::{IndexBuildResult, IndexInfo, RelationData};
use service::prelude::*;
pub struct Builder {
pub rpc: Rpc,
pub rpc: ClientGuard<Rpc>,
pub heap_relation: *mut RelationData,
pub index_info: *mut IndexInfo,
pub result: *mut IndexBuildResult,
@@ -22,27 +24,22 @@ pub unsafe fn build(
let id = Id::from_sys(oid);
flush_if_commit(id);
let options = options(index);
client(|mut rpc| {
rpc.create(id, options).friendly();
rpc
});
let mut rpc = crate::ipc::client::borrow_mut();
rpc.create(id, options);
if let Some((heap_relation, index_info, result)) = data {
client(|rpc| {
let mut builder = Builder {
rpc,
heap_relation,
index_info,
result,
};
pgrx::pg_sys::IndexBuildHeapScan(
heap_relation,
index,
index_info,
Some(callback),
&mut builder,
);
builder.rpc
});
let mut builder = Builder {
rpc,
heap_relation,
index_info,
result,
};
pgrx::pg_sys::IndexBuildHeapScan(
heap_relation,
index,
index_info,
Some(callback),
&mut builder,
);
}
}
@@ -52,21 +49,17 @@ unsafe extern "C" fn callback(
index_relation: pgrx::pg_sys::Relation,
htup: pgrx::pg_sys::HeapTuple,
values: *mut pgrx::pg_sys::Datum,
is_null: *mut bool,
_is_null: *mut bool,
_tuple_is_alive: bool,
state: *mut std::os::raw::c_void,
) {
use super::datatype::VectorInput;
use pgrx::FromDatum;
let ctid = &(*htup).t_self;
let oid = (*index_relation).rd_node.relNode;
let id = Id::from_sys(oid);
let state = &mut *(state as *mut Builder);
let pgvector = VectorInput::from_datum(*values.add(0), *is_null.add(0)).unwrap();
let data = (pgvector.to_vec(), Pointer::from_sys(*ctid));
state.rpc.insert(id, data).friendly().friendly();
let vector = from_datum(*values.add(0));
let data = (vector, Pointer::from_sys(*ctid));
state.rpc.insert(id, data);
(*state.result).heap_tuples += 1.0;
(*state.result).index_tuples += 1.0;
}
@@ -77,22 +70,19 @@ unsafe extern "C" fn callback(
index_relation: pgrx::pg_sys::Relation,
ctid: pgrx::pg_sys::ItemPointer,
values: *mut pgrx::pg_sys::Datum,
is_null: *mut bool,
_is_null: *mut bool,
_tuple_is_alive: bool,
state: *mut std::os::raw::c_void,
) {
use super::datatype::VectorInput;
use pgrx::FromDatum;
#[cfg(any(feature = "pg13", feature = "pg14", feature = "pg15"))]
let oid = (*index_relation).rd_node.relNode;
#[cfg(feature = "pg16")]
let oid = (*index_relation).rd_locator.relNumber;
let id = Id::from_sys(oid);
let state = &mut *(state as *mut Builder);
let pgvector = VectorInput::from_datum(*values.add(0), *is_null.add(0)).unwrap();
let data = (pgvector.to_vec(), Pointer::from_sys(*ctid));
state.rpc.insert(id, data).friendly().friendly();
let vector = from_datum(*values.add(0));
let data = (vector, Pointer::from_sys(*ctid));
state.rpc.insert(id, data);
(*state.result).heap_tuples += 1.0;
(*state.result).index_tuples += 1.0;
}

234
src/index/am_scan.rs Normal file
View File

@@ -0,0 +1,234 @@
use crate::gucs::ENABLE_PREFILTER;
use crate::gucs::ENABLE_VBASE;
use crate::gucs::K;
use crate::index::utils::from_datum;
use crate::ipc::client::ClientGuard;
use crate::ipc::client::Vbase;
use crate::prelude::*;
use pgrx::FromDatum;
use service::prelude::*;
pub enum Scanner {
Initial {
node: Option<*mut pgrx::pg_sys::IndexScanState>,
vector: Option<DynamicVector>,
},
Search {
node: *mut pgrx::pg_sys::IndexScanState,
data: Vec<Pointer>,
},
Vbase {
node: *mut pgrx::pg_sys::IndexScanState,
vbase: ClientGuard<Vbase>,
},
}
impl Scanner {
fn node(&self) -> Option<*mut pgrx::pg_sys::IndexScanState> {
match self {
Scanner::Initial { node, .. } => *node,
Scanner::Search { node, .. } => Some(*node),
Scanner::Vbase { node, .. } => Some(*node),
}
}
}
pub unsafe fn make_scan(index_relation: pgrx::pg_sys::Relation) -> pgrx::pg_sys::IndexScanDesc {
use pgrx::PgMemoryContexts;
let scan = pgrx::pg_sys::RelationGetIndexScan(index_relation, 0, 1);
(*scan).xs_recheck = false;
(*scan).xs_recheckorderby = false;
(*scan).opaque =
PgMemoryContexts::CurrentMemoryContext.leak_and_drop_on_delete(Scanner::Initial {
vector: None,
node: None,
}) as _;
(*scan).xs_orderbyvals = pgrx::pg_sys::palloc0(std::mem::size_of::<pgrx::pg_sys::Datum>()) as _;
(*scan).xs_orderbynulls = {
let data = pgrx::pg_sys::palloc(std::mem::size_of::<bool>()) as *mut bool;
data.write_bytes(1, 1);
data
};
scan
}
pub unsafe fn start_scan(scan: pgrx::pg_sys::IndexScanDesc, orderbys: pgrx::pg_sys::ScanKey) {
std::ptr::copy(orderbys, (*scan).orderByData, 1);
let vector = from_datum((*orderbys.add(0)).sk_argument);
let scanner = &mut *((*scan).opaque as *mut Scanner);
let scanner = std::mem::replace(
scanner,
Scanner::Initial {
node: scanner.node(),
vector: Some(vector),
},
);
match scanner {
Scanner::Initial { .. } => {}
Scanner::Search { .. } => {}
Scanner::Vbase { vbase, .. } => {
vbase.leave();
}
}
}
pub unsafe fn next_scan(scan: pgrx::pg_sys::IndexScanDesc) -> bool {
let scanner = &mut *((*scan).opaque as *mut Scanner);
if let Scanner::Initial { node, vector } = scanner {
let node = node.expect("Hook failed.");
let vector = vector.as_ref().expect("Scan failed.");
#[cfg(any(feature = "pg12", feature = "pg13", feature = "pg14", feature = "pg15"))]
let oid = (*(*scan).indexRelation).rd_node.relNode;
#[cfg(feature = "pg16")]
let oid = (*(*scan).indexRelation).rd_locator.relNumber;
let id = Id::from_sys(oid);
let mut rpc = crate::ipc::client::borrow_mut();
if ENABLE_VBASE.get() {
let vbase = rpc.vbase(id, vector.clone());
*scanner = Scanner::Vbase { node, vbase };
} else {
let k = K.get() as _;
struct Search {
node: *mut pgrx::pg_sys::IndexScanState,
}
impl crate::ipc::client::Search for Search {
fn check(&mut self, p: Pointer) -> bool {
unsafe { check(self.node, p) }
}
}
let search = Search { node };
let mut data = rpc.search(id, (vector.clone(), k), ENABLE_PREFILTER.get(), search);
data.reverse();
*scanner = Scanner::Search { node, data };
}
}
match scanner {
Scanner::Initial { .. } => unreachable!(),
Scanner::Search { data, .. } => {
if let Some(p) = data.pop() {
(*scan).xs_heaptid = p.into_sys();
true
} else {
false
}
}
Scanner::Vbase { vbase, .. } => {
if let Some(p) = vbase.next() {
(*scan).xs_heaptid = p.into_sys();
true
} else {
false
}
}
}
}
pub unsafe fn end_scan(scan: pgrx::pg_sys::IndexScanDesc) {
let scanner = &mut *((*scan).opaque as *mut Scanner);
let scanner = std::mem::replace(
scanner,
Scanner::Initial {
node: scanner.node(),
vector: None,
},
);
match scanner {
Scanner::Initial { .. } => {}
Scanner::Search { .. } => {}
Scanner::Vbase { vbase, .. } => {
vbase.leave();
}
}
}
unsafe fn execute_boolean_qual(
state: *mut pgrx::pg_sys::ExprState,
econtext: *mut pgrx::pg_sys::ExprContext,
) -> bool {
use pgrx::PgMemoryContexts;
if state.is_null() {
return true;
}
assert!((*state).flags & pgrx::pg_sys::EEO_FLAG_IS_QUAL as u8 != 0);
let mut is_null = true;
pgrx::pg_sys::MemoryContextReset((*econtext).ecxt_per_tuple_memory);
let ret = PgMemoryContexts::For((*econtext).ecxt_per_tuple_memory)
.switch_to(|_| (*state).evalfunc.unwrap()(state, econtext, &mut is_null));
assert!(!is_null);
bool::from_datum(ret, is_null).unwrap()
}
unsafe fn check_quals(node: *mut pgrx::pg_sys::IndexScanState) -> bool {
let slot = (*node).ss.ss_ScanTupleSlot;
let econtext = (*node).ss.ps.ps_ExprContext;
(*econtext).ecxt_scantuple = slot;
if (*node).ss.ps.qual.is_null() {
return true;
}
let state = (*node).ss.ps.qual;
let econtext = (*node).ss.ps.ps_ExprContext;
execute_boolean_qual(state, econtext)
}
unsafe fn check_mvcc(node: *mut pgrx::pg_sys::IndexScanState, p: Pointer) -> bool {
let scan_desc = (*node).iss_ScanDesc;
let heap_fetch = (*scan_desc).xs_heapfetch;
let index_relation = (*heap_fetch).rel;
let rd_tableam = (*index_relation).rd_tableam;
let snapshot = (*scan_desc).xs_snapshot;
let index_fetch_tuple = (*rd_tableam).index_fetch_tuple.unwrap();
let mut all_dead = false;
let slot = (*node).ss.ss_ScanTupleSlot;
let mut heap_continue = false;
let found = index_fetch_tuple(
heap_fetch,
&mut p.into_sys(),
snapshot,
slot,
&mut heap_continue,
&mut all_dead,
);
if found {
return true;
}
while heap_continue {
let found = index_fetch_tuple(
heap_fetch,
&mut p.into_sys(),
snapshot,
slot,
&mut heap_continue,
&mut all_dead,
);
if found {
return true;
}
}
false
}
unsafe fn check(node: *mut pgrx::pg_sys::IndexScanState, p: Pointer) -> bool {
if !check_mvcc(node, p) {
return false;
}
if !check_quals(node) {
return false;
}
true
}

View File

@@ -1,22 +1,22 @@
use crate::index::indexing::IndexingOptions;
use crate::index::optimizing::OptimizingOptions;
use crate::index::segments::SegmentsOptions;
use crate::index::{IndexOptions, VectorOptions};
use crate::postgres::datatype::VectorTypmod;
use crate::prelude::*;
use crate::datatype::typmod::Typmod;
use serde::{Deserialize, Serialize};
use service::index::indexing::IndexingOptions;
use service::index::optimizing::OptimizingOptions;
use service::index::segments::SegmentsOptions;
use service::index::{IndexOptions, VectorOptions};
use service::prelude::*;
use std::ffi::CStr;
use validator::Validate;
pub fn helper_offset() -> usize {
memoffset::offset_of!(Helper, offset)
std::mem::offset_of!(Helper, offset)
}
pub fn helper_size() -> usize {
std::mem::size_of::<Helper>()
}
pub unsafe fn convert_opclass_to_distance(opclass: pgrx::pg_sys::Oid) -> Distance {
pub unsafe fn convert_opclass_to_distance(opclass: pgrx::pg_sys::Oid) -> (Distance, Kind) {
let opclass_cache_id = pgrx::pg_sys::SysCacheIdentifier_CLAOID as _;
let tuple = pgrx::pg_sys::SearchSysCache1(opclass_cache_id, opclass.into());
assert!(
@@ -25,12 +25,12 @@ pub unsafe fn convert_opclass_to_distance(opclass: pgrx::pg_sys::Oid) -> Distanc
);
let classform = pgrx::pg_sys::GETSTRUCT(tuple).cast::<pgrx::pg_sys::FormData_pg_opclass>();
let opfamily = (*classform).opcfamily;
let distance = convert_opfamily_to_distance(opfamily);
let result = convert_opfamily_to_distance(opfamily);
pgrx::pg_sys::ReleaseSysCache(tuple);
distance
result
}
pub unsafe fn convert_opfamily_to_distance(opfamily: pgrx::pg_sys::Oid) -> Distance {
pub unsafe fn convert_opfamily_to_distance(opfamily: pgrx::pg_sys::Oid) -> (Distance, Kind) {
let opfamily_cache_id = pgrx::pg_sys::SysCacheIdentifier_OPFAMILYOID as _;
let opstrategy_cache_id = pgrx::pg_sys::SysCacheIdentifier_AMOPSTRATEGY as _;
let tuple = pgrx::pg_sys::SearchSysCache1(opfamily_cache_id, opfamily.into());
@@ -52,19 +52,25 @@ pub unsafe fn convert_opfamily_to_distance(opfamily: pgrx::pg_sys::Oid) -> Dista
assert!((*amop).amopstrategy == 1);
assert!((*amop).amoppurpose == pgrx::pg_sys::AMOP_ORDER as libc::c_char);
let operator = (*amop).amopopr;
let distance;
let result;
if operator == regoperatorin("<->(vector,vector)") {
distance = Distance::L2;
result = (Distance::L2, Kind::F32);
} else if operator == regoperatorin("<#>(vector,vector)") {
distance = Distance::Dot;
result = (Distance::Dot, Kind::F32);
} else if operator == regoperatorin("<=>(vector,vector)") {
distance = Distance::Cosine;
result = (Distance::Cos, Kind::F32);
} else if operator == regoperatorin("<->(vecf16,vecf16)") {
result = (Distance::L2, Kind::F16);
} else if operator == regoperatorin("<#>(vecf16,vecf16)") {
result = (Distance::Dot, Kind::F16);
} else if operator == regoperatorin("<=>(vecf16,vecf16)") {
result = (Distance::Cos, Kind::F16);
} else {
FriendlyError::UnsupportedOperator.friendly();
FriendlyError::BadOptions3.friendly();
};
pgrx::pg_sys::ReleaseCatCacheList(list);
pgrx::pg_sys::ReleaseSysCache(tuple);
distance
result
}
pub unsafe fn options(index_relation: pgrx::pg_sys::Relation) -> IndexOptions {
@@ -72,22 +78,25 @@ pub unsafe fn options(index_relation: pgrx::pg_sys::Relation) -> IndexOptions {
assert!(nkeysatts == 1, "Can not be built on multicolumns.");
// get distance
let opfamily = (*index_relation).rd_opfamily.read();
let d = convert_opfamily_to_distance(opfamily);
let (d, k) = convert_opfamily_to_distance(opfamily);
// get dims
let attrs = (*(*index_relation).rd_att).attrs.as_slice(1);
let attr = &attrs[0];
let typmod = VectorTypmod::parse_from_i32(attr.type_mod()).unwrap();
let dims = typmod.dims().ok_or(FriendlyError::DimsIsNeeded).friendly();
let typmod = Typmod::parse_from_i32(attr.type_mod()).unwrap();
let dims = typmod.dims().ok_or(FriendlyError::BadOption2).friendly();
// get other options
let parsed = get_parsed_from_varlena((*index_relation).rd_options);
let options = IndexOptions {
vector: VectorOptions { dims, d },
vector: VectorOptions { dims, d, k },
segment: parsed.segment,
optimizing: parsed.optimizing,
indexing: parsed.indexing,
};
if let Err(errors) = options.validate() {
FriendlyError::BadOption(errors.to_string()).friendly();
FriendlyError::BadOption {
validation: errors.to_string(),
}
.friendly();
}
options
}

31
src/index/am_update.rs Normal file
View File

@@ -0,0 +1,31 @@
use crate::index::hook_transaction::flush_if_commit;
use crate::prelude::*;
use service::prelude::*;
pub fn update_insert(id: Id, vector: DynamicVector, tid: pgrx::pg_sys::ItemPointerData) {
flush_if_commit(id);
let p = Pointer::from_sys(tid);
let mut rpc = crate::ipc::client::borrow_mut();
rpc.insert(id, (vector, p));
}
pub fn update_delete(id: Id, hook: impl Fn(Pointer) -> bool) {
struct Delete<H> {
hook: H,
}
impl<H> crate::ipc::client::Delete for Delete<H>
where
H: Fn(Pointer) -> bool,
{
fn test(&mut self, p: Pointer) -> bool {
(self.hook)(p)
}
}
let client_delete = Delete { hook };
flush_if_commit(id);
let mut rpc = crate::ipc::client::borrow_mut();
rpc.delete(id, client_delete);
}

View File

@@ -1,5 +1,4 @@
use crate::postgres::index_scan::Scanner;
use crate::postgres::index_scan::ScannerState;
use crate::index::am_scan::Scanner;
use std::ptr::null_mut;
pub unsafe fn post_executor_start(query_desc: *mut pgrx::pg_sys::QueryDesc) {
@@ -21,7 +20,7 @@ unsafe extern "C" fn rewrite_plan_state(
if index_relation
.as_ref()
.and_then(|p| p.rd_indam.as_ref())
.map(|p| p.amvalidate == Some(super::index::amvalidate))
.map(|p| p.amvalidate == Some(super::am::amvalidate))
.unwrap_or(false)
{
// The logic is copied from Postgres source code.
@@ -33,6 +32,13 @@ unsafe extern "C" fn rewrite_plan_state(
(*node).iss_NumScanKeys,
(*node).iss_NumOrderByKeys,
);
let scanner = &mut *((*(*node).iss_ScanDesc).opaque as *mut Scanner);
*scanner = Scanner::Initial {
node: Some(node),
vector: None,
};
if (*node).iss_NumRuntimeKeys == 0 || (*node).iss_RuntimeKeysReady {
pgrx::pg_sys::index_rescan(
(*node).iss_ScanDesc,
@@ -42,10 +48,6 @@ unsafe extern "C" fn rewrite_plan_state(
(*node).iss_NumOrderByKeys,
);
}
// inject
let scanner = &mut *((*(*node).iss_ScanDesc).opaque as *mut Scanner);
scanner.index_scan_state = node;
assert!(matches!(scanner.state, ScannerState::Initial { .. }));
}
}
}

View File

@@ -0,0 +1,26 @@
use crate::utils::cells::PgRefCell;
use service::prelude::*;
use std::collections::BTreeSet;
static FLUSH_IF_COMMIT: PgRefCell<BTreeSet<Id>> = unsafe { PgRefCell::new(BTreeSet::new()) };
pub fn aborting() {
*FLUSH_IF_COMMIT.borrow_mut() = BTreeSet::new();
}
pub fn committing() {
{
let flush_if_commit = FLUSH_IF_COMMIT.borrow();
if flush_if_commit.len() != 0 {
let mut rpc = crate::ipc::client::borrow_mut();
for id in flush_if_commit.iter().copied() {
rpc.flush(id);
}
}
}
*FLUSH_IF_COMMIT.borrow_mut() = BTreeSet::new();
}
pub fn flush_if_commit(id: Id) {
FLUSH_IF_COMMIT.borrow_mut().insert(id);
}

View File

@@ -1,5 +1,5 @@
use crate::postgres::hook_transaction::client;
use crate::prelude::*;
use service::prelude::*;
static mut PREV_EXECUTOR_START: pgrx::pg_sys::ExecutorStart_hook_type = None;
@@ -46,10 +46,8 @@ unsafe fn xact_delete() {
.iter()
.map(|node| Id::from_sys(node.relNode))
.collect::<Vec<_>>();
client(|mut rpc| {
rpc.destory(ids).friendly();
rpc
});
let mut rpc = crate::ipc::client::borrow_mut();
rpc.destory(ids);
}
}
@@ -63,9 +61,7 @@ unsafe fn xact_delete() {
.iter()
.map(|node| Id::from_sys(node.relNumber))
.collect::<Vec<_>>();
client(|mut rpc| {
rpc.destory(ids).friendly();
rpc
});
let mut rpc = crate::ipc::client::borrow_mut();
rpc.destory(ids);
}
}

View File

@@ -1,552 +1,17 @@
pub mod delete;
pub mod indexing;
pub mod optimizing;
pub mod segments;
#![allow(unsafe_op_in_unsafe_fn)]
use self::delete::Delete;
use self::indexing::IndexingOptions;
use self::optimizing::OptimizingOptions;
use self::segments::growing::GrowingSegment;
use self::segments::growing::GrowingSegmentInsertError;
use self::segments::sealed::SealedSegment;
use self::segments::SegmentsOptions;
use crate::index::indexing::DynamicIndexIter;
use crate::prelude::*;
use crate::utils::clean::clean;
use crate::utils::dir_ops::sync_dir;
use crate::utils::file_atomic::FileAtomic;
use arc_swap::ArcSwap;
use crossbeam::sync::Parker;
use crossbeam::sync::Unparker;
use parking_lot::Mutex;
use serde::{Deserialize, Serialize};
use std::cmp::Reverse;
use std::collections::BinaryHeap;
use std::collections::HashMap;
use std::collections::HashSet;
use std::path::PathBuf;
use std::sync::{Arc, Weak};
use thiserror::Error;
use uuid::Uuid;
use validator::Validate;
mod am;
mod am_build;
mod am_scan;
mod am_setup;
mod am_update;
mod hook_executor;
mod hook_transaction;
mod hooks;
mod utils;
mod views;
#[derive(Debug, Error)]
pub enum IndexInsertError {
#[error("The vector is invalid.")]
InvalidVector(Vec<Scalar>),
#[error("The index view is outdated.")]
OutdatedView(#[from] Option<GrowingSegmentInsertError>),
}
#[derive(Debug, Error)]
pub enum IndexSearchError {
#[error("The vector is invalid.")]
InvalidVector(Vec<Scalar>),
}
#[derive(Debug, Clone, Serialize, Deserialize, Validate)]
pub struct VectorOptions {
#[validate(range(min = 1, max = 65535))]
#[serde(rename = "dimensions")]
pub dims: u16,
#[serde(rename = "distance")]
pub d: Distance,
}
#[derive(Debug, Clone, Serialize, Deserialize, Validate)]
pub struct IndexOptions {
#[validate]
pub vector: VectorOptions,
#[validate]
pub segment: SegmentsOptions,
#[validate]
pub optimizing: OptimizingOptions,
#[validate]
pub indexing: IndexingOptions,
}
pub struct Index {
path: PathBuf,
options: IndexOptions,
delete: Arc<Delete>,
protect: Mutex<IndexProtect>,
view: ArcSwap<IndexView>,
optimize_unparker: Unparker,
indexing: Mutex<bool>,
_tracker: Arc<IndexTracker>,
}
impl Index {
pub fn create(path: PathBuf, options: IndexOptions) -> Arc<Self> {
assert!(options.validate().is_ok());
std::fs::create_dir(&path).unwrap();
std::fs::create_dir(path.join("segments")).unwrap();
let startup = FileAtomic::create(
path.join("startup"),
IndexStartup {
sealeds: HashSet::new(),
growings: HashSet::new(),
},
);
let delete = Delete::create(path.join("delete"));
sync_dir(&path);
let parker = Parker::new();
let index = Arc::new(Index {
path: path.clone(),
options: options.clone(),
delete: delete.clone(),
protect: Mutex::new(IndexProtect {
startup,
sealed: HashMap::new(),
growing: HashMap::new(),
write: None,
}),
view: ArcSwap::new(Arc::new(IndexView {
options: options.clone(),
sealed: HashMap::new(),
growing: HashMap::new(),
delete: delete.clone(),
write: None,
})),
optimize_unparker: parker.unparker().clone(),
indexing: Mutex::new(true),
_tracker: Arc::new(IndexTracker { path }),
});
IndexBackground {
index: Arc::downgrade(&index),
parker,
}
.spawn();
index
}
pub fn open(path: PathBuf, options: IndexOptions) -> Arc<Self> {
let tracker = Arc::new(IndexTracker { path: path.clone() });
let startup = FileAtomic::<IndexStartup>::open(path.join("startup"));
clean(
path.join("segments"),
startup
.get()
.sealeds
.iter()
.map(|s| s.to_string())
.chain(startup.get().growings.iter().map(|s| s.to_string())),
);
let sealed = startup
.get()
.sealeds
.iter()
.map(|&uuid| {
(
uuid,
SealedSegment::open(
tracker.clone(),
path.join("segments").join(uuid.to_string()),
uuid,
options.clone(),
),
)
})
.collect::<HashMap<_, _>>();
let growing = startup
.get()
.growings
.iter()
.map(|&uuid| {
(
uuid,
GrowingSegment::open(
tracker.clone(),
path.join("segments").join(uuid.to_string()),
uuid,
options.clone(),
),
)
})
.collect::<HashMap<_, _>>();
let delete = Delete::open(path.join("delete"));
let parker = Parker::new();
let index = Arc::new(Index {
path: path.clone(),
options: options.clone(),
delete: delete.clone(),
protect: Mutex::new(IndexProtect {
startup,
sealed: sealed.clone(),
growing: growing.clone(),
write: None,
}),
view: ArcSwap::new(Arc::new(IndexView {
options: options.clone(),
delete: delete.clone(),
sealed,
growing,
write: None,
})),
optimize_unparker: parker.unparker().clone(),
indexing: Mutex::new(true),
_tracker: tracker,
});
IndexBackground {
index: Arc::downgrade(&index),
parker,
}
.spawn();
index
}
pub fn options(&self) -> &IndexOptions {
&self.options
}
pub fn view(&self) -> Arc<IndexView> {
self.view.load_full()
}
pub fn refresh(&self) {
let mut protect = self.protect.lock();
if let Some((uuid, write)) = protect.write.clone() {
write.seal();
protect.growing.insert(uuid, write);
}
let write_segment_uuid = Uuid::new_v4();
let write_segment = GrowingSegment::create(
self._tracker.clone(),
self.path
.join("segments")
.join(write_segment_uuid.to_string()),
write_segment_uuid,
self.options.clone(),
);
protect.write = Some((write_segment_uuid, write_segment));
protect.maintain(self.options.clone(), self.delete.clone(), &self.view);
self.optimize_unparker.unpark();
}
pub fn indexing(&self) -> bool {
*self.indexing.lock()
}
}
impl Drop for Index {
fn drop(&mut self) {
self.optimize_unparker.unpark();
}
}
#[derive(Debug, Clone)]
pub struct IndexTracker {
path: PathBuf,
}
impl Drop for IndexTracker {
fn drop(&mut self) {
std::fs::remove_dir_all(&self.path).unwrap();
}
}
pub struct IndexView {
options: IndexOptions,
delete: Arc<Delete>,
sealed: HashMap<Uuid, Arc<SealedSegment>>,
growing: HashMap<Uuid, Arc<GrowingSegment>>,
write: Option<(Uuid, Arc<GrowingSegment>)>,
}
impl IndexView {
pub fn sealed_len(&self) -> u32 {
self.sealed.values().map(|x| x.len()).sum::<u32>()
}
pub fn growing_len(&self) -> u32 {
self.growing.values().map(|x| x.len()).sum::<u32>()
}
pub fn write_len(&self) -> u32 {
self.write.as_ref().map(|x| x.1.len()).unwrap_or(0)
}
pub fn sealed_len_vec(&self) -> Vec<u32> {
self.sealed.values().map(|x| x.len()).collect()
}
pub fn growing_len_vec(&self) -> Vec<u32> {
self.growing.values().map(|x| x.len()).collect()
}
pub fn search<F: FnMut(Pointer) -> bool>(
&self,
k: usize,
vector: &[Scalar],
mut filter: F,
) -> Result<Vec<Pointer>, IndexSearchError> {
if self.options.vector.dims as usize != vector.len() {
return Err(IndexSearchError::InvalidVector(vector.to_vec()));
}
struct Comparer(BinaryHeap<Reverse<HeapElement>>);
impl PartialEq for Comparer {
fn eq(&self, other: &Self) -> bool {
self.cmp(other).is_eq()
}
}
impl Eq for Comparer {}
impl PartialOrd for Comparer {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl Ord for Comparer {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.0.peek().cmp(&other.0.peek()).reverse()
}
}
let mut filter = |payload| {
if let Some(p) = self.delete.check(payload) {
filter(p)
} else {
false
}
};
let n = self.sealed.len() + self.growing.len() + 1;
let mut result = Heap::new(k);
let mut heaps = BinaryHeap::with_capacity(1 + n);
for (_, sealed) in self.sealed.iter() {
let p = sealed.search(k, vector, &mut filter).into_reversed_heap();
heaps.push(Comparer(p));
}
for (_, growing) in self.growing.iter() {
let p = growing.search(k, vector, &mut filter).into_reversed_heap();
heaps.push(Comparer(p));
}
if let Some((_, write)) = &self.write {
let p = write.search(k, vector, &mut filter).into_reversed_heap();
heaps.push(Comparer(p));
}
while let Some(Comparer(mut heap)) = heaps.pop() {
if let Some(Reverse(x)) = heap.pop() {
result.push(x);
heaps.push(Comparer(heap));
}
}
Ok(result
.into_sorted_vec()
.iter()
.map(|x| Pointer::from_u48(x.payload >> 16))
.collect())
}
pub fn search_vbase<F>(
&self,
range: usize,
vector: &[Scalar],
mut next: F,
) -> Result<(), IndexSearchError>
where
F: FnMut(Pointer) -> bool,
{
if self.options.vector.dims as usize != vector.len() {
return Err(IndexSearchError::InvalidVector(vector.to_vec()));
}
struct Comparer<'index, 'vector> {
iter: ComparerIter<'index, 'vector>,
item: Option<HeapElement>,
}
enum ComparerIter<'index, 'vector> {
Sealed(DynamicIndexIter<'index, 'vector>),
Growing(std::vec::IntoIter<HeapElement>),
}
impl PartialEq for Comparer<'_, '_> {
fn eq(&self, other: &Self) -> bool {
self.cmp(other).is_eq()
}
}
impl Eq for Comparer<'_, '_> {}
impl PartialOrd for Comparer<'_, '_> {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl Ord for Comparer<'_, '_> {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.item.cmp(&other.item).reverse()
}
}
impl Iterator for ComparerIter<'_, '_> {
type Item = HeapElement;
fn next(&mut self) -> Option<Self::Item> {
match self {
Self::Sealed(iter) => iter.next(),
Self::Growing(iter) => iter.next(),
}
}
}
impl Iterator for Comparer<'_, '_> {
type Item = HeapElement;
fn next(&mut self) -> Option<Self::Item> {
let item = self.item.take();
self.item = self.iter.next();
item
}
}
fn from_iter<'index, 'vector>(
mut iter: ComparerIter<'index, 'vector>,
) -> Comparer<'index, 'vector> {
let item = iter.next();
Comparer { iter, item }
}
use ComparerIter::*;
let filter = |payload| self.delete.check(payload).is_some();
let n = self.sealed.len() + self.growing.len() + 1;
let mut heaps: BinaryHeap<Comparer> = BinaryHeap::with_capacity(1 + n);
for (_, sealed) in self.sealed.iter() {
let res = sealed.search_vbase(range, vector);
heaps.push(from_iter(Sealed(res)));
}
for (_, growing) in self.growing.iter() {
let mut res = growing.search_all(vector);
res.sort_unstable();
heaps.push(from_iter(Growing(res.into_iter())));
}
if let Some((_, write)) = &self.write {
let mut res = write.search_all(vector);
res.sort_unstable();
heaps.push(from_iter(Growing(res.into_iter())));
}
while let Some(mut iter) = heaps.pop() {
if let Some(x) = iter.next() {
if !filter(x.payload) {
continue;
}
let stop = next(Pointer::from_u48(x.payload >> 16));
if stop {
break;
}
heaps.push(iter);
}
}
Ok(())
}
pub fn insert(&self, vector: Vec<Scalar>, pointer: Pointer) -> Result<(), IndexInsertError> {
if self.options.vector.dims as usize != vector.len() {
return Err(IndexInsertError::InvalidVector(vector));
}
let payload = (pointer.as_u48() << 16) | self.delete.version(pointer) as Payload;
if let Some((_, growing)) = self.write.as_ref() {
Ok(growing.insert(vector, payload)?)
} else {
Err(IndexInsertError::OutdatedView(None))
}
}
pub fn delete<F: FnMut(Pointer) -> bool>(&self, mut f: F) {
for (_, sealed) in self.sealed.iter() {
let n = sealed.len();
for i in 0..n {
if let Some(p) = self.delete.check(sealed.payload(i)) {
if f(p) {
self.delete.delete(p);
}
}
}
}
for (_, growing) in self.growing.iter() {
let n = growing.len();
for i in 0..n {
if let Some(p) = self.delete.check(growing.payload(i)) {
if f(p) {
self.delete.delete(p);
}
}
}
}
if let Some((_, write)) = &self.write {
let n = write.len();
for i in 0..n {
if let Some(p) = self.delete.check(write.payload(i)) {
if f(p) {
self.delete.delete(p);
}
}
}
}
}
pub fn flush(&self) -> Result<(), IndexInsertError> {
self.delete.flush();
if let Some((_, write)) = &self.write {
write.flush();
}
Ok(())
}
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
struct IndexStartup {
sealeds: HashSet<Uuid>,
growings: HashSet<Uuid>,
}
struct IndexProtect {
startup: FileAtomic<IndexStartup>,
sealed: HashMap<Uuid, Arc<SealedSegment>>,
growing: HashMap<Uuid, Arc<GrowingSegment>>,
write: Option<(Uuid, Arc<GrowingSegment>)>,
}
impl IndexProtect {
fn maintain(&mut self, options: IndexOptions, delete: Arc<Delete>, swap: &ArcSwap<IndexView>) {
let view: Arc<IndexView> = Arc::new(IndexView {
options,
delete,
sealed: self.sealed.clone(),
growing: self.growing.clone(),
write: self.write.clone(),
});
let startup_write = self.write.as_ref().map(|(uuid, _)| *uuid);
let startup_sealeds = self.sealed.keys().copied().collect();
let startup_growings = self.growing.keys().copied().chain(startup_write).collect();
self.startup.set(IndexStartup {
sealeds: startup_sealeds,
growings: startup_growings,
});
swap.swap(view);
}
}
pub struct IndexBackground {
index: Weak<Index>,
parker: Parker,
}
impl IndexBackground {
pub fn main(self) {
let pool;
if let Some(index) = self.index.upgrade() {
pool = rayon::ThreadPoolBuilder::new()
.num_threads(index.options.optimizing.optimizing_threads)
.build()
.unwrap();
} else {
return;
}
while let Some(index) = self.index.upgrade() {
let done = pool.install(|| optimizing::indexing::optimizing_indexing(index.clone()));
if done {
*index.indexing.lock() = false;
drop(index);
self.parker.park();
if let Some(index) = self.index.upgrade() {
*index.indexing.lock() = true;
}
}
}
}
pub fn spawn(self) {
std::thread::spawn(move || {
self.main();
});
}
pub unsafe fn init() {
self::hooks::init();
self::am::init();
}

Some files were not shown because too many files have changed in this diff Show More