1
0
mirror of https://github.com/tensorchord/pgvecto.rs.git synced 2025-08-08 14:22:07 +03:00

feat: fp16 vector (#178)

* feat: fp16 vector

Signed-off-by: usamoi <usamoi@outlook.com>

* feat: detect avx512fp16

Signed-off-by: usamoi <usamoi@outlook.com>

* fix: install clang-16 for ci

Signed-off-by: usamoi <usamoi@outlook.com>

* fix: clippy

Signed-off-by: usamoi <usamoi@outlook.com>

* fix: rename c to pgvectorsc

Signed-off-by: usamoi <usamoi@outlook.com>

* feat: hand-writing avx512fp16

Signed-off-by: usamoi <usamoi@outlook.com>

* fix: index on fp16

Signed-off-by: usamoi <usamoi@outlook.com>

* feat: hand-writing avx2

Signed-off-by: usamoi <usamoi@outlook.com>

* fix: clippy

Signed-off-by: usamoi <usamoi@outlook.com>

* fix: add rerun in build script

Signed-off-by: usamoi <usamoi@outlook.com>

* fix: cross compilation

Signed-off-by: usamoi <usamoi@outlook.com>

* fix: do not leave uninitialized bytes in datatype input function

Signed-off-by: usamoi <usamoi@outlook.com>

* fix: compiler built-in function calling convention workaround

Signed-off-by: usamoi <usamoi@outlook.com>

* fix: cross compile on aarch64

Signed-off-by: usamoi <usamoi@outlook.com>

* fix: fix detect avx512fp16

Signed-off-by: usamoi <usamoi@outlook.com>

* fix: avx512 codegen by multiversion

Signed-off-by: usamoi <usamoi@outlook.com>

* fix: enable more target features for c

Signed-off-by: usamoi <usamoi@outlook.com>

* fix: use tensorchord/stdarch

Signed-off-by: usamoi <usamoi@outlook.com>

* fix: ci

Signed-off-by: usamoi <usamoi@outlook.com>

* fix: remove no-run cross test

Signed-off-by: usamoi <usamoi@outlook.com>

* fix: vbase

Signed-off-by: usamoi <usamoi@outlook.com>

* fix: error and document

Signed-off-by: usamoi <usamoi@outlook.com>

* [skip ci]

Signed-off-by: usamoi <usamoi@outlook.com>

---------

Signed-off-by: usamoi <usamoi@outlook.com>
This commit is contained in:
Usamoi
2023-12-14 17:50:52 +08:00
committed by GitHub
parent 2ab76118fc
commit 5c0450274d
146 changed files with 7436 additions and 4108 deletions

View File

@@ -4,3 +4,13 @@ rustdocflags = ["--document-private-items"]
[target.'cfg(target_os="macos")'] [target.'cfg(target_os="macos")']
# Postgres symbols won't be available until runtime # Postgres symbols won't be available until runtime
rustflags = ["-Clink-arg=-Wl,-undefined,dynamic_lookup"] rustflags = ["-Clink-arg=-Wl,-undefined,dynamic_lookup"]
[target.x86_64-unknown-linux-gnu]
linker = "x86_64-linux-gnu-gcc"
[target.aarch64-unknown-linux-gnu]
linker = "aarch64-linux-gnu-gcc"
[env]
BINDGEN_EXTRA_CLANG_ARGS_x86_64_unknown_linux_gnu = "-isystem /usr/x86_64-linux-gnu/include/ -ccc-gcc-name x86_64-linux-gnu-gcc"
BINDGEN_EXTRA_CLANG_ARGS_aarch64_unknown_linux_gnu = "-isystem /usr/aarch64-linux-gnu/include/ -ccc-gcc-name aarch64-linux-gnu-gcc"

View File

@@ -6,6 +6,7 @@ on:
paths: paths:
- ".cargo/**" - ".cargo/**"
- ".github/**" - ".github/**"
- "crates/**"
- "scripts/**" - "scripts/**"
- "src/**" - "src/**"
- "tests/**" - "tests/**"
@@ -18,6 +19,7 @@ on:
paths: paths:
- ".cargo/**" - ".cargo/**"
- ".github/**" - ".github/**"
- "crates/**"
- "scripts/**" - "scripts/**"
- "src/**" - "src/**"
- "tests/**" - "tests/**"
@@ -90,11 +92,16 @@ jobs:
- name: Format check - name: Format check
run: cargo fmt --check run: cargo fmt --check
- name: Semantic check - name: Semantic check
run: cargo clippy --no-default-features --features "pg${{ matrix.version }} pg_test" run: |
cargo clippy --no-default-features --features "pg${{ matrix.version }} pg_test" --target x86_64-unknown-linux-gnu
cargo clippy --no-default-features --features "pg${{ matrix.version }} pg_test" --target aarch64-unknown-linux-gnu
- name: Debug build - name: Debug build
run: cargo build --no-default-features --features "pg${{ matrix.version }} pg_test" run: |
cargo build --no-default-features --features "pg${{ matrix.version }} pg_test" --target x86_64-unknown-linux-gnu
cargo build --no-default-features --features "pg${{ matrix.version }} pg_test" --target aarch64-unknown-linux-gnu
- name: Test - name: Test
run: cargo test --all --no-default-features --features "pg${{ matrix.version }} pg_test" -- --nocapture run: |
cargo test --all --no-fail-fast --no-default-features --features "pg${{ matrix.version }} pg_test" --target x86_64-unknown-linux-gnu
- name: Install release - name: Install release
run: ./scripts/ci_install.sh run: ./scripts/ci_install.sh
- name: Sqllogictest - name: Sqllogictest

View File

@@ -112,15 +112,17 @@ jobs:
- uses: mozilla-actions/sccache-action@v0.0.3 - uses: mozilla-actions/sccache-action@v0.0.3
- name: Prepare - name: Prepare
run: | run: |
sudo sh -c 'echo "deb http://apt.postgresql.org/pub/repos/apt $(lsb_release -cs)-pgdg main" > /etc/apt/sources.list.d/pgdg.list' sudo sh -c 'echo "deb http://apt.postgresql.org/pub/repos/apt $(lsb_release -cs)-pgdg main" >> /etc/apt/sources.list.d/pgdg.list'
sudo sh -c 'echo "deb http://apt.llvm.org/$(lsb_release -cs)/ llvm-toolchain-$(lsb_release -cs)-16 main" >> /etc/apt/sources.list'
wget --quiet -O - https://www.postgresql.org/media/keys/ACCC4CF8.asc | sudo apt-key add - wget --quiet -O - https://www.postgresql.org/media/keys/ACCC4CF8.asc | sudo apt-key add -
wget --quiet -O - https://apt.llvm.org/llvm-snapshot.gpg.key | sudo apt-key add -
sudo apt-get update sudo apt-get update
sudo apt-get -y install libpq-dev postgresql-${{ matrix.version }} postgresql-server-dev-${{ matrix.version }} sudo apt-get -y install libpq-dev postgresql-${{ matrix.version }} postgresql-server-dev-${{ matrix.version }}
sudo apt-get -y install clang-16
cargo install cargo-pgrx --git https://github.com/tensorchord/pgrx.git --rev $(cat Cargo.toml | grep "pgrx =" | awk -F'rev = "' '{print $2}' | cut -d'"' -f1) cargo install cargo-pgrx --git https://github.com/tensorchord/pgrx.git --rev $(cat Cargo.toml | grep "pgrx =" | awk -F'rev = "' '{print $2}' | cut -d'"' -f1)
cargo pgrx init --pg${{ matrix.version }}=/usr/lib/postgresql/${{ matrix.version }}/bin/pg_config cargo pgrx init --pg${{ matrix.version }}=/usr/lib/postgresql/${{ matrix.version }}/bin/pg_config
if [[ "${{ matrix.arch }}" == "arm64" ]]; then if [[ "${{ matrix.arch }}" == "arm64" ]]; then
sudo apt-get -y install crossbuild-essential-arm64 sudo apt-get -y install crossbuild-essential-arm64
rustup target add aarch64-unknown-linux-gnu
fi fi
- name: Build Release - name: Build Release
id: build_release id: build_release
@@ -130,8 +132,6 @@ jobs:
mkdir ./artifacts mkdir ./artifacts
cargo pgrx package cargo pgrx package
if [[ "${{ matrix.arch }}" == "arm64" ]]; then if [[ "${{ matrix.arch }}" == "arm64" ]]; then
export CARGO_TARGET_AARCH64_UNKNOWN_LINUX_GNU_LINKER=aarch64-linux-gnu-gcc
export BINDGEN_EXTRA_CLANG_ARGS_aarch64_unknown_linux_gnu="-target aarch64-unknown-linux-gnu -isystem /usr/aarch64-linux-gnu/include/ -ccc-gcc-name aarch64-linux-gnu-gcc"
cargo build --target aarch64-unknown-linux-gnu --release --features "pg${{ matrix.version }}" --no-default-features cargo build --target aarch64-unknown-linux-gnu --release --features "pg${{ matrix.version }}" --no-default-features
mv ./target/aarch64-unknown-linux-gnu/release/libvectors.so ./target/release/vectors-pg${{ matrix.version }}/usr/lib/postgresql/${{ matrix.version }}/lib/vectors.so mv ./target/aarch64-unknown-linux-gnu/release/libvectors.so ./target/release/vectors-pg${{ matrix.version }}/usr/lib/postgresql/${{ matrix.version }}/lib/vectors.so
fi fi

3
.gitignore vendored
View File

@@ -6,4 +6,5 @@
.vscode .vscode
.ignore .ignore
__pycache__ __pycache__
.pytest_cache .pytest_cache
rustc-ice-*.txt

576
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -1,7 +1,7 @@
[package] [package]
name = "vectors" name = "vectors"
version = "0.1.1" version.workspace = true
edition = "2021" edition.workspace = true
[lib] [lib]
crate-type = ["cdylib"] crate-type = ["cdylib"]
@@ -16,45 +16,60 @@ pg16 = ["pgrx/pg16", "pgrx-tests/pg16"]
pg_test = [] pg_test = []
[dependencies] [dependencies]
libc.workspace = true
log.workspace = true
serde.workspace = true
serde_json.workspace = true
validator.workspace = true
rustix.workspace = true
thiserror.workspace = true
byteorder.workspace = true
bincode.workspace = true
half.workspace = true
num-traits.workspace = true
service = { path = "crates/service" }
pgrx = { git = "https://github.com/tensorchord/pgrx.git", rev = "7c30e2023876c1efce613756f5ec81f3ab05696b", default-features = false, features = [ pgrx = { git = "https://github.com/tensorchord/pgrx.git", rev = "7c30e2023876c1efce613756f5ec81f3ab05696b", default-features = false, features = [
] } ] }
openai_api_rust = { git = "https://github.com/tensorchord/openai-api.git", rev = "228d54b6002e98257b3c81501a054942342f585f" } openai_api_rust = { git = "https://github.com/tensorchord/openai-api.git", rev = "228d54b6002e98257b3c81501a054942342f585f" }
static_assertions = "1.1.0"
libc = "~0.2"
serde = "1.0.163"
bincode = "1.3.3"
rand = "0.8.5"
byteorder = "1.4.3"
crc32fast = "1.3.2"
log = "0.4.18"
env_logger = "0.10.0" env_logger = "0.10.0"
crossbeam = "0.8.2"
dashmap = "5.4.0"
parking_lot = "0.12.1"
memoffset = "0.9.0"
serde_json = "1"
thiserror = "1.0.40"
tempfile = "3.6.0"
cstr = "0.2.11"
arrayvec = { version = "0.7.3", features = ["serde"] }
memmap2 = "0.9.0"
validator = { version = "0.16.1", features = ["derive"] }
toml = "0.8.8" toml = "0.8.8"
rayon = "1.6.1"
uuid = { version = "1.4.1", features = ["serde"] }
rustix = { version = "0.38.20", features = ["net", "mm"] }
arc-swap = "1.6.0"
bytemuck = { version = "1.14.0", features = ["extern_crate_alloc"] }
serde_with = "3.4.0"
multiversion = "0.7.3"
[dev-dependencies] [dev-dependencies]
pgrx-tests = { git = "https://github.com/tensorchord/pgrx.git", rev = "7c30e2023876c1efce613756f5ec81f3ab05696b" } pgrx-tests = { git = "https://github.com/tensorchord/pgrx.git", rev = "7c30e2023876c1efce613756f5ec81f3ab05696b" }
httpmock = "0.6" httpmock = "0.6"
mockall = "0.11.4" mockall = "0.11.4"
[target.'cfg(target_os = "macos")'.dependencies] [lints]
ulock-sys = "0.1.0" clippy.too_many_arguments = "allow"
clippy.unnecessary_literal_unwrap = "allow"
clippy.unnecessary_unwrap = "allow"
rust.unsafe_op_in_unsafe_fn = "warn"
[workspace]
resolver = "2"
members = ["crates/*"]
[workspace.package]
version = "0.0.0"
edition = "2021"
[workspace.dependencies]
libc = "~0.2"
log = "~0.4"
serde = "~1.0"
serde_json = "1"
thiserror = "~1.0"
bincode = "~1.3"
byteorder = "~1.4"
half = { version = "~2.3", features = [
"bytemuck",
"num-traits",
"serde",
"use-intrinsics",
] }
num-traits = "~0.2"
validator = { version = "~0.16", features = ["derive"] }
rustix = { version = "~0.38", features = ["net", "mm"] }
[profile.dev] [profile.dev]
panic = "unwind" panic = "unwind"
@@ -65,10 +80,3 @@ opt-level = 3
lto = "fat" lto = "fat"
codegen-units = 1 codegen-units = 1
debug = true debug = true
[lints.clippy]
needless_range_loop = "allow"
derivable_impls = "allow"
unnecessary_literal_unwrap = "allow"
too_many_arguments = "allow"
unnecessary_unwrap = "allow"

View File

@@ -21,13 +21,13 @@ pgvecto.rs is a Postgres extension that provides vector similarity search functi
## Comparison with pgvector ## Comparison with pgvector
| | pgvecto.rs | pgvector | | | pgvecto.rs | pgvector |
| ------------------------------------------- | ------------------------------------------------------ | ------------------------ | | ------------------------------------------- | ------------------------------------------------------ | ----------------------- |
| Transaction support | ✅ | ⚠️ | | Transaction support | ✅ | ⚠️ |
| Sufficient Result with Delete/Update/Filter | ✅ | ⚠️ | | Sufficient Result with Delete/Update/Filter | ✅ | ⚠️ |
| Vector Dimension Limit | 65535 | 2000 | | Vector Dimension Limit | 65535 | 2000 |
| Prefilter on HNSW | ✅ | ❌ | | Prefilter on HNSW | ✅ | ❌ |
| Parallel HNSW Index build | ⚡️ Linearly faster with more cores | 🐌 Only single core used | | Parallel HNSW Index build | ⚡️ Linearly faster with more cores | 🐌 Only single core used |
| Async Index build | Ready for queries anytime and do not block insertions. | ❌ | | Async Index build | Ready for queries anytime and do not block insertions. | ❌ |
| Quantization | Scalar/Product Quantization | ❌ | | Quantization | Scalar/Product Quantization | ❌ |
@@ -45,7 +45,11 @@ More details at [./docs/comparison-pgvector.md](./docs/comparison-pgvector.md)
For users, we recommend you to try pgvecto.rs using our pre-built docker image, by running For users, we recommend you to try pgvecto.rs using our pre-built docker image, by running
```sh ```sh
docker run --name pgvecto-rs-demo -e POSTGRES_PASSWORD=mysecretpassword -p 5432:5432 -d tensorchord/pgvecto-rs:pg16-latest docker run \
--name pgvecto-rs-demo \
-e POSTGRES_PASSWORD=mysecretpassword \
-p 5432:5432 \
-d tensorchord/pgvecto-rs:pg16-latest
``` ```
## Development with envd ## Development with envd

View File

@@ -19,14 +19,12 @@ URL = f"postgresql://{USER}:{PASS}@{HOST}:{PORT}/{DB_NAME}"
TOML_SETTINGS = { TOML_SETTINGS = {
"flat": toml.dumps( "flat": toml.dumps(
{ {
"capacity": 2097152, "indexing": {"flat": {}},
"algorithm": {"flat": {}},
}, },
), ),
"hnsw": toml.dumps( "hnsw": toml.dumps(
{ {
"capacity": 2097152, "indexing": {"hnsw": {}},
"algorithm": {"hnsw": {}},
}, },
), ),
} }

View File

@@ -38,7 +38,7 @@ def conn():
@pytest.mark.parametrize(("index_name", "index_setting"), TOML_SETTINGS.items()) @pytest.mark.parametrize(("index_name", "index_setting"), TOML_SETTINGS.items())
def test_create_index(conn: Connection, index_name: str, index_setting: str): def test_create_index(conn: Connection, index_name: str, index_setting: str):
stat = sql.SQL( stat = sql.SQL(
"CREATE INDEX {} ON tb_test_item USING vectors (embedding l2_ops) WITH (options={});", "CREATE INDEX {} ON tb_test_item USING vectors (embedding vector_l2_ops) WITH (options={});",
).format(sql.Identifier(index_name), index_setting) ).format(sql.Identifier(index_name), index_setting)
conn.execute(stat) conn.execute(stat)

View File

@@ -68,7 +68,7 @@ def test_create_index(session: Session, index_name: str, index_setting: str):
Document.embedding, Document.embedding,
postgresql_using="vectors", postgresql_using="vectors",
postgresql_with={"options": f"$${index_setting}$$"}, postgresql_with={"options": f"$${index_setting}$$"},
postgresql_ops={"embedding": "l2_ops"}, postgresql_ops={"embedding": "vector_l2_ops"},
) )
index.create(session.bind) index.create(session.bind)
session.commit() session.commit()

3
crates/c/.gitignore vendored Normal file
View File

@@ -0,0 +1,3 @@
*.s
*.o
*.out

10
crates/c/Cargo.toml Normal file
View File

@@ -0,0 +1,10 @@
[package]
name = "c"
version.workspace = true
edition.workspace = true
[dependencies]
half = { version = "~2.3", features = ["use-intrinsics"] }
[build-dependencies]
cc = "1.0"

10
crates/c/build.rs Normal file
View File

@@ -0,0 +1,10 @@
fn main() {
println!("cargo:rerun-if-changed=src/c.h");
println!("cargo:rerun-if-changed=src/c.c");
cc::Build::new()
.compiler("/usr/bin/clang-16")
.file("./src/c.c")
.opt_level(3)
.debug(true)
.compile("pgvectorsc");
}

118
crates/c/src/c.c Normal file
View File

@@ -0,0 +1,118 @@
#include "c.h"
#include <math.h>
#if defined(__x86_64__)
#include <immintrin.h>
#endif
#if defined(__x86_64__)
__attribute__((target("arch=x86-64-v4,avx512fp16"))) extern float
v_f16_cosine_avx512fp16(_Float16 *a, _Float16 *b, size_t n) {
__m512h xy = _mm512_set1_ph(0);
__m512h xx = _mm512_set1_ph(0);
__m512h yy = _mm512_set1_ph(0);
while (n >= 32) {
__m512h x = _mm512_loadu_ph(a);
__m512h y = _mm512_loadu_ph(b);
a += 32, b += 32, n -= 32;
xy = _mm512_fmadd_ph(x, y, xy);
xx = _mm512_fmadd_ph(x, x, xx);
yy = _mm512_fmadd_ph(y, y, yy);
}
if (n > 0) {
__mmask32 mask = _bzhi_u32(0xFFFFFFFF, n);
__m512h x = _mm512_castsi512_ph(_mm512_maskz_loadu_epi16(mask, a));
__m512h y = _mm512_castsi512_ph(_mm512_maskz_loadu_epi16(mask, b));
xy = _mm512_fmadd_ph(x, y, xy);
xx = _mm512_fmadd_ph(x, x, xx);
yy = _mm512_fmadd_ph(y, y, yy);
}
return (float)(_mm512_reduce_add_ph(xy) /
sqrt(_mm512_reduce_add_ph(xx) * _mm512_reduce_add_ph(yy)));
}
__attribute__((target("arch=x86-64-v4,avx512fp16"))) extern float
v_f16_dot_avx512fp16(_Float16 *a, _Float16 *b, size_t n) {
__m512h xy = _mm512_set1_ph(0);
while (n >= 32) {
__m512h x = _mm512_loadu_ph(a);
__m512h y = _mm512_loadu_ph(b);
a += 32, b += 32, n -= 32;
xy = _mm512_fmadd_ph(x, y, xy);
}
if (n > 0) {
__mmask32 mask = _bzhi_u32(0xFFFFFFFF, n);
__m512h x = _mm512_castsi512_ph(_mm512_maskz_loadu_epi16(mask, a));
__m512h y = _mm512_castsi512_ph(_mm512_maskz_loadu_epi16(mask, b));
xy = _mm512_fmadd_ph(x, y, xy);
}
return (float)_mm512_reduce_add_ph(xy);
}
__attribute__((target("arch=x86-64-v4,avx512fp16"))) extern float
v_f16_sl2_avx512fp16(_Float16 *a, _Float16 *b, size_t n) {
__m512h dd = _mm512_set1_ph(0);
while (n >= 32) {
__m512h x = _mm512_loadu_ph(a);
__m512h y = _mm512_loadu_ph(b);
a += 32, b += 32, n -= 32;
__m512h d = _mm512_sub_ph(x, y);
dd = _mm512_fmadd_ph(d, d, dd);
}
if (n > 0) {
__mmask32 mask = _bzhi_u32(0xFFFFFFFF, n);
__m512h x = _mm512_castsi512_ph(_mm512_maskz_loadu_epi16(mask, a));
__m512h y = _mm512_castsi512_ph(_mm512_maskz_loadu_epi16(mask, b));
__m512h d = _mm512_sub_ph(x, y);
dd = _mm512_fmadd_ph(d, d, dd);
}
return (float)_mm512_reduce_add_ph(dd);
}
__attribute__((target("arch=x86-64-v3"))) extern float
v_f16_cosine_v3(_Float16 *a, _Float16 *b, size_t n) {
float xy = 0;
float xx = 0;
float yy = 0;
#pragma clang loop vectorize_width(8)
for (size_t i = 0; i < n; i++) {
float x = a[i];
float y = b[i];
xy += x * y;
xx += x * x;
yy += y * y;
}
return xy / sqrt(xx * yy);
}
__attribute__((target("arch=x86-64-v3"))) extern float
v_f16_dot_v3(_Float16 *a, _Float16 *b, size_t n) {
float xy = 0;
#pragma clang loop vectorize_width(8)
for (size_t i = 0; i < n; i++) {
float x = a[i];
float y = b[i];
xy += x * y;
}
return xy;
}
__attribute__((target("arch=x86-64-v3"))) extern float
v_f16_sl2_v3(_Float16 *a, _Float16 *b, size_t n) {
float dd = 0;
#pragma clang loop vectorize_width(8)
for (size_t i = 0; i < n; i++) {
float x = a[i];
float y = b[i];
float d = x - y;
dd += d * d;
}
return dd;
}
#endif

13
crates/c/src/c.h Normal file
View File

@@ -0,0 +1,13 @@
#include <stddef.h>
#include <stdint.h>
#if defined(__x86_64__)
extern float v_f16_cosine_avx512fp16(_Float16 *, _Float16 *, size_t n);
extern float v_f16_dot_avx512fp16(_Float16 *, _Float16 *, size_t n);
extern float v_f16_sl2_avx512fp16(_Float16 *, _Float16 *, size_t n);
extern float v_f16_cosine_v3(_Float16 *, _Float16 *, size_t n);
extern float v_f16_dot_v3(_Float16 *, _Float16 *, size_t n);
extern float v_f16_sl2_v3(_Float16 *, _Float16 *, size_t n);
#endif

24
crates/c/src/c.rs Normal file
View File

@@ -0,0 +1,24 @@
#[cfg(target_arch = "x86_64")]
#[link(name = "pgvectorsc", kind = "static")]
extern "C" {
pub fn v_f16_cosine_avx512fp16(a: *const u16, b: *const u16, n: usize) -> f32;
pub fn v_f16_dot_avx512fp16(a: *const u16, b: *const u16, n: usize) -> f32;
pub fn v_f16_sl2_avx512fp16(a: *const u16, b: *const u16, n: usize) -> f32;
pub fn v_f16_cosine_v3(a: *const u16, b: *const u16, n: usize) -> f32;
pub fn v_f16_dot_v3(a: *const u16, b: *const u16, n: usize) -> f32;
pub fn v_f16_sl2_v3(a: *const u16, b: *const u16, n: usize) -> f32;
}
// `compiler_builtin` defines `__extendhfsf2` with integer calling convention.
// However C compilers links `__extendhfsf2` with floating calling convention.
// The code should be removed once Rust offically supports `f16`.
#[cfg(target_arch = "x86_64")]
#[no_mangle]
#[linkage = "external"]
extern "C" fn __extendhfsf2(f: f64) -> f32 {
unsafe {
let f: half::f16 = std::mem::transmute_copy(&f);
f.to_f32()
}
}

6
crates/c/src/lib.rs Normal file
View File

@@ -0,0 +1,6 @@
#![feature(linkage)]
mod c;
#[allow(unused_imports)]
pub use self::c::*;

45
crates/service/Cargo.toml Normal file
View File

@@ -0,0 +1,45 @@
[package]
name = "service"
version.workspace = true
edition.workspace = true
[dependencies]
libc.workspace = true
log.workspace = true
serde.workspace = true
serde_json.workspace = true
validator.workspace = true
rustix.workspace = true
thiserror.workspace = true
byteorder.workspace = true
bincode.workspace = true
half.workspace = true
num-traits.workspace = true
c = { path = "../c" }
std_detect = { git = "https://github.com/tensorchord/stdarch.git", branch = "avx512fp16" }
rand = "0.8.5"
crc32fast = "1.3.2"
crossbeam = "0.8.2"
dashmap = "5.4.0"
parking_lot = "0.12.1"
memoffset = "0.9.0"
tempfile = "3.6.0"
arrayvec = { version = "0.7.3", features = ["serde"] }
memmap2 = "0.9.0"
rayon = "1.6.1"
uuid = { version = "1.6.1", features = ["serde"] }
arc-swap = "1.6.0"
bytemuck = { version = "1.14.0", features = ["extern_crate_alloc"] }
serde_with = "3.4.0"
multiversion = "0.7.3"
ctor = "0.2.6"
[target.'cfg(target_os = "macos")'.dependencies]
ulock-sys = "0.1.0"
[lints]
clippy.derivable_impls = "allow"
clippy.len_without_is_empty = "allow"
clippy.needless_range_loop = "allow"
clippy.too_many_arguments = "allow"
rust.unsafe_op_in_unsafe_fn = "warn"

View File

@@ -4,38 +4,37 @@ use rand::rngs::StdRng;
use rand::{Rng, SeedableRng}; use rand::{Rng, SeedableRng};
use std::ops::{Index, IndexMut}; use std::ops::{Index, IndexMut};
pub struct ElkanKMeans { pub struct ElkanKMeans<S: G> {
dims: u16, dims: u16,
c: usize, c: usize,
pub centroids: Vec2, pub centroids: Vec2<S>,
lowerbound: Square, lowerbound: Square,
upperbound: Vec<Scalar>, upperbound: Vec<F32>,
assign: Vec<usize>, assign: Vec<usize>,
rand: StdRng, rand: StdRng,
samples: Vec2, samples: Vec2<S>,
d: Distance,
} }
const DELTA: f32 = 1.0 / 1024.0; const DELTA: f32 = 1.0 / 1024.0;
impl ElkanKMeans { impl<S: G> ElkanKMeans<S> {
pub fn new(c: usize, samples: Vec2, d: Distance) -> Self { pub fn new(c: usize, samples: Vec2<S>) -> Self {
let n = samples.len(); let n = samples.len();
let dims = samples.dims(); let dims = samples.dims();
let mut rand = StdRng::from_entropy(); let mut rand = StdRng::from_entropy();
let mut centroids = Vec2::new(dims, c); let mut centroids = Vec2::new(dims, c);
let mut lowerbound = Square::new(n, c); let mut lowerbound = Square::new(n, c);
let mut upperbound = vec![Scalar::Z; n]; let mut upperbound = vec![F32::zero(); n];
let mut assign = vec![0usize; n]; let mut assign = vec![0usize; n];
centroids[0].copy_from_slice(&samples[rand.gen_range(0..n)]); centroids[0].copy_from_slice(&samples[rand.gen_range(0..n)]);
let mut weight = vec![Scalar::INFINITY; n]; let mut weight = vec![F32::infinity(); n];
for i in 0..c { for i in 0..c {
let mut sum = Scalar::Z; let mut sum = F32::zero();
for j in 0..n { for j in 0..n {
let dis = d.elkan_k_means_distance(&samples[j], &centroids[i]); let dis = S::elkan_k_means_distance(&samples[j], &centroids[i]);
lowerbound[(j, i)] = dis; lowerbound[(j, i)] = dis;
if dis * dis < weight[j] { if dis * dis < weight[j] {
weight[j] = dis * dis; weight[j] = dis * dis;
@@ -49,7 +48,7 @@ impl ElkanKMeans {
let mut choice = sum * rand.gen_range(0.0..1.0); let mut choice = sum * rand.gen_range(0.0..1.0);
for j in 0..(n - 1) { for j in 0..(n - 1) {
choice -= weight[j]; choice -= weight[j];
if choice <= Scalar::Z { if choice <= F32::zero() {
break 'a j; break 'a j;
} }
} }
@@ -59,7 +58,7 @@ impl ElkanKMeans {
} }
for i in 0..n { for i in 0..n {
let mut minimal = Scalar::INFINITY; let mut minimal = F32::infinity();
let mut target = 0; let mut target = 0;
for j in 0..c { for j in 0..c {
let dis = lowerbound[(i, j)]; let dis = lowerbound[(i, j)];
@@ -81,13 +80,11 @@ impl ElkanKMeans {
assign, assign,
rand, rand,
samples, samples,
d,
} }
} }
pub fn iterate(&mut self) -> bool { pub fn iterate(&mut self) -> bool {
let c = self.c; let c = self.c;
let f = |lhs: &[Scalar], rhs: &[Scalar]| self.d.elkan_k_means_distance(lhs, rhs);
let dims = self.dims; let dims = self.dims;
let samples = &self.samples; let samples = &self.samples;
let rand = &mut self.rand; let rand = &mut self.rand;
@@ -100,16 +97,16 @@ impl ElkanKMeans {
// Step 1 // Step 1
let mut dist0 = Square::new(c, c); let mut dist0 = Square::new(c, c);
let mut sp = vec![Scalar::Z; c]; let mut sp = vec![F32::zero(); c];
for i in 0..c { for i in 0..c {
for j in i + 1..c { for j in i + 1..c {
let dis = f(&centroids[i], &centroids[j]) * 0.5; let dis = S::elkan_k_means_distance(&centroids[i], &centroids[j]) * 0.5;
dist0[(i, j)] = dis; dist0[(i, j)] = dis;
dist0[(j, i)] = dis; dist0[(j, i)] = dis;
} }
} }
for i in 0..c { for i in 0..c {
let mut minimal = Scalar::INFINITY; let mut minimal = F32::infinity();
for j in 0..c { for j in 0..c {
if i == j { if i == j {
continue; continue;
@@ -127,7 +124,7 @@ impl ElkanKMeans {
if upperbound[i] <= sp[assign[i]] { if upperbound[i] <= sp[assign[i]] {
continue; continue;
} }
let mut minimal = f(&samples[i], &centroids[assign[i]]); let mut minimal = S::elkan_k_means_distance(&samples[i], &centroids[assign[i]]);
lowerbound[(i, assign[i])] = minimal; lowerbound[(i, assign[i])] = minimal;
upperbound[i] = minimal; upperbound[i] = minimal;
// Step 3 // Step 3
@@ -142,7 +139,7 @@ impl ElkanKMeans {
continue; continue;
} }
if minimal > lowerbound[(i, j)] || minimal > dist0[(assign[i], j)] { if minimal > lowerbound[(i, j)] || minimal > dist0[(assign[i], j)] {
let dis = f(&samples[i], &centroids[j]); let dis = S::elkan_k_means_distance(&samples[i], &centroids[j]);
lowerbound[(i, j)] = dis; lowerbound[(i, j)] = dis;
if dis < minimal { if dis < minimal {
minimal = dis; minimal = dis;
@@ -156,8 +153,8 @@ impl ElkanKMeans {
// Step 4, 7 // Step 4, 7
let old = std::mem::replace(centroids, Vec2::new(dims, c)); let old = std::mem::replace(centroids, Vec2::new(dims, c));
let mut count = vec![Scalar::Z; c]; let mut count = vec![F32::zero(); c];
centroids.fill(Scalar::Z); centroids.fill(S::Scalar::zero());
for i in 0..n { for i in 0..n {
for j in 0..dims as usize { for j in 0..dims as usize {
centroids[assign[i]][j] += samples[i][j]; centroids[assign[i]][j] += samples[i][j];
@@ -165,21 +162,21 @@ impl ElkanKMeans {
count[assign[i]] += 1.0; count[assign[i]] += 1.0;
} }
for i in 0..c { for i in 0..c {
if count[i] == Scalar::Z { if count[i] == F32::zero() {
continue; continue;
} }
for dim in 0..dims as usize { for dim in 0..dims as usize {
centroids[i][dim] /= count[i]; centroids[i][dim] /= S::Scalar::from_f32(count[i].into());
} }
} }
for i in 0..c { for i in 0..c {
if count[i] != Scalar::Z { if count[i] != F32::zero() {
continue; continue;
} }
let mut o = 0; let mut o = 0;
loop { loop {
let alpha = Scalar(rand.gen_range(0.0..1.0)); let alpha = F32::from_f32(rand.gen_range(0.0..1.0f32));
let beta = (count[o] - 1.0) / (n - c) as Float; let beta = (count[o] - 1.0) / (n - c) as f32;
if alpha < beta { if alpha < beta {
break; break;
} }
@@ -188,28 +185,28 @@ impl ElkanKMeans {
centroids.copy_within(o, i); centroids.copy_within(o, i);
for dim in 0..dims as usize { for dim in 0..dims as usize {
if dim % 2 == 0 { if dim % 2 == 0 {
centroids[i][dim] *= 1.0 + DELTA; centroids[i][dim] *= S::Scalar::from_f32(1.0 + DELTA);
centroids[o][dim] *= 1.0 - DELTA; centroids[o][dim] *= S::Scalar::from_f32(1.0 - DELTA);
} else { } else {
centroids[i][dim] *= 1.0 - DELTA; centroids[i][dim] *= S::Scalar::from_f32(1.0 - DELTA);
centroids[o][dim] *= 1.0 + DELTA; centroids[o][dim] *= S::Scalar::from_f32(1.0 + DELTA);
} }
} }
count[i] = count[o] / 2.0; count[i] = count[o] / 2.0;
count[o] = count[o] - count[i]; count[o] = count[o] - count[i];
} }
for i in 0..c { for i in 0..c {
self.d.elkan_k_means_normalize(&mut centroids[i]); S::elkan_k_means_normalize(&mut centroids[i]);
} }
// Step 5, 6 // Step 5, 6
let mut dist1 = vec![Scalar::Z; c]; let mut dist1 = vec![F32::zero(); c];
for i in 0..c { for i in 0..c {
dist1[i] = f(&old[i], &centroids[i]); dist1[i] = S::elkan_k_means_distance(&old[i], &centroids[i]);
} }
for i in 0..n { for i in 0..n {
for j in 0..c { for j in 0..c {
lowerbound[(i, j)] = (lowerbound[(i, j)] - dist1[j]).max(Scalar::Z); lowerbound[(i, j)] = std::cmp::max(lowerbound[(i, j)] - dist1[j], F32::zero());
} }
} }
for i in 0..n { for i in 0..n {
@@ -219,7 +216,7 @@ impl ElkanKMeans {
change == 0 change == 0
} }
pub fn finish(self) -> Vec2 { pub fn finish(self) -> Vec2<S> {
self.centroids self.centroids
} }
} }
@@ -227,7 +224,7 @@ impl ElkanKMeans {
pub struct Square { pub struct Square {
x: usize, x: usize,
y: usize, y: usize,
v: Box<[Scalar]>, v: Vec<F32>,
} }
impl Square { impl Square {
@@ -235,13 +232,13 @@ impl Square {
Self { Self {
x, x,
y, y,
v: bytemuck::zeroed_slice_box(x * y), v: bytemuck::zeroed_vec(x * y),
} }
} }
} }
impl Index<(usize, usize)> for Square { impl Index<(usize, usize)> for Square {
type Output = Scalar; type Output = F32;
fn index(&self, (x, y): (usize, usize)) -> &Self::Output { fn index(&self, (x, y): (usize, usize)) -> &Self::Output {
debug_assert!(x < self.x); debug_assert!(x < self.x);

View File

@@ -9,16 +9,16 @@ use std::fs::create_dir;
use std::path::PathBuf; use std::path::PathBuf;
use std::sync::Arc; use std::sync::Arc;
pub struct Flat { pub struct Flat<S: G> {
mmap: FlatMmap, mmap: FlatMmap<S>,
} }
impl Flat { impl<S: G> Flat<S> {
pub fn create( pub fn create(
path: PathBuf, path: PathBuf,
options: IndexOptions, options: IndexOptions,
sealed: Vec<Arc<SealedSegment>>, sealed: Vec<Arc<SealedSegment<S>>>,
growing: Vec<Arc<GrowingSegment>>, growing: Vec<Arc<GrowingSegment<S>>>,
) -> Self { ) -> Self {
create_dir(&path).unwrap(); create_dir(&path).unwrap();
let ram = make(path.clone(), sealed, growing, options.clone()); let ram = make(path.clone(), sealed, growing, options.clone());
@@ -35,7 +35,7 @@ impl Flat {
self.mmap.raw.len() self.mmap.raw.len()
} }
pub fn vector(&self, i: u32) -> &[Scalar] { pub fn vector(&self, i: u32) -> &[S::Scalar] {
self.mmap.raw.vector(i) self.mmap.raw.vector(i)
} }
@@ -43,35 +43,33 @@ impl Flat {
self.mmap.raw.payload(i) self.mmap.raw.payload(i)
} }
pub fn search(&self, k: usize, vector: &[Scalar], filter: &mut impl Filter) -> Heap { pub fn search(&self, k: usize, vector: &[S::Scalar], filter: &mut impl Filter) -> Heap {
search(&self.mmap, k, vector, filter) search(&self.mmap, k, vector, filter)
} }
} }
unsafe impl Send for Flat {} unsafe impl<S: G> Send for Flat<S> {}
unsafe impl Sync for Flat {} unsafe impl<S: G> Sync for Flat<S> {}
pub struct FlatRam { pub struct FlatRam<S: G> {
raw: Arc<Raw>, raw: Arc<Raw<S>>,
quantization: Quantization, quantization: Quantization<S>,
d: Distance,
} }
pub struct FlatMmap { pub struct FlatMmap<S: G> {
raw: Arc<Raw>, raw: Arc<Raw<S>>,
quantization: Quantization, quantization: Quantization<S>,
d: Distance,
} }
unsafe impl Send for FlatMmap {} unsafe impl<S: G> Send for FlatMmap<S> {}
unsafe impl Sync for FlatMmap {} unsafe impl<S: G> Sync for FlatMmap<S> {}
pub fn make( pub fn make<S: G>(
path: PathBuf, path: PathBuf,
sealed: Vec<Arc<SealedSegment>>, sealed: Vec<Arc<SealedSegment<S>>>,
growing: Vec<Arc<GrowingSegment>>, growing: Vec<Arc<GrowingSegment<S>>>,
options: IndexOptions, options: IndexOptions,
) -> FlatRam { ) -> FlatRam<S> {
let idx_opts = options.indexing.clone().unwrap_flat(); let idx_opts = options.indexing.clone().unwrap_flat();
let raw = Arc::new(Raw::create( let raw = Arc::new(Raw::create(
path.join("raw"), path.join("raw"),
@@ -85,22 +83,17 @@ pub fn make(
idx_opts.quantization, idx_opts.quantization,
&raw, &raw,
); );
FlatRam { FlatRam { raw, quantization }
raw,
quantization,
d: options.vector.d,
}
} }
pub fn save(ram: FlatRam, _: PathBuf) -> FlatMmap { pub fn save<S: G>(ram: FlatRam<S>, _: PathBuf) -> FlatMmap<S> {
FlatMmap { FlatMmap {
raw: ram.raw, raw: ram.raw,
quantization: ram.quantization, quantization: ram.quantization,
d: ram.d,
} }
} }
pub fn load(path: PathBuf, options: IndexOptions) -> FlatMmap { pub fn load<S: G>(path: PathBuf, options: IndexOptions) -> FlatMmap<S> {
let idx_opts = options.indexing.clone().unwrap_flat(); let idx_opts = options.indexing.clone().unwrap_flat();
let raw = Arc::new(Raw::open(path.join("raw"), options.clone())); let raw = Arc::new(Raw::open(path.join("raw"), options.clone()));
let quantization = Quantization::open( let quantization = Quantization::open(
@@ -109,17 +102,18 @@ pub fn load(path: PathBuf, options: IndexOptions) -> FlatMmap {
idx_opts.quantization, idx_opts.quantization,
&raw, &raw,
); );
FlatMmap { FlatMmap { raw, quantization }
raw,
quantization,
d: options.vector.d,
}
} }
pub fn search(mmap: &FlatMmap, k: usize, vector: &[Scalar], filter: &mut impl Filter) -> Heap { pub fn search<S: G>(
mmap: &FlatMmap<S>,
k: usize,
vector: &[S::Scalar],
filter: &mut impl Filter,
) -> Heap {
let mut result = Heap::new(k); let mut result = Heap::new(k);
for i in 0..mmap.raw.len() { for i in 0..mmap.raw.len() {
let distance = mmap.quantization.distance(mmap.d, vector, i); let distance = mmap.quantization.distance(vector, i);
let payload = mmap.raw.payload(i); let payload = mmap.raw.payload(i);
if filter.check(payload) { if filter.check(payload) {
result.push(HeapElement { distance, payload }); result.push(HeapElement { distance, payload });

View File

@@ -3,7 +3,7 @@ use super::raw::Raw;
use crate::index::indexing::hnsw::HnswIndexingOptions; use crate::index::indexing::hnsw::HnswIndexingOptions;
use crate::index::segments::growing::GrowingSegment; use crate::index::segments::growing::GrowingSegment;
use crate::index::segments::sealed::SealedSegment; use crate::index::segments::sealed::SealedSegment;
use crate::index::{IndexOptions, VectorOptions}; use crate::index::IndexOptions;
use crate::prelude::*; use crate::prelude::*;
use crate::utils::dir_ops::sync_dir; use crate::utils::dir_ops::sync_dir;
use crate::utils::mmap_array::MmapArray; use crate::utils::mmap_array::MmapArray;
@@ -17,16 +17,16 @@ use std::ops::RangeInclusive;
use std::path::PathBuf; use std::path::PathBuf;
use std::sync::Arc; use std::sync::Arc;
pub struct Hnsw { pub struct Hnsw<S: G> {
mmap: HnswMmap, mmap: HnswMmap<S>,
} }
impl Hnsw { impl<S: G> Hnsw<S> {
pub fn create( pub fn create(
path: PathBuf, path: PathBuf,
options: IndexOptions, options: IndexOptions,
sealed: Vec<Arc<SealedSegment>>, sealed: Vec<Arc<SealedSegment<S>>>,
growing: Vec<Arc<GrowingSegment>>, growing: Vec<Arc<GrowingSegment<S>>>,
) -> Self { ) -> Self {
create_dir(&path).unwrap(); create_dir(&path).unwrap();
let ram = make(path.clone(), sealed, growing, options.clone()); let ram = make(path.clone(), sealed, growing, options.clone());
@@ -43,7 +43,7 @@ impl Hnsw {
self.mmap.raw.len() self.mmap.raw.len()
} }
pub fn vector(&self, i: u32) -> &[Scalar] { pub fn vector(&self, i: u32) -> &[S::Scalar] {
self.mmap.raw.vector(i) self.mmap.raw.vector(i)
} }
@@ -51,27 +51,21 @@ impl Hnsw {
self.mmap.raw.payload(i) self.mmap.raw.payload(i)
} }
pub fn search(&self, k: usize, vector: &[Scalar], filter: &mut impl Filter) -> Heap { pub fn search(&self, k: usize, vector: &[S::Scalar], filter: &mut impl Filter) -> Heap {
search(&self.mmap, k, vector, filter) search(&self.mmap, k, vector, filter)
} }
pub fn search_vbase<'index, 'vector>( pub fn search_vbase(&self, range: usize, vector: &[S::Scalar]) -> HnswIndexIter<'_, S> {
&'index self,
range: usize,
vector: &'vector [Scalar],
) -> HnswIndexIter<'index, 'vector> {
search_vbase(&self.mmap, range, vector) search_vbase(&self.mmap, range, vector)
} }
} }
unsafe impl Send for Hnsw {} unsafe impl<S: G> Send for Hnsw<S> {}
unsafe impl Sync for Hnsw {} unsafe impl<S: G> Sync for Hnsw<S> {}
pub struct HnswRam { pub struct HnswRam<S: G> {
raw: Arc<Raw>, raw: Arc<Raw<S>>,
quantization: Quantization, quantization: Quantization<S>,
// ----------------------
d: Distance,
// ---------------------- // ----------------------
m: u32, m: u32,
// ---------------------- // ----------------------
@@ -95,14 +89,12 @@ impl HnswRamVertex {
} }
struct HnswRamLayer { struct HnswRamLayer {
edges: Vec<(Scalar, u32)>, edges: Vec<(F32, u32)>,
} }
pub struct HnswMmap { pub struct HnswMmap<S: G> {
raw: Arc<Raw>, raw: Arc<Raw<S>>,
quantization: Quantization, quantization: Quantization<S>,
// ----------------------
d: Distance,
// ---------------------- // ----------------------
m: u32, m: u32,
// ---------------------- // ----------------------
@@ -114,20 +106,19 @@ pub struct HnswMmap {
} }
#[derive(Debug, Clone, Copy, Default)] #[derive(Debug, Clone, Copy, Default)]
struct HnswMmapEdge(Scalar, u32); struct HnswMmapEdge(F32, u32);
unsafe impl Send for HnswMmap {} unsafe impl<S: G> Send for HnswMmap<S> {}
unsafe impl Sync for HnswMmap {} unsafe impl<S: G> Sync for HnswMmap<S> {}
unsafe impl Pod for HnswMmapEdge {} unsafe impl Pod for HnswMmapEdge {}
unsafe impl Zeroable for HnswMmapEdge {} unsafe impl Zeroable for HnswMmapEdge {}
pub fn make( pub fn make<S: G>(
path: PathBuf, path: PathBuf,
sealed: Vec<Arc<SealedSegment>>, sealed: Vec<Arc<SealedSegment<S>>>,
growing: Vec<Arc<GrowingSegment>>, growing: Vec<Arc<GrowingSegment<S>>>,
options: IndexOptions, options: IndexOptions,
) -> HnswRam { ) -> HnswRam<S> {
let VectorOptions { d, .. } = options.vector;
let HnswIndexingOptions { let HnswIndexingOptions {
m, m,
ef_construction, ef_construction,
@@ -159,23 +150,22 @@ pub fn make(
let entry = RwLock::<Option<u32>>::new(None); let entry = RwLock::<Option<u32>>::new(None);
let visited = VisitedPool::new(raw.len()); let visited = VisitedPool::new(raw.len());
(0..n).into_par_iter().for_each(|i| { (0..n).into_par_iter().for_each(|i| {
fn fast_search( fn fast_search<S: G>(
quantization: &Quantization, quantization: &Quantization<S>,
graph: &HnswRamGraph, graph: &HnswRamGraph,
d: Distance,
levels: RangeInclusive<u8>, levels: RangeInclusive<u8>,
u: u32, u: u32,
target: &[Scalar], target: &[S::Scalar],
) -> u32 { ) -> u32 {
let mut u = u; let mut u = u;
let mut u_dis = quantization.distance(d, target, u); let mut u_dis = quantization.distance(target, u);
for i in levels.rev() { for i in levels.rev() {
let mut changed = true; let mut changed = true;
while changed { while changed {
changed = false; changed = false;
let guard = graph.vertexs[u as usize].layers[i as usize].read(); let guard = graph.vertexs[u as usize].layers[i as usize].read();
for &(_, v) in guard.edges.iter() { for &(_, v) in guard.edges.iter() {
let v_dis = quantization.distance(d, target, v); let v_dis = quantization.distance(target, v);
if v_dis < u_dis { if v_dis < u_dis {
u = v; u = v;
u_dis = v_dis; u_dis = v_dis;
@@ -186,21 +176,20 @@ pub fn make(
} }
u u
} }
fn local_search( fn local_search<S: G>(
quantization: &Quantization, quantization: &Quantization<S>,
graph: &HnswRamGraph, graph: &HnswRamGraph,
d: Distance,
visited: &mut VisitedGuard, visited: &mut VisitedGuard,
vector: &[Scalar], vector: &[S::Scalar],
s: u32, s: u32,
k: usize, k: usize,
i: u8, i: u8,
) -> Vec<(Scalar, u32)> { ) -> Vec<(F32, u32)> {
assert!(k > 0); assert!(k > 0);
let mut visited = visited.fetch(); let mut visited = visited.fetch();
let mut candidates = BinaryHeap::<Reverse<(Scalar, u32)>>::new(); let mut candidates = BinaryHeap::<Reverse<(F32, u32)>>::new();
let mut results = BinaryHeap::new(); let mut results = BinaryHeap::new();
let s_dis = quantization.distance(d, vector, s); let s_dis = quantization.distance(vector, s);
visited.mark(s); visited.mark(s);
candidates.push(Reverse((s_dis, s))); candidates.push(Reverse((s_dis, s)));
results.push((s_dis, s)); results.push((s_dis, s));
@@ -217,7 +206,7 @@ pub fn make(
continue; continue;
} }
visited.mark(v); visited.mark(v);
let v_dis = quantization.distance(d, vector, v); let v_dis = quantization.distance(vector, v);
if results.len() < k || v_dis < results.peek().unwrap().0 { if results.len() < k || v_dis < results.peek().unwrap().0 {
candidates.push(Reverse((v_dis, v))); candidates.push(Reverse((v_dis, v)));
results.push((v_dis, v)); results.push((v_dis, v));
@@ -229,12 +218,7 @@ pub fn make(
} }
results.into_sorted_vec() results.into_sorted_vec()
} }
fn select( fn select<S: G>(quantization: &Quantization<S>, input: &mut Vec<(F32, u32)>, size: u32) {
quantization: &Quantization,
d: Distance,
input: &mut Vec<(Scalar, u32)>,
size: u32,
) {
if input.len() <= size as usize { if input.len() <= size as usize {
return; return;
} }
@@ -245,7 +229,7 @@ pub fn make(
} }
let check = res let check = res
.iter() .iter()
.map(|&(_, v)| quantization.distance2(d, u, v)) .map(|&(_, v)| quantization.distance2(u, v))
.all(|dist| dist > u_dis); .all(|dist| dist > u_dis);
if check { if check {
res.push((u_dis, u)); res.push((u_dis, u));
@@ -290,14 +274,13 @@ pub fn make(
}; };
let top = graph.vertexs[u as usize].levels(); let top = graph.vertexs[u as usize].levels();
if top > levels { if top > levels {
u = fast_search(&quantization, &graph, d, levels + 1..=top, u, target); u = fast_search(&quantization, &graph, levels + 1..=top, u, target);
} }
let mut result = Vec::with_capacity(1 + std::cmp::min(levels, top) as usize); let mut result = Vec::with_capacity(1 + std::cmp::min(levels, top) as usize);
for j in (0..=std::cmp::min(levels, top)).rev() { for j in (0..=std::cmp::min(levels, top)).rev() {
let mut edges = local_search( let mut edges = local_search(
&quantization, &quantization,
&graph, &graph,
d,
&mut visited, &mut visited,
target, target,
u, u,
@@ -305,12 +288,7 @@ pub fn make(
j, j,
); );
edges.sort(); edges.sort();
select( select(&quantization, &mut edges, count_max_edges_of_a_layer(m, j));
&quantization,
d,
&mut edges,
count_max_edges_of_a_layer(m, j),
);
u = edges.first().unwrap().1; u = edges.first().unwrap().1;
result.push(edges); result.push(edges);
} }
@@ -325,7 +303,6 @@ pub fn make(
write.edges.insert(index, element); write.edges.insert(index, element);
select( select(
&quantization, &quantization,
d,
&mut write.edges, &mut write.edges,
count_max_edges_of_a_layer(m, j), count_max_edges_of_a_layer(m, j),
); );
@@ -338,14 +315,13 @@ pub fn make(
HnswRam { HnswRam {
raw, raw,
quantization, quantization,
d,
m, m,
graph, graph,
visited, visited,
} }
} }
pub fn save(mut ram: HnswRam, path: PathBuf) -> HnswMmap { pub fn save<S: G>(mut ram: HnswRam<S>, path: PathBuf) -> HnswMmap<S> {
let edges = MmapArray::create( let edges = MmapArray::create(
path.join("edges"), path.join("edges"),
ram.graph ram.graph
@@ -369,7 +345,6 @@ pub fn save(mut ram: HnswRam, path: PathBuf) -> HnswMmap {
HnswMmap { HnswMmap {
raw: ram.raw, raw: ram.raw,
quantization: ram.quantization, quantization: ram.quantization,
d: ram.d,
m: ram.m, m: ram.m,
edges, edges,
by_layer_id, by_layer_id,
@@ -378,7 +353,7 @@ pub fn save(mut ram: HnswRam, path: PathBuf) -> HnswMmap {
} }
} }
pub fn load(path: PathBuf, options: IndexOptions) -> HnswMmap { pub fn load<S: G>(path: PathBuf, options: IndexOptions) -> HnswMmap<S> {
let idx_opts = options.indexing.clone().unwrap_hnsw(); let idx_opts = options.indexing.clone().unwrap_hnsw();
let raw = Arc::new(Raw::open(path.join("raw"), options.clone())); let raw = Arc::new(Raw::open(path.join("raw"), options.clone()));
let quantization = Quantization::open( let quantization = Quantization::open(
@@ -395,7 +370,6 @@ pub fn load(path: PathBuf, options: IndexOptions) -> HnswMmap {
HnswMmap { HnswMmap {
raw, raw,
quantization, quantization,
d: options.vector.d,
m: idx_opts.m, m: idx_opts.m,
edges, edges,
by_layer_id, by_layer_id,
@@ -404,7 +378,12 @@ pub fn load(path: PathBuf, options: IndexOptions) -> HnswMmap {
} }
} }
pub fn search(mmap: &HnswMmap, k: usize, vector: &[Scalar], filter: &mut impl Filter) -> Heap { pub fn search<S: G>(
mmap: &HnswMmap<S>,
k: usize,
vector: &[S::Scalar],
filter: &mut impl Filter,
) -> Heap {
let Some(s) = entry(mmap, filter) else { let Some(s) = entry(mmap, filter) else {
return Heap::new(k); return Heap::new(k);
}; };
@@ -413,11 +392,11 @@ pub fn search(mmap: &HnswMmap, k: usize, vector: &[Scalar], filter: &mut impl Fi
local_search(mmap, k, u, vector, filter) local_search(mmap, k, u, vector, filter)
} }
pub fn search_vbase<'index, 'vector>( pub fn search_vbase<'a, S: G>(
mmap: &'index HnswMmap, mmap: &'a HnswMmap<S>,
range: usize, range: usize,
vector: &'vector [Scalar], vector: &[S::Scalar],
) -> HnswIndexIter<'index, 'vector> { ) -> HnswIndexIter<'a, S> {
let filter_fn = &mut |_| true; let filter_fn = &mut |_| true;
let Some(s) = entry(mmap, filter_fn) else { let Some(s) = entry(mmap, filter_fn) else {
return HnswIndexIter(None); return HnswIndexIter(None);
@@ -427,7 +406,7 @@ pub fn search_vbase<'index, 'vector>(
local_search_vbase(mmap, range, u, vector) local_search_vbase(mmap, range, u, vector)
} }
pub fn entry(mmap: &HnswMmap, filter: &mut impl Filter) -> Option<u32> { pub fn entry<S: G>(mmap: &HnswMmap<S>, filter: &mut impl Filter) -> Option<u32> {
let m = mmap.m; let m = mmap.m;
let n = mmap.raw.len(); let n = mmap.raw.len();
let mut shift = 1u64; let mut shift = 1u64;
@@ -455,15 +434,15 @@ pub fn entry(mmap: &HnswMmap, filter: &mut impl Filter) -> Option<u32> {
None None
} }
pub fn fast_search( pub fn fast_search<S: G>(
mmap: &HnswMmap, mmap: &HnswMmap<S>,
levels: RangeInclusive<u8>, levels: RangeInclusive<u8>,
u: u32, u: u32,
vector: &[Scalar], vector: &[S::Scalar],
filter: &mut impl Filter, filter: &mut impl Filter,
) -> u32 { ) -> u32 {
let mut u = u; let mut u = u;
let mut u_dis = mmap.quantization.distance(mmap.d, vector, u); let mut u_dis = mmap.quantization.distance(vector, u);
for i in levels.rev() { for i in levels.rev() {
let mut changed = true; let mut changed = true;
while changed { while changed {
@@ -473,7 +452,7 @@ pub fn fast_search(
if !filter.check(mmap.raw.payload(v)) { if !filter.check(mmap.raw.payload(v)) {
continue; continue;
} }
let v_dis = mmap.quantization.distance(mmap.d, vector, v); let v_dis = mmap.quantization.distance(vector, v);
if v_dis < u_dis { if v_dis < u_dis {
u = v; u = v;
u_dis = v_dis; u_dis = v_dis;
@@ -485,20 +464,20 @@ pub fn fast_search(
u u
} }
pub fn local_search( pub fn local_search<S: G>(
mmap: &HnswMmap, mmap: &HnswMmap<S>,
k: usize, k: usize,
s: u32, s: u32,
vector: &[Scalar], vector: &[S::Scalar],
filter: &mut impl Filter, filter: &mut impl Filter,
) -> Heap { ) -> Heap {
assert!(k > 0); assert!(k > 0);
let mut visited = mmap.visited.fetch(); let mut visited = mmap.visited.fetch();
let mut visited = visited.fetch(); let mut visited = visited.fetch();
let mut candidates = BinaryHeap::<Reverse<(Scalar, u32)>>::new(); let mut candidates = BinaryHeap::<Reverse<(F32, u32)>>::new();
let mut results = Heap::new(k); let mut results = Heap::new(k);
visited.mark(s); visited.mark(s);
let s_dis = mmap.quantization.distance(mmap.d, vector, s); let s_dis = mmap.quantization.distance(vector, s);
candidates.push(Reverse((s_dis, s))); candidates.push(Reverse((s_dis, s)));
results.push(HeapElement { results.push(HeapElement {
distance: s_dis, distance: s_dis,
@@ -517,7 +496,7 @@ pub fn local_search(
if !filter.check(mmap.raw.payload(v)) { if !filter.check(mmap.raw.payload(v)) {
continue; continue;
} }
let v_dis = mmap.quantization.distance(mmap.d, vector, v); let v_dis = mmap.quantization.distance(vector, v);
if !results.check(v_dis) { if !results.check(v_dis) {
continue; continue;
} }
@@ -531,20 +510,20 @@ pub fn local_search(
results results
} }
fn local_search_vbase<'mmap, 'vector>( fn local_search_vbase<'a, S: G>(
mmap: &'mmap HnswMmap, mmap: &'a HnswMmap<S>,
range: usize, range: usize,
s: u32, s: u32,
vector: &'vector [Scalar], vector: &[S::Scalar],
) -> HnswIndexIter<'mmap, 'vector> { ) -> HnswIndexIter<'a, S> {
assert!(range > 0); assert!(range > 0);
let mut visited_guard = mmap.visited.fetch(); let mut visited_guard = mmap.visited.fetch();
let mut visited = visited_guard.fetch(); let mut visited = visited_guard.fetch();
let mut candidates = BinaryHeap::<Reverse<(Scalar, u32)>>::new(); let mut candidates = BinaryHeap::<Reverse<(F32, u32)>>::new();
let mut results = Heap::new(range); let mut results = Heap::new(range);
let mut lost = Vec::<Reverse<HeapElement>>::new(); let mut lost = Vec::<Reverse<HeapElement>>::new();
visited.mark(s); visited.mark(s);
let s_dis = mmap.quantization.distance(mmap.d, vector, s); let s_dis = mmap.quantization.distance(vector, s);
candidates.push(Reverse((s_dis, s))); candidates.push(Reverse((s_dis, s)));
results.push(HeapElement { results.push(HeapElement {
distance: s_dis, distance: s_dis,
@@ -561,7 +540,7 @@ fn local_search_vbase<'mmap, 'vector>(
continue; continue;
} }
visited.mark(v); visited.mark(v);
let v_dis = mmap.quantization.distance(mmap.d, vector, v); let v_dis = mmap.quantization.distance(vector, v);
if !results.check(v_dis) { if !results.check(v_dis) {
continue; continue;
} }
@@ -582,7 +561,7 @@ fn local_search_vbase<'mmap, 'vector>(
results: results.into_reversed_heap(), results: results.into_reversed_heap(),
lost, lost,
visited: visited_guard, visited: visited_guard,
vector, vector: vector.to_vec(),
})) }))
} }
@@ -614,7 +593,7 @@ fn caluate_offsets(iter: impl Iterator<Item = usize>) -> impl Iterator<Item = us
}) })
} }
fn find_edges(mmap: &HnswMmap, u: u32, level: u8) -> &[HnswMmapEdge] { fn find_edges<S: G>(mmap: &HnswMmap<S>, u: u32, level: u8) -> &[HnswMmapEdge] {
let offset = u as usize; let offset = u as usize;
let index = mmap.by_vertex_id[offset]..mmap.by_vertex_id[offset + 1]; let index = mmap.by_vertex_id[offset]..mmap.by_vertex_id[offset + 1];
let offset = index.start + level as usize; let offset = index.start + level as usize;
@@ -670,7 +649,7 @@ impl<'a> Drop for VisitedGuard<'a> {
fn drop(&mut self) { fn drop(&mut self) {
let src = VisitedBuffer { let src = VisitedBuffer {
version: 0, version: 0,
data: Box::new([]), data: Vec::new(),
}; };
let buffer = std::mem::replace(&mut self.buffer, src); let buffer = std::mem::replace(&mut self.buffer, src);
self.pool.locked_buffers.lock().push(buffer); self.pool.locked_buffers.lock().push(buffer);
@@ -692,39 +671,39 @@ impl<'a> VisitedChecker<'a> {
struct VisitedBuffer { struct VisitedBuffer {
version: usize, version: usize,
data: Box<[usize]>, data: Vec<usize>,
} }
impl VisitedBuffer { impl VisitedBuffer {
fn new(capacity: usize) -> Self { fn new(capacity: usize) -> Self {
Self { Self {
version: 0, version: 0,
data: bytemuck::zeroed_slice_box(capacity), data: bytemuck::zeroed_vec(capacity),
} }
} }
} }
pub struct HnswIndexIter<'mmap, 'vector>(Option<HnswIndexIterInner<'mmap, 'vector>>); pub struct HnswIndexIter<'mmap, S: G>(Option<HnswIndexIterInner<'mmap, S>>);
pub struct HnswIndexIterInner<'mmap, 'vector> { pub struct HnswIndexIterInner<'mmap, S: G> {
mmap: &'mmap HnswMmap, mmap: &'mmap HnswMmap<S>,
range: usize, range: usize,
candidates: BinaryHeap<Reverse<(Scalar, u32)>>, candidates: BinaryHeap<Reverse<(F32, u32)>>,
results: BinaryHeap<Reverse<HeapElement>>, results: BinaryHeap<Reverse<HeapElement>>,
// The points lost in the first stage, we should keep it to the second stage. // The points lost in the first stage, we should keep it to the second stage.
lost: Vec<Reverse<HeapElement>>, lost: Vec<Reverse<HeapElement>>,
visited: VisitedGuard<'mmap>, visited: VisitedGuard<'mmap>,
vector: &'vector [Scalar], vector: Vec<S::Scalar>,
} }
impl Iterator for HnswIndexIter<'_, '_> { impl<S: G> Iterator for HnswIndexIter<'_, S> {
type Item = HeapElement; type Item = HeapElement;
fn next(&mut self) -> Option<Self::Item> { fn next(&mut self) -> Option<Self::Item> {
self.0.as_mut()?.next() self.0.as_mut()?.next()
} }
} }
impl Iterator for HnswIndexIterInner<'_, '_> { impl<S: G> Iterator for HnswIndexIterInner<'_, S> {
type Item = HeapElement; type Item = HeapElement;
fn next(&mut self) -> Option<Self::Item> { fn next(&mut self) -> Option<Self::Item> {
if self.results.len() > self.range { if self.results.len() > self.range {
@@ -739,7 +718,7 @@ impl Iterator for HnswIndexIterInner<'_, '_> {
continue; continue;
} }
visited.mark(v); visited.mark(v);
let v_dis = self.mmap.quantization.distance(self.mmap.d, self.vector, v); let v_dis = self.mmap.quantization.distance(&self.vector, v);
self.candidates.push(Reverse((v_dis, v))); self.candidates.push(Reverse((v_dis, v)));
self.results.push(Reverse(HeapElement { self.results.push(Reverse(HeapElement {
distance: v_dis, distance: v_dis,
@@ -755,7 +734,7 @@ impl Iterator for HnswIndexIterInner<'_, '_> {
} }
} }
impl HnswIndexIterInner<'_, '_> { impl<S: G> HnswIndexIterInner<'_, S> {
fn pop(&mut self) -> Option<HeapElement> { fn pop(&mut self) -> Option<HeapElement> {
if self.results.peek() > self.lost.last() { if self.results.peek() > self.lost.last() {
self.results.pop().map(|x| x.0) self.results.pop().map(|x| x.0)

View File

@@ -20,16 +20,16 @@ use std::sync::atomic::AtomicU32;
use std::sync::atomic::Ordering::{Acquire, Relaxed, Release}; use std::sync::atomic::Ordering::{Acquire, Relaxed, Release};
use std::sync::Arc; use std::sync::Arc;
pub struct IvfNaive { pub struct IvfNaive<S: G> {
mmap: IvfMmap, mmap: IvfMmap<S>,
} }
impl IvfNaive { impl<S: G> IvfNaive<S> {
pub fn create( pub fn create(
path: PathBuf, path: PathBuf,
options: IndexOptions, options: IndexOptions,
sealed: Vec<Arc<SealedSegment>>, sealed: Vec<Arc<SealedSegment<S>>>,
growing: Vec<Arc<GrowingSegment>>, growing: Vec<Arc<GrowingSegment<S>>>,
) -> Self { ) -> Self {
create_dir(&path).unwrap(); create_dir(&path).unwrap();
let ram = make(path.clone(), sealed, growing, options); let ram = make(path.clone(), sealed, growing, options);
@@ -47,7 +47,7 @@ impl IvfNaive {
self.mmap.raw.len() self.mmap.raw.len()
} }
pub fn vector(&self, i: u32) -> &[Scalar] { pub fn vector(&self, i: u32) -> &[S::Scalar] {
self.mmap.raw.vector(i) self.mmap.raw.vector(i)
} }
@@ -55,65 +55,63 @@ impl IvfNaive {
self.mmap.raw.payload(i) self.mmap.raw.payload(i)
} }
pub fn search(&self, k: usize, vector: &[Scalar], filter: &mut impl Filter) -> Heap { pub fn search(&self, k: usize, vector: &[S::Scalar], filter: &mut impl Filter) -> Heap {
search(&self.mmap, k, vector, filter) search(&self.mmap, k, vector, filter)
} }
} }
unsafe impl Send for IvfNaive {} unsafe impl<S: G> Send for IvfNaive<S> {}
unsafe impl Sync for IvfNaive {} unsafe impl<S: G> Sync for IvfNaive<S> {}
pub struct IvfRam { pub struct IvfRam<S: G> {
raw: Arc<Raw>, raw: Arc<Raw<S>>,
quantization: Quantization, quantization: Quantization<S>,
// ---------------------- // ----------------------
dims: u16, dims: u16,
d: Distance,
// ---------------------- // ----------------------
nlist: u32, nlist: u32,
nprobe: u32, nprobe: u32,
// ---------------------- // ----------------------
centroids: Vec2, centroids: Vec2<S>,
heads: Vec<AtomicU32>, heads: Vec<AtomicU32>,
nexts: Vec<SyncUnsafeCell<u32>>, nexts: Vec<SyncUnsafeCell<u32>>,
} }
unsafe impl Send for IvfRam {} unsafe impl<S: G> Send for IvfRam<S> {}
unsafe impl Sync for IvfRam {} unsafe impl<S: G> Sync for IvfRam<S> {}
pub struct IvfMmap { pub struct IvfMmap<S: G> {
raw: Arc<Raw>, raw: Arc<Raw<S>>,
quantization: Quantization, quantization: Quantization<S>,
// ---------------------- // ----------------------
dims: u16, dims: u16,
d: Distance,
// ---------------------- // ----------------------
nlist: u32, nlist: u32,
nprobe: u32, nprobe: u32,
// ---------------------- // ----------------------
centroids: MmapArray<Scalar>, centroids: MmapArray<S::Scalar>,
heads: MmapArray<u32>, heads: MmapArray<u32>,
nexts: MmapArray<u32>, nexts: MmapArray<u32>,
} }
unsafe impl Send for IvfMmap {} unsafe impl<S: G> Send for IvfMmap<S> {}
unsafe impl Sync for IvfMmap {} unsafe impl<S: G> Sync for IvfMmap<S> {}
impl IvfMmap { impl<S: G> IvfMmap<S> {
fn centroids(&self, i: u32) -> &[Scalar] { fn centroids(&self, i: u32) -> &[S::Scalar] {
let s = i as usize * self.dims as usize; let s = i as usize * self.dims as usize;
let e = (i + 1) as usize * self.dims as usize; let e = (i + 1) as usize * self.dims as usize;
&self.centroids[s..e] &self.centroids[s..e]
} }
} }
pub fn make( pub fn make<S: G>(
path: PathBuf, path: PathBuf,
sealed: Vec<Arc<SealedSegment>>, sealed: Vec<Arc<SealedSegment<S>>>,
growing: Vec<Arc<GrowingSegment>>, growing: Vec<Arc<GrowingSegment<S>>>,
options: IndexOptions, options: IndexOptions,
) -> IvfRam { ) -> IvfRam<S> {
let VectorOptions { dims, d } = options.vector; let VectorOptions { dims, .. } = options.vector;
let IvfIndexingOptions { let IvfIndexingOptions {
least_iterations, least_iterations,
iterations, iterations,
@@ -140,9 +138,9 @@ pub fn make(
let mut samples = Vec2::new(dims, m as usize); let mut samples = Vec2::new(dims, m as usize);
for i in 0..m { for i in 0..m {
samples[i as usize].copy_from_slice(raw.vector(f[i as usize] as u32)); samples[i as usize].copy_from_slice(raw.vector(f[i as usize] as u32));
d.elkan_k_means_normalize(&mut samples[i as usize]); S::elkan_k_means_normalize(&mut samples[i as usize]);
} }
let mut k_means = ElkanKMeans::new(nlist as usize, samples, d); let mut k_means = ElkanKMeans::new(nlist as usize, samples);
for _ in 0..least_iterations { for _ in 0..least_iterations {
k_means.iterate(); k_means.iterate();
} }
@@ -164,10 +162,10 @@ pub fn make(
}; };
(0..n).into_par_iter().for_each(|i| { (0..n).into_par_iter().for_each(|i| {
let mut vector = raw.vector(i).to_vec(); let mut vector = raw.vector(i).to_vec();
d.elkan_k_means_normalize(&mut vector); S::elkan_k_means_normalize(&mut vector);
let mut result = (Scalar::INFINITY, 0); let mut result = (F32::infinity(), 0);
for i in 0..nlist { for i in 0..nlist {
let dis = d.elkan_k_means_distance(&vector, &centroids[i as usize]); let dis = S::elkan_k_means_distance(&vector, &centroids[i as usize]);
result = std::cmp::min(result, (dis, i)); result = std::cmp::min(result, (dis, i));
} }
let centroid_id = result.1; let centroid_id = result.1;
@@ -191,11 +189,10 @@ pub fn make(
nprobe, nprobe,
nlist, nlist,
dims, dims,
d,
} }
} }
pub fn save(mut ram: IvfRam, path: PathBuf) -> IvfMmap { pub fn save<S: G>(mut ram: IvfRam<S>, path: PathBuf) -> IvfMmap<S> {
let centroids = MmapArray::create( let centroids = MmapArray::create(
path.join("centroids"), path.join("centroids"),
(0..ram.nlist) (0..ram.nlist)
@@ -214,7 +211,6 @@ pub fn save(mut ram: IvfRam, path: PathBuf) -> IvfMmap {
raw: ram.raw, raw: ram.raw,
quantization: ram.quantization, quantization: ram.quantization,
dims: ram.dims, dims: ram.dims,
d: ram.d,
nlist: ram.nlist, nlist: ram.nlist,
nprobe: ram.nprobe, nprobe: ram.nprobe,
centroids, centroids,
@@ -223,7 +219,7 @@ pub fn save(mut ram: IvfRam, path: PathBuf) -> IvfMmap {
} }
} }
pub fn load(path: PathBuf, options: IndexOptions) -> IvfMmap { pub fn load<S: G>(path: PathBuf, options: IndexOptions) -> IvfMmap<S> {
let raw = Arc::new(Raw::open(path.join("raw"), options.clone())); let raw = Arc::new(Raw::open(path.join("raw"), options.clone()));
let quantization = Quantization::open( let quantization = Quantization::open(
path.join("quantization"), path.join("quantization"),
@@ -239,7 +235,6 @@ pub fn load(path: PathBuf, options: IndexOptions) -> IvfMmap {
raw, raw,
quantization, quantization,
dims: options.vector.dims, dims: options.vector.dims,
d: options.vector.d,
nlist, nlist,
nprobe, nprobe,
centroids, centroids,
@@ -248,13 +243,18 @@ pub fn load(path: PathBuf, options: IndexOptions) -> IvfMmap {
} }
} }
pub fn search(mmap: &IvfMmap, k: usize, vector: &[Scalar], filter: &mut impl Filter) -> Heap { pub fn search<S: G>(
mmap: &IvfMmap<S>,
k: usize,
vector: &[S::Scalar],
filter: &mut impl Filter,
) -> Heap {
let mut target = vector.to_vec(); let mut target = vector.to_vec();
mmap.d.elkan_k_means_normalize(&mut target); S::elkan_k_means_normalize(&mut target);
let mut lists = Heap::new(mmap.nprobe as usize); let mut lists = Heap::new(mmap.nprobe as usize);
for i in 0..mmap.nlist { for i in 0..mmap.nlist {
let centroid = mmap.centroids(i); let centroid = mmap.centroids(i);
let distance = mmap.d.elkan_k_means_distance(&target, centroid); let distance = S::elkan_k_means_distance(&target, centroid);
if lists.check(distance) { if lists.check(distance) {
lists.push(HeapElement { lists.push(HeapElement {
distance, distance,
@@ -267,7 +267,7 @@ pub fn search(mmap: &IvfMmap, k: usize, vector: &[Scalar], filter: &mut impl Fil
for i in lists.iter().map(|e| e.payload as usize) { for i in lists.iter().map(|e| e.payload as usize) {
let mut j = mmap.heads[i]; let mut j = mmap.heads[i];
while u32::MAX != j { while u32::MAX != j {
let distance = mmap.quantization.distance(mmap.d, vector, j); let distance = mmap.quantization.distance(vector, j);
let payload = mmap.raw.payload(j); let payload = mmap.raw.payload(j);
if result.check(distance) && filter.check(payload) { if result.check(distance) && filter.check(payload) {
result.push(HeapElement { distance, payload }); result.push(HeapElement { distance, payload });

View File

@@ -20,16 +20,16 @@ use std::sync::atomic::AtomicU32;
use std::sync::atomic::Ordering::{Acquire, Relaxed, Release}; use std::sync::atomic::Ordering::{Acquire, Relaxed, Release};
use std::sync::Arc; use std::sync::Arc;
pub struct IvfPq { pub struct IvfPq<S: G> {
mmap: IvfMmap, mmap: IvfMmap<S>,
} }
impl IvfPq { impl<S: G> IvfPq<S> {
pub fn create( pub fn create(
path: PathBuf, path: PathBuf,
options: IndexOptions, options: IndexOptions,
sealed: Vec<Arc<SealedSegment>>, sealed: Vec<Arc<SealedSegment<S>>>,
growing: Vec<Arc<GrowingSegment>>, growing: Vec<Arc<GrowingSegment<S>>>,
) -> Self { ) -> Self {
create_dir(&path).unwrap(); create_dir(&path).unwrap();
let ram = make(path.clone(), sealed, growing, options); let ram = make(path.clone(), sealed, growing, options);
@@ -47,7 +47,7 @@ impl IvfPq {
self.mmap.raw.len() self.mmap.raw.len()
} }
pub fn vector(&self, i: u32) -> &[Scalar] { pub fn vector(&self, i: u32) -> &[S::Scalar] {
self.mmap.raw.vector(i) self.mmap.raw.vector(i)
} }
@@ -55,65 +55,63 @@ impl IvfPq {
self.mmap.raw.payload(i) self.mmap.raw.payload(i)
} }
pub fn search(&self, k: usize, vector: &[Scalar], filter: &mut impl Filter) -> Heap { pub fn search(&self, k: usize, vector: &[S::Scalar], filter: &mut impl Filter) -> Heap {
search(&self.mmap, k, vector, filter) search(&self.mmap, k, vector, filter)
} }
} }
unsafe impl Send for IvfPq {} unsafe impl<S: G> Send for IvfPq<S> {}
unsafe impl Sync for IvfPq {} unsafe impl<S: G> Sync for IvfPq<S> {}
pub struct IvfRam { pub struct IvfRam<S: G> {
raw: Arc<Raw>, raw: Arc<Raw<S>>,
quantization: ProductQuantization, quantization: ProductQuantization<S>,
// ---------------------- // ----------------------
dims: u16, dims: u16,
d: Distance,
// ---------------------- // ----------------------
nlist: u32, nlist: u32,
nprobe: u32, nprobe: u32,
// ---------------------- // ----------------------
centroids: Vec2, centroids: Vec2<S>,
heads: Vec<AtomicU32>, heads: Vec<AtomicU32>,
nexts: Vec<SyncUnsafeCell<u32>>, nexts: Vec<SyncUnsafeCell<u32>>,
} }
unsafe impl Send for IvfRam {} unsafe impl<S: G> Send for IvfRam<S> {}
unsafe impl Sync for IvfRam {} unsafe impl<S: G> Sync for IvfRam<S> {}
pub struct IvfMmap { pub struct IvfMmap<S: G> {
raw: Arc<Raw>, raw: Arc<Raw<S>>,
quantization: ProductQuantization, quantization: ProductQuantization<S>,
// ---------------------- // ----------------------
dims: u16, dims: u16,
d: Distance,
// ---------------------- // ----------------------
nlist: u32, nlist: u32,
nprobe: u32, nprobe: u32,
// ---------------------- // ----------------------
centroids: MmapArray<Scalar>, centroids: MmapArray<S::Scalar>,
heads: MmapArray<u32>, heads: MmapArray<u32>,
nexts: MmapArray<u32>, nexts: MmapArray<u32>,
} }
unsafe impl Send for IvfMmap {} unsafe impl<S: G> Send for IvfMmap<S> {}
unsafe impl Sync for IvfMmap {} unsafe impl<S: G> Sync for IvfMmap<S> {}
impl IvfMmap { impl<S: G> IvfMmap<S> {
fn centroids(&self, i: u32) -> &[Scalar] { fn centroids(&self, i: u32) -> &[S::Scalar] {
let s = i as usize * self.dims as usize; let s = i as usize * self.dims as usize;
let e = (i + 1) as usize * self.dims as usize; let e = (i + 1) as usize * self.dims as usize;
&self.centroids[s..e] &self.centroids[s..e]
} }
} }
pub fn make( pub fn make<S: G>(
path: PathBuf, path: PathBuf,
sealed: Vec<Arc<SealedSegment>>, sealed: Vec<Arc<SealedSegment<S>>>,
growing: Vec<Arc<GrowingSegment>>, growing: Vec<Arc<GrowingSegment<S>>>,
options: IndexOptions, options: IndexOptions,
) -> IvfRam { ) -> IvfRam<S> {
let VectorOptions { dims, d } = options.vector; let VectorOptions { dims, .. } = options.vector;
let IvfIndexingOptions { let IvfIndexingOptions {
least_iterations, least_iterations,
iterations, iterations,
@@ -134,9 +132,9 @@ pub fn make(
let mut samples = Vec2::new(dims, m as usize); let mut samples = Vec2::new(dims, m as usize);
for i in 0..m { for i in 0..m {
samples[i as usize].copy_from_slice(raw.vector(f[i as usize] as u32)); samples[i as usize].copy_from_slice(raw.vector(f[i as usize] as u32));
d.elkan_k_means_normalize(&mut samples[i as usize]); S::elkan_k_means_normalize(&mut samples[i as usize]);
} }
let mut k_means = ElkanKMeans::new(nlist as usize, samples, d); let mut k_means = ElkanKMeans::new(nlist as usize, samples);
for _ in 0..least_iterations { for _ in 0..least_iterations {
k_means.iterate(); k_means.iterate();
} }
@@ -163,10 +161,10 @@ pub fn make(
&raw, &raw,
|i, target| { |i, target| {
let mut vector = target.to_vec(); let mut vector = target.to_vec();
d.elkan_k_means_normalize(&mut vector); S::elkan_k_means_normalize(&mut vector);
let mut result = (Scalar::INFINITY, 0); let mut result = (F32::infinity(), 0);
for i in 0..nlist { for i in 0..nlist {
let dis = d.elkan_k_means_distance(&vector, &centroids[i as usize]); let dis = S::elkan_k_means_distance(&vector, &centroids[i as usize]);
result = std::cmp::min(result, (dis, i)); result = std::cmp::min(result, (dis, i));
} }
let centroid_id = result.1; let centroid_id = result.1;
@@ -194,11 +192,10 @@ pub fn make(
nprobe, nprobe,
nlist, nlist,
dims, dims,
d,
} }
} }
pub fn save(mut ram: IvfRam, path: PathBuf) -> IvfMmap { pub fn save<S: G>(mut ram: IvfRam<S>, path: PathBuf) -> IvfMmap<S> {
let centroids = MmapArray::create( let centroids = MmapArray::create(
path.join("centroids"), path.join("centroids"),
(0..ram.nlist) (0..ram.nlist)
@@ -217,7 +214,6 @@ pub fn save(mut ram: IvfRam, path: PathBuf) -> IvfMmap {
raw: ram.raw, raw: ram.raw,
quantization: ram.quantization, quantization: ram.quantization,
dims: ram.dims, dims: ram.dims,
d: ram.d,
nlist: ram.nlist, nlist: ram.nlist,
nprobe: ram.nprobe, nprobe: ram.nprobe,
centroids, centroids,
@@ -226,7 +222,7 @@ pub fn save(mut ram: IvfRam, path: PathBuf) -> IvfMmap {
} }
} }
pub fn load(path: PathBuf, options: IndexOptions) -> IvfMmap { pub fn load<S: G>(path: PathBuf, options: IndexOptions) -> IvfMmap<S> {
let raw = Arc::new(Raw::open(path.join("raw"), options.clone())); let raw = Arc::new(Raw::open(path.join("raw"), options.clone()));
let quantization = ProductQuantization::open( let quantization = ProductQuantization::open(
path.join("quantization"), path.join("quantization"),
@@ -242,7 +238,6 @@ pub fn load(path: PathBuf, options: IndexOptions) -> IvfMmap {
raw, raw,
quantization, quantization,
dims: options.vector.dims, dims: options.vector.dims,
d: options.vector.d,
nlist, nlist,
nprobe, nprobe,
centroids, centroids,
@@ -251,13 +246,18 @@ pub fn load(path: PathBuf, options: IndexOptions) -> IvfMmap {
} }
} }
pub fn search(mmap: &IvfMmap, k: usize, vector: &[Scalar], filter: &mut impl Filter) -> Heap { pub fn search<S: G>(
mmap: &IvfMmap<S>,
k: usize,
vector: &[S::Scalar],
filter: &mut impl Filter,
) -> Heap {
let mut target = vector.to_vec(); let mut target = vector.to_vec();
mmap.d.elkan_k_means_normalize(&mut target); S::elkan_k_means_normalize(&mut target);
let mut lists = Heap::new(mmap.nprobe as usize); let mut lists = Heap::new(mmap.nprobe as usize);
for i in 0..mmap.nlist { for i in 0..mmap.nlist {
let centroid = mmap.centroids(i); let centroid = mmap.centroids(i);
let distance = mmap.d.elkan_k_means_distance(&target, centroid); let distance = S::elkan_k_means_distance(&target, centroid);
if lists.check(distance) { if lists.check(distance) {
lists.push(HeapElement { lists.push(HeapElement {
distance, distance,
@@ -270,9 +270,9 @@ pub fn search(mmap: &IvfMmap, k: usize, vector: &[Scalar], filter: &mut impl Fil
for i in lists.iter().map(|e| e.payload as u32) { for i in lists.iter().map(|e| e.payload as u32) {
let mut j = mmap.heads[i as usize]; let mut j = mmap.heads[i as usize];
while u32::MAX != j { while u32::MAX != j {
let distance = let distance = mmap
mmap.quantization .quantization
.distance_with_delta(mmap.d, vector, j, mmap.centroids(i)); .distance_with_delta(vector, j, mmap.centroids(i));
let payload = mmap.raw.payload(j); let payload = mmap.raw.payload(j);
if result.check(distance) && filter.check(payload) { if result.check(distance) && filter.check(payload) {
result.push(HeapElement { distance, payload }); result.push(HeapElement { distance, payload });

View File

@@ -10,17 +10,17 @@ use crate::prelude::*;
use std::path::PathBuf; use std::path::PathBuf;
use std::sync::Arc; use std::sync::Arc;
pub enum Ivf { pub enum Ivf<S: G> {
Naive(IvfNaive), Naive(IvfNaive<S>),
Pq(IvfPq), Pq(IvfPq<S>),
} }
impl Ivf { impl<S: G> Ivf<S> {
pub fn create( pub fn create(
path: PathBuf, path: PathBuf,
options: IndexOptions, options: IndexOptions,
sealed: Vec<Arc<SealedSegment>>, sealed: Vec<Arc<SealedSegment<S>>>,
growing: Vec<Arc<GrowingSegment>>, growing: Vec<Arc<GrowingSegment<S>>>,
) -> Self { ) -> Self {
if options if options
.indexing .indexing
@@ -56,7 +56,7 @@ impl Ivf {
} }
} }
pub fn vector(&self, i: u32) -> &[Scalar] { pub fn vector(&self, i: u32) -> &[S::Scalar] {
match self { match self {
Ivf::Naive(x) => x.vector(i), Ivf::Naive(x) => x.vector(i),
Ivf::Pq(x) => x.vector(i), Ivf::Pq(x) => x.vector(i),
@@ -70,7 +70,7 @@ impl Ivf {
} }
} }
pub fn search(&self, k: usize, vector: &[Scalar], filter: &mut impl Filter) -> Heap { pub fn search(&self, k: usize, vector: &[S::Scalar], filter: &mut impl Filter) -> Heap {
match self { match self {
Ivf::Naive(x) => x.search(k, vector, filter), Ivf::Naive(x) => x.search(k, vector, filter),
Ivf::Pq(x) => x.search(k, vector, filter), Ivf::Pq(x) => x.search(k, vector, filter),

View File

@@ -1,5 +1,4 @@
pub mod clustering; pub mod clustering;
pub mod diskann;
pub mod flat; pub mod flat;
pub mod hnsw; pub mod hnsw;
pub mod ivf; pub mod ivf;

View File

@@ -56,35 +56,35 @@ impl QuantizationOptions {
} }
} }
pub trait Quan { pub trait Quan<S: G> {
fn create( fn create(
path: PathBuf, path: PathBuf,
options: IndexOptions, options: IndexOptions,
quantization_options: QuantizationOptions, quantization_options: QuantizationOptions,
raw: &Arc<Raw>, raw: &Arc<Raw<S>>,
) -> Self; ) -> Self;
fn open( fn open(
path: PathBuf, path: PathBuf,
options: IndexOptions, options: IndexOptions,
quantization_options: QuantizationOptions, quantization_options: QuantizationOptions,
raw: &Arc<Raw>, raw: &Arc<Raw<S>>,
) -> Self; ) -> Self;
fn distance(&self, d: Distance, lhs: &[Scalar], rhs: u32) -> Scalar; fn distance(&self, lhs: &[S::Scalar], rhs: u32) -> F32;
fn distance2(&self, d: Distance, lhs: u32, rhs: u32) -> Scalar; fn distance2(&self, lhs: u32, rhs: u32) -> F32;
} }
pub enum Quantization { pub enum Quantization<S: G> {
Trivial(TrivialQuantization), Trivial(TrivialQuantization<S>),
Scalar(ScalarQuantization), Scalar(ScalarQuantization<S>),
Product(ProductQuantization), Product(ProductQuantization<S>),
} }
impl Quantization { impl<S: G> Quantization<S> {
pub fn create( pub fn create(
path: PathBuf, path: PathBuf,
options: IndexOptions, options: IndexOptions,
quantization_options: QuantizationOptions, quantization_options: QuantizationOptions,
raw: &Arc<Raw>, raw: &Arc<Raw<S>>,
) -> Self { ) -> Self {
match quantization_options { match quantization_options {
QuantizationOptions::Trivial(_) => Self::Trivial(TrivialQuantization::create( QuantizationOptions::Trivial(_) => Self::Trivial(TrivialQuantization::create(
@@ -112,7 +112,7 @@ impl Quantization {
path: PathBuf, path: PathBuf,
options: IndexOptions, options: IndexOptions,
quantization_options: QuantizationOptions, quantization_options: QuantizationOptions,
raw: &Arc<Raw>, raw: &Arc<Raw<S>>,
) -> Self { ) -> Self {
match quantization_options { match quantization_options {
QuantizationOptions::Trivial(_) => Self::Trivial(TrivialQuantization::open( QuantizationOptions::Trivial(_) => Self::Trivial(TrivialQuantization::open(
@@ -136,21 +136,21 @@ impl Quantization {
} }
} }
pub fn distance(&self, d: Distance, lhs: &[Scalar], rhs: u32) -> Scalar { pub fn distance(&self, lhs: &[S::Scalar], rhs: u32) -> F32 {
use Quantization::*; use Quantization::*;
match self { match self {
Trivial(x) => x.distance(d, lhs, rhs), Trivial(x) => x.distance(lhs, rhs),
Scalar(x) => x.distance(d, lhs, rhs), Scalar(x) => x.distance(lhs, rhs),
Product(x) => x.distance(d, lhs, rhs), Product(x) => x.distance(lhs, rhs),
} }
} }
pub fn distance2(&self, d: Distance, lhs: u32, rhs: u32) -> Scalar { pub fn distance2(&self, lhs: u32, rhs: u32) -> F32 {
use Quantization::*; use Quantization::*;
match self { match self {
Trivial(x) => x.distance2(d, lhs, rhs), Trivial(x) => x.distance2(lhs, rhs),
Scalar(x) => x.distance2(d, lhs, rhs), Scalar(x) => x.distance2(lhs, rhs),
Product(x) => x.distance2(d, lhs, rhs), Product(x) => x.distance2(lhs, rhs),
} }
} }
} }

View File

@@ -55,17 +55,17 @@ impl Default for ProductQuantizationOptionsRatio {
} }
} }
pub struct ProductQuantization { pub struct ProductQuantization<S: G> {
dims: u16, dims: u16,
ratio: u16, ratio: u16,
centroids: Vec<Scalar>, centroids: Vec<S::Scalar>,
codes: MmapArray<u8>, codes: MmapArray<u8>,
} }
unsafe impl Send for ProductQuantization {} unsafe impl<S: G> Send for ProductQuantization<S> {}
unsafe impl Sync for ProductQuantization {} unsafe impl<S: G> Sync for ProductQuantization<S> {}
impl ProductQuantization { impl<S: G> ProductQuantization<S> {
fn codes(&self, i: u32) -> &[u8] { fn codes(&self, i: u32) -> &[u8] {
let width = self.dims.div_ceil(self.ratio); let width = self.dims.div_ceil(self.ratio);
let s = i as usize * width as usize; let s = i as usize * width as usize;
@@ -74,12 +74,12 @@ impl ProductQuantization {
} }
} }
impl Quan for ProductQuantization { impl<S: G> Quan<S> for ProductQuantization<S> {
fn create( fn create(
path: PathBuf, path: PathBuf,
options: IndexOptions, options: IndexOptions,
quantization_options: QuantizationOptions, quantization_options: QuantizationOptions,
raw: &Arc<Raw>, raw: &Arc<Raw<S>>,
) -> Self { ) -> Self {
Self::with_normalizer(path, options, quantization_options, raw, |_, _| ()) Self::with_normalizer(path, options, quantization_options, raw, |_, _| ())
} }
@@ -88,7 +88,7 @@ impl Quan for ProductQuantization {
path: PathBuf, path: PathBuf,
options: IndexOptions, options: IndexOptions,
quantization_options: QuantizationOptions, quantization_options: QuantizationOptions,
_: &Arc<Raw>, _: &Arc<Raw<S>>,
) -> Self { ) -> Self {
let centroids = let centroids =
serde_json::from_slice(&std::fs::read(path.join("centroids")).unwrap()).unwrap(); serde_json::from_slice(&std::fs::read(path.join("centroids")).unwrap()).unwrap();
@@ -101,32 +101,32 @@ impl Quan for ProductQuantization {
} }
} }
fn distance(&self, d: Distance, lhs: &[Scalar], rhs: u32) -> Scalar { fn distance(&self, lhs: &[S::Scalar], rhs: u32) -> F32 {
let dims = self.dims; let dims = self.dims;
let ratio = self.ratio; let ratio = self.ratio;
let rhs = self.codes(rhs); let rhs = self.codes(rhs);
d.product_quantization_distance(dims, ratio, &self.centroids, lhs, rhs) S::product_quantization_distance(dims, ratio, &self.centroids, lhs, rhs)
} }
fn distance2(&self, d: Distance, lhs: u32, rhs: u32) -> Scalar { fn distance2(&self, lhs: u32, rhs: u32) -> F32 {
let dims = self.dims; let dims = self.dims;
let ratio = self.ratio; let ratio = self.ratio;
let lhs = self.codes(lhs); let lhs = self.codes(lhs);
let rhs = self.codes(rhs); let rhs = self.codes(rhs);
d.product_quantization_distance2(dims, ratio, &self.centroids, lhs, rhs) S::product_quantization_distance2(dims, ratio, &self.centroids, lhs, rhs)
} }
} }
impl ProductQuantization { impl<S: G> ProductQuantization<S> {
pub fn with_normalizer<F>( pub fn with_normalizer<F>(
path: PathBuf, path: PathBuf,
options: IndexOptions, options: IndexOptions,
quantization_options: QuantizationOptions, quantization_options: QuantizationOptions,
raw: &Raw, raw: &Raw<S>,
normalizer: F, normalizer: F,
) -> Self ) -> Self
where where
F: Fn(u32, &mut [Scalar]), F: Fn(u32, &mut [S::Scalar]),
{ {
std::fs::create_dir(&path).unwrap(); std::fs::create_dir(&path).unwrap();
let quantization_options = quantization_options.unwrap_product_quantization(); let quantization_options = quantization_options.unwrap_product_quantization();
@@ -136,22 +136,22 @@ impl ProductQuantization {
let m = std::cmp::min(n, quantization_options.sample); let m = std::cmp::min(n, quantization_options.sample);
let samples = { let samples = {
let f = sample(&mut thread_rng(), n as usize, m as usize).into_vec(); let f = sample(&mut thread_rng(), n as usize, m as usize).into_vec();
let mut samples = Vec2::new(options.vector.dims, m as usize); let mut samples = Vec2::<S>::new(options.vector.dims, m as usize);
for i in 0..m { for i in 0..m {
samples[i as usize].copy_from_slice(raw.vector(f[i as usize] as u32)); samples[i as usize].copy_from_slice(raw.vector(f[i as usize] as u32));
} }
samples samples
}; };
let width = dims.div_ceil(ratio); let width = dims.div_ceil(ratio);
let mut centroids = vec![Scalar::Z; 256 * dims as usize]; let mut centroids = vec![S::Scalar::zero(); 256 * dims as usize];
for i in 0..width { for i in 0..width {
let subdims = std::cmp::min(ratio, dims - ratio * i); let subdims = std::cmp::min(ratio, dims - ratio * i);
let mut subsamples = Vec2::new(subdims, m as usize); let mut subsamples = Vec2::<S::L2>::new(subdims, m as usize);
for j in 0..m { for j in 0..m {
let src = &samples[j as usize][(i * ratio) as usize..][..subdims as usize]; let src = &samples[j as usize][(i * ratio) as usize..][..subdims as usize];
subsamples[j as usize].copy_from_slice(src); subsamples[j as usize].copy_from_slice(src);
} }
let mut k_means = ElkanKMeans::new(256, subsamples, Distance::L2); let mut k_means = ElkanKMeans::<S::L2>::new(256, subsamples);
for _ in 0..25 { for _ in 0..25 {
if k_means.iterate() { if k_means.iterate() {
break; break;
@@ -170,13 +170,13 @@ impl ProductQuantization {
let mut result = Vec::with_capacity(width as usize); let mut result = Vec::with_capacity(width as usize);
for i in 0..width { for i in 0..width {
let subdims = std::cmp::min(ratio, dims - ratio * i); let subdims = std::cmp::min(ratio, dims - ratio * i);
let mut minimal = Scalar::INFINITY; let mut minimal = F32::infinity();
let mut target = 0u8; let mut target = 0u8;
let left = &vector[(i * ratio) as usize..][..subdims as usize]; let left = &vector[(i * ratio) as usize..][..subdims as usize];
for j in 0u8..=255 { for j in 0u8..=255 {
let right = &centroids[j as usize * dims as usize..][(i * ratio) as usize..] let right = &centroids[j as usize * dims as usize..][(i * ratio) as usize..]
[..subdims as usize]; [..subdims as usize];
let dis = Distance::L2.distance(left, right); let dis = S::L2::distance(left, right);
if dis < minimal { if dis < minimal {
minimal = dis; minimal = dis;
target = j; target = j;
@@ -201,16 +201,10 @@ impl ProductQuantization {
} }
} }
pub fn distance_with_delta( pub fn distance_with_delta(&self, lhs: &[S::Scalar], rhs: u32, delta: &[S::Scalar]) -> F32 {
&self,
d: Distance,
lhs: &[Scalar],
rhs: u32,
delta: &[Scalar],
) -> Scalar {
let dims = self.dims; let dims = self.dims;
let ratio = self.ratio; let ratio = self.ratio;
let rhs = self.codes(rhs); let rhs = self.codes(rhs);
d.product_quantization_distance_with_delta(dims, ratio, &self.centroids, lhs, rhs, delta) S::product_quantization_distance_with_delta(dims, ratio, &self.centroids, lhs, rhs, delta)
} }
} }

View File

@@ -19,17 +19,17 @@ impl Default for ScalarQuantizationOptions {
} }
} }
pub struct ScalarQuantization { pub struct ScalarQuantization<S: G> {
dims: u16, dims: u16,
max: Vec<Scalar>, max: Vec<S::Scalar>,
min: Vec<Scalar>, min: Vec<S::Scalar>,
codes: MmapArray<u8>, codes: MmapArray<u8>,
} }
unsafe impl Send for ScalarQuantization {} unsafe impl<S: G> Send for ScalarQuantization<S> {}
unsafe impl Sync for ScalarQuantization {} unsafe impl<S: G> Sync for ScalarQuantization<S> {}
impl ScalarQuantization { impl<S: G> ScalarQuantization<S> {
fn codes(&self, i: u32) -> &[u8] { fn codes(&self, i: u32) -> &[u8] {
let s = i as usize * self.dims as usize; let s = i as usize * self.dims as usize;
let e = (i + 1) as usize * self.dims as usize; let e = (i + 1) as usize * self.dims as usize;
@@ -37,17 +37,17 @@ impl ScalarQuantization {
} }
} }
impl Quan for ScalarQuantization { impl<S: G> Quan<S> for ScalarQuantization<S> {
fn create( fn create(
path: PathBuf, path: PathBuf,
options: IndexOptions, options: IndexOptions,
_: QuantizationOptions, _: QuantizationOptions,
raw: &Arc<Raw>, raw: &Arc<Raw<S>>,
) -> Self { ) -> Self {
std::fs::create_dir(&path).unwrap(); std::fs::create_dir(&path).unwrap();
let dims = options.vector.dims; let dims = options.vector.dims;
let mut max = vec![Scalar::NEG_INFINITY; dims as usize]; let mut max = vec![S::Scalar::neg_infinity(); dims as usize];
let mut min = vec![Scalar::INFINITY; dims as usize]; let mut min = vec![S::Scalar::infinity(); dims as usize];
let n = raw.len(); let n = raw.len();
for i in 0..n { for i in 0..n {
let vector = raw.vector(i); let vector = raw.vector(i);
@@ -62,7 +62,7 @@ impl Quan for ScalarQuantization {
let vector = raw.vector(i); let vector = raw.vector(i);
let mut result = vec![0u8; dims as usize]; let mut result = vec![0u8; dims as usize];
for i in 0..dims as usize { for i in 0..dims as usize {
let w = ((vector[i] - min[i]) / (max[i] - min[i]) * 256.0).0 as u32; let w = (((vector[i] - min[i]) / (max[i] - min[i])).to_f32() * 256.0) as u32;
result[i] = w.clamp(0, 255) as u8; result[i] = w.clamp(0, 255) as u8;
} }
result.into_iter() result.into_iter()
@@ -77,7 +77,7 @@ impl Quan for ScalarQuantization {
} }
} }
fn open(path: PathBuf, options: IndexOptions, _: QuantizationOptions, _: &Arc<Raw>) -> Self { fn open(path: PathBuf, options: IndexOptions, _: QuantizationOptions, _: &Arc<Raw<S>>) -> Self {
let dims = options.vector.dims; let dims = options.vector.dims;
let max = serde_json::from_slice(&std::fs::read("max").unwrap()).unwrap(); let max = serde_json::from_slice(&std::fs::read("max").unwrap()).unwrap();
let min = serde_json::from_slice(&std::fs::read("min").unwrap()).unwrap(); let min = serde_json::from_slice(&std::fs::read("min").unwrap()).unwrap();
@@ -90,16 +90,16 @@ impl Quan for ScalarQuantization {
} }
} }
fn distance(&self, d: Distance, lhs: &[Scalar], rhs: u32) -> Scalar { fn distance(&self, lhs: &[S::Scalar], rhs: u32) -> F32 {
let dims = self.dims; let dims = self.dims;
let rhs = self.codes(rhs); let rhs = self.codes(rhs);
d.scalar_quantization_distance(dims, &self.max, &self.min, lhs, rhs) S::scalar_quantization_distance(dims, &self.max, &self.min, lhs, rhs)
} }
fn distance2(&self, d: Distance, lhs: u32, rhs: u32) -> Scalar { fn distance2(&self, lhs: u32, rhs: u32) -> F32 {
let dims = self.dims; let dims = self.dims;
let lhs = self.codes(lhs); let lhs = self.codes(lhs);
let rhs = self.codes(rhs); let rhs = self.codes(rhs);
d.scalar_quantization_distance2(dims, &self.max, &self.min, lhs, rhs) S::scalar_quantization_distance2(dims, &self.max, &self.min, lhs, rhs)
} }
} }

View File

@@ -17,24 +17,24 @@ impl Default for TrivialQuantizationOptions {
} }
} }
pub struct TrivialQuantization { pub struct TrivialQuantization<S: G> {
raw: Arc<Raw>, raw: Arc<Raw<S>>,
} }
impl Quan for TrivialQuantization { impl<S: G> Quan<S> for TrivialQuantization<S> {
fn create(_: PathBuf, _: IndexOptions, _: QuantizationOptions, raw: &Arc<Raw>) -> Self { fn create(_: PathBuf, _: IndexOptions, _: QuantizationOptions, raw: &Arc<Raw<S>>) -> Self {
Self { raw: raw.clone() } Self { raw: raw.clone() }
} }
fn open(_: PathBuf, _: IndexOptions, _: QuantizationOptions, raw: &Arc<Raw>) -> Self { fn open(_: PathBuf, _: IndexOptions, _: QuantizationOptions, raw: &Arc<Raw<S>>) -> Self {
Self { raw: raw.clone() } Self { raw: raw.clone() }
} }
fn distance(&self, d: Distance, lhs: &[Scalar], rhs: u32) -> Scalar { fn distance(&self, lhs: &[S::Scalar], rhs: u32) -> F32 {
d.distance(lhs, self.raw.vector(rhs)) S::distance(lhs, self.raw.vector(rhs))
} }
fn distance2(&self, d: Distance, lhs: u32, rhs: u32) -> Scalar { fn distance2(&self, lhs: u32, rhs: u32) -> F32 {
d.distance(self.raw.vector(lhs), self.raw.vector(rhs)) S::distance(self.raw.vector(lhs), self.raw.vector(rhs))
} }
} }

View File

@@ -6,16 +6,16 @@ use crate::utils::mmap_array::MmapArray;
use std::path::PathBuf; use std::path::PathBuf;
use std::sync::Arc; use std::sync::Arc;
pub struct Raw { pub struct Raw<S: G> {
mmap: RawMmap, mmap: RawMmap<S>,
} }
impl Raw { impl<S: G> Raw<S> {
pub fn create( pub fn create(
path: PathBuf, path: PathBuf,
options: IndexOptions, options: IndexOptions,
sealed: Vec<Arc<SealedSegment>>, sealed: Vec<Arc<SealedSegment<S>>>,
growing: Vec<Arc<GrowingSegment>>, growing: Vec<Arc<GrowingSegment<S>>>,
) -> Self { ) -> Self {
std::fs::create_dir(&path).unwrap(); std::fs::create_dir(&path).unwrap();
let ram = make(sealed, growing, options); let ram = make(sealed, growing, options);
@@ -33,7 +33,7 @@ impl Raw {
self.mmap.len() self.mmap.len()
} }
pub fn vector(&self, i: u32) -> &[Scalar] { pub fn vector(&self, i: u32) -> &[S::Scalar] {
self.mmap.vector(i) self.mmap.vector(i)
} }
@@ -42,21 +42,21 @@ impl Raw {
} }
} }
unsafe impl Send for Raw {} unsafe impl<S: G> Send for Raw<S> {}
unsafe impl Sync for Raw {} unsafe impl<S: G> Sync for Raw<S> {}
struct RawRam { struct RawRam<S: G> {
sealed: Vec<Arc<SealedSegment>>, sealed: Vec<Arc<SealedSegment<S>>>,
growing: Vec<Arc<GrowingSegment>>, growing: Vec<Arc<GrowingSegment<S>>>,
dims: u16, dims: u16,
} }
impl RawRam { impl<S: G> RawRam<S> {
fn len(&self) -> u32 { fn len(&self) -> u32 {
self.sealed.iter().map(|x| x.len()).sum::<u32>() self.sealed.iter().map(|x| x.len()).sum::<u32>()
+ self.growing.iter().map(|x| x.len()).sum::<u32>() + self.growing.iter().map(|x| x.len()).sum::<u32>()
} }
fn vector(&self, mut index: u32) -> &[Scalar] { fn vector(&self, mut index: u32) -> &[S::Scalar] {
for x in self.sealed.iter() { for x in self.sealed.iter() {
if index < x.len() { if index < x.len() {
return x.vector(index); return x.vector(index);
@@ -88,18 +88,18 @@ impl RawRam {
} }
} }
struct RawMmap { struct RawMmap<S: G> {
vectors: MmapArray<Scalar>, vectors: MmapArray<S::Scalar>,
payload: MmapArray<Payload>, payload: MmapArray<Payload>,
dims: u16, dims: u16,
} }
impl RawMmap { impl<S: G> RawMmap<S> {
fn len(&self) -> u32 { fn len(&self) -> u32 {
self.payload.len() as u32 self.payload.len() as u32
} }
fn vector(&self, i: u32) -> &[Scalar] { fn vector(&self, i: u32) -> &[S::Scalar] {
let s = i as usize * self.dims as usize; let s = i as usize * self.dims as usize;
let e = (i + 1) as usize * self.dims as usize; let e = (i + 1) as usize * self.dims as usize;
&self.vectors[s..e] &self.vectors[s..e]
@@ -110,14 +110,14 @@ impl RawMmap {
} }
} }
unsafe impl Send for RawMmap {} unsafe impl<S: G> Send for RawMmap<S> {}
unsafe impl Sync for RawMmap {} unsafe impl<S: G> Sync for RawMmap<S> {}
fn make( fn make<S: G>(
sealed: Vec<Arc<SealedSegment>>, sealed: Vec<Arc<SealedSegment<S>>>,
growing: Vec<Arc<GrowingSegment>>, growing: Vec<Arc<GrowingSegment<S>>>,
options: IndexOptions, options: IndexOptions,
) -> RawRam { ) -> RawRam<S> {
RawRam { RawRam {
sealed, sealed,
growing, growing,
@@ -125,7 +125,7 @@ fn make(
} }
} }
fn save(ram: RawRam, path: PathBuf) -> RawMmap { fn save<S: G>(ram: RawRam<S>, path: PathBuf) -> RawMmap<S> {
let n = ram.len(); let n = ram.len();
let vectors_iter = (0..n).flat_map(|i| ram.vector(i)).copied(); let vectors_iter = (0..n).flat_map(|i| ram.vector(i)).copied();
let payload_iter = (0..n).map(|i| ram.payload(i)); let payload_iter = (0..n).map(|i| ram.payload(i));
@@ -138,8 +138,8 @@ fn save(ram: RawRam, path: PathBuf) -> RawMmap {
} }
} }
fn load(path: PathBuf, options: IndexOptions) -> RawMmap { fn load<S: G>(path: PathBuf, options: IndexOptions) -> RawMmap<S> {
let vectors: MmapArray<Scalar> = MmapArray::open(path.join("vectors")); let vectors = MmapArray::open(path.join("vectors"));
let payload = MmapArray::open(path.join("payload")); let payload = MmapArray::open(path.join("payload"));
RawMmap { RawMmap {
vectors, vectors,

View File

@@ -24,16 +24,16 @@ impl Default for FlatIndexingOptions {
} }
} }
pub struct FlatIndexing { pub struct FlatIndexing<S: G> {
raw: crate::algorithms::flat::Flat, raw: crate::algorithms::flat::Flat<S>,
} }
impl AbstractIndexing for FlatIndexing { impl<S: G> AbstractIndexing<S> for FlatIndexing<S> {
fn create( fn create(
path: PathBuf, path: PathBuf,
options: IndexOptions, options: IndexOptions,
sealed: Vec<Arc<SealedSegment>>, sealed: Vec<Arc<SealedSegment<S>>>,
growing: Vec<Arc<GrowingSegment>>, growing: Vec<Arc<GrowingSegment<S>>>,
) -> Self { ) -> Self {
let raw = Flat::create(path, options, sealed, growing); let raw = Flat::create(path, options, sealed, growing);
Self { raw } Self { raw }
@@ -48,7 +48,7 @@ impl AbstractIndexing for FlatIndexing {
self.raw.len() self.raw.len()
} }
fn vector(&self, i: u32) -> &[Scalar] { fn vector(&self, i: u32) -> &[S::Scalar] {
self.raw.vector(i) self.raw.vector(i)
} }
@@ -56,7 +56,7 @@ impl AbstractIndexing for FlatIndexing {
self.raw.payload(i) self.raw.payload(i)
} }
fn search(&self, k: usize, vector: &[Scalar], filter: &mut impl Filter) -> Heap { fn search(&self, k: usize, vector: &[S::Scalar], filter: &mut impl Filter) -> Heap {
self.raw.search(k, vector, filter) self.raw.search(k, vector, filter)
} }
} }

View File

@@ -41,16 +41,16 @@ impl Default for HnswIndexingOptions {
} }
} }
pub struct HnswIndexing { pub struct HnswIndexing<S: G> {
raw: Hnsw, raw: Hnsw<S>,
} }
impl AbstractIndexing for HnswIndexing { impl<S: G> AbstractIndexing<S> for HnswIndexing<S> {
fn create( fn create(
path: PathBuf, path: PathBuf,
options: IndexOptions, options: IndexOptions,
sealed: Vec<Arc<SealedSegment>>, sealed: Vec<Arc<SealedSegment<S>>>,
growing: Vec<Arc<GrowingSegment>>, growing: Vec<Arc<GrowingSegment<S>>>,
) -> Self { ) -> Self {
let raw = Hnsw::create(path, options, sealed, growing); let raw = Hnsw::create(path, options, sealed, growing);
Self { raw } Self { raw }
@@ -65,7 +65,7 @@ impl AbstractIndexing for HnswIndexing {
self.raw.len() self.raw.len()
} }
fn vector(&self, i: u32) -> &[Scalar] { fn vector(&self, i: u32) -> &[S::Scalar] {
self.raw.vector(i) self.raw.vector(i)
} }
@@ -73,17 +73,13 @@ impl AbstractIndexing for HnswIndexing {
self.raw.payload(i) self.raw.payload(i)
} }
fn search(&self, k: usize, vector: &[Scalar], filter: &mut impl Filter) -> Heap { fn search(&self, k: usize, vector: &[S::Scalar], filter: &mut impl Filter) -> Heap {
self.raw.search(k, vector, filter) self.raw.search(k, vector, filter)
} }
} }
impl HnswIndexing { impl<S: G> HnswIndexing<S> {
pub fn search_vbase<'index, 'vector>( pub fn search_vbase(&self, range: usize, vector: &[S::Scalar]) -> HnswIndexIter<'_, S> {
&'index self,
range: usize,
vector: &'vector [Scalar],
) -> HnswIndexIter<'index, 'vector> {
self.raw.search_vbase(range, vector) self.raw.search_vbase(range, vector)
} }
} }

View File

@@ -4,7 +4,6 @@ use crate::algorithms::quantization::QuantizationOptions;
use crate::index::segments::growing::GrowingSegment; use crate::index::segments::growing::GrowingSegment;
use crate::index::segments::sealed::SealedSegment; use crate::index::segments::sealed::SealedSegment;
use crate::index::IndexOptions; use crate::index::IndexOptions;
use crate::prelude::Scalar;
use crate::prelude::*; use crate::prelude::*;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::path::PathBuf; use std::path::PathBuf;
@@ -64,16 +63,16 @@ impl Default for IvfIndexingOptions {
} }
} }
pub struct IvfIndexing { pub struct IvfIndexing<S: G> {
raw: Ivf, raw: Ivf<S>,
} }
impl AbstractIndexing for IvfIndexing { impl<S: G> AbstractIndexing<S> for IvfIndexing<S> {
fn create( fn create(
path: PathBuf, path: PathBuf,
options: IndexOptions, options: IndexOptions,
sealed: Vec<Arc<SealedSegment>>, sealed: Vec<Arc<SealedSegment<S>>>,
growing: Vec<Arc<GrowingSegment>>, growing: Vec<Arc<GrowingSegment<S>>>,
) -> Self { ) -> Self {
let raw = Ivf::create(path, options, sealed, growing); let raw = Ivf::create(path, options, sealed, growing);
Self { raw } Self { raw }
@@ -88,7 +87,7 @@ impl AbstractIndexing for IvfIndexing {
self.raw.len() self.raw.len()
} }
fn vector(&self, i: u32) -> &[Scalar] { fn vector(&self, i: u32) -> &[S::Scalar] {
self.raw.vector(i) self.raw.vector(i)
} }
@@ -96,7 +95,7 @@ impl AbstractIndexing for IvfIndexing {
self.raw.payload(i) self.raw.payload(i)
} }
fn search(&self, k: usize, vector: &[Scalar], filter: &mut impl Filter) -> Heap { fn search(&self, k: usize, vector: &[S::Scalar], filter: &mut impl Filter) -> Heap {
self.raw.search(k, vector, filter) self.raw.search(k, vector, filter)
} }
} }

View File

@@ -60,36 +60,36 @@ impl Validate for IndexingOptions {
} }
} }
pub trait AbstractIndexing: Sized { pub trait AbstractIndexing<S: G>: Sized {
fn create( fn create(
path: PathBuf, path: PathBuf,
options: IndexOptions, options: IndexOptions,
sealed: Vec<Arc<SealedSegment>>, sealed: Vec<Arc<SealedSegment<S>>>,
growing: Vec<Arc<GrowingSegment>>, growing: Vec<Arc<GrowingSegment<S>>>,
) -> Self; ) -> Self;
fn open(path: PathBuf, options: IndexOptions) -> Self; fn open(path: PathBuf, options: IndexOptions) -> Self;
fn len(&self) -> u32; fn len(&self) -> u32;
fn vector(&self, i: u32) -> &[Scalar]; fn vector(&self, i: u32) -> &[S::Scalar];
fn payload(&self, i: u32) -> Payload; fn payload(&self, i: u32) -> Payload;
fn search(&self, k: usize, vector: &[Scalar], filter: &mut impl Filter) -> Heap; fn search(&self, k: usize, vector: &[S::Scalar], filter: &mut impl Filter) -> Heap;
} }
pub enum DynamicIndexing { pub enum DynamicIndexing<S: G> {
Flat(FlatIndexing), Flat(FlatIndexing<S>),
Ivf(IvfIndexing), Ivf(IvfIndexing<S>),
Hnsw(HnswIndexing), Hnsw(HnswIndexing<S>),
} }
pub enum DynamicIndexIter<'index, 'vector> { pub enum DynamicIndexIter<'a, S: G> {
Hnsw(HnswIndexIter<'index, 'vector>), Hnsw(HnswIndexIter<'a, S>),
} }
impl DynamicIndexing { impl<S: G> DynamicIndexing<S> {
pub fn create( pub fn create(
path: PathBuf, path: PathBuf,
options: IndexOptions, options: IndexOptions,
sealed: Vec<Arc<SealedSegment>>, sealed: Vec<Arc<SealedSegment<S>>>,
growing: Vec<Arc<GrowingSegment>>, growing: Vec<Arc<GrowingSegment<S>>>,
) -> Self { ) -> Self {
match options.indexing { match options.indexing {
IndexingOptions::Flat(_) => { IndexingOptions::Flat(_) => {
@@ -120,7 +120,7 @@ impl DynamicIndexing {
} }
} }
pub fn vector(&self, i: u32) -> &[Scalar] { pub fn vector(&self, i: u32) -> &[S::Scalar] {
match self { match self {
DynamicIndexing::Flat(x) => x.vector(i), DynamicIndexing::Flat(x) => x.vector(i),
DynamicIndexing::Ivf(x) => x.vector(i), DynamicIndexing::Ivf(x) => x.vector(i),
@@ -136,7 +136,7 @@ impl DynamicIndexing {
} }
} }
pub fn search(&self, k: usize, vector: &[Scalar], filter: &mut impl Filter) -> Heap { pub fn search(&self, k: usize, vector: &[S::Scalar], filter: &mut impl Filter) -> Heap {
match self { match self {
DynamicIndexing::Flat(x) => x.search(k, vector, filter), DynamicIndexing::Flat(x) => x.search(k, vector, filter),
DynamicIndexing::Ivf(x) => x.search(k, vector, filter), DynamicIndexing::Ivf(x) => x.search(k, vector, filter),
@@ -144,11 +144,7 @@ impl DynamicIndexing {
} }
} }
pub fn search_vbase<'index, 'vector>( pub fn vbase(&self, range: usize, vector: &[S::Scalar]) -> DynamicIndexIter<'_, S> {
&'index self,
range: usize,
vector: &'vector [Scalar],
) -> DynamicIndexIter<'index, 'vector> {
use DynamicIndexIter::*; use DynamicIndexIter::*;
match self { match self {
DynamicIndexing::Hnsw(x) => Hnsw(x.search_vbase(range, vector)), DynamicIndexing::Hnsw(x) => Hnsw(x.search_vbase(range, vector)),
@@ -157,7 +153,7 @@ impl DynamicIndexing {
} }
} }
impl Iterator for DynamicIndexIter<'_, '_> { impl<S: G> Iterator for DynamicIndexIter<'_, S> {
type Item = HeapElement; type Item = HeapElement;
fn next(&mut self) -> Option<Self::Item> { fn next(&mut self) -> Option<Self::Item> {
use DynamicIndexIter::*; use DynamicIndexIter::*;

View File

@@ -0,0 +1,506 @@
pub mod delete;
pub mod indexing;
pub mod optimizing;
pub mod segments;
use self::delete::Delete;
use self::indexing::IndexingOptions;
use self::optimizing::OptimizingOptions;
use self::segments::growing::GrowingSegment;
use self::segments::growing::GrowingSegmentInsertError;
use self::segments::sealed::SealedSegment;
use self::segments::SegmentsOptions;
use crate::index::indexing::DynamicIndexIter;
use crate::index::optimizing::indexing::OptimizerIndexing;
use crate::index::optimizing::sealing::OptimizerSealing;
use crate::prelude::*;
use crate::utils::clean::clean;
use crate::utils::dir_ops::sync_dir;
use crate::utils::file_atomic::FileAtomic;
use arc_swap::ArcSwap;
use crossbeam::atomic::AtomicCell;
use parking_lot::Mutex;
use serde::{Deserialize, Serialize};
use std::cmp::Reverse;
use std::collections::BinaryHeap;
use std::collections::HashMap;
use std::collections::HashSet;
use std::path::PathBuf;
use std::sync::Arc;
use std::time::Instant;
use thiserror::Error;
use uuid::Uuid;
use validator::Validate;
#[derive(Debug, Error)]
#[error("The index view is outdated.")]
pub struct OutdatedError(#[from] pub Option<GrowingSegmentInsertError>);
#[derive(Debug, Clone, Serialize, Deserialize, Validate)]
pub struct VectorOptions {
#[validate(range(min = 1, max = 65535))]
#[serde(rename = "dimensions")]
pub dims: u16,
#[serde(rename = "distance")]
pub d: Distance,
#[serde(rename = "kind")]
pub k: Kind,
}
#[derive(Debug, Clone, Serialize, Deserialize, Validate)]
pub struct IndexOptions {
#[validate]
pub vector: VectorOptions,
#[validate]
pub segment: SegmentsOptions,
#[validate]
pub optimizing: OptimizingOptions,
#[validate]
pub indexing: IndexingOptions,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct IndexStat {
pub indexing: bool,
pub sealed: Vec<u32>,
pub growing: Vec<u32>,
pub write: u32,
pub options: IndexOptions,
}
pub struct Index<S: G> {
path: PathBuf,
options: IndexOptions,
delete: Arc<Delete>,
protect: Mutex<IndexProtect<S>>,
view: ArcSwap<IndexView<S>>,
instant_index: AtomicCell<Instant>,
instant_write: AtomicCell<Instant>,
_tracker: Arc<IndexTracker>,
}
impl<S: G> Index<S> {
pub fn create(path: PathBuf, options: IndexOptions) -> Arc<Self> {
assert!(options.validate().is_ok());
std::fs::create_dir(&path).unwrap();
std::fs::create_dir(path.join("segments")).unwrap();
let startup = FileAtomic::create(
path.join("startup"),
IndexStartup {
sealeds: HashSet::new(),
growings: HashSet::new(),
},
);
let delete = Delete::create(path.join("delete"));
sync_dir(&path);
let index = Arc::new(Index {
path: path.clone(),
options: options.clone(),
delete: delete.clone(),
protect: Mutex::new(IndexProtect {
startup,
sealed: HashMap::new(),
growing: HashMap::new(),
write: None,
}),
view: ArcSwap::new(Arc::new(IndexView {
options: options.clone(),
sealed: HashMap::new(),
growing: HashMap::new(),
delete: delete.clone(),
write: None,
})),
instant_index: AtomicCell::new(Instant::now()),
instant_write: AtomicCell::new(Instant::now()),
_tracker: Arc::new(IndexTracker { path }),
});
OptimizerIndexing::new(index.clone()).spawn();
OptimizerSealing::new(index.clone()).spawn();
index
}
pub fn open(path: PathBuf, options: IndexOptions) -> Arc<Self> {
let tracker = Arc::new(IndexTracker { path: path.clone() });
let startup = FileAtomic::<IndexStartup>::open(path.join("startup"));
clean(
path.join("segments"),
startup
.get()
.sealeds
.iter()
.map(|s| s.to_string())
.chain(startup.get().growings.iter().map(|s| s.to_string())),
);
let sealed = startup
.get()
.sealeds
.iter()
.map(|&uuid| {
(
uuid,
SealedSegment::open(
tracker.clone(),
path.join("segments").join(uuid.to_string()),
uuid,
options.clone(),
),
)
})
.collect::<HashMap<_, _>>();
let growing = startup
.get()
.growings
.iter()
.map(|&uuid| {
(
uuid,
GrowingSegment::open(
tracker.clone(),
path.join("segments").join(uuid.to_string()),
uuid,
),
)
})
.collect::<HashMap<_, _>>();
let delete = Delete::open(path.join("delete"));
let index = Arc::new(Index {
path: path.clone(),
options: options.clone(),
delete: delete.clone(),
protect: Mutex::new(IndexProtect {
startup,
sealed: sealed.clone(),
growing: growing.clone(),
write: None,
}),
view: ArcSwap::new(Arc::new(IndexView {
options: options.clone(),
delete: delete.clone(),
sealed,
growing,
write: None,
})),
instant_index: AtomicCell::new(Instant::now()),
instant_write: AtomicCell::new(Instant::now()),
_tracker: tracker,
});
OptimizerIndexing::new(index.clone()).spawn();
OptimizerSealing::new(index.clone()).spawn();
index
}
pub fn options(&self) -> &IndexOptions {
&self.options
}
pub fn view(&self) -> Arc<IndexView<S>> {
self.view.load_full()
}
pub fn refresh(&self) {
let mut protect = self.protect.lock();
if let Some((uuid, write)) = protect.write.clone() {
if !write.is_full() {
return;
}
write.seal();
protect.growing.insert(uuid, write);
}
let write_segment_uuid = Uuid::new_v4();
let write_segment = GrowingSegment::create(
self._tracker.clone(),
self.path
.join("segments")
.join(write_segment_uuid.to_string()),
write_segment_uuid,
self.options.clone(),
);
protect.write = Some((write_segment_uuid, write_segment));
protect.maintain(self.options.clone(), self.delete.clone(), &self.view);
self.instant_write.store(Instant::now());
}
pub fn seal(&self, check: Uuid) {
let mut protect = self.protect.lock();
if let Some((uuid, write)) = protect.write.clone() {
if check != uuid {
return;
}
write.seal();
protect.growing.insert(uuid, write);
}
protect.write = None;
protect.maintain(self.options.clone(), self.delete.clone(), &self.view);
self.instant_write.store(Instant::now());
}
pub fn stat(&self) -> IndexStat {
let view = self.view();
IndexStat {
indexing: self.instant_index.load() < self.instant_write.load(),
sealed: view.sealed.values().map(|x| x.len()).collect(),
growing: view.growing.values().map(|x| x.len()).collect(),
write: view.write.as_ref().map(|(_, x)| x.len()).unwrap_or(0),
options: self.options().clone(),
}
}
}
impl<S: G> Drop for Index<S> {
fn drop(&mut self) {}
}
#[derive(Debug, Clone)]
pub struct IndexTracker {
path: PathBuf,
}
impl Drop for IndexTracker {
fn drop(&mut self) {
std::fs::remove_dir_all(&self.path).unwrap();
}
}
pub struct IndexView<S: G> {
pub options: IndexOptions,
pub delete: Arc<Delete>,
pub sealed: HashMap<Uuid, Arc<SealedSegment<S>>>,
pub growing: HashMap<Uuid, Arc<GrowingSegment<S>>>,
pub write: Option<(Uuid, Arc<GrowingSegment<S>>)>,
}
impl<S: G> IndexView<S> {
pub fn search<F: FnMut(Pointer) -> bool>(
&self,
k: usize,
vector: &[S::Scalar],
mut filter: F,
) -> Vec<Pointer> {
assert_eq!(self.options.vector.dims as usize, vector.len());
struct Comparer(BinaryHeap<Reverse<HeapElement>>);
impl PartialEq for Comparer {
fn eq(&self, other: &Self) -> bool {
self.cmp(other).is_eq()
}
}
impl Eq for Comparer {}
impl PartialOrd for Comparer {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl Ord for Comparer {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.0.peek().cmp(&other.0.peek()).reverse()
}
}
let mut filter = |payload| {
if let Some(p) = self.delete.check(payload) {
filter(p)
} else {
false
}
};
let n = self.sealed.len() + self.growing.len() + 1;
let mut result = Heap::new(k);
let mut heaps = BinaryHeap::with_capacity(1 + n);
for (_, sealed) in self.sealed.iter() {
let p = sealed.search(k, vector, &mut filter).into_reversed_heap();
heaps.push(Comparer(p));
}
for (_, growing) in self.growing.iter() {
let p = growing.search(k, vector, &mut filter).into_reversed_heap();
heaps.push(Comparer(p));
}
if let Some((_, write)) = &self.write {
let p = write.search(k, vector, &mut filter).into_reversed_heap();
heaps.push(Comparer(p));
}
while let Some(Comparer(mut heap)) = heaps.pop() {
if let Some(Reverse(x)) = heap.pop() {
result.push(x);
heaps.push(Comparer(heap));
}
}
result
.into_sorted_vec()
.iter()
.map(|x| Pointer::from_u48(x.payload >> 16))
.collect()
}
pub fn vbase(&self, vector: &[S::Scalar]) -> impl Iterator<Item = Pointer> + '_ {
assert_eq!(self.options.vector.dims as usize, vector.len());
let range = 86;
struct Comparer<'a, S: G> {
iter: ComparerIter<'a, S>,
item: Option<HeapElement>,
}
enum ComparerIter<'a, S: G> {
Sealed(DynamicIndexIter<'a, S>),
Growing(std::vec::IntoIter<HeapElement>),
}
impl<S: G> PartialEq for Comparer<'_, S> {
fn eq(&self, other: &Self) -> bool {
self.cmp(other).is_eq()
}
}
impl<S: G> Eq for Comparer<'_, S> {}
impl<S: G> PartialOrd for Comparer<'_, S> {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl<S: G> Ord for Comparer<'_, S> {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.item.cmp(&other.item).reverse()
}
}
impl<S: G> Iterator for ComparerIter<'_, S> {
type Item = HeapElement;
fn next(&mut self) -> Option<Self::Item> {
match self {
Self::Sealed(iter) => iter.next(),
Self::Growing(iter) => iter.next(),
}
}
}
impl<S: G> Iterator for Comparer<'_, S> {
type Item = HeapElement;
fn next(&mut self) -> Option<Self::Item> {
let item = self.item.take();
self.item = self.iter.next();
item
}
}
fn from_iter<S: G>(mut iter: ComparerIter<'_, S>) -> Comparer<'_, S> {
let item = iter.next();
Comparer { iter, item }
}
use ComparerIter::*;
let filter = |payload| self.delete.check(payload).is_some();
let n = self.sealed.len() + self.growing.len() + 1;
let mut heaps: BinaryHeap<Comparer<S>> = BinaryHeap::with_capacity(1 + n);
for (_, sealed) in self.sealed.iter() {
let res = sealed.vbase(range, vector);
heaps.push(from_iter(Sealed(res)));
}
for (_, growing) in self.growing.iter() {
let mut res = growing.vbase(vector);
res.sort_unstable();
heaps.push(from_iter(Growing(res.into_iter())));
}
if let Some((_, write)) = &self.write {
let mut res = write.vbase(vector);
res.sort_unstable();
heaps.push(from_iter(Growing(res.into_iter())));
}
std::iter::from_fn(move || {
while let Some(mut iter) = heaps.pop() {
if let Some(x) = iter.next() {
if !filter(x.payload) {
continue;
}
heaps.push(iter);
return Some(Pointer::from_u48(x.payload >> 16));
}
}
None
})
}
pub fn insert(&self, vector: Vec<S::Scalar>, pointer: Pointer) -> Result<(), OutdatedError> {
assert_eq!(self.options.vector.dims as usize, vector.len());
let payload = (pointer.as_u48() << 16) | self.delete.version(pointer) as Payload;
if let Some((_, growing)) = self.write.as_ref() {
Ok(growing.insert(vector, payload)?)
} else {
Err(OutdatedError(None))
}
}
pub fn delete<F: FnMut(Pointer) -> bool>(&self, mut f: F) {
for (_, sealed) in self.sealed.iter() {
let n = sealed.len();
for i in 0..n {
if let Some(p) = self.delete.check(sealed.payload(i)) {
if f(p) {
self.delete.delete(p);
}
}
}
}
for (_, growing) in self.growing.iter() {
let n = growing.len();
for i in 0..n {
if let Some(p) = self.delete.check(growing.payload(i)) {
if f(p) {
self.delete.delete(p);
}
}
}
}
if let Some((_, write)) = &self.write {
let n = write.len();
for i in 0..n {
if let Some(p) = self.delete.check(write.payload(i)) {
if f(p) {
self.delete.delete(p);
}
}
}
}
}
pub fn flush(&self) {
self.delete.flush();
if let Some((_, write)) = &self.write {
write.flush();
}
}
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
struct IndexStartup {
sealeds: HashSet<Uuid>,
growings: HashSet<Uuid>,
}
struct IndexProtect<S: G> {
startup: FileAtomic<IndexStartup>,
sealed: HashMap<Uuid, Arc<SealedSegment<S>>>,
growing: HashMap<Uuid, Arc<GrowingSegment<S>>>,
write: Option<(Uuid, Arc<GrowingSegment<S>>)>,
}
impl<S: G> IndexProtect<S> {
fn maintain(
&mut self,
options: IndexOptions,
delete: Arc<Delete>,
swap: &ArcSwap<IndexView<S>>,
) {
let view = Arc::new(IndexView {
options,
delete,
sealed: self.sealed.clone(),
growing: self.growing.clone(),
write: self.write.clone(),
});
let startup_write = self.write.as_ref().map(|(uuid, _)| *uuid);
let startup_sealeds = self.sealed.keys().copied().collect();
let startup_growings = self.growing.keys().copied().chain(startup_write).collect();
self.startup.set(IndexStartup {
sealeds: startup_sealeds,
growings: startup_growings,
});
swap.swap(view);
}
}

View File

@@ -0,0 +1,148 @@
use crate::index::GrowingSegment;
use crate::index::Index;
use crate::index::SealedSegment;
use crate::prelude::*;
use std::cmp::Reverse;
use std::sync::Arc;
use std::time::Instant;
use uuid::Uuid;
pub struct OptimizerIndexing<S: G> {
index: Arc<Index<S>>,
}
impl<S: G> OptimizerIndexing<S> {
pub fn new(index: Arc<Index<S>>) -> Self {
Self { index }
}
pub fn spawn(self) {
std::thread::spawn(move || {
self.main();
});
}
pub fn main(self) {
let index = self.index;
let pool = rayon::ThreadPoolBuilder::new()
.num_threads(index.options.optimizing.optimizing_threads)
.build()
.unwrap();
let weak_index = Arc::downgrade(&index);
std::mem::drop(index);
loop {
{
let Some(index) = weak_index.upgrade() else {
return;
};
if let Ok(()) = pool.install(|| optimizing_indexing(index.clone())) {
continue;
}
}
std::thread::sleep(std::time::Duration::from_secs(60));
}
}
}
enum Seg<S: G> {
Sealed(Arc<SealedSegment<S>>),
Growing(Arc<GrowingSegment<S>>),
}
impl<S: G> Seg<S> {
fn uuid(&self) -> Uuid {
use Seg::*;
match self {
Sealed(x) => x.uuid(),
Growing(x) => x.uuid(),
}
}
fn len(&self) -> u32 {
use Seg::*;
match self {
Sealed(x) => x.len(),
Growing(x) => x.len(),
}
}
fn get_sealed(&self) -> Option<Arc<SealedSegment<S>>> {
match self {
Seg::Sealed(x) => Some(x.clone()),
_ => None,
}
}
fn get_growing(&self) -> Option<Arc<GrowingSegment<S>>> {
match self {
Seg::Growing(x) => Some(x.clone()),
_ => None,
}
}
}
#[derive(Debug, thiserror::Error)]
#[error("Interrupted, retry again.")]
pub struct RetryError;
pub fn optimizing_indexing<S: G>(index: Arc<Index<S>>) -> Result<(), RetryError> {
use Seg::*;
let segs = {
let protect = index.protect.lock();
let mut segs_0 = Vec::new();
segs_0.extend(protect.growing.values().map(|x| Growing(x.clone())));
segs_0.extend(protect.sealed.values().map(|x| Sealed(x.clone())));
segs_0.sort_by_key(|case| Reverse(case.len()));
let mut segs_1 = Vec::new();
let mut total = 0u64;
let mut count = 0;
while let Some(seg) = segs_0.pop() {
if total + seg.len() as u64 <= index.options.segment.max_sealed_segment_size as u64 {
total += seg.len() as u64;
if let Growing(_) = seg {
count += 1;
}
segs_1.push(seg);
} else {
break;
}
}
if segs_1.is_empty() || (segs_1.len() == 1 && count == 0) {
index.instant_index.store(Instant::now());
return Err(RetryError);
}
segs_1
};
let sealed_segment = merge(&index, &segs);
{
let mut protect = index.protect.lock();
for seg in segs.iter() {
if protect.sealed.contains_key(&seg.uuid()) {
continue;
}
if protect.growing.contains_key(&seg.uuid()) {
continue;
}
return Ok(());
}
for seg in segs.iter() {
protect.sealed.remove(&seg.uuid());
protect.growing.remove(&seg.uuid());
}
protect.sealed.insert(sealed_segment.uuid(), sealed_segment);
protect.maintain(index.options.clone(), index.delete.clone(), &index.view);
}
Ok(())
}
fn merge<S: G>(index: &Arc<Index<S>>, segs: &[Seg<S>]) -> Arc<SealedSegment<S>> {
let sealed = segs.iter().filter_map(|x| x.get_sealed()).collect();
let growing = segs.iter().filter_map(|x| x.get_growing()).collect();
let sealed_segment_uuid = Uuid::new_v4();
SealedSegment::create(
index._tracker.clone(),
index
.path
.join("segments")
.join(sealed_segment_uuid.to_string()),
sealed_segment_uuid,
index.options.clone(),
sealed,
growing,
)
}

View File

@@ -1,4 +1,5 @@
pub mod indexing; pub mod indexing;
pub mod sealing;
pub mod vacuum; pub mod vacuum;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
@@ -6,9 +7,12 @@ use validator::Validate;
#[derive(Debug, Clone, Serialize, Deserialize, Validate)] #[derive(Debug, Clone, Serialize, Deserialize, Validate)]
pub struct OptimizingOptions { pub struct OptimizingOptions {
#[serde(default = "OptimizingOptions::default_waiting_secs", skip)] #[serde(default = "OptimizingOptions::default_sealing_secs")]
#[validate(range(min = 0, max = 600))] #[validate(range(min = 0, max = 60))]
pub waiting_secs: u64, pub sealing_secs: u64,
#[serde(default = "OptimizingOptions::default_sealing_size")]
#[validate(range(min = 1, max = 4_000_000_000))]
pub sealing_size: u32,
#[serde(default = "OptimizingOptions::default_deleted_threshold", skip)] #[serde(default = "OptimizingOptions::default_deleted_threshold", skip)]
#[validate(range(min = 0.01, max = 1.00))] #[validate(range(min = 0.01, max = 1.00))]
pub deleted_threshold: f64, pub deleted_threshold: f64,
@@ -18,9 +22,12 @@ pub struct OptimizingOptions {
} }
impl OptimizingOptions { impl OptimizingOptions {
fn default_waiting_secs() -> u64 { fn default_sealing_secs() -> u64 {
60 60
} }
fn default_sealing_size() -> u32 {
1
}
fn default_deleted_threshold() -> f64 { fn default_deleted_threshold() -> f64 {
0.2 0.2
} }
@@ -35,7 +42,8 @@ impl OptimizingOptions {
impl Default for OptimizingOptions { impl Default for OptimizingOptions {
fn default() -> Self { fn default() -> Self {
Self { Self {
waiting_secs: Self::default_waiting_secs(), sealing_secs: Self::default_sealing_secs(),
sealing_size: Self::default_sealing_size(),
deleted_threshold: Self::default_deleted_threshold(), deleted_threshold: Self::default_deleted_threshold(),
optimizing_threads: Self::default_optimizing_threads(), optimizing_threads: Self::default_optimizing_threads(),
} }

View File

@@ -0,0 +1,49 @@
use crate::index::Index;
use crate::prelude::*;
use std::sync::Arc;
use std::time::Duration;
pub struct OptimizerSealing<S: G> {
index: Arc<Index<S>>,
}
impl<S: G> OptimizerSealing<S> {
pub fn new(index: Arc<Index<S>>) -> Self {
Self { index }
}
pub fn spawn(self) {
std::thread::spawn(move || {
self.main();
});
}
pub fn main(self) {
let index = self.index;
let dur = Duration::from_secs(index.options.optimizing.sealing_secs);
let least = index.options.optimizing.sealing_size;
let weak_index = Arc::downgrade(&index);
std::mem::drop(index);
let mut check = None;
loop {
{
let Some(index) = weak_index.upgrade() else {
return;
};
let view = index.view();
let stamp = view
.write
.as_ref()
.map(|(uuid, segment)| (*uuid, segment.len()));
if stamp == check {
if let Some((uuid, len)) = stamp {
if len >= least {
index.seal(uuid);
}
}
} else {
check = stamp;
}
}
std::thread::sleep(dur);
}
}
}

View File

@@ -1,7 +1,8 @@
#![allow(clippy::all)] // Clippy bug.
use super::SegmentTracker; use super::SegmentTracker;
use crate::index::IndexOptions; use crate::index::IndexOptions;
use crate::index::IndexTracker; use crate::index::IndexTracker;
use crate::index::VectorOptions;
use crate::prelude::*; use crate::prelude::*;
use crate::utils::dir_ops::sync_dir; use crate::utils::dir_ops::sync_dir;
use crate::utils::file_wal::FileWal; use crate::utils::file_wal::FileWal;
@@ -19,17 +20,16 @@ use uuid::Uuid;
#[error("`GrowingSegment` stopped growing.")] #[error("`GrowingSegment` stopped growing.")]
pub struct GrowingSegmentInsertError; pub struct GrowingSegmentInsertError;
pub struct GrowingSegment { pub struct GrowingSegment<S: G> {
uuid: Uuid, uuid: Uuid,
options: VectorOptions, vec: Vec<UnsafeCell<MaybeUninit<Log<S>>>>,
vec: Vec<UnsafeCell<MaybeUninit<Log>>>,
wal: Mutex<FileWal>, wal: Mutex<FileWal>,
len: AtomicUsize, len: AtomicUsize,
pro: Mutex<Protect>, pro: Mutex<Protect>,
_tracker: Arc<SegmentTracker>, _tracker: Arc<SegmentTracker>,
} }
impl GrowingSegment { impl<S: G> GrowingSegment<S> {
pub fn create( pub fn create(
_tracker: Arc<IndexTracker>, _tracker: Arc<IndexTracker>,
path: PathBuf, path: PathBuf,
@@ -42,7 +42,6 @@ impl GrowingSegment {
sync_dir(&path); sync_dir(&path);
Arc::new(Self { Arc::new(Self {
uuid, uuid,
options: options.vector,
vec: unsafe { vec: unsafe {
let mut vec = Vec::with_capacity(capacity as usize); let mut vec = Vec::with_capacity(capacity as usize);
vec.set_len(capacity as usize); vec.set_len(capacity as usize);
@@ -57,23 +56,17 @@ impl GrowingSegment {
_tracker: Arc::new(SegmentTracker { path, _tracker }), _tracker: Arc::new(SegmentTracker { path, _tracker }),
}) })
} }
pub fn open( pub fn open(_tracker: Arc<IndexTracker>, path: PathBuf, uuid: Uuid) -> Arc<Self> {
_tracker: Arc<IndexTracker>,
path: PathBuf,
uuid: Uuid,
options: IndexOptions,
) -> Arc<Self> {
let mut wal = FileWal::open(path.join("wal")); let mut wal = FileWal::open(path.join("wal"));
let mut vec = Vec::new(); let mut vec = Vec::new();
while let Some(log) = wal.read() { while let Some(log) = wal.read() {
let log = bincode::deserialize::<Log>(&log).unwrap(); let log = bincode::deserialize::<Log<S>>(&log).unwrap();
vec.push(UnsafeCell::new(MaybeUninit::new(log))); vec.push(UnsafeCell::new(MaybeUninit::new(log)));
} }
wal.truncate(); wal.truncate();
let n = vec.len(); let n = vec.len();
Arc::new(Self { Arc::new(Self {
uuid, uuid,
options: options.vector,
vec, vec,
wal: { Mutex::new(wal) }, wal: { Mutex::new(wal) },
len: AtomicUsize::new(n), len: AtomicUsize::new(n),
@@ -87,6 +80,20 @@ impl GrowingSegment {
pub fn uuid(&self) -> Uuid { pub fn uuid(&self) -> Uuid {
self.uuid self.uuid
} }
pub fn is_full(&self) -> bool {
let n;
{
let pro = self.pro.lock();
if pro.inflight < pro.capacity {
return false;
}
n = pro.inflight;
}
while self.len.load(Ordering::Acquire) != n {
std::hint::spin_loop();
}
true
}
pub fn seal(&self) { pub fn seal(&self) {
let n; let n;
{ {
@@ -104,7 +111,7 @@ impl GrowingSegment {
} }
pub fn insert( pub fn insert(
&self, &self,
vector: Vec<Scalar>, vector: Vec<S::Scalar>,
payload: Payload, payload: Payload,
) -> Result<(), GrowingSegmentInsertError> { ) -> Result<(), GrowingSegmentInsertError> {
let log = Log { vector, payload }; let log = Log { vector, payload };
@@ -126,13 +133,13 @@ impl GrowingSegment {
self.len.store(1 + i, Ordering::Release); self.len.store(1 + i, Ordering::Release);
self.wal self.wal
.lock() .lock()
.write(&bincode::serialize::<Log>(&log).unwrap()); .write(&bincode::serialize::<Log<S>>(&log).unwrap());
Ok(()) Ok(())
} }
pub fn len(&self) -> u32 { pub fn len(&self) -> u32 {
self.len.load(Ordering::Acquire) as u32 self.len.load(Ordering::Acquire) as u32
} }
pub fn vector(&self, i: u32) -> &[Scalar] { pub fn vector(&self, i: u32) -> &[S::Scalar] {
let i = i as usize; let i = i as usize;
if i >= self.len.load(Ordering::Acquire) { if i >= self.len.load(Ordering::Acquire) {
panic!("Out of bound."); panic!("Out of bound.");
@@ -148,12 +155,12 @@ impl GrowingSegment {
let log = unsafe { (*self.vec[i].get()).assume_init_ref() }; let log = unsafe { (*self.vec[i].get()).assume_init_ref() };
log.payload log.payload
} }
pub fn search(&self, k: usize, vector: &[Scalar], filter: &mut impl Filter) -> Heap { pub fn search(&self, k: usize, vector: &[S::Scalar], filter: &mut impl Filter) -> Heap {
let n = self.len.load(Ordering::Acquire); let n = self.len.load(Ordering::Acquire);
let mut heap = Heap::new(k); let mut heap = Heap::new(k);
for i in 0..n { for i in 0..n {
let log = unsafe { (*self.vec[i].get()).assume_init_ref() }; let log = unsafe { (*self.vec[i].get()).assume_init_ref() };
let distance = self.options.d.distance(vector, &log.vector); let distance = S::distance(vector, &log.vector);
if heap.check(distance) && filter.check(log.payload) { if heap.check(distance) && filter.check(log.payload) {
heap.push(HeapElement { heap.push(HeapElement {
distance, distance,
@@ -163,12 +170,12 @@ impl GrowingSegment {
} }
heap heap
} }
pub fn search_all(&self, vector: &[Scalar]) -> Vec<HeapElement> { pub fn vbase(&self, vector: &[S::Scalar]) -> Vec<HeapElement> {
let n = self.len.load(Ordering::Acquire); let n = self.len.load(Ordering::Acquire);
let mut result = Vec::new(); let mut result = Vec::new();
for i in 0..n { for i in 0..n {
let log = unsafe { (*self.vec[i].get()).assume_init_ref() }; let log = unsafe { (*self.vec[i].get()).assume_init_ref() };
let distance = self.options.d.distance(vector, &log.vector); let distance = S::distance(vector, &log.vector);
result.push(HeapElement { result.push(HeapElement {
distance, distance,
payload: log.payload, payload: log.payload,
@@ -178,10 +185,10 @@ impl GrowingSegment {
} }
} }
unsafe impl Send for GrowingSegment {} unsafe impl<S: G> Send for GrowingSegment<S> {}
unsafe impl Sync for GrowingSegment {} unsafe impl<S: G> Sync for GrowingSegment<S> {}
impl Drop for GrowingSegment { impl<S: G> Drop for GrowingSegment<S> {
fn drop(&mut self) { fn drop(&mut self) {
let n = *self.len.get_mut(); let n = *self.len.get_mut();
for i in 0..n { for i in 0..n {
@@ -193,8 +200,8 @@ impl Drop for GrowingSegment {
} }
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] #[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
struct Log { struct Log<S: G> {
vector: Vec<Scalar>, vector: Vec<S::Scalar>,
payload: Payload, payload: Payload,
} }

View File

@@ -10,14 +10,10 @@ use validator::ValidationError;
#[derive(Debug, Clone, Serialize, Deserialize, Validate)] #[derive(Debug, Clone, Serialize, Deserialize, Validate)]
#[validate(schema(function = "Self::validate_0"))] #[validate(schema(function = "Self::validate_0"))]
#[validate(schema(function = "Self::validate_1"))]
pub struct SegmentsOptions { pub struct SegmentsOptions {
#[serde(default = "SegmentsOptions::default_max_growing_segment_size")] #[serde(default = "SegmentsOptions::default_max_growing_segment_size")]
#[validate(range(min = 1, max = 4_000_000_000))] #[validate(range(min = 1, max = 4_000_000_000))]
pub max_growing_segment_size: u32, pub max_growing_segment_size: u32,
#[serde(default = "SegmentsOptions::default_min_sealed_segment_size")]
#[validate(range(min = 1, max = 4_000_000_000))]
pub min_sealed_segment_size: u32,
#[serde(default = "SegmentsOptions::default_max_sealed_segment_size")] #[serde(default = "SegmentsOptions::default_max_sealed_segment_size")]
#[validate(range(min = 1, max = 4_000_000_000))] #[validate(range(min = 1, max = 4_000_000_000))]
pub max_sealed_segment_size: u32, pub max_sealed_segment_size: u32,
@@ -27,22 +23,11 @@ impl SegmentsOptions {
fn default_max_growing_segment_size() -> u32 { fn default_max_growing_segment_size() -> u32 {
20_000 20_000
} }
fn default_min_sealed_segment_size() -> u32 {
1_000
}
fn default_max_sealed_segment_size() -> u32 { fn default_max_sealed_segment_size() -> u32 {
1_000_000 1_000_000
} }
// min_sealed_segment_size <= max_growing_segment_size <= max_sealed_segment_size // max_growing_segment_size <= max_sealed_segment_size
fn validate_0(&self) -> Result<(), ValidationError> { fn validate_0(&self) -> Result<(), ValidationError> {
if self.min_sealed_segment_size > self.max_growing_segment_size {
return Err(ValidationError::new(
"`min_sealed_segment_size` must be less than or equal to `max_growing_segment_size`",
));
}
Ok(())
}
fn validate_1(&self) -> Result<(), ValidationError> {
if self.max_growing_segment_size > self.max_sealed_segment_size { if self.max_growing_segment_size > self.max_sealed_segment_size {
return Err(ValidationError::new( return Err(ValidationError::new(
"`max_growing_segment_size` must be less than or equal to `max_sealed_segment_size`", "`max_growing_segment_size` must be less than or equal to `max_sealed_segment_size`",
@@ -56,7 +41,6 @@ impl Default for SegmentsOptions {
fn default() -> Self { fn default() -> Self {
Self { Self {
max_growing_segment_size: Self::default_max_growing_segment_size(), max_growing_segment_size: Self::default_max_growing_segment_size(),
min_sealed_segment_size: Self::default_min_sealed_segment_size(),
max_sealed_segment_size: Self::default_max_sealed_segment_size(), max_sealed_segment_size: Self::default_max_sealed_segment_size(),
} }
} }

View File

@@ -8,20 +8,20 @@ use std::path::PathBuf;
use std::sync::Arc; use std::sync::Arc;
use uuid::Uuid; use uuid::Uuid;
pub struct SealedSegment { pub struct SealedSegment<S: G> {
uuid: Uuid, uuid: Uuid,
indexing: DynamicIndexing, indexing: DynamicIndexing<S>,
_tracker: Arc<SegmentTracker>, _tracker: Arc<SegmentTracker>,
} }
impl SealedSegment { impl<S: G> SealedSegment<S> {
pub fn create( pub fn create(
_tracker: Arc<IndexTracker>, _tracker: Arc<IndexTracker>,
path: PathBuf, path: PathBuf,
uuid: Uuid, uuid: Uuid,
options: IndexOptions, options: IndexOptions,
sealed: Vec<Arc<SealedSegment>>, sealed: Vec<Arc<SealedSegment<S>>>,
growing: Vec<Arc<GrowingSegment>>, growing: Vec<Arc<GrowingSegment<S>>>,
) -> Arc<Self> { ) -> Arc<Self> {
std::fs::create_dir(&path).unwrap(); std::fs::create_dir(&path).unwrap();
let indexing = DynamicIndexing::create(path.join("indexing"), options, sealed, growing); let indexing = DynamicIndexing::create(path.join("indexing"), options, sealed, growing);
@@ -51,20 +51,16 @@ impl SealedSegment {
pub fn len(&self) -> u32 { pub fn len(&self) -> u32 {
self.indexing.len() self.indexing.len()
} }
pub fn vector(&self, i: u32) -> &[Scalar] { pub fn vector(&self, i: u32) -> &[S::Scalar] {
self.indexing.vector(i) self.indexing.vector(i)
} }
pub fn payload(&self, i: u32) -> Payload { pub fn payload(&self, i: u32) -> Payload {
self.indexing.payload(i) self.indexing.payload(i)
} }
pub fn search(&self, k: usize, vector: &[Scalar], filter: &mut impl Filter) -> Heap { pub fn search(&self, k: usize, vector: &[S::Scalar], filter: &mut impl Filter) -> Heap {
self.indexing.search(k, vector, filter) self.indexing.search(k, vector, filter)
} }
pub fn search_vbase<'index, 'vector>( pub fn vbase(&self, range: usize, vector: &[S::Scalar]) -> DynamicIndexIter<'_, S> {
&'index self, self.indexing.vbase(range, vector)
range: usize,
vector: &'vector [Scalar],
) -> DynamicIndexIter<'index, 'vector> {
self.indexing.search_vbase(range, vector)
} }
} }

View File

@@ -0,0 +1,9 @@
#![feature(core_intrinsics)]
#![feature(avx512_target_feature)]
pub mod algorithms;
pub mod index;
pub mod prelude;
pub mod worker;
mod utils;

View File

@@ -1,5 +1,3 @@
use crate::ipc::IpcError;
use crate::prelude::*;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use thiserror::Error; use thiserror::Error;
@@ -15,100 +13,77 @@ or simply run the command `psql -U postgres -c 'ALTER SYSTEM SET shared_preload_
")] ")]
BadInit, BadInit,
#[error("\ #[error("\
The given index option is invalid. Bad literal.
INFORMATION: reason = {0:?}\ INFORMATION: hint = {hint}\
")] ")]
BadOption(String), BadLiteral {
#[error("\ hint: String,
The given vector is invalid for input. },
INFORMATION: vector = {0:?}
ADVICE: Check if dimensions of the vector is matched with the index.\
")]
BadVector(Vec<Scalar>),
#[error("\ #[error("\
Modifier of the type is invalid. Modifier of the type is invalid.
ADVICE: Check if modifier of the type is an integer among 1 and 65535.\ ADVICE: Check if modifier of the type is an integer among 1 and 65535.\
")] ")]
BadTypmod, BadTypeDimensions,
#[error("\ #[error("\
Dimensions of the vector is invalid. Dimensions of the vector is invalid.
ADVICE: Check if dimensions of the vector are among 1 and 65535.\ ADVICE: Check if dimensions of the vector are among 1 and 65535.\
")] ")]
BadVecForDims, BadValueDimensions,
#[error("\ #[error("\
Dimensions of the vector is unmatched with the type modifier. The given index option is invalid.
INFORMATION: type_dimensions = {type_dimensions}, value_dimensions = {value_dimensions}\ INFORMATION: reason = {validation:?}\
")] ")]
BadVecForUnmatchedDims { BadOption { validation: String },
value_dimensions: u16,
type_dimensions: u16,
},
#[error("\ #[error("\
Operands of the operator differs in dimensions. Dimensions type modifier of a vector column is needed for building the index.\
INFORMATION: left_dimensions = {left_dimensions}, right_dimensions = {right_dimensions}\
")] ")]
DifferentVectorDims { BadOption2,
left_dimensions: u16,
right_dimensions: u16,
},
#[error("\ #[error("\
Indexes can only be built on built-in distance functions. Indexes can only be built on built-in distance functions.
ADVICE: If you want pgvecto.rs to support more distance functions, \ ADVICE: If you want pgvecto.rs to support more distance functions, \
visit `https://github.com/tensorchord/pgvecto.rs/issues` and contribute your ideas.\ visit `https://github.com/tensorchord/pgvecto.rs/issues` and contribute your ideas.\
")] ")]
UnsupportedOperator, BadOptions3,
#[error("\ #[error("\
The index is not existing in the background worker. The index is not existing in the background worker.
ADVICE: Drop or rebuild the index.\ ADVICE: Drop or rebuild the index.\
")] ")]
Index404, UnknownIndex,
#[error("\ #[error("\
Dimensions type modifier of a vector column is needed for building the index.\ Operands of the operator differs in dimensions or scalar type.
INFORMATION: left_dimensions = {left_dimensions}, right_dimensions = {right_dimensions}\
")] ")]
DimsIsNeeded, Unmatched {
#[error("\ left_dimensions: u16,
Bad vector string. right_dimensions: u16,
INFORMATION: hint = {hint}\
")]
BadVectorString {
hint: String,
}, },
#[error("\ #[error("\
`mmap` transport is not supported by MacOS.\ The given vector is invalid for input.
ADVICE: Check if dimensions and scalar type of the vector is matched with the index.\
")] ")]
MmapTransportNotSupported, Unmatched2,
} }
impl FriendlyError { pub trait FriendlyErrorLike {
pub fn friendly(self) -> ! { fn friendly(self) -> !;
}
impl FriendlyErrorLike for FriendlyError {
fn friendly(self) -> ! {
panic!("pgvecto.rs: {}", self); panic!("pgvecto.rs: {}", self);
} }
} }
impl IpcError { pub trait FriendlyResult {
pub fn friendly(self) -> ! {
panic!("pgvecto.rs: {}", self);
}
}
pub trait Friendly {
type Output; type Output;
fn friendly(self) -> Self::Output; fn friendly(self) -> Self::Output;
} }
impl<T> Friendly for Result<T, FriendlyError> { impl<T, E> FriendlyResult for Result<T, E>
type Output = T; where
E: FriendlyErrorLike,
fn friendly(self) -> T { {
match self {
Ok(x) => x,
Err(e) => e.friendly(),
}
}
}
impl<T> Friendly for Result<T, IpcError> {
type Output = T; type Output = T;
fn friendly(self) -> T { fn friendly(self) -> T {

View File

@@ -0,0 +1,114 @@
use crate::prelude::*;
pub fn cosine(lhs: &[F16], rhs: &[F16]) -> F32 {
#[inline(always)]
#[multiversion::multiversion(targets(
"x86_64/x86-64-v4",
"x86_64/x86-64-v3",
"x86_64/x86-64-v2",
"aarch64+neon"
))]
pub fn cosine(lhs: &[F16], rhs: &[F16]) -> F32 {
assert!(lhs.len() == rhs.len());
let n = lhs.len();
let mut xy = F32::zero();
let mut x2 = F32::zero();
let mut y2 = F32::zero();
for i in 0..n {
xy += lhs[i].to_f() * rhs[i].to_f();
x2 += lhs[i].to_f() * lhs[i].to_f();
y2 += rhs[i].to_f() * rhs[i].to_f();
}
xy / (x2 * y2).sqrt()
}
#[cfg(target_arch = "x86_64")]
if crate::utils::detect::x86_64::detect_avx512fp16() {
assert!(lhs.len() == rhs.len());
let n = lhs.len();
unsafe {
return c::v_f16_cosine_avx512fp16(lhs.as_ptr().cast(), rhs.as_ptr().cast(), n).into();
}
}
#[cfg(target_arch = "x86_64")]
if crate::utils::detect::x86_64::detect_v3() {
assert!(lhs.len() == rhs.len());
let n = lhs.len();
unsafe {
return c::v_f16_cosine_v3(lhs.as_ptr().cast(), rhs.as_ptr().cast(), n).into();
}
}
cosine(lhs, rhs)
}
pub fn dot(lhs: &[F16], rhs: &[F16]) -> F32 {
#[inline(always)]
#[multiversion::multiversion(targets(
"x86_64/x86-64-v4",
"x86_64/x86-64-v3",
"x86_64/x86-64-v2",
"aarch64+neon"
))]
pub fn dot(lhs: &[F16], rhs: &[F16]) -> F32 {
assert!(lhs.len() == rhs.len());
let n = lhs.len();
let mut xy = F32::zero();
for i in 0..n {
xy += lhs[i].to_f() * rhs[i].to_f();
}
xy
}
#[cfg(target_arch = "x86_64")]
if crate::utils::detect::x86_64::detect_avx512fp16() {
assert!(lhs.len() == rhs.len());
let n = lhs.len();
unsafe {
return c::v_f16_dot_avx512fp16(lhs.as_ptr().cast(), rhs.as_ptr().cast(), n).into();
}
}
#[cfg(target_arch = "x86_64")]
if crate::utils::detect::x86_64::detect_v3() {
assert!(lhs.len() == rhs.len());
let n = lhs.len();
unsafe {
return c::v_f16_dot_v3(lhs.as_ptr().cast(), rhs.as_ptr().cast(), n).into();
}
}
cosine(lhs, rhs)
}
pub fn sl2(lhs: &[F16], rhs: &[F16]) -> F32 {
#[inline(always)]
#[multiversion::multiversion(targets(
"x86_64/x86-64-v4",
"x86_64/x86-64-v3",
"x86_64/x86-64-v2",
"aarch64+neon"
))]
pub fn sl2(lhs: &[F16], rhs: &[F16]) -> F32 {
assert!(lhs.len() == rhs.len());
let n = lhs.len();
let mut d2 = F32::zero();
for i in 0..n {
let d = lhs[i].to_f() - rhs[i].to_f();
d2 += d * d;
}
d2
}
#[cfg(target_arch = "x86_64")]
if crate::utils::detect::x86_64::detect_avx512fp16() {
assert!(lhs.len() == rhs.len());
let n = lhs.len();
unsafe {
return c::v_f16_sl2_avx512fp16(lhs.as_ptr().cast(), rhs.as_ptr().cast(), n).into();
}
}
#[cfg(target_arch = "x86_64")]
if crate::utils::detect::x86_64::detect_v3() {
assert!(lhs.len() == rhs.len());
let n = lhs.len();
unsafe {
return c::v_f16_sl2_v3(lhs.as_ptr().cast(), rhs.as_ptr().cast(), n).into();
}
}
sl2(lhs, rhs)
}

View File

@@ -0,0 +1,244 @@
use super::G;
use crate::prelude::scalar::F32;
use crate::prelude::*;
#[derive(Debug, Clone, Copy)]
pub enum F16Cos {}
impl G for F16Cos {
type Scalar = F16;
const DISTANCE: Distance = Distance::Cos;
type L2 = F16L2;
fn distance(lhs: &[F16], rhs: &[F16]) -> F32 {
super::f16::cosine(lhs, rhs) * (-1.0)
}
fn elkan_k_means_normalize(vector: &mut [F16]) {
l2_normalize(vector)
}
fn elkan_k_means_distance(lhs: &[F16], rhs: &[F16]) -> F32 {
super::f16::dot(lhs, rhs).acos()
}
#[multiversion::multiversion(targets(
"x86_64/x86-64-v4",
"x86_64/x86-64-v3",
"x86_64/x86-64-v2",
"aarch64+neon"
))]
fn scalar_quantization_distance(
dims: u16,
max: &[F16],
min: &[F16],
lhs: &[F16],
rhs: &[u8],
) -> F32 {
let mut xy = F32::zero();
let mut x2 = F32::zero();
let mut y2 = F32::zero();
for i in 0..dims as usize {
let _x = lhs[i].to_f();
let _y = F32(rhs[i] as f32 / 256.0) * (max[i].to_f() - min[i].to_f()) + min[i].to_f();
xy += _x * _y;
x2 += _x * _x;
y2 += _y * _y;
}
xy / (x2 * y2).sqrt() * (-1.0)
}
#[multiversion::multiversion(targets(
"x86_64/x86-64-v4",
"x86_64/x86-64-v3",
"x86_64/x86-64-v2",
"aarch64+neon"
))]
fn scalar_quantization_distance2(
dims: u16,
max: &[F16],
min: &[F16],
lhs: &[u8],
rhs: &[u8],
) -> F32 {
let mut xy = F32::zero();
let mut x2 = F32::zero();
let mut y2 = F32::zero();
for i in 0..dims as usize {
let _x = F32(lhs[i] as f32 / 256.0) * (max[i].to_f() - min[i].to_f()) + min[i].to_f();
let _y = F32(rhs[i] as f32 / 256.0) * (max[i].to_f() - min[i].to_f()) + min[i].to_f();
xy += _x * _y;
x2 += _x * _x;
y2 += _y * _y;
}
xy / (x2 * y2).sqrt() * (-1.0)
}
#[multiversion::multiversion(targets(
"x86_64/x86-64-v4",
"x86_64/x86-64-v3",
"x86_64/x86-64-v2",
"aarch64+neon"
))]
fn product_quantization_distance(
dims: u16,
ratio: u16,
centroids: &[F16],
lhs: &[F16],
rhs: &[u8],
) -> F32 {
let width = dims.div_ceil(ratio);
let mut xy = F32::zero();
let mut x2 = F32::zero();
let mut y2 = F32::zero();
for i in 0..width {
let k = std::cmp::min(ratio, dims - ratio * i);
let lhs = &lhs[(i * ratio) as usize..][..k as usize];
let rhsp = rhs[i as usize] as usize * dims as usize;
let rhs = &centroids[rhsp..][(i * ratio) as usize..][..k as usize];
let (_xy, _x2, _y2) = xy_x2_y2(lhs, rhs);
xy += _xy;
x2 += _x2;
y2 += _y2;
}
xy / (x2 * y2).sqrt() * (-1.0)
}
#[multiversion::multiversion(targets(
"x86_64/x86-64-v4",
"x86_64/x86-64-v3",
"x86_64/x86-64-v2",
"aarch64+neon"
))]
fn product_quantization_distance2(
dims: u16,
ratio: u16,
centroids: &[F16],
lhs: &[u8],
rhs: &[u8],
) -> F32 {
let width = dims.div_ceil(ratio);
let mut xy = F32::zero();
let mut x2 = F32::zero();
let mut y2 = F32::zero();
for i in 0..width {
let k = std::cmp::min(ratio, dims - ratio * i);
let lhsp = lhs[i as usize] as usize * dims as usize;
let lhs = &centroids[lhsp..][(i * ratio) as usize..][..k as usize];
let rhsp = rhs[i as usize] as usize * dims as usize;
let rhs = &centroids[rhsp..][(i * ratio) as usize..][..k as usize];
let (_xy, _x2, _y2) = xy_x2_y2(lhs, rhs);
xy += _xy;
x2 += _x2;
y2 += _y2;
}
xy / (x2 * y2).sqrt() * (-1.0)
}
#[multiversion::multiversion(targets(
"x86_64/x86-64-v4",
"x86_64/x86-64-v3",
"x86_64/x86-64-v2",
"aarch64+neon"
))]
fn product_quantization_distance_with_delta(
dims: u16,
ratio: u16,
centroids: &[F16],
lhs: &[F16],
rhs: &[u8],
delta: &[F16],
) -> F32 {
let width = dims.div_ceil(ratio);
let mut xy = F32::zero();
let mut x2 = F32::zero();
let mut y2 = F32::zero();
for i in 0..width {
let k = std::cmp::min(ratio, dims - ratio * i);
let lhs = &lhs[(i * ratio) as usize..][..k as usize];
let rhsp = rhs[i as usize] as usize * dims as usize;
let rhs = &centroids[rhsp..][(i * ratio) as usize..][..k as usize];
let del = &delta[(i * ratio) as usize..][..k as usize];
let (_xy, _x2, _y2) = xy_x2_y2_delta(lhs, rhs, del);
xy += _xy;
x2 += _x2;
y2 += _y2;
}
xy / (x2 * y2).sqrt() * (-1.0)
}
}
#[inline(always)]
#[multiversion::multiversion(targets(
"x86_64/x86-64-v4",
"x86_64/x86-64-v3",
"x86_64/x86-64-v2",
"aarch64+neon"
))]
fn length(vector: &[F16]) -> F16 {
let n = vector.len();
let mut dot = F16::zero();
for i in 0..n {
dot += vector[i] * vector[i];
}
dot.sqrt()
}
#[inline(always)]
#[multiversion::multiversion(targets(
"x86_64/x86-64-v4",
"x86_64/x86-64-v3",
"x86_64/x86-64-v2",
"aarch64+neon"
))]
fn l2_normalize(vector: &mut [F16]) {
let n = vector.len();
let l = length(vector);
for i in 0..n {
vector[i] /= l;
}
}
#[inline(always)]
#[multiversion::multiversion(targets(
"x86_64/x86-64-v4",
"x86_64/x86-64-v3",
"x86_64/x86-64-v2",
"aarch64+neon"
))]
fn xy_x2_y2(lhs: &[F16], rhs: &[F16]) -> (F32, F32, F32) {
assert!(lhs.len() == rhs.len());
let n = lhs.len();
let mut xy = F32::zero();
let mut x2 = F32::zero();
let mut y2 = F32::zero();
for i in 0..n {
xy += lhs[i].to_f() * rhs[i].to_f();
x2 += lhs[i].to_f() * lhs[i].to_f();
y2 += rhs[i].to_f() * rhs[i].to_f();
}
(xy, x2, y2)
}
#[inline(always)]
#[multiversion::multiversion(targets(
"x86_64/x86-64-v4",
"x86_64/x86-64-v3",
"x86_64/x86-64-v2",
"aarch64+neon"
))]
fn xy_x2_y2_delta(lhs: &[F16], rhs: &[F16], del: &[F16]) -> (F32, F32, F32) {
assert!(lhs.len() == rhs.len());
let n = lhs.len();
let mut xy = F32::zero();
let mut x2 = F32::zero();
let mut y2 = F32::zero();
for i in 0..n {
xy += lhs[i].to_f() * (rhs[i].to_f() + del[i].to_f());
x2 += lhs[i].to_f() * lhs[i].to_f();
y2 += (rhs[i].to_f() + del[i].to_f()) * (rhs[i].to_f() + del[i].to_f());
}
(xy, x2, y2)
}

View File

@@ -0,0 +1,199 @@
use super::G;
use crate::prelude::scalar::F32;
use crate::prelude::*;
#[derive(Debug, Clone, Copy)]
pub enum F16Dot {}
impl G for F16Dot {
type Scalar = F16;
const DISTANCE: Distance = Distance::Dot;
type L2 = F16L2;
fn distance(lhs: &[F16], rhs: &[F16]) -> F32 {
super::f16::dot(lhs, rhs) * (-1.0)
}
fn elkan_k_means_normalize(vector: &mut [F16]) {
l2_normalize(vector)
}
fn elkan_k_means_distance(lhs: &[F16], rhs: &[F16]) -> F32 {
super::f16::dot(lhs, rhs).acos()
}
#[multiversion::multiversion(targets(
"x86_64/x86-64-v4",
"x86_64/x86-64-v3",
"x86_64/x86-64-v2",
"aarch64+neon"
))]
fn scalar_quantization_distance(
dims: u16,
max: &[F16],
min: &[F16],
lhs: &[F16],
rhs: &[u8],
) -> F32 {
let mut xy = F32::zero();
for i in 0..dims as usize {
let _x = lhs[i].to_f();
let _y = F32(rhs[i] as f32 / 256.0) * (max[i].to_f() - min[i].to_f()) + min[i].to_f();
xy += _x * _y;
}
xy * (-1.0)
}
#[multiversion::multiversion(targets(
"x86_64/x86-64-v4",
"x86_64/x86-64-v3",
"x86_64/x86-64-v2",
"aarch64+neon"
))]
fn scalar_quantization_distance2(
dims: u16,
max: &[F16],
min: &[F16],
lhs: &[u8],
rhs: &[u8],
) -> F32 {
let mut xy = F32::zero();
for i in 0..dims as usize {
let _x = F32(lhs[i] as f32 / 256.0) * (max[i].to_f() - min[i].to_f()) + min[i].to_f();
let _y = F32(rhs[i] as f32 / 256.0) * (max[i].to_f() - min[i].to_f()) + min[i].to_f();
xy += _x * _y;
}
xy * (-1.0)
}
#[multiversion::multiversion(targets(
"x86_64/x86-64-v4",
"x86_64/x86-64-v3",
"x86_64/x86-64-v2",
"aarch64+neon"
))]
fn product_quantization_distance(
dims: u16,
ratio: u16,
centroids: &[F16],
lhs: &[F16],
rhs: &[u8],
) -> F32 {
let width = dims.div_ceil(ratio);
let mut xy = F32::zero();
for i in 0..width {
let k = std::cmp::min(ratio, dims - ratio * i);
let lhs = &lhs[(i * ratio) as usize..][..k as usize];
let rhsp = rhs[i as usize] as usize * dims as usize;
let rhs = &centroids[rhsp..][(i * ratio) as usize..][..k as usize];
let _xy = super::f16::dot(lhs, rhs);
xy += _xy;
}
xy * (-1.0)
}
#[multiversion::multiversion(targets(
"x86_64/x86-64-v4",
"x86_64/x86-64-v3",
"x86_64/x86-64-v2",
"aarch64+neon"
))]
fn product_quantization_distance2(
dims: u16,
ratio: u16,
centroids: &[F16],
lhs: &[u8],
rhs: &[u8],
) -> F32 {
let width = dims.div_ceil(ratio);
let mut xy = F32::zero();
for i in 0..width {
let k = std::cmp::min(ratio, dims - ratio * i);
let lhsp = lhs[i as usize] as usize * dims as usize;
let lhs = &centroids[lhsp..][(i * ratio) as usize..][..k as usize];
let rhsp = rhs[i as usize] as usize * dims as usize;
let rhs = &centroids[rhsp..][(i * ratio) as usize..][..k as usize];
let _xy = super::f16::dot(lhs, rhs);
xy += _xy;
}
xy * (-1.0)
}
#[multiversion::multiversion(targets(
"x86_64/x86-64-v4",
"x86_64/x86-64-v3",
"x86_64/x86-64-v2",
"aarch64+neon"
))]
fn product_quantization_distance_with_delta(
dims: u16,
ratio: u16,
centroids: &[F16],
lhs: &[F16],
rhs: &[u8],
delta: &[F16],
) -> F32 {
let width = dims.div_ceil(ratio);
let mut xy = F32::zero();
for i in 0..width {
let k = std::cmp::min(ratio, dims - ratio * i);
let lhs = &lhs[(i * ratio) as usize..][..k as usize];
let rhsp = rhs[i as usize] as usize * dims as usize;
let rhs = &centroids[rhsp..][(i * ratio) as usize..][..k as usize];
let del = &delta[(i * ratio) as usize..][..k as usize];
let _xy = dot_delta(lhs, rhs, del);
xy += _xy;
}
xy * (-1.0)
}
}
#[inline(always)]
#[multiversion::multiversion(targets(
"x86_64/x86-64-v4",
"x86_64/x86-64-v3",
"x86_64/x86-64-v2",
"aarch64+neon"
))]
fn length(vector: &[F16]) -> F16 {
let n = vector.len();
let mut dot = F16::zero();
for i in 0..n {
dot += vector[i] * vector[i];
}
dot.sqrt()
}
#[inline(always)]
#[multiversion::multiversion(targets(
"x86_64/x86-64-v4",
"x86_64/x86-64-v3",
"x86_64/x86-64-v2",
"aarch64+neon"
))]
fn l2_normalize(vector: &mut [F16]) {
let n = vector.len();
let l = length(vector);
for i in 0..n {
vector[i] /= l;
}
}
#[inline(always)]
#[multiversion::multiversion(targets(
"x86_64/x86-64-v4",
"x86_64/x86-64-v3",
"x86_64/x86-64-v2",
"aarch64+neon"
))]
fn dot_delta(lhs: &[F16], rhs: &[F16], del: &[F16]) -> F32 {
assert!(lhs.len() == rhs.len());
let n: usize = lhs.len();
let mut xy = F32::zero();
for i in 0..n {
xy += lhs[i].to_f() * (rhs[i].to_f() + del[i].to_f());
}
xy
}

View File

@@ -0,0 +1,165 @@
use super::G;
use crate::prelude::scalar::F16;
use crate::prelude::scalar::F32;
use crate::prelude::*;
#[derive(Debug, Clone, Copy)]
pub enum F16L2 {}
impl G for F16L2 {
type Scalar = F16;
const DISTANCE: Distance = Distance::L2;
type L2 = F16L2;
fn distance(lhs: &[F16], rhs: &[F16]) -> F32 {
super::f16::sl2(lhs, rhs)
}
fn elkan_k_means_normalize(_: &mut [F16]) {}
fn elkan_k_means_distance(lhs: &[F16], rhs: &[F16]) -> F32 {
super::f16::sl2(lhs, rhs).sqrt()
}
#[multiversion::multiversion(targets(
"x86_64/x86-64-v4",
"x86_64/x86-64-v3",
"x86_64/x86-64-v2",
"aarch64+neon"
))]
fn scalar_quantization_distance(
dims: u16,
max: &[F16],
min: &[F16],
lhs: &[F16],
rhs: &[u8],
) -> F32 {
let mut result = F32::zero();
for i in 0..dims as usize {
let _x = lhs[i].to_f();
let _y = (F32(rhs[i] as f32) / 256.0) * (max[i].to_f() - min[i].to_f()) + min[i].to_f();
result += (_x - _y) * (_x - _y);
}
result
}
#[multiversion::multiversion(targets(
"x86_64/x86-64-v4",
"x86_64/x86-64-v3",
"x86_64/x86-64-v2",
"aarch64+neon"
))]
fn scalar_quantization_distance2(
dims: u16,
max: &[F16],
min: &[F16],
lhs: &[u8],
rhs: &[u8],
) -> F32 {
let mut result = F32::zero();
for i in 0..dims as usize {
let _x = F32(lhs[i] as f32 / 256.0) * (max[i].to_f() - min[i].to_f()) + min[i].to_f();
let _y = F32(rhs[i] as f32 / 256.0) * (max[i].to_f() - min[i].to_f()) + min[i].to_f();
result += (_x - _y) * (_x - _y);
}
result
}
#[multiversion::multiversion(targets(
"x86_64/x86-64-v4",
"x86_64/x86-64-v3",
"x86_64/x86-64-v2",
"aarch64+neon"
))]
fn product_quantization_distance(
dims: u16,
ratio: u16,
centroids: &[F16],
lhs: &[F16],
rhs: &[u8],
) -> F32 {
let width = dims.div_ceil(ratio);
let mut result = F32::zero();
for i in 0..width {
let k = std::cmp::min(ratio, dims - ratio * i);
let lhs = &lhs[(i * ratio) as usize..][..k as usize];
let rhsp = rhs[i as usize] as usize * dims as usize;
let rhs = &centroids[rhsp..][(i * ratio) as usize..][..k as usize];
result += super::f16::sl2(lhs, rhs);
}
result
}
#[multiversion::multiversion(targets(
"x86_64/x86-64-v4",
"x86_64/x86-64-v3",
"x86_64/x86-64-v2",
"aarch64+neon"
))]
fn product_quantization_distance2(
dims: u16,
ratio: u16,
centroids: &[F16],
lhs: &[u8],
rhs: &[u8],
) -> F32 {
let width = dims.div_ceil(ratio);
let mut result = F32::zero();
for i in 0..width {
let k = std::cmp::min(ratio, dims - ratio * i);
let lhsp = lhs[i as usize] as usize * dims as usize;
let lhs = &centroids[lhsp..][(i * ratio) as usize..][..k as usize];
let rhsp = rhs[i as usize] as usize * dims as usize;
let rhs = &centroids[rhsp..][(i * ratio) as usize..][..k as usize];
result += super::f16::sl2(lhs, rhs);
}
result
}
#[multiversion::multiversion(targets(
"x86_64/x86-64-v4",
"x86_64/x86-64-v3",
"x86_64/x86-64-v2",
"aarch64+neon"
))]
fn product_quantization_distance_with_delta(
dims: u16,
ratio: u16,
centroids: &[F16],
lhs: &[F16],
rhs: &[u8],
delta: &[F16],
) -> F32 {
let width = dims.div_ceil(ratio);
let mut result = F32::zero();
for i in 0..width {
let k = std::cmp::min(ratio, dims - ratio * i);
let lhs = &lhs[(i * ratio) as usize..][..k as usize];
let rhsp = rhs[i as usize] as usize * dims as usize;
let rhs = &centroids[rhsp..][(i * ratio) as usize..][..k as usize];
let del = &delta[(i * ratio) as usize..][..k as usize];
result += distance_squared_l2_delta(lhs, rhs, del);
}
result
}
}
#[inline(always)]
#[multiversion::multiversion(targets(
"x86_64/x86-64-v4",
"x86_64/x86-64-v3",
"x86_64/x86-64-v2",
"aarch64+neon"
))]
fn distance_squared_l2_delta(lhs: &[F16], rhs: &[F16], del: &[F16]) -> F32 {
assert!(lhs.len() == rhs.len());
let n = lhs.len();
let mut d2 = F32::zero();
for i in 0..n {
let d = lhs[i].to_f() - (rhs[i].to_f() + del[i].to_f());
d2 += d * d;
}
d2
}

View File

@@ -0,0 +1,265 @@
use super::G;
use crate::prelude::scalar::F32;
use crate::prelude::*;
#[derive(Debug, Clone, Copy)]
pub enum F32Cos {}
impl G for F32Cos {
type Scalar = F32;
const DISTANCE: Distance = Distance::Cos;
type L2 = F32L2;
fn distance(lhs: &[F32], rhs: &[F32]) -> F32 {
cosine(lhs, rhs) * (-1.0)
}
fn elkan_k_means_normalize(vector: &mut [F32]) {
l2_normalize(vector)
}
fn elkan_k_means_distance(lhs: &[F32], rhs: &[F32]) -> F32 {
super::f32_dot::dot(lhs, rhs).acos()
}
#[multiversion::multiversion(targets(
"x86_64/x86-64-v4",
"x86_64/x86-64-v3",
"x86_64/x86-64-v2",
"aarch64+neon"
))]
fn scalar_quantization_distance(
dims: u16,
max: &[F32],
min: &[F32],
lhs: &[F32],
rhs: &[u8],
) -> F32 {
let mut xy = F32::zero();
let mut x2 = F32::zero();
let mut y2 = F32::zero();
for i in 0..dims as usize {
let _x = lhs[i];
let _y = F32(rhs[i] as f32 / 256.0) * (max[i] - min[i]) + min[i];
xy += _x * _y;
x2 += _x * _x;
y2 += _y * _y;
}
xy / (x2 * y2).sqrt() * (-1.0)
}
#[multiversion::multiversion(targets(
"x86_64/x86-64-v4",
"x86_64/x86-64-v3",
"x86_64/x86-64-v2",
"aarch64+neon"
))]
fn scalar_quantization_distance2(
dims: u16,
max: &[F32],
min: &[F32],
lhs: &[u8],
rhs: &[u8],
) -> F32 {
let mut xy = F32::zero();
let mut x2 = F32::zero();
let mut y2 = F32::zero();
for i in 0..dims as usize {
let _x = F32(lhs[i] as f32 / 256.0) * (max[i] - min[i]) + min[i];
let _y = F32(rhs[i] as f32 / 256.0) * (max[i] - min[i]) + min[i];
xy += _x * _y;
x2 += _x * _x;
y2 += _y * _y;
}
xy / (x2 * y2).sqrt() * (-1.0)
}
#[multiversion::multiversion(targets(
"x86_64/x86-64-v4",
"x86_64/x86-64-v3",
"x86_64/x86-64-v2",
"aarch64+neon"
))]
fn product_quantization_distance(
dims: u16,
ratio: u16,
centroids: &[F32],
lhs: &[F32],
rhs: &[u8],
) -> F32 {
let width = dims.div_ceil(ratio);
let mut xy = F32::zero();
let mut x2 = F32::zero();
let mut y2 = F32::zero();
for i in 0..width {
let k = std::cmp::min(ratio, dims - ratio * i);
let lhs = &lhs[(i * ratio) as usize..][..k as usize];
let rhsp = rhs[i as usize] as usize * dims as usize;
let rhs = &centroids[rhsp..][(i * ratio) as usize..][..k as usize];
let (_xy, _x2, _y2) = xy_x2_y2(lhs, rhs);
xy += _xy;
x2 += _x2;
y2 += _y2;
}
xy / (x2 * y2).sqrt() * (-1.0)
}
#[multiversion::multiversion(targets(
"x86_64/x86-64-v4",
"x86_64/x86-64-v3",
"x86_64/x86-64-v2",
"aarch64+neon"
))]
fn product_quantization_distance2(
dims: u16,
ratio: u16,
centroids: &[F32],
lhs: &[u8],
rhs: &[u8],
) -> F32 {
let width = dims.div_ceil(ratio);
let mut xy = F32::zero();
let mut x2 = F32::zero();
let mut y2 = F32::zero();
for i in 0..width {
let k = std::cmp::min(ratio, dims - ratio * i);
let lhsp = lhs[i as usize] as usize * dims as usize;
let lhs = &centroids[lhsp..][(i * ratio) as usize..][..k as usize];
let rhsp = rhs[i as usize] as usize * dims as usize;
let rhs = &centroids[rhsp..][(i * ratio) as usize..][..k as usize];
let (_xy, _x2, _y2) = xy_x2_y2(lhs, rhs);
xy += _xy;
x2 += _x2;
y2 += _y2;
}
xy / (x2 * y2).sqrt() * (-1.0)
}
#[multiversion::multiversion(targets(
"x86_64/x86-64-v4",
"x86_64/x86-64-v3",
"x86_64/x86-64-v2",
"aarch64+neon"
))]
fn product_quantization_distance_with_delta(
dims: u16,
ratio: u16,
centroids: &[F32],
lhs: &[F32],
rhs: &[u8],
delta: &[F32],
) -> F32 {
let width = dims.div_ceil(ratio);
let mut xy = F32::zero();
let mut x2 = F32::zero();
let mut y2 = F32::zero();
for i in 0..width {
let k = std::cmp::min(ratio, dims - ratio * i);
let lhs = &lhs[(i * ratio) as usize..][..k as usize];
let rhsp = rhs[i as usize] as usize * dims as usize;
let rhs = &centroids[rhsp..][(i * ratio) as usize..][..k as usize];
let del = &delta[(i * ratio) as usize..][..k as usize];
let (_xy, _x2, _y2) = xy_x2_y2_delta(lhs, rhs, del);
xy += _xy;
x2 += _x2;
y2 += _y2;
}
xy / (x2 * y2).sqrt() * (-1.0)
}
}
#[inline(always)]
#[multiversion::multiversion(targets(
"x86_64/x86-64-v4",
"x86_64/x86-64-v3",
"x86_64/x86-64-v2",
"aarch64+neon"
))]
fn length(vector: &[F32]) -> F32 {
let n = vector.len();
let mut dot = F32::zero();
for i in 0..n {
dot += vector[i] * vector[i];
}
dot.sqrt()
}
#[inline(always)]
#[multiversion::multiversion(targets(
"x86_64/x86-64-v4",
"x86_64/x86-64-v3",
"x86_64/x86-64-v2",
"aarch64+neon"
))]
fn l2_normalize(vector: &mut [F32]) {
let n = vector.len();
let l = length(vector);
for i in 0..n {
vector[i] /= l;
}
}
#[inline(always)]
#[multiversion::multiversion(targets(
"x86_64/x86-64-v4",
"x86_64/x86-64-v3",
"x86_64/x86-64-v2",
"aarch64+neon"
))]
fn cosine(lhs: &[F32], rhs: &[F32]) -> F32 {
assert!(lhs.len() == rhs.len());
let n = lhs.len();
let mut xy = F32::zero();
let mut x2 = F32::zero();
let mut y2 = F32::zero();
for i in 0..n {
xy += lhs[i] * rhs[i];
x2 += lhs[i] * lhs[i];
y2 += rhs[i] * rhs[i];
}
xy / (x2 * y2).sqrt()
}
#[inline(always)]
#[multiversion::multiversion(targets(
"x86_64/x86-64-v4",
"x86_64/x86-64-v3",
"x86_64/x86-64-v2",
"aarch64+neon"
))]
fn xy_x2_y2(lhs: &[F32], rhs: &[F32]) -> (F32, F32, F32) {
assert!(lhs.len() == rhs.len());
let n = lhs.len();
let mut xy = F32::zero();
let mut x2 = F32::zero();
let mut y2 = F32::zero();
for i in 0..n {
xy += lhs[i] * rhs[i];
x2 += lhs[i] * lhs[i];
y2 += rhs[i] * rhs[i];
}
(xy, x2, y2)
}
#[inline(always)]
#[multiversion::multiversion(targets(
"x86_64/x86-64-v4",
"x86_64/x86-64-v3",
"x86_64/x86-64-v2",
"aarch64+neon"
))]
fn xy_x2_y2_delta(lhs: &[F32], rhs: &[F32], del: &[F32]) -> (F32, F32, F32) {
assert!(lhs.len() == rhs.len());
let n = lhs.len();
let mut xy = F32::zero();
let mut x2 = F32::zero();
let mut y2 = F32::zero();
for i in 0..n {
xy += lhs[i] * (rhs[i] + del[i]);
x2 += lhs[i] * lhs[i];
y2 += (rhs[i] + del[i]) * (rhs[i] + del[i]);
}
(xy, x2, y2)
}

View File

@@ -0,0 +1,237 @@
use super::G;
use crate::prelude::scalar::F32;
use crate::prelude::*;
#[derive(Debug, Clone, Copy)]
pub enum F32Dot {}
impl G for F32Dot {
type Scalar = F32;
const DISTANCE: Distance = Distance::Dot;
type L2 = F32L2;
fn distance(lhs: &[F32], rhs: &[F32]) -> F32 {
dot(lhs, rhs) * (-1.0)
}
fn elkan_k_means_normalize(vector: &mut [F32]) {
l2_normalize(vector)
}
fn elkan_k_means_distance(lhs: &[F32], rhs: &[F32]) -> F32 {
super::f32_dot::dot(lhs, rhs).acos()
}
#[multiversion::multiversion(targets(
"x86_64/x86-64-v4",
"x86_64/x86-64-v3",
"x86_64/x86-64-v2",
"aarch64+neon"
))]
fn scalar_quantization_distance(
dims: u16,
max: &[F32],
min: &[F32],
lhs: &[F32],
rhs: &[u8],
) -> F32 {
let mut xy = F32::zero();
for i in 0..dims as usize {
let _x = lhs[i];
let _y = F32(rhs[i] as f32 / 256.0) * (max[i] - min[i]) + min[i];
xy += _x * _y;
}
xy * (-1.0)
}
#[multiversion::multiversion(targets(
"x86_64/x86-64-v4",
"x86_64/x86-64-v3",
"x86_64/x86-64-v2",
"aarch64+neon"
))]
fn scalar_quantization_distance2(
dims: u16,
max: &[F32],
min: &[F32],
lhs: &[u8],
rhs: &[u8],
) -> F32 {
let mut xy = F32::zero();
for i in 0..dims as usize {
let _x = F32(lhs[i] as f32 / 256.0) * (max[i] - min[i]) + min[i];
let _y = F32(rhs[i] as f32 / 256.0) * (max[i] - min[i]) + min[i];
xy += _x * _y;
}
xy * (-1.0)
}
#[multiversion::multiversion(targets(
"x86_64/x86-64-v4",
"x86_64/x86-64-v3",
"x86_64/x86-64-v2",
"aarch64+neon"
))]
fn product_quantization_distance(
dims: u16,
ratio: u16,
centroids: &[F32],
lhs: &[F32],
rhs: &[u8],
) -> F32 {
let width = dims.div_ceil(ratio);
let mut xy = F32::zero();
for i in 0..width {
let k = std::cmp::min(ratio, dims - ratio * i);
let lhs = &lhs[(i * ratio) as usize..][..k as usize];
let rhsp = rhs[i as usize] as usize * dims as usize;
let rhs = &centroids[rhsp..][(i * ratio) as usize..][..k as usize];
let _xy = dot(lhs, rhs);
xy += _xy;
}
xy * (-1.0)
}
#[multiversion::multiversion(targets(
"x86_64/x86-64-v4",
"x86_64/x86-64-v3",
"x86_64/x86-64-v2",
"aarch64+neon"
))]
fn product_quantization_distance2(
dims: u16,
ratio: u16,
centroids: &[F32],
lhs: &[u8],
rhs: &[u8],
) -> F32 {
let width = dims.div_ceil(ratio);
let mut xy = F32::zero();
for i in 0..width {
let k = std::cmp::min(ratio, dims - ratio * i);
let lhsp = lhs[i as usize] as usize * dims as usize;
let lhs = &centroids[lhsp..][(i * ratio) as usize..][..k as usize];
let rhsp = rhs[i as usize] as usize * dims as usize;
let rhs = &centroids[rhsp..][(i * ratio) as usize..][..k as usize];
let _xy = dot(lhs, rhs);
xy += _xy;
}
xy * (-1.0)
}
#[multiversion::multiversion(targets(
"x86_64/x86-64-v4",
"x86_64/x86-64-v3",
"x86_64/x86-64-v2",
"aarch64+neon"
))]
fn product_quantization_distance_with_delta(
dims: u16,
ratio: u16,
centroids: &[F32],
lhs: &[F32],
rhs: &[u8],
delta: &[F32],
) -> F32 {
let width = dims.div_ceil(ratio);
let mut xy = F32::zero();
for i in 0..width {
let k = std::cmp::min(ratio, dims - ratio * i);
let lhs = &lhs[(i * ratio) as usize..][..k as usize];
let rhsp = rhs[i as usize] as usize * dims as usize;
let rhs = &centroids[rhsp..][(i * ratio) as usize..][..k as usize];
let del = &delta[(i * ratio) as usize..][..k as usize];
let _xy = dot_delta(lhs, rhs, del);
xy += _xy;
}
xy * (-1.0)
}
}
#[inline(always)]
#[multiversion::multiversion(targets(
"x86_64/x86-64-v4",
"x86_64/x86-64-v3",
"x86_64/x86-64-v2",
"aarch64+neon"
))]
fn length(vector: &[F32]) -> F32 {
let n = vector.len();
let mut dot = F32::zero();
for i in 0..n {
dot += vector[i] * vector[i];
}
dot.sqrt()
}
#[inline(always)]
#[multiversion::multiversion(targets(
"x86_64/x86-64-v4",
"x86_64/x86-64-v3",
"x86_64/x86-64-v2",
"aarch64+neon"
))]
fn l2_normalize(vector: &mut [F32]) {
let n = vector.len();
let l = length(vector);
for i in 0..n {
vector[i] /= l;
}
}
#[inline(always)]
#[multiversion::multiversion(targets(
"x86_64/x86-64-v4",
"x86_64/x86-64-v3",
"x86_64/x86-64-v2",
"aarch64+neon"
))]
fn cosine(lhs: &[F32], rhs: &[F32]) -> F32 {
assert!(lhs.len() == rhs.len());
let n = lhs.len();
let mut xy = F32::zero();
let mut x2 = F32::zero();
let mut y2 = F32::zero();
for i in 0..n {
xy += lhs[i] * rhs[i];
x2 += lhs[i] * lhs[i];
y2 += rhs[i] * rhs[i];
}
xy / (x2 * y2).sqrt()
}
#[inline(always)]
#[multiversion::multiversion(targets(
"x86_64/x86-64-v4",
"x86_64/x86-64-v3",
"x86_64/x86-64-v2",
"aarch64+neon"
))]
pub fn dot(lhs: &[F32], rhs: &[F32]) -> F32 {
assert!(lhs.len() == rhs.len());
let n = lhs.len();
let mut xy = F32::zero();
for i in 0..n {
xy += lhs[i] * rhs[i];
}
xy
}
#[inline(always)]
#[multiversion::multiversion(targets(
"x86_64/x86-64-v4",
"x86_64/x86-64-v3",
"x86_64/x86-64-v2",
"aarch64+neon"
))]
fn dot_delta(lhs: &[F32], rhs: &[F32], del: &[F32]) -> F32 {
assert!(lhs.len() == rhs.len());
let n: usize = lhs.len();
let mut xy = F32::zero();
for i in 0..n {
xy += lhs[i] * (rhs[i] + del[i]);
}
xy
}

View File

@@ -0,0 +1,182 @@
use super::G;
use crate::prelude::scalar::F32;
use crate::prelude::*;
#[derive(Debug, Clone, Copy)]
pub enum F32L2 {}
impl G for F32L2 {
type Scalar = F32;
const DISTANCE: Distance = Distance::L2;
type L2 = F32L2;
fn distance(lhs: &[F32], rhs: &[F32]) -> F32 {
distance_squared_l2(lhs, rhs)
}
fn elkan_k_means_normalize(_: &mut [F32]) {}
fn elkan_k_means_distance(lhs: &[F32], rhs: &[F32]) -> F32 {
distance_squared_l2(lhs, rhs).sqrt()
}
#[multiversion::multiversion(targets(
"x86_64/x86-64-v4",
"x86_64/x86-64-v3",
"x86_64/x86-64-v2",
"aarch64+neon"
))]
fn scalar_quantization_distance(
dims: u16,
max: &[F32],
min: &[F32],
lhs: &[F32],
rhs: &[u8],
) -> F32 {
let mut result = F32::zero();
for i in 0..dims as usize {
let _x = lhs[i];
let _y = F32(rhs[i] as f32 / 256.0) * (max[i] - min[i]) + min[i];
result += (_x - _y) * (_x - _y);
}
result
}
#[multiversion::multiversion(targets(
"x86_64/x86-64-v4",
"x86_64/x86-64-v3",
"x86_64/x86-64-v2",
"aarch64+neon"
))]
fn scalar_quantization_distance2(
dims: u16,
max: &[F32],
min: &[F32],
lhs: &[u8],
rhs: &[u8],
) -> F32 {
let mut result = F32::zero();
for i in 0..dims as usize {
let _x = F32(lhs[i] as f32 / 256.0) * (max[i] - min[i]) + min[i];
let _y = F32(rhs[i] as f32 / 256.0) * (max[i] - min[i]) + min[i];
result += (_x - _y) * (_x - _y);
}
result
}
#[multiversion::multiversion(targets(
"x86_64/x86-64-v4",
"x86_64/x86-64-v3",
"x86_64/x86-64-v2",
"aarch64+neon"
))]
fn product_quantization_distance(
dims: u16,
ratio: u16,
centroids: &[F32],
lhs: &[F32],
rhs: &[u8],
) -> F32 {
let width = dims.div_ceil(ratio);
let mut result = F32::zero();
for i in 0..width {
let k = std::cmp::min(ratio, dims - ratio * i);
let lhs = &lhs[(i * ratio) as usize..][..k as usize];
let rhsp = rhs[i as usize] as usize * dims as usize;
let rhs = &centroids[rhsp..][(i * ratio) as usize..][..k as usize];
result += distance_squared_l2(lhs, rhs);
}
result
}
#[multiversion::multiversion(targets(
"x86_64/x86-64-v4",
"x86_64/x86-64-v3",
"x86_64/x86-64-v2",
"aarch64+neon"
))]
fn product_quantization_distance2(
dims: u16,
ratio: u16,
centroids: &[F32],
lhs: &[u8],
rhs: &[u8],
) -> F32 {
let width = dims.div_ceil(ratio);
let mut result = F32::zero();
for i in 0..width {
let k = std::cmp::min(ratio, dims - ratio * i);
let lhsp = lhs[i as usize] as usize * dims as usize;
let lhs = &centroids[lhsp..][(i * ratio) as usize..][..k as usize];
let rhsp = rhs[i as usize] as usize * dims as usize;
let rhs = &centroids[rhsp..][(i * ratio) as usize..][..k as usize];
result += distance_squared_l2(lhs, rhs);
}
result
}
#[multiversion::multiversion(targets(
"x86_64/x86-64-v4",
"x86_64/x86-64-v3",
"x86_64/x86-64-v2",
"aarch64+neon"
))]
fn product_quantization_distance_with_delta(
dims: u16,
ratio: u16,
centroids: &[F32],
lhs: &[F32],
rhs: &[u8],
delta: &[F32],
) -> F32 {
let width = dims.div_ceil(ratio);
let mut result = F32::zero();
for i in 0..width {
let k = std::cmp::min(ratio, dims - ratio * i);
let lhs = &lhs[(i * ratio) as usize..][..k as usize];
let rhsp = rhs[i as usize] as usize * dims as usize;
let rhs = &centroids[rhsp..][(i * ratio) as usize..][..k as usize];
let del = &delta[(i * ratio) as usize..][..k as usize];
result += distance_squared_l2_delta(lhs, rhs, del);
}
result
}
}
#[inline(always)]
#[multiversion::multiversion(targets(
"x86_64/x86-64-v4",
"x86_64/x86-64-v3",
"x86_64/x86-64-v2",
"aarch64+neon"
))]
pub fn distance_squared_l2(lhs: &[F32], rhs: &[F32]) -> F32 {
assert!(lhs.len() == rhs.len());
let n = lhs.len();
let mut d2 = F32::zero();
for i in 0..n {
let d = lhs[i] - rhs[i];
d2 += d * d;
}
d2
}
#[inline(always)]
#[multiversion::multiversion(targets(
"x86_64/x86-64-v4",
"x86_64/x86-64-v3",
"x86_64/x86-64-v2",
"aarch64+neon"
))]
fn distance_squared_l2_delta(lhs: &[F32], rhs: &[F32], del: &[F32]) -> F32 {
assert!(lhs.len() == rhs.len());
let n = lhs.len();
let mut d2 = F32::zero();
for i in 0..n {
let d = lhs[i] - (rhs[i] + del[i]);
d2 += d * d;
}
d2
}

View File

@@ -0,0 +1,121 @@
mod f16;
mod f16_cos;
mod f16_dot;
mod f16_l2;
mod f32_cos;
mod f32_dot;
mod f32_l2;
pub use f16_cos::F16Cos;
pub use f16_dot::F16Dot;
pub use f16_l2::F16L2;
pub use f32_cos::F32Cos;
pub use f32_dot::F32Dot;
pub use f32_l2::F32L2;
use crate::prelude::*;
use serde::{Deserialize, Serialize};
use std::fmt::Debug;
pub trait G: Copy + std::fmt::Debug + 'static {
type Scalar: Copy
+ Send
+ Sync
+ std::fmt::Debug
+ std::fmt::Display
+ serde::Serialize
+ for<'a> serde::Deserialize<'a>
+ Ord
+ bytemuck::Zeroable
+ bytemuck::Pod
+ num_traits::Float
+ num_traits::NumOps
+ num_traits::NumAssignOps
+ FloatCast;
const DISTANCE: Distance;
type L2: G<Scalar = Self::Scalar>;
fn distance(lhs: &[Self::Scalar], rhs: &[Self::Scalar]) -> F32;
fn elkan_k_means_normalize(vector: &mut [Self::Scalar]);
fn elkan_k_means_distance(lhs: &[Self::Scalar], rhs: &[Self::Scalar]) -> F32;
fn scalar_quantization_distance(
dims: u16,
max: &[Self::Scalar],
min: &[Self::Scalar],
lhs: &[Self::Scalar],
rhs: &[u8],
) -> F32;
fn scalar_quantization_distance2(
dims: u16,
max: &[Self::Scalar],
min: &[Self::Scalar],
lhs: &[u8],
rhs: &[u8],
) -> F32;
fn product_quantization_distance(
dims: u16,
ratio: u16,
centroids: &[Self::Scalar],
lhs: &[Self::Scalar],
rhs: &[u8],
) -> F32;
fn product_quantization_distance2(
dims: u16,
ratio: u16,
centroids: &[Self::Scalar],
lhs: &[u8],
rhs: &[u8],
) -> F32;
fn product_quantization_distance_with_delta(
dims: u16,
ratio: u16,
centroids: &[Self::Scalar],
lhs: &[Self::Scalar],
rhs: &[u8],
delta: &[Self::Scalar],
) -> F32;
}
pub trait FloatCast: Sized {
fn from_f32(x: f32) -> Self;
fn to_f32(self) -> f32;
fn from_f(x: F32) -> Self {
Self::from_f32(x.0)
}
fn to_f(self) -> F32 {
F32(Self::to_f32(self))
}
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub enum DynamicVector {
F32(Vec<F32>),
F16(Vec<F16>),
}
impl From<Vec<F32>> for DynamicVector {
fn from(value: Vec<F32>) -> Self {
Self::F32(value)
}
}
impl From<Vec<F16>> for DynamicVector {
fn from(value: Vec<F16>) -> Self {
Self::F16(value)
}
}
#[repr(u8)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
pub enum Distance {
L2,
Cos,
Dot,
}
#[repr(u8)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
pub enum Kind {
F32,
F16,
}

View File

@@ -1,9 +1,9 @@
use crate::prelude::{Payload, Scalar}; use crate::prelude::{Payload, F32};
use std::{cmp::Reverse, collections::BinaryHeap}; use std::{cmp::Reverse, collections::BinaryHeap};
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub struct HeapElement { pub struct HeapElement {
pub distance: Scalar, pub distance: F32,
pub payload: Payload, pub payload: Payload,
} }
@@ -20,7 +20,7 @@ impl Heap {
k, k,
} }
} }
pub fn check(&self, distance: Scalar) -> bool { pub fn check(&self, distance: F32) -> bool {
self.binary_heap.len() < self.k || distance < self.binary_heap.peek().unwrap().distance self.binary_heap.len() < self.k || distance < self.binary_heap.peek().unwrap().distance
} }
pub fn push(&mut self, element: HeapElement) -> Option<HeapElement> { pub fn push(&mut self, element: HeapElement) -> Option<HeapElement> {

View File

@@ -0,0 +1,16 @@
mod error;
mod filter;
mod global;
mod heap;
mod scalar;
mod sys;
pub use self::error::{FriendlyError, FriendlyErrorLike, FriendlyResult};
pub use self::global::*;
pub use self::scalar::{F16, F32};
pub use self::filter::{Filter, Payload};
pub use self::heap::{Heap, HeapElement};
pub use self::sys::{Id, Pointer};
pub use num_traits::{Float, Zero};

View File

@@ -0,0 +1,653 @@
use crate::prelude::global::FloatCast;
use half::f16;
use serde::{Deserialize, Serialize};
use std::cmp::Ordering;
use std::fmt::{Debug, Display};
use std::num::ParseFloatError;
use std::ops::*;
use std::str::FromStr;
#[derive(Clone, Copy, Default, Serialize, Deserialize)]
#[repr(transparent)]
#[serde(transparent)]
pub struct F16(pub f16);
impl Debug for F16 {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
Debug::fmt(&self.0, f)
}
}
impl Display for F16 {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
Display::fmt(&self.0, f)
}
}
impl PartialEq for F16 {
fn eq(&self, other: &Self) -> bool {
self.0.total_cmp(&other.0) == Ordering::Equal
}
}
impl Eq for F16 {}
impl PartialOrd for F16 {
#[inline(always)]
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(Ord::cmp(self, other))
}
}
impl Ord for F16 {
#[inline(always)]
fn cmp(&self, other: &Self) -> Ordering {
self.0.total_cmp(&other.0)
}
}
unsafe impl bytemuck::Zeroable for F16 {}
unsafe impl bytemuck::Pod for F16 {}
impl num_traits::Zero for F16 {
fn zero() -> Self {
Self(f16::zero())
}
fn is_zero(&self) -> bool {
self.0.is_zero()
}
}
impl num_traits::One for F16 {
fn one() -> Self {
Self(f16::one())
}
}
impl num_traits::FromPrimitive for F16 {
fn from_i64(n: i64) -> Option<Self> {
f16::from_i64(n).map(Self)
}
fn from_u64(n: u64) -> Option<Self> {
f16::from_u64(n).map(Self)
}
fn from_isize(n: isize) -> Option<Self> {
f16::from_isize(n).map(Self)
}
fn from_i8(n: i8) -> Option<Self> {
f16::from_i8(n).map(Self)
}
fn from_i16(n: i16) -> Option<Self> {
f16::from_i16(n).map(Self)
}
fn from_i32(n: i32) -> Option<Self> {
f16::from_i32(n).map(Self)
}
fn from_i128(n: i128) -> Option<Self> {
f16::from_i128(n).map(Self)
}
fn from_usize(n: usize) -> Option<Self> {
f16::from_usize(n).map(Self)
}
fn from_u8(n: u8) -> Option<Self> {
f16::from_u8(n).map(Self)
}
fn from_u16(n: u16) -> Option<Self> {
f16::from_u16(n).map(Self)
}
fn from_u32(n: u32) -> Option<Self> {
f16::from_u32(n).map(Self)
}
fn from_u128(n: u128) -> Option<Self> {
f16::from_u128(n).map(Self)
}
fn from_f32(n: f32) -> Option<Self> {
Some(Self(f16::from_f32(n)))
}
fn from_f64(n: f64) -> Option<Self> {
Some(Self(f16::from_f64(n)))
}
}
impl num_traits::ToPrimitive for F16 {
fn to_i64(&self) -> Option<i64> {
self.0.to_i64()
}
fn to_u64(&self) -> Option<u64> {
self.0.to_u64()
}
fn to_isize(&self) -> Option<isize> {
self.0.to_isize()
}
fn to_i8(&self) -> Option<i8> {
self.0.to_i8()
}
fn to_i16(&self) -> Option<i16> {
self.0.to_i16()
}
fn to_i32(&self) -> Option<i32> {
self.0.to_i32()
}
fn to_i128(&self) -> Option<i128> {
self.0.to_i128()
}
fn to_usize(&self) -> Option<usize> {
self.0.to_usize()
}
fn to_u8(&self) -> Option<u8> {
self.0.to_u8()
}
fn to_u16(&self) -> Option<u16> {
self.0.to_u16()
}
fn to_u32(&self) -> Option<u32> {
self.0.to_u32()
}
fn to_u128(&self) -> Option<u128> {
self.0.to_u128()
}
fn to_f32(&self) -> Option<f32> {
Some(self.0.to_f32())
}
fn to_f64(&self) -> Option<f64> {
Some(self.0.to_f64())
}
}
impl num_traits::NumCast for F16 {
fn from<T: num_traits::ToPrimitive>(n: T) -> Option<Self> {
num_traits::NumCast::from(n).map(Self)
}
}
impl num_traits::Num for F16 {
type FromStrRadixErr = <f16 as num_traits::Num>::FromStrRadixErr;
fn from_str_radix(str: &str, radix: u32) -> Result<Self, Self::FromStrRadixErr> {
f16::from_str_radix(str, radix).map(Self)
}
}
impl num_traits::Float for F16 {
fn nan() -> Self {
Self(f16::nan())
}
fn infinity() -> Self {
Self(f16::infinity())
}
fn neg_infinity() -> Self {
Self(f16::neg_infinity())
}
fn neg_zero() -> Self {
Self(f16::neg_zero())
}
fn min_value() -> Self {
Self(f16::min_value())
}
fn min_positive_value() -> Self {
Self(f16::min_positive_value())
}
fn max_value() -> Self {
Self(f16::max_value())
}
fn is_nan(self) -> bool {
self.0.is_nan()
}
fn is_infinite(self) -> bool {
self.0.is_infinite()
}
fn is_finite(self) -> bool {
self.0.is_finite()
}
fn is_normal(self) -> bool {
self.0.is_normal()
}
fn classify(self) -> std::num::FpCategory {
self.0.classify()
}
fn floor(self) -> Self {
Self(self.0.floor())
}
fn ceil(self) -> Self {
Self(self.0.ceil())
}
fn round(self) -> Self {
Self(self.0.round())
}
fn trunc(self) -> Self {
Self(self.0.trunc())
}
fn fract(self) -> Self {
Self(self.0.fract())
}
fn abs(self) -> Self {
Self(self.0.abs())
}
fn signum(self) -> Self {
Self(self.0.signum())
}
fn is_sign_positive(self) -> bool {
self.0.is_sign_positive()
}
fn is_sign_negative(self) -> bool {
self.0.is_sign_negative()
}
fn mul_add(self, a: Self, b: Self) -> Self {
Self(self.0.mul_add(a.0, b.0))
}
fn recip(self) -> Self {
Self(self.0.recip())
}
fn powi(self, n: i32) -> Self {
Self(self.0.powi(n))
}
fn powf(self, n: Self) -> Self {
Self(self.0.powf(n.0))
}
fn sqrt(self) -> Self {
Self(self.0.sqrt())
}
fn exp(self) -> Self {
Self(self.0.exp())
}
fn exp2(self) -> Self {
Self(self.0.exp2())
}
fn ln(self) -> Self {
Self(self.0.ln())
}
fn log(self, base: Self) -> Self {
Self(self.0.log(base.0))
}
fn log2(self) -> Self {
Self(self.0.log2())
}
fn log10(self) -> Self {
Self(self.0.log10())
}
fn max(self, other: Self) -> Self {
Self(self.0.max(other.0))
}
fn min(self, other: Self) -> Self {
Self(self.0.min(other.0))
}
fn abs_sub(self, _: Self) -> Self {
unimplemented!()
}
fn cbrt(self) -> Self {
Self(self.0.cbrt())
}
fn hypot(self, other: Self) -> Self {
Self(self.0.hypot(other.0))
}
fn sin(self) -> Self {
Self(self.0.sin())
}
fn cos(self) -> Self {
Self(self.0.cos())
}
fn tan(self) -> Self {
Self(self.0.tan())
}
fn asin(self) -> Self {
Self(self.0.asin())
}
fn acos(self) -> Self {
Self(self.0.acos())
}
fn atan(self) -> Self {
Self(self.0.atan())
}
fn atan2(self, other: Self) -> Self {
Self(self.0.atan2(other.0))
}
fn sin_cos(self) -> (Self, Self) {
let (_x, _y) = self.0.sin_cos();
(Self(_x), Self(_y))
}
fn exp_m1(self) -> Self {
Self(self.0.exp_m1())
}
fn ln_1p(self) -> Self {
Self(self.0.ln_1p())
}
fn sinh(self) -> Self {
Self(self.0.sinh())
}
fn cosh(self) -> Self {
Self(self.0.cosh())
}
fn tanh(self) -> Self {
Self(self.0.tanh())
}
fn asinh(self) -> Self {
Self(self.0.asinh())
}
fn acosh(self) -> Self {
Self(self.0.acosh())
}
fn atanh(self) -> Self {
Self(self.0.atanh())
}
fn integer_decode(self) -> (u64, i16, i8) {
self.0.integer_decode()
}
fn epsilon() -> Self {
Self(f16::EPSILON)
}
fn is_subnormal(self) -> bool {
self.0.classify() == std::num::FpCategory::Subnormal
}
fn to_degrees(self) -> Self {
Self(self.0.to_degrees())
}
fn to_radians(self) -> Self {
Self(self.0.to_radians())
}
fn copysign(self, sign: Self) -> Self {
Self(self.0.copysign(sign.0))
}
}
impl Add<F16> for F16 {
type Output = F16;
#[inline(always)]
fn add(self, rhs: F16) -> F16 {
unsafe { self::intrinsics::fadd_fast(self.0, rhs.0).into() }
}
}
impl AddAssign<F16> for F16 {
#[inline(always)]
fn add_assign(&mut self, rhs: F16) {
unsafe { self.0 = self::intrinsics::fadd_fast(self.0, rhs.0) }
}
}
impl Sub<F16> for F16 {
type Output = F16;
#[inline(always)]
fn sub(self, rhs: F16) -> F16 {
unsafe { self::intrinsics::fsub_fast(self.0, rhs.0).into() }
}
}
impl SubAssign<F16> for F16 {
#[inline(always)]
fn sub_assign(&mut self, rhs: F16) {
unsafe { self.0 = self::intrinsics::fsub_fast(self.0, rhs.0) }
}
}
impl Mul<F16> for F16 {
type Output = F16;
#[inline(always)]
fn mul(self, rhs: F16) -> F16 {
unsafe { self::intrinsics::fmul_fast(self.0, rhs.0).into() }
}
}
impl MulAssign<F16> for F16 {
#[inline(always)]
fn mul_assign(&mut self, rhs: F16) {
unsafe { self.0 = self::intrinsics::fmul_fast(self.0, rhs.0) }
}
}
impl Div<F16> for F16 {
type Output = F16;
#[inline(always)]
fn div(self, rhs: F16) -> F16 {
unsafe { self::intrinsics::fdiv_fast(self.0, rhs.0).into() }
}
}
impl DivAssign<F16> for F16 {
#[inline(always)]
fn div_assign(&mut self, rhs: F16) {
unsafe { self.0 = self::intrinsics::fdiv_fast(self.0, rhs.0) }
}
}
impl Rem<F16> for F16 {
type Output = F16;
#[inline(always)]
fn rem(self, rhs: F16) -> F16 {
unsafe { self::intrinsics::frem_fast(self.0, rhs.0).into() }
}
}
impl RemAssign<F16> for F16 {
#[inline(always)]
fn rem_assign(&mut self, rhs: F16) {
unsafe { self.0 = self::intrinsics::frem_fast(self.0, rhs.0) }
}
}
impl Neg for F16 {
type Output = Self;
fn neg(self) -> Self::Output {
Self(self.0.neg())
}
}
impl FromStr for F16 {
type Err = ParseFloatError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
f16::from_str(s).map(|x| x.into())
}
}
impl FloatCast for F16 {
fn from_f32(x: f32) -> Self {
Self(f16::from_f32(x))
}
fn to_f32(self) -> f32 {
f16::to_f32(self.0)
}
}
impl From<f16> for F16 {
fn from(value: f16) -> Self {
Self(value)
}
}
impl From<F16> for f16 {
fn from(F16(float): F16) -> Self {
float
}
}
impl Add<f16> for F16 {
type Output = F16;
#[inline(always)]
fn add(self, rhs: f16) -> F16 {
unsafe { self::intrinsics::fadd_fast(self.0, rhs).into() }
}
}
impl AddAssign<f16> for F16 {
fn add_assign(&mut self, rhs: f16) {
unsafe { self.0 = self::intrinsics::fadd_fast(self.0, rhs) }
}
}
impl Sub<f16> for F16 {
type Output = F16;
#[inline(always)]
fn sub(self, rhs: f16) -> F16 {
unsafe { self::intrinsics::fsub_fast(self.0, rhs).into() }
}
}
impl SubAssign<f16> for F16 {
#[inline(always)]
fn sub_assign(&mut self, rhs: f16) {
unsafe { self.0 = self::intrinsics::fsub_fast(self.0, rhs) }
}
}
impl Mul<f16> for F16 {
type Output = F16;
#[inline(always)]
fn mul(self, rhs: f16) -> F16 {
unsafe { self::intrinsics::fmul_fast(self.0, rhs).into() }
}
}
impl MulAssign<f16> for F16 {
#[inline(always)]
fn mul_assign(&mut self, rhs: f16) {
unsafe { self.0 = self::intrinsics::fmul_fast(self.0, rhs) }
}
}
impl Div<f16> for F16 {
type Output = F16;
#[inline(always)]
fn div(self, rhs: f16) -> F16 {
unsafe { self::intrinsics::fdiv_fast(self.0, rhs).into() }
}
}
impl DivAssign<f16> for F16 {
#[inline(always)]
fn div_assign(&mut self, rhs: f16) {
unsafe { self.0 = self::intrinsics::fdiv_fast(self.0, rhs) }
}
}
impl Rem<f16> for F16 {
type Output = F16;
#[inline(always)]
fn rem(self, rhs: f16) -> F16 {
unsafe { self::intrinsics::frem_fast(self.0, rhs).into() }
}
}
impl RemAssign<f16> for F16 {
#[inline(always)]
fn rem_assign(&mut self, rhs: f16) {
unsafe { self.0 = self::intrinsics::frem_fast(self.0, rhs) }
}
}
mod intrinsics {
use half::f16;
pub unsafe fn fadd_fast(lhs: f16, rhs: f16) -> f16 {
lhs + rhs
}
pub unsafe fn fsub_fast(lhs: f16, rhs: f16) -> f16 {
lhs - rhs
}
pub unsafe fn fmul_fast(lhs: f16, rhs: f16) -> f16 {
lhs * rhs
}
pub unsafe fn fdiv_fast(lhs: f16, rhs: f16) -> f16 {
lhs / rhs
}
pub unsafe fn frem_fast(lhs: f16, rhs: f16) -> f16 {
lhs % rhs
}
}

View File

@@ -0,0 +1,632 @@
use crate::prelude::global::FloatCast;
use serde::{Deserialize, Serialize};
use std::cmp::Ordering;
use std::fmt::{Debug, Display};
use std::num::ParseFloatError;
use std::ops::*;
use std::str::FromStr;
#[derive(Clone, Copy, Default, Serialize, Deserialize)]
#[repr(transparent)]
#[serde(transparent)]
pub struct F32(pub f32);
impl Debug for F32 {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
Debug::fmt(&self.0, f)
}
}
impl Display for F32 {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
Display::fmt(&self.0, f)
}
}
impl PartialEq for F32 {
fn eq(&self, other: &Self) -> bool {
self.0.total_cmp(&other.0) == Ordering::Equal
}
}
impl Eq for F32 {}
impl PartialOrd for F32 {
#[inline(always)]
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(Ord::cmp(self, other))
}
}
impl Ord for F32 {
#[inline(always)]
fn cmp(&self, other: &Self) -> Ordering {
self.0.total_cmp(&other.0)
}
}
unsafe impl bytemuck::Zeroable for F32 {}
unsafe impl bytemuck::Pod for F32 {}
impl num_traits::Zero for F32 {
fn zero() -> Self {
Self(f32::zero())
}
fn is_zero(&self) -> bool {
self.0.is_zero()
}
}
impl num_traits::One for F32 {
fn one() -> Self {
Self(f32::one())
}
}
impl num_traits::FromPrimitive for F32 {
fn from_i64(n: i64) -> Option<Self> {
f32::from_i64(n).map(Self)
}
fn from_u64(n: u64) -> Option<Self> {
f32::from_u64(n).map(Self)
}
fn from_isize(n: isize) -> Option<Self> {
f32::from_isize(n).map(Self)
}
fn from_i8(n: i8) -> Option<Self> {
f32::from_i8(n).map(Self)
}
fn from_i16(n: i16) -> Option<Self> {
f32::from_i16(n).map(Self)
}
fn from_i32(n: i32) -> Option<Self> {
f32::from_i32(n).map(Self)
}
fn from_i128(n: i128) -> Option<Self> {
f32::from_i128(n).map(Self)
}
fn from_usize(n: usize) -> Option<Self> {
f32::from_usize(n).map(Self)
}
fn from_u8(n: u8) -> Option<Self> {
f32::from_u8(n).map(Self)
}
fn from_u16(n: u16) -> Option<Self> {
f32::from_u16(n).map(Self)
}
fn from_u32(n: u32) -> Option<Self> {
f32::from_u32(n).map(Self)
}
fn from_u128(n: u128) -> Option<Self> {
f32::from_u128(n).map(Self)
}
fn from_f32(n: f32) -> Option<Self> {
f32::from_f32(n).map(Self)
}
fn from_f64(n: f64) -> Option<Self> {
f32::from_f64(n).map(Self)
}
}
impl num_traits::ToPrimitive for F32 {
fn to_i64(&self) -> Option<i64> {
self.0.to_i64()
}
fn to_u64(&self) -> Option<u64> {
self.0.to_u64()
}
fn to_isize(&self) -> Option<isize> {
self.0.to_isize()
}
fn to_i8(&self) -> Option<i8> {
self.0.to_i8()
}
fn to_i16(&self) -> Option<i16> {
self.0.to_i16()
}
fn to_i32(&self) -> Option<i32> {
self.0.to_i32()
}
fn to_i128(&self) -> Option<i128> {
self.0.to_i128()
}
fn to_usize(&self) -> Option<usize> {
self.0.to_usize()
}
fn to_u8(&self) -> Option<u8> {
self.0.to_u8()
}
fn to_u16(&self) -> Option<u16> {
self.0.to_u16()
}
fn to_u32(&self) -> Option<u32> {
self.0.to_u32()
}
fn to_u128(&self) -> Option<u128> {
self.0.to_u128()
}
fn to_f32(&self) -> Option<f32> {
self.0.to_f32()
}
fn to_f64(&self) -> Option<f64> {
self.0.to_f64()
}
}
impl num_traits::NumCast for F32 {
fn from<T: num_traits::ToPrimitive>(n: T) -> Option<Self> {
num_traits::NumCast::from(n).map(Self)
}
}
impl num_traits::Num for F32 {
type FromStrRadixErr = <f32 as num_traits::Num>::FromStrRadixErr;
fn from_str_radix(str: &str, radix: u32) -> Result<Self, Self::FromStrRadixErr> {
f32::from_str_radix(str, radix).map(Self)
}
}
impl num_traits::Float for F32 {
fn nan() -> Self {
Self(f32::nan())
}
fn infinity() -> Self {
Self(f32::infinity())
}
fn neg_infinity() -> Self {
Self(f32::neg_infinity())
}
fn neg_zero() -> Self {
Self(f32::neg_zero())
}
fn min_value() -> Self {
Self(f32::min_value())
}
fn min_positive_value() -> Self {
Self(f32::min_positive_value())
}
fn max_value() -> Self {
Self(f32::max_value())
}
fn is_nan(self) -> bool {
self.0.is_nan()
}
fn is_infinite(self) -> bool {
self.0.is_infinite()
}
fn is_finite(self) -> bool {
self.0.is_finite()
}
fn is_normal(self) -> bool {
self.0.is_normal()
}
fn classify(self) -> std::num::FpCategory {
self.0.classify()
}
fn floor(self) -> Self {
Self(self.0.floor())
}
fn ceil(self) -> Self {
Self(self.0.ceil())
}
fn round(self) -> Self {
Self(self.0.round())
}
fn trunc(self) -> Self {
Self(self.0.trunc())
}
fn fract(self) -> Self {
Self(self.0.fract())
}
fn abs(self) -> Self {
Self(self.0.abs())
}
fn signum(self) -> Self {
Self(self.0.signum())
}
fn is_sign_positive(self) -> bool {
self.0.is_sign_positive()
}
fn is_sign_negative(self) -> bool {
self.0.is_sign_negative()
}
fn mul_add(self, a: Self, b: Self) -> Self {
Self(self.0.mul_add(a.0, b.0))
}
fn recip(self) -> Self {
Self(self.0.recip())
}
fn powi(self, n: i32) -> Self {
Self(self.0.powi(n))
}
fn powf(self, n: Self) -> Self {
Self(self.0.powf(n.0))
}
fn sqrt(self) -> Self {
Self(self.0.sqrt())
}
fn exp(self) -> Self {
Self(self.0.exp())
}
fn exp2(self) -> Self {
Self(self.0.exp2())
}
fn ln(self) -> Self {
Self(self.0.ln())
}
fn log(self, base: Self) -> Self {
Self(self.0.log(base.0))
}
fn log2(self) -> Self {
Self(self.0.log2())
}
fn log10(self) -> Self {
Self(self.0.log10())
}
fn max(self, other: Self) -> Self {
Self(self.0.max(other.0))
}
fn min(self, other: Self) -> Self {
Self(self.0.min(other.0))
}
fn abs_sub(self, _: Self) -> Self {
unimplemented!()
}
fn cbrt(self) -> Self {
Self(self.0.cbrt())
}
fn hypot(self, other: Self) -> Self {
Self(self.0.hypot(other.0))
}
fn sin(self) -> Self {
Self(self.0.sin())
}
fn cos(self) -> Self {
Self(self.0.cos())
}
fn tan(self) -> Self {
Self(self.0.tan())
}
fn asin(self) -> Self {
Self(self.0.asin())
}
fn acos(self) -> Self {
Self(self.0.acos())
}
fn atan(self) -> Self {
Self(self.0.atan())
}
fn atan2(self, other: Self) -> Self {
Self(self.0.atan2(other.0))
}
fn sin_cos(self) -> (Self, Self) {
let (_x, _y) = self.0.sin_cos();
(Self(_x), Self(_y))
}
fn exp_m1(self) -> Self {
Self(self.0.exp_m1())
}
fn ln_1p(self) -> Self {
Self(self.0.ln_1p())
}
fn sinh(self) -> Self {
Self(self.0.sinh())
}
fn cosh(self) -> Self {
Self(self.0.cosh())
}
fn tanh(self) -> Self {
Self(self.0.tanh())
}
fn asinh(self) -> Self {
Self(self.0.asinh())
}
fn acosh(self) -> Self {
Self(self.0.acosh())
}
fn atanh(self) -> Self {
Self(self.0.atanh())
}
fn integer_decode(self) -> (u64, i16, i8) {
self.0.integer_decode()
}
fn epsilon() -> Self {
Self(f32::EPSILON)
}
fn is_subnormal(self) -> bool {
self.0.classify() == std::num::FpCategory::Subnormal
}
fn to_degrees(self) -> Self {
Self(self.0.to_degrees())
}
fn to_radians(self) -> Self {
Self(self.0.to_radians())
}
fn copysign(self, sign: Self) -> Self {
Self(self.0.copysign(sign.0))
}
}
impl Add<F32> for F32 {
type Output = F32;
#[inline(always)]
fn add(self, rhs: F32) -> F32 {
unsafe { std::intrinsics::fadd_fast(self.0, rhs.0).into() }
}
}
impl AddAssign<F32> for F32 {
#[inline(always)]
fn add_assign(&mut self, rhs: F32) {
unsafe { self.0 = std::intrinsics::fadd_fast(self.0, rhs.0) }
}
}
impl Sub<F32> for F32 {
type Output = F32;
#[inline(always)]
fn sub(self, rhs: F32) -> F32 {
unsafe { std::intrinsics::fsub_fast(self.0, rhs.0).into() }
}
}
impl SubAssign<F32> for F32 {
#[inline(always)]
fn sub_assign(&mut self, rhs: F32) {
unsafe { self.0 = std::intrinsics::fsub_fast(self.0, rhs.0) }
}
}
impl Mul<F32> for F32 {
type Output = F32;
#[inline(always)]
fn mul(self, rhs: F32) -> F32 {
unsafe { std::intrinsics::fmul_fast(self.0, rhs.0).into() }
}
}
impl MulAssign<F32> for F32 {
#[inline(always)]
fn mul_assign(&mut self, rhs: F32) {
unsafe { self.0 = std::intrinsics::fmul_fast(self.0, rhs.0) }
}
}
impl Div<F32> for F32 {
type Output = F32;
#[inline(always)]
fn div(self, rhs: F32) -> F32 {
unsafe { std::intrinsics::fdiv_fast(self.0, rhs.0).into() }
}
}
impl DivAssign<F32> for F32 {
#[inline(always)]
fn div_assign(&mut self, rhs: F32) {
unsafe { self.0 = std::intrinsics::fdiv_fast(self.0, rhs.0) }
}
}
impl Rem<F32> for F32 {
type Output = F32;
#[inline(always)]
fn rem(self, rhs: F32) -> F32 {
unsafe { std::intrinsics::frem_fast(self.0, rhs.0).into() }
}
}
impl RemAssign<F32> for F32 {
#[inline(always)]
fn rem_assign(&mut self, rhs: F32) {
unsafe { self.0 = std::intrinsics::frem_fast(self.0, rhs.0) }
}
}
impl Neg for F32 {
type Output = Self;
fn neg(self) -> Self::Output {
Self(self.0.neg())
}
}
impl FromStr for F32 {
type Err = ParseFloatError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
f32::from_str(s).map(|x| x.into())
}
}
impl FloatCast for F32 {
fn from_f32(x: f32) -> Self {
Self(x)
}
fn to_f32(self) -> f32 {
self.0
}
}
impl From<f32> for F32 {
fn from(value: f32) -> Self {
Self(value)
}
}
impl From<F32> for f32 {
fn from(F32(float): F32) -> Self {
float
}
}
impl Add<f32> for F32 {
type Output = F32;
#[inline(always)]
fn add(self, rhs: f32) -> F32 {
unsafe { std::intrinsics::fadd_fast(self.0, rhs).into() }
}
}
impl AddAssign<f32> for F32 {
fn add_assign(&mut self, rhs: f32) {
unsafe { self.0 = std::intrinsics::fadd_fast(self.0, rhs) }
}
}
impl Sub<f32> for F32 {
type Output = F32;
#[inline(always)]
fn sub(self, rhs: f32) -> F32 {
unsafe { std::intrinsics::fsub_fast(self.0, rhs).into() }
}
}
impl SubAssign<f32> for F32 {
#[inline(always)]
fn sub_assign(&mut self, rhs: f32) {
unsafe { self.0 = std::intrinsics::fsub_fast(self.0, rhs) }
}
}
impl Mul<f32> for F32 {
type Output = F32;
#[inline(always)]
fn mul(self, rhs: f32) -> F32 {
unsafe { std::intrinsics::fmul_fast(self.0, rhs).into() }
}
}
impl MulAssign<f32> for F32 {
#[inline(always)]
fn mul_assign(&mut self, rhs: f32) {
unsafe { self.0 = std::intrinsics::fmul_fast(self.0, rhs) }
}
}
impl Div<f32> for F32 {
type Output = F32;
#[inline(always)]
fn div(self, rhs: f32) -> F32 {
unsafe { std::intrinsics::fdiv_fast(self.0, rhs).into() }
}
}
impl DivAssign<f32> for F32 {
#[inline(always)]
fn div_assign(&mut self, rhs: f32) {
unsafe { self.0 = std::intrinsics::fdiv_fast(self.0, rhs) }
}
}
impl Rem<f32> for F32 {
type Output = F32;
#[inline(always)]
fn rem(self, rhs: f32) -> F32 {
unsafe { std::intrinsics::frem_fast(self.0, rhs).into() }
}
}
impl RemAssign<f32> for F32 {
#[inline(always)]
fn rem_assign(&mut self, rhs: f32) {
unsafe { self.0 = std::intrinsics::frem_fast(self.0, rhs) }
}
}

View File

@@ -0,0 +1,5 @@
mod f16;
mod f32;
pub use f16::F16;
pub use f32::F32;

View File

@@ -3,15 +3,10 @@ use std::{fmt::Display, num::ParseIntError, str::FromStr};
#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)] #[derive(Debug, Clone, Copy, Hash, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
pub struct Id { pub struct Id {
newtype: u32, pub newtype: u32,
} }
impl Id { impl Id {
pub fn from_sys(sys: pgrx::pg_sys::Oid) -> Self {
Self {
newtype: sys.as_u32(),
}
}
pub fn as_u32(self) -> u32 { pub fn as_u32(self) -> u32 {
self.newtype self.newtype
} }
@@ -35,26 +30,10 @@ impl FromStr for Id {
#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq, Serialize, Deserialize)] #[derive(Debug, Clone, Copy, Hash, PartialEq, Eq, Serialize, Deserialize)]
pub struct Pointer { pub struct Pointer {
newtype: u64, pub newtype: u64,
} }
impl Pointer { impl Pointer {
pub fn from_sys(sys: pgrx::pg_sys::ItemPointerData) -> Self {
let mut newtype = 0;
newtype |= (sys.ip_blkid.bi_hi as u64) << 32;
newtype |= (sys.ip_blkid.bi_lo as u64) << 16;
newtype |= sys.ip_posid as u64;
Self { newtype }
}
pub fn into_sys(self) -> pgrx::pg_sys::ItemPointerData {
pgrx::pg_sys::ItemPointerData {
ip_blkid: pgrx::pg_sys::BlockIdData {
bi_hi: ((self.newtype >> 32) & 0xffff) as u16,
bi_lo: ((self.newtype >> 16) & 0xffff) as u16,
},
ip_posid: (self.newtype & 0xffff) as u16,
}
}
pub fn from_u48(value: u64) -> Self { pub fn from_u48(value: u64) -> Self {
assert!(value < (1u64 << 48)); assert!(value < (1u64 << 48));
Self { newtype: value } Self { newtype: value }

View File

@@ -0,0 +1,26 @@
use std::cell::UnsafeCell;
#[repr(transparent)]
pub struct SyncUnsafeCell<T: ?Sized> {
value: UnsafeCell<T>,
}
unsafe impl<T: ?Sized + Sync> Sync for SyncUnsafeCell<T> {}
impl<T> SyncUnsafeCell<T> {
pub const fn new(value: T) -> Self {
Self {
value: UnsafeCell::new(value),
}
}
}
impl<T: ?Sized> SyncUnsafeCell<T> {
pub fn get(&self) -> *mut T {
self.value.get()
}
pub fn get_mut(&mut self) -> &mut T {
self.value.get_mut()
}
}

View File

@@ -0,0 +1 @@
pub mod x86_64;

View File

@@ -0,0 +1,85 @@
#![cfg(target_arch = "x86_64")]
use std::sync::atomic::{AtomicBool, Ordering};
static ATOMIC_AVX512FP16: AtomicBool = AtomicBool::new(false);
pub fn test_avx512fp16() -> bool {
std_detect::is_x86_feature_detected!("avx512fp16") && test_v4()
}
#[ctor::ctor]
fn ctor_avx512fp16() {
ATOMIC_AVX512FP16.store(test_avx512fp16(), Ordering::Relaxed);
}
pub fn detect_avx512fp16() -> bool {
ATOMIC_AVX512FP16.load(Ordering::Relaxed)
}
static ATOMIC_V4: AtomicBool = AtomicBool::new(false);
pub fn test_v4() -> bool {
std_detect::is_x86_feature_detected!("avx512bw")
&& std_detect::is_x86_feature_detected!("avx512cd")
&& std_detect::is_x86_feature_detected!("avx512dq")
&& std_detect::is_x86_feature_detected!("avx512f")
&& std_detect::is_x86_feature_detected!("avx512vl")
&& test_v3()
}
#[ctor::ctor]
fn ctor_v4() {
ATOMIC_V4.store(test_v4(), Ordering::Relaxed);
}
pub fn _detect_v4() -> bool {
ATOMIC_V4.load(Ordering::Relaxed)
}
static ATOMIC_V3: AtomicBool = AtomicBool::new(false);
pub fn test_v3() -> bool {
std_detect::is_x86_feature_detected!("avx")
&& std_detect::is_x86_feature_detected!("avx2")
&& std_detect::is_x86_feature_detected!("bmi1")
&& std_detect::is_x86_feature_detected!("bmi2")
&& std_detect::is_x86_feature_detected!("f16c")
&& std_detect::is_x86_feature_detected!("fma")
&& std_detect::is_x86_feature_detected!("lzcnt")
&& std_detect::is_x86_feature_detected!("movbe")
&& std_detect::is_x86_feature_detected!("xsave")
&& test_v2()
}
#[ctor::ctor]
fn ctor_v3() {
ATOMIC_V3.store(test_v3(), Ordering::Relaxed);
}
pub fn detect_v3() -> bool {
ATOMIC_V3.load(Ordering::Relaxed)
}
static ATOMIC_V2: AtomicBool = AtomicBool::new(false);
pub fn test_v2() -> bool {
std_detect::is_x86_feature_detected!("cmpxchg16b")
&& std_detect::is_x86_feature_detected!("fxsr")
&& std_detect::is_x86_feature_detected!("popcnt")
&& std_detect::is_x86_feature_detected!("sse")
&& std_detect::is_x86_feature_detected!("sse2")
&& std_detect::is_x86_feature_detected!("sse3")
&& std_detect::is_x86_feature_detected!("sse4.1")
&& std_detect::is_x86_feature_detected!("sse4.2")
&& std_detect::is_x86_feature_detected!("ssse3")
}
#[ctor::ctor]
fn ctor_v2() {
ATOMIC_V2.store(test_v2(), Ordering::Relaxed);
}
pub fn _detect_v2() -> bool {
ATOMIC_V2.load(Ordering::Relaxed)
}

View File

@@ -111,9 +111,11 @@ fn read_information(mut file: &File) -> Information {
unsafe fn read_mmap(file: &File, len: usize) -> memmap2::Mmap { unsafe fn read_mmap(file: &File, len: usize) -> memmap2::Mmap {
let len = len.next_multiple_of(4096); let len = len.next_multiple_of(4096);
memmap2::MmapOptions::new() unsafe {
.populate() memmap2::MmapOptions::new()
.len(len) .populate()
.map(file) .len(len)
.unwrap() .map(file)
.unwrap()
}
} }

View File

@@ -0,0 +1,8 @@
pub mod cells;
pub mod clean;
pub mod detect;
pub mod dir_ops;
pub mod file_atomic;
pub mod file_wal;
pub mod mmap_array;
pub mod vec2;

View File

@@ -2,16 +2,16 @@ use crate::prelude::*;
use std::ops::{Deref, DerefMut, Index, IndexMut}; use std::ops::{Deref, DerefMut, Index, IndexMut};
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] #[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct Vec2 { pub struct Vec2<S: G> {
dims: u16, dims: u16,
v: Box<[Scalar]>, v: Vec<S::Scalar>,
} }
impl Vec2 { impl<S: G> Vec2<S> {
pub fn new(dims: u16, n: usize) -> Self { pub fn new(dims: u16, n: usize) -> Self {
Self { Self {
dims, dims,
v: bytemuck::zeroed_slice_box(dims as usize * n), v: bytemuck::zeroed_vec(dims as usize * n),
} }
} }
pub fn dims(&self) -> u16 { pub fn dims(&self) -> u16 {
@@ -32,29 +32,29 @@ impl Vec2 {
} }
} }
impl Index<usize> for Vec2 { impl<S: G> Index<usize> for Vec2<S> {
type Output = [Scalar]; type Output = [S::Scalar];
fn index(&self, index: usize) -> &Self::Output { fn index(&self, index: usize) -> &Self::Output {
&self.v[self.dims as usize * index..][..self.dims as usize] &self.v[self.dims as usize * index..][..self.dims as usize]
} }
} }
impl IndexMut<usize> for Vec2 { impl<S: G> IndexMut<usize> for Vec2<S> {
fn index_mut(&mut self, index: usize) -> &mut Self::Output { fn index_mut(&mut self, index: usize) -> &mut Self::Output {
&mut self.v[self.dims as usize * index..][..self.dims as usize] &mut self.v[self.dims as usize * index..][..self.dims as usize]
} }
} }
impl Deref for Vec2 { impl<S: G> Deref for Vec2<S> {
type Target = [Scalar]; type Target = [S::Scalar];
fn deref(&self) -> &Self::Target { fn deref(&self) -> &Self::Target {
self.v.deref() self.v.deref()
} }
} }
impl DerefMut for Vec2 { impl<S: G> DerefMut for Vec2<S> {
fn deref_mut(&mut self) -> &mut Self::Target { fn deref_mut(&mut self) -> &mut Self::Target {
self.v.deref_mut() self.v.deref_mut()
} }

View File

@@ -0,0 +1,248 @@
use crate::index::Index;
use crate::index::IndexOptions;
use crate::index::IndexStat;
use crate::index::IndexView;
use crate::index::OutdatedError;
use crate::prelude::*;
use std::path::PathBuf;
use std::sync::Arc;
#[derive(Clone)]
pub enum Instance {
F32Cos(Arc<Index<F32Cos>>),
F32Dot(Arc<Index<F32Dot>>),
F32L2(Arc<Index<F32L2>>),
F16Cos(Arc<Index<F16Cos>>),
F16Dot(Arc<Index<F16Dot>>),
F16L2(Arc<Index<F16L2>>),
}
impl Instance {
pub fn create(path: PathBuf, options: IndexOptions) -> Self {
match (options.vector.d, options.vector.k) {
(Distance::Cos, Kind::F32) => Self::F32Cos(Index::create(path, options)),
(Distance::Dot, Kind::F32) => Self::F32Dot(Index::create(path, options)),
(Distance::L2, Kind::F32) => Self::F32L2(Index::create(path, options)),
(Distance::Cos, Kind::F16) => Self::F16Cos(Index::create(path, options)),
(Distance::Dot, Kind::F16) => Self::F16Dot(Index::create(path, options)),
(Distance::L2, Kind::F16) => Self::F16L2(Index::create(path, options)),
}
}
pub fn open(path: PathBuf, options: IndexOptions) -> Self {
match (options.vector.d, options.vector.k) {
(Distance::Cos, Kind::F32) => Self::F32Cos(Index::open(path, options)),
(Distance::Dot, Kind::F32) => Self::F32Dot(Index::open(path, options)),
(Distance::L2, Kind::F32) => Self::F32L2(Index::open(path, options)),
(Distance::Cos, Kind::F16) => Self::F16Cos(Index::open(path, options)),
(Distance::Dot, Kind::F16) => Self::F16Dot(Index::open(path, options)),
(Distance::L2, Kind::F16) => Self::F16L2(Index::open(path, options)),
}
}
pub fn options(&self) -> &IndexOptions {
match self {
Instance::F32Cos(x) => x.options(),
Instance::F32Dot(x) => x.options(),
Instance::F32L2(x) => x.options(),
Instance::F16Cos(x) => x.options(),
Instance::F16Dot(x) => x.options(),
Instance::F16L2(x) => x.options(),
}
}
pub fn refresh(&self) {
match self {
Instance::F32Cos(x) => x.refresh(),
Instance::F32Dot(x) => x.refresh(),
Instance::F32L2(x) => x.refresh(),
Instance::F16Cos(x) => x.refresh(),
Instance::F16Dot(x) => x.refresh(),
Instance::F16L2(x) => x.refresh(),
}
}
pub fn view(&self) -> InstanceView {
match self {
Instance::F32Cos(x) => InstanceView::F32Cos(x.view()),
Instance::F32Dot(x) => InstanceView::F32Dot(x.view()),
Instance::F32L2(x) => InstanceView::F32L2(x.view()),
Instance::F16Cos(x) => InstanceView::F16Cos(x.view()),
Instance::F16Dot(x) => InstanceView::F16Dot(x.view()),
Instance::F16L2(x) => InstanceView::F16L2(x.view()),
}
}
pub fn stat(&self) -> IndexStat {
match self {
Instance::F32Cos(x) => x.stat(),
Instance::F32Dot(x) => x.stat(),
Instance::F32L2(x) => x.stat(),
Instance::F16Cos(x) => x.stat(),
Instance::F16Dot(x) => x.stat(),
Instance::F16L2(x) => x.stat(),
}
}
}
pub enum InstanceView {
F32Cos(Arc<IndexView<F32Cos>>),
F32Dot(Arc<IndexView<F32Dot>>),
F32L2(Arc<IndexView<F32L2>>),
F16Cos(Arc<IndexView<F16Cos>>),
F16Dot(Arc<IndexView<F16Dot>>),
F16L2(Arc<IndexView<F16L2>>),
}
impl InstanceView {
pub fn search<F: FnMut(Pointer) -> bool>(
&self,
k: usize,
vector: DynamicVector,
filter: F,
) -> Result<Vec<Pointer>, FriendlyError> {
match (self, vector) {
(InstanceView::F32Cos(x), DynamicVector::F32(vector)) => {
if x.options.vector.dims as usize != vector.len() {
return Err(FriendlyError::Unmatched2);
}
Ok(x.search(k, &vector, filter))
}
(InstanceView::F32Dot(x), DynamicVector::F32(vector)) => {
if x.options.vector.dims as usize != vector.len() {
return Err(FriendlyError::Unmatched2);
}
Ok(x.search(k, &vector, filter))
}
(InstanceView::F32L2(x), DynamicVector::F32(vector)) => {
if x.options.vector.dims as usize != vector.len() {
return Err(FriendlyError::Unmatched2);
}
Ok(x.search(k, &vector, filter))
}
(InstanceView::F16Cos(x), DynamicVector::F16(vector)) => {
if x.options.vector.dims as usize != vector.len() {
return Err(FriendlyError::Unmatched2);
}
Ok(x.search(k, &vector, filter))
}
(InstanceView::F16Dot(x), DynamicVector::F16(vector)) => {
if x.options.vector.dims as usize != vector.len() {
return Err(FriendlyError::Unmatched2);
}
Ok(x.search(k, &vector, filter))
}
(InstanceView::F16L2(x), DynamicVector::F16(vector)) => {
if x.options.vector.dims as usize != vector.len() {
return Err(FriendlyError::Unmatched2);
}
Ok(x.search(k, &vector, filter))
}
_ => Err(FriendlyError::Unmatched2),
}
}
pub fn vbase(
&self,
vector: DynamicVector,
) -> Result<impl Iterator<Item = Pointer> + '_, FriendlyError> {
match (self, vector) {
(InstanceView::F32Cos(x), DynamicVector::F32(vector)) => {
if x.options.vector.dims as usize != vector.len() {
return Err(FriendlyError::Unmatched2);
}
Ok(Box::new(x.vbase(&vector)) as Box<dyn Iterator<Item = Pointer>>)
}
(InstanceView::F32Dot(x), DynamicVector::F32(vector)) => {
if x.options.vector.dims as usize != vector.len() {
return Err(FriendlyError::Unmatched2);
}
Ok(Box::new(x.vbase(&vector)))
}
(InstanceView::F32L2(x), DynamicVector::F32(vector)) => {
if x.options.vector.dims as usize != vector.len() {
return Err(FriendlyError::Unmatched2);
}
Ok(Box::new(x.vbase(&vector)))
}
(InstanceView::F16Cos(x), DynamicVector::F16(vector)) => {
if x.options.vector.dims as usize != vector.len() {
return Err(FriendlyError::Unmatched2);
}
Ok(Box::new(x.vbase(&vector)))
}
(InstanceView::F16Dot(x), DynamicVector::F16(vector)) => {
if x.options.vector.dims as usize != vector.len() {
return Err(FriendlyError::Unmatched2);
}
Ok(Box::new(x.vbase(&vector)))
}
(InstanceView::F16L2(x), DynamicVector::F16(vector)) => {
if x.options.vector.dims as usize != vector.len() {
return Err(FriendlyError::Unmatched2);
}
Ok(Box::new(x.vbase(&vector)))
}
_ => Err(FriendlyError::Unmatched2),
}
}
pub fn insert(
&self,
vector: DynamicVector,
pointer: Pointer,
) -> Result<Result<(), OutdatedError>, FriendlyError> {
match (self, vector) {
(InstanceView::F32Cos(x), DynamicVector::F32(vector)) => {
if x.options.vector.dims as usize != vector.len() {
return Err(FriendlyError::Unmatched2);
}
Ok(x.insert(vector, pointer))
}
(InstanceView::F32Dot(x), DynamicVector::F32(vector)) => {
if x.options.vector.dims as usize != vector.len() {
return Err(FriendlyError::Unmatched2);
}
Ok(x.insert(vector, pointer))
}
(InstanceView::F32L2(x), DynamicVector::F32(vector)) => {
if x.options.vector.dims as usize != vector.len() {
return Err(FriendlyError::Unmatched2);
}
Ok(x.insert(vector, pointer))
}
(InstanceView::F16Cos(x), DynamicVector::F16(vector)) => {
if x.options.vector.dims as usize != vector.len() {
return Err(FriendlyError::Unmatched2);
}
Ok(x.insert(vector, pointer))
}
(InstanceView::F16Dot(x), DynamicVector::F16(vector)) => {
if x.options.vector.dims as usize != vector.len() {
return Err(FriendlyError::Unmatched2);
}
Ok(x.insert(vector, pointer))
}
(InstanceView::F16L2(x), DynamicVector::F16(vector)) => {
if x.options.vector.dims as usize != vector.len() {
return Err(FriendlyError::Unmatched2);
}
Ok(x.insert(vector, pointer))
}
_ => Err(FriendlyError::Unmatched2),
}
}
pub fn delete<F: FnMut(Pointer) -> bool>(&self, f: F) {
match self {
InstanceView::F32Cos(x) => x.delete(f),
InstanceView::F32Dot(x) => x.delete(f),
InstanceView::F32L2(x) => x.delete(f),
InstanceView::F16Cos(x) => x.delete(f),
InstanceView::F16Dot(x) => x.delete(f),
InstanceView::F16L2(x) => x.delete(f),
}
}
pub fn flush(&self) {
match self {
InstanceView::F32Cos(x) => x.flush(),
InstanceView::F32Dot(x) => x.flush(),
InstanceView::F32L2(x) => x.flush(),
InstanceView::F16Cos(x) => x.flush(),
InstanceView::F16Dot(x) => x.flush(),
InstanceView::F16L2(x) => x.flush(),
}
}
}

View File

@@ -1,7 +1,9 @@
use crate::index::Index; pub mod instance;
use crate::index::IndexInsertError;
use self::instance::Instance;
use crate::index::IndexOptions; use crate::index::IndexOptions;
use crate::index::IndexSearchError; use crate::index::IndexStat;
use crate::index::OutdatedError;
use crate::prelude::*; use crate::prelude::*;
use crate::utils::clean::clean; use crate::utils::clean::clean;
use crate::utils::dir_ops::sync_dir; use crate::utils::dir_ops::sync_dir;
@@ -57,7 +59,7 @@ impl Worker {
let mut indexes = HashMap::new(); let mut indexes = HashMap::new();
for (&id, options) in startup.get().indexes.iter() { for (&id, options) in startup.get().indexes.iter() {
let path = path.join("indexes").join(id.to_string()); let path = path.join("indexes").join(id.to_string());
let index = Index::open(path, options.clone()); let index = Instance::open(path, options.clone());
indexes.insert(id, index); indexes.insert(id, index);
} }
let view = Arc::new(WorkerView { let view = Arc::new(WorkerView {
@@ -72,7 +74,7 @@ impl Worker {
} }
pub fn call_create(&self, id: Id, options: IndexOptions) { pub fn call_create(&self, id: Id, options: IndexOptions) {
let mut protect = self.protect.lock(); let mut protect = self.protect.lock();
let index = Index::create(self.path.join("indexes").join(id.to_string()), options); let index = Instance::create(self.path.join("indexes").join(id.to_string()), options);
if protect.indexes.insert(id, index).is_some() { if protect.indexes.insert(id, index).is_some() {
panic!("index {} already exists", id) panic!("index {} already exists", id)
} }
@@ -81,44 +83,29 @@ impl Worker {
pub fn call_search<F>( pub fn call_search<F>(
&self, &self,
id: Id, id: Id,
search: (Vec<Scalar>, usize), search: (DynamicVector, usize),
filter: F, filter: F,
) -> Result<Vec<Pointer>, FriendlyError> ) -> Result<Vec<Pointer>, FriendlyError>
where where
F: FnMut(Pointer) -> bool, F: FnMut(Pointer) -> bool,
{ {
let view = self.view.load_full(); let view = self.view.load_full();
let index = view.indexes.get(&id).ok_or(FriendlyError::Index404)?; let index = view.indexes.get(&id).ok_or(FriendlyError::UnknownIndex)?;
let view = index.view(); let view = index.view();
match view.search(search.1, &search.0, filter) { view.search(search.1, search.0, filter)
Ok(x) => Ok(x),
Err(IndexSearchError::InvalidVector(x)) => Err(FriendlyError::BadVector(x)),
}
} }
pub fn call_search_vbase<F>( pub fn call_insert(
&self, &self,
id: Id, id: Id,
search: (Vec<Scalar>, usize), insert: (DynamicVector, Pointer),
next: F, ) -> Result<(), FriendlyError> {
) -> Result<(), FriendlyError>
where
F: FnMut(Pointer) -> bool,
{
let view = self.view.load_full(); let view = self.view.load_full();
let index = view.indexes.get(&id).ok_or(FriendlyError::Index404)?; let index = view.indexes.get(&id).ok_or(FriendlyError::UnknownIndex)?;
let view = index.view();
view.search_vbase(search.1, &search.0, next)
.map_err(|IndexSearchError::InvalidVector(x)| FriendlyError::BadVector(x))
}
pub fn call_insert(&self, id: Id, insert: (Vec<Scalar>, Pointer)) -> Result<(), FriendlyError> {
let view = self.view.load_full();
let index = view.indexes.get(&id).ok_or(FriendlyError::Index404)?;
loop { loop {
let view = index.view(); let view = index.view();
match view.insert(insert.0.clone(), insert.1) { match view.insert(insert.0.clone(), insert.1)? {
Ok(()) => break Ok(()), Ok(()) => break Ok(()),
Err(IndexInsertError::InvalidVector(x)) => break Err(FriendlyError::BadVector(x)), Err(OutdatedError(_)) => index.refresh(),
Err(IndexInsertError::OutdatedView(_)) => index.refresh(),
} }
} }
} }
@@ -127,16 +114,16 @@ impl Worker {
F: FnMut(Pointer) -> bool, F: FnMut(Pointer) -> bool,
{ {
let view = self.view.load_full(); let view = self.view.load_full();
let index = view.indexes.get(&id).ok_or(FriendlyError::Index404)?; let index = view.indexes.get(&id).ok_or(FriendlyError::UnknownIndex)?;
let view = index.view(); let view = index.view();
view.delete(f); view.delete(f);
Ok(()) Ok(())
} }
pub fn call_flush(&self, id: Id) -> Result<(), FriendlyError> { pub fn call_flush(&self, id: Id) -> Result<(), FriendlyError> {
let view = self.view.load_full(); let view = self.view.load_full();
let index = view.indexes.get(&id).ok_or(FriendlyError::Index404)?; let index = view.indexes.get(&id).ok_or(FriendlyError::UnknownIndex)?;
let view = index.view(); let view = index.view();
view.flush().unwrap(); view.flush();
Ok(()) Ok(())
} }
pub fn call_destory(&self, ids: Vec<Id>) { pub fn call_destory(&self, ids: Vec<Id>) {
@@ -149,44 +136,25 @@ impl Worker {
protect.maintain(&self.view); protect.maintain(&self.view);
} }
} }
pub fn call_stat(&self, id: Id) -> Result<VectorIndexInfo, FriendlyError> { pub fn call_stat(&self, id: Id) -> Result<IndexStat, FriendlyError> {
let view = self.view.load_full(); let view = self.view.load_full();
let index = view.indexes.get(&id).ok_or(FriendlyError::Index404)?; let index = view.indexes.get(&id).ok_or(FriendlyError::UnknownIndex)?;
let view = index.view(); Ok(index.stat())
let idx_sealed_len = view.sealed_len(); }
let idx_growing_len = view.growing_len(); pub fn get_instance(&self, id: Id) -> Result<Instance, FriendlyError> {
let idx_write = view.write_len(); let view = self.view.load_full();
let res = VectorIndexInfo { let index = view.indexes.get(&id).ok_or(FriendlyError::UnknownIndex)?;
indexing: index.indexing(), Ok(index.clone())
idx_tuples: (idx_write + idx_sealed_len + idx_growing_len)
.try_into()
.unwrap(),
idx_sealed_len: idx_sealed_len.try_into().unwrap(),
idx_growing_len: idx_growing_len.try_into().unwrap(),
idx_write: idx_write.try_into().unwrap(),
idx_sealed: view
.sealed_len_vec()
.into_iter()
.map(|x| x.try_into().unwrap())
.collect(),
idx_growing: view
.growing_len_vec()
.into_iter()
.map(|x| x.try_into().unwrap())
.collect(),
idx_config: serde_json::to_string(index.options()).unwrap(),
};
Ok(res)
} }
} }
struct WorkerView { struct WorkerView {
indexes: HashMap<Id, Arc<Index>>, indexes: HashMap<Id, Instance>,
} }
struct WorkerProtect { struct WorkerProtect {
startup: FileAtomic<WorkerStartup>, startup: FileAtomic<WorkerStartup>,
indexes: HashMap<Id, Arc<Index>>, indexes: HashMap<Id, Instance>,
} }
impl WorkerProtect { impl WorkerProtect {

View File

@@ -11,7 +11,7 @@ Why not just use Postgres to do the vector similarity search? This is the reason
UPDATE documents SET embedding = ai_embedding_vector(content) WHERE length(embedding) = 0; UPDATE documents SET embedding = ai_embedding_vector(content) WHERE length(embedding) = 0;
-- Create an index on the embedding column -- Create an index on the embedding column
CREATE INDEX ON documents USING vectors (embedding l2_ops); CREATE INDEX ON documents USING vectors (embedding vector_l2_ops);
-- Query the similar embeddings -- Query the similar embeddings
SELECT * FROM documents ORDER BY embedding <-> ai_embedding_vector('hello world') LIMIT 5; SELECT * FROM documents ORDER BY embedding <-> ai_embedding_vector('hello world') LIMIT 5;

View File

@@ -46,9 +46,9 @@ We support three operators to calculate the distance between two vectors.
-- squared Euclidean distance -- squared Euclidean distance
SELECT '[1, 2, 3]'::vector <-> '[3, 2, 1]'::vector; SELECT '[1, 2, 3]'::vector <-> '[3, 2, 1]'::vector;
-- negative dot product -- negative dot product
SELECT '[1, 2, 3]' <#> '[3, 2, 1]'; SELECT '[1, 2, 3]'::vector <#> '[3, 2, 1]'::vector;
-- negative cosine similarity -- negative cosine similarity
SELECT '[1, 2, 3]' <=> '[3, 2, 1]'; SELECT '[1, 2, 3]'::vector <=> '[3, 2, 1]'::vector;
``` ```
You can search for a vector simply like this. You can search for a vector simply like this.
@@ -58,6 +58,10 @@ You can search for a vector simply like this.
SELECT * FROM items ORDER BY embedding <-> '[3,2,1]' LIMIT 5; SELECT * FROM items ORDER BY embedding <-> '[3,2,1]' LIMIT 5;
``` ```
## Half-precision floating-point
`vecf16` type is the same with `vector` in anything but the scalar type. It stores 16-bit floating point numbers. If you want to reduce the memory usage to get better performace, you can try to replace `vector` type with `vecf16` type.
## Things You Need to Know ## Things You Need to Know
`vector(n)` is a valid data type only if $1 \leq n \leq 65535$. Due to limits of PostgreSQL, it's possible to create a value of type `vector(3)` of $5$ dimensions and `vector` is also a valid data. However, you cannot still put $0$ scalar or more than $65535$ scalars to a vector. If you use `vector` for a column or there is some values mismatched with dimension denoted by the column, you won't able to create an index on it. `vector(n)` is a valid data type only if $1 \leq n \leq 65535$. Due to limits of PostgreSQL, it's possible to create a value of type `vector(3)` of $5$ dimensions and `vector` is also a valid data. However, you cannot still put $0$ scalar or more than $65535$ scalars to a vector. If you use `vector` for a column or there is some values mismatched with dimension denoted by the column, you won't able to create an index on it.

View File

@@ -5,11 +5,19 @@ Indexing is the core ability of pgvecto.rs.
Assuming there is a table `items` and there is a column named `embedding` of type `vector(n)`, you can create a vector index for squared Euclidean distance with the following SQL. Assuming there is a table `items` and there is a column named `embedding` of type `vector(n)`, you can create a vector index for squared Euclidean distance with the following SQL.
```sql ```sql
CREATE INDEX ON items USING vectors (embedding l2_ops); CREATE INDEX ON items USING vectors (embedding vector_l2_ops);
``` ```
For negative dot product, replace `l2_ops` with `dot_ops`. There is a table for you to choose a proper operator class for creating indexes.
For negative cosine similarity, replace `l2_ops` with `cosine_ops`.
| Vector type | Distance type | Operator class |
| ----------- | -------------------------- | -------------- |
| vector | squared Euclidean distance | vector_l2_ops |
| vector | negative dot product | vector_dot_ops |
| vector | negative cosine similarity | vector_cos_ops |
| vecf16 | squared Euclidean distance | vecf16_l2_ops |
| vecf16 | negative dot product | vecf16_dot_ops |
| vecf16 | negative cosine similarity | vecf16_cos_ops |
Now you can perform a KNN search with the following SQL again, but this time the vector index is used for searching. Now you can perform a KNN search with the following SQL again, but this time the vector index is used for searching.
@@ -36,14 +44,15 @@ Options for table `segment`.
| Key | Type | Description | | Key | Type | Description |
| ------------------------ | ------- | ------------------------------------------------------------------- | | ------------------------ | ------- | ------------------------------------------------------------------- |
| max_growing_segment_size | integer | Maximum size of unindexed vectors. Default value is `20_000`. | | max_growing_segment_size | integer | Maximum size of unindexed vectors. Default value is `20_000`. |
| min_sealed_segment_size | integer | Minimum size of vectors for indexing. Default value is `1_000`. |
| max_sealed_segment_size | integer | Maximum size of vectors for indexing. Default value is `1_000_000`. | | max_sealed_segment_size | integer | Maximum size of vectors for indexing. Default value is `1_000_000`. |
Options for table `optimizing`. Options for table `optimizing`.
| Key | Type | Description | | Key | Type | Description |
| ------------------ | ------- | --------------------------------------------------------------------------- | | ------------------ | ------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------- |
| optimizing_threads | integer | Maximum threads for indexing. Default value is the sqrt of number of cores. | | optimizing_threads | integer | Maximum threads for indexing. Default value is the sqrt of number of cores. |
| sealing_secs | integer | If a writing segment larger than `sealing_size` do not accept new data for `sealing_secs` seconds, the writing segment will be turned to a sealed segment. |
| sealing_size | integer | See above. |
Options for table `indexing`. Options for table `indexing`.
@@ -99,23 +108,19 @@ Options for table `product`.
## Progress View ## Progress View
We also provide a view `pg_vector_index_info` to monitor the progress of indexing. We also provide a view `pg_vector_index_info` to monitor the progress of indexing.
Note that whether idx_sealed_len is equal to idx_tuples doesn't relate to the completion of indexing.
It may do further optimization after indexing. It may also stop indexing because there are too few tuples left.
| Column | Type | Description | | Column | Type | Description |
| --------------- | ------ | --------------------------------------------- | | ------------ | ------ | --------------------------------------------- |
| tablerelid | oid | The oid of the table. | | tablerelid | oid | The oid of the table. |
| indexrelid | oid | The oid of the index. | | indexrelid | oid | The oid of the index. |
| tablename | name | The name of the table. | | tablename | name | The name of the table. |
| indexname | name | The name of the index. | | indexname | name | The name of the index. |
| indexing | bool | Whether the background thread is indexing. | | idx_indexing | bool | Whether the background thread is indexing. |
| idx_tuples | int4 | The number of tuples. | | idx_tuples | int8 | The number of tuples. |
| idx_sealed_len | int4 | The number of tuples in sealed segments. | | idx_sealed | int8[] | The number of tuples in each sealed segment. |
| idx_growing_len | int4 | The number of tuples in growing segments. | | idx_growing | int8[] | The number of tuples in each growing segment. |
| idx_write | int4 | The number of tuples in write buffer. | | idx_write | int8 | The number of tuples in write buffer. |
| idx_sealed | int4[] | The number of tuples in each sealed segment. | | idx_config | text | The configuration of the index. |
| idx_growing | int4[] | The number of tuples in each growing segment. |
| idx_config | text | The configuration of the index. |
## Examples ## Examples
@@ -124,11 +129,11 @@ There are some examples.
```sql ```sql
-- HNSW algorithm, default settings. -- HNSW algorithm, default settings.
CREATE INDEX ON items USING vectors (embedding l2_ops); CREATE INDEX ON items USING vectors (embedding vector_l2_ops);
--- Or using bruteforce with PQ. --- Or using bruteforce with PQ.
CREATE INDEX ON items USING vectors (embedding l2_ops) CREATE INDEX ON items USING vectors (embedding vector_l2_ops)
WITH (options = $$ WITH (options = $$
[indexing.flat] [indexing.flat]
quantization.product.ratio = "x16" quantization.product.ratio = "x16"
@@ -136,7 +141,7 @@ $$);
--- Or using IVFPQ algorithm. --- Or using IVFPQ algorithm.
CREATE INDEX ON items USING vectors (embedding l2_ops) CREATE INDEX ON items USING vectors (embedding vector_l2_ops)
WITH (options = $$ WITH (options = $$
[indexing.ivf] [indexing.ivf]
quantization.product.ratio = "x16" quantization.product.ratio = "x16"
@@ -144,14 +149,14 @@ $$);
-- Use more threads for background building the index. -- Use more threads for background building the index.
CREATE INDEX ON items USING vectors (embedding l2_ops) CREATE INDEX ON items USING vectors (embedding vector_l2_ops)
WITH (options = $$ WITH (options = $$
optimizing.optimizing_threads = 16 optimizing.optimizing_threads = 16
$$); $$);
-- Prefer smaller HNSW graph. -- Prefer smaller HNSW graph.
CREATE INDEX ON items USING vectors (embedding l2_ops) CREATE INDEX ON items USING vectors (embedding vector_l2_ops)
WITH (options = $$ WITH (options = $$
segment.max_growing_segment_size = 200000 segment.max_growing_segment_size = 200000
$$); $$);

View File

@@ -19,24 +19,49 @@ To acheive full performance, please mount the volume to pg data directory by add
You can configure PostgreSQL by the reference of the parent image in https://hub.docker.com/_/postgres/. You can configure PostgreSQL by the reference of the parent image in https://hub.docker.com/_/postgres/.
## Build from source ## Install from source
Install Rust and base dependency. Install Rust and base dependency.
```sh ```sh
sudo apt install -y build-essential libpq-dev libssl-dev pkg-config gcc libreadline-dev flex bison libxml2-dev libxslt-dev libxml2-utils xsltproc zlib1g-dev ccache clang git sudo apt install -y \
build-essential \
libpq-dev \
libssl-dev \
pkg-config \
gcc \
libreadline-dev \
flex \
bison \
libxml2-dev \
libxslt-dev \
libxml2-utils \
xsltproc \
zlib1g-dev \
ccache \
clang \
git
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh
``` ```
Install PostgreSQL. Install PostgreSQL.
```sh ```sh
sudo sh -c 'echo "deb http://apt.postgresql.org/pub/repos/apt $(lsb_release -cs)-pgdg main" > /etc/apt/sources.list.d/pgdg.list' sudo sh -c 'echo "deb http://apt.postgresql.org/pub/repos/apt $(lsb_release -cs)-pgdg main" >> /etc/apt/sources.list.d/pgdg.list'
wget --quiet -O - https://www.postgresql.org/media/keys/ACCC4CF8.asc | sudo apt-key add - wget --quiet -O - https://www.postgresql.org/media/keys/ACCC4CF8.asc | sudo apt-key add -
sudo apt-get update sudo apt-get update
sudo apt-get -y install libpq-dev postgresql-15 postgresql-server-dev-15 sudo apt-get -y install libpq-dev postgresql-15 postgresql-server-dev-15
``` ```
Install clang-16.
```sh
sudo sh -c 'echo "deb http://apt.llvm.org/$(lsb_release -cs)/ llvm-toolchain-$(lsb_release -cs)-16 main" >> /etc/apt/sources.list'
wget --quiet -O - https://apt.llvm.org/llvm-snapshot.gpg.key | sudo apt-key add -
sudo apt-get update
sudo apt-get -y install clang-16
```
Clone the Repository. Clone the Repository.
```sh ```sh
@@ -54,7 +79,7 @@ cargo pgrx init --pg15=/usr/lib/postgresql/15/bin/pg_config
Install pgvecto.rs. Install pgvecto.rs.
```sh ```sh
cargo pgrx install --release cargo pgrx install --sudo --release
``` ```
Configure your PostgreSQL by modifying the `shared_preload_libraries` to include `vectors.so`. Configure your PostgreSQL by modifying the `shared_preload_libraries` to include `vectors.so`.

View File

@@ -15,11 +15,12 @@ If `vectors.k` is set to `64`, but your SQL returned less than `64` rows, for ex
* The vector index returned `64` rows, but `32` of which are invisble to the transaction so PostgreSQL decided to hide these rows for you. * The vector index returned `64` rows, but `32` of which are invisble to the transaction so PostgreSQL decided to hide these rows for you.
* The vector index returned `64` rows, but `32` of which are satifying the condition `id % 2 = 0` in `WHERE` clause. * The vector index returned `64` rows, but `32` of which are satifying the condition `id % 2 = 0` in `WHERE` clause.
There are three ways to solve the problem: There are four ways to solve the problem:
* Set `vectors.k` larger. If you estimate that 20% of rows will satisfy the condition in `WHERE`, just set `vectors.k` to be 5 times than before. * Set `vectors.k` larger. If you estimate that 20% of rows will satisfy the condition in `WHERE`, just set `vectors.k` to be 5 times than before.
* Set `vectors.enable_vector_index` to `off`. If you estimate that 0.0001% of rows will satisfy the condition in `WHERE`, just do not use vector index. No alogrithms will be faster than brute force by PostgreSQL. * Set `vectors.enable_vector_index` to `off`. If you estimate that 0.0001% of rows will satisfy the condition in `WHERE`, just do not use vector index. No alogrithms will be faster than brute force by PostgreSQL.
* Set `vectors.enable_prefilter` to `on`. If you cannot estimate how many rows will satisfy the condition in `WHERE`, leave the job for the index. The index will check if the returned row can be accepted by PostgreSQL. However, it will make queries slower so the default value for this option is `off`. * Set `vectors.enable_prefilter` to `on`. If you cannot estimate how many rows will satisfy the condition in `WHERE`, leave the job for the index. The index will check if the returned row can be accepted by PostgreSQL. However, it will make queries slower so the default value for this option is `off`.
* Set `vectors.enable_vbase` to `on`. It will use vbase optimization, so that the index will pull rows as many as you need. It only works for HNSW algorithm.
## Options ## Options
@@ -30,6 +31,6 @@ Search options are specified by PostgreSQL GUC. You can use `SET` command to app
| vectors.k | integer | Expected number of candidates returned by index. The parameter will influence the recall if you use HNSW or quantization for indexing. Default value is `64`. | | vectors.k | integer | Expected number of candidates returned by index. The parameter will influence the recall if you use HNSW or quantization for indexing. Default value is `64`. |
| vectors.enable_prefilter | boolean | Enable prefiltering or not. Default value is `off`. | | vectors.enable_prefilter | boolean | Enable prefiltering or not. Default value is `off`. |
| vectors.enable_vector_index | boolean | Enable vector indexes or not. This option is for debugging. Default value is `on`. | | vectors.enable_vector_index | boolean | Enable vector indexes or not. This option is for debugging. Default value is `on`. |
| vectors.vbase_range | int4 | The range size when using vbase optimization. When it is set to `0`, vbase optimization will be disabled. A recommended value is `86`. Default value is `0`. | | vectors.enable_vbase | boolean | Enable vbase optimization. Default value is `off`. |
Note: When `vectors.vbase_range` is enabled, it will ignore `vectors.enable_prefilter`. Note: When `vectors.enable_vbase` is enabled, prefilter does not work.

View File

@@ -6,10 +6,14 @@ if [ "$OS" == "ubuntu-latest" ]; then
sudo pg_dropcluster 14 main sudo pg_dropcluster 14 main
fi fi
sudo apt-get remove -y '^postgres.*' '^libpq.*' '^clang.*' '^llvm.*' '^libclang.*' '^libllvm.*' '^mono-llvm.*' sudo apt-get remove -y '^postgres.*' '^libpq.*' '^clang.*' '^llvm.*' '^libclang.*' '^libllvm.*' '^mono-llvm.*'
sudo sh -c 'echo "deb http://apt.postgresql.org/pub/repos/apt $(lsb_release -cs)-pgdg main" > /etc/apt/sources.list.d/pgdg.list' sudo sh -c 'echo "deb http://apt.postgresql.org/pub/repos/apt $(lsb_release -cs)-pgdg main" >> /etc/apt/sources.list.d/pgdg.list'
sudo sh -c 'echo "deb http://apt.llvm.org/$(lsb_release -cs)/ llvm-toolchain-$(lsb_release -cs)-16 main" >> /etc/apt/sources.list'
wget --quiet -O - https://www.postgresql.org/media/keys/ACCC4CF8.asc | sudo apt-key add - wget --quiet -O - https://www.postgresql.org/media/keys/ACCC4CF8.asc | sudo apt-key add -
wget --quiet -O - https://apt.llvm.org/llvm-snapshot.gpg.key | sudo apt-key add -
sudo apt-get update sudo apt-get update
sudo apt-get -y install build-essential libpq-dev postgresql-$VERSION postgresql-server-dev-$VERSION sudo apt-get -y install build-essential libpq-dev postgresql-$VERSION postgresql-server-dev-$VERSION
sudo apt-get -y install clang-16
sudo apt-get -y install crossbuild-essential-arm64
echo "local all all trust" | sudo tee /etc/postgresql/$VERSION/main/pg_hba.conf echo "local all all trust" | sudo tee /etc/postgresql/$VERSION/main/pg_hba.conf
echo "host all all 127.0.0.1/32 trust" | sudo tee -a /etc/postgresql/$VERSION/main/pg_hba.conf echo "host all all 127.0.0.1/32 trust" | sudo tee -a /etc/postgresql/$VERSION/main/pg_hba.conf
echo "host all all ::1/128 trust" | sudo tee -a /etc/postgresql/$VERSION/main/pg_hba.conf echo "host all all ::1/128 trust" | sudo tee -a /etc/postgresql/$VERSION/main/pg_hba.conf

View File

@@ -1 +0,0 @@
pub mod vamana;

View File

@@ -1,11 +1,26 @@
pub mod worker;
use self::worker::Worker;
use crate::ipc::server::RpcHandler; use crate::ipc::server::RpcHandler;
use crate::ipc::IpcError; use crate::ipc::IpcError;
use service::worker::Worker;
use std::path::{Path, PathBuf}; use std::path::{Path, PathBuf};
use std::sync::Arc; use std::sync::Arc;
pub unsafe fn init() {
use pgrx::bgworkers::BackgroundWorkerBuilder;
use pgrx::bgworkers::BgWorkerStartTime;
BackgroundWorkerBuilder::new("vectors")
.set_function("vectors_main")
.set_library("vectors")
.set_argument(None)
.enable_shmem_access(None)
.set_start_time(BgWorkerStartTime::PostmasterStart)
.load();
}
#[no_mangle]
extern "C" fn vectors_main(_arg: pgrx::pg_sys::Datum) {
let _ = std::panic::catch_unwind(crate::bgworker::main);
}
pub fn main() { pub fn main() {
{ {
let mut builder = env_logger::builder(); let mut builder = env_logger::builder();
@@ -109,10 +124,6 @@ fn session(worker: Arc<Worker>, mut handler: RpcHandler) -> Result<(), IpcError>
handler = x.leave(res)?; handler = x.leave(res)?;
} }
} }
RpcHandle::SearchVbase { id, search, mut x } => {
let res = worker.call_search_vbase(id, search, |p| x.next(p).unwrap());
handler = x.leave(res)?;
}
RpcHandle::Flush { id, x } => { RpcHandle::Flush { id, x } => {
let result = worker.call_flush(id); let result = worker.call_flush(id);
handler = x.leave(result)?; handler = x.leave(result)?;
@@ -125,11 +136,36 @@ fn session(worker: Arc<Worker>, mut handler: RpcHandler) -> Result<(), IpcError>
let result = worker.call_stat(id); let result = worker.call_stat(id);
handler = x.leave(result)?; handler = x.leave(result)?;
} }
RpcHandle::Leave {} => { RpcHandle::Vbase { id, vector, x } => {
log::debug!("Handle leave rpc."); use crate::ipc::server::VbaseHandle::*;
break; let instance = match worker.get_instance(id) {
Ok(x) => x,
Err(e) => {
x.error(Err(e))?;
break Ok(());
}
};
let view = instance.view();
let mut it = match view.vbase(vector) {
Ok(x) => x,
Err(e) => {
x.error(Err(e))?;
break Ok(());
}
};
let mut x = x.error(Ok(()))?;
loop {
match x.handle()? {
Next { x: y } => {
x = y.leave(it.next())?;
}
Leave { x } => {
handler = x;
break;
}
}
}
} }
} }
} }
Ok(())
} }

26
src/datatype/casts_f32.rs Normal file
View File

@@ -0,0 +1,26 @@
use crate::datatype::typmod::Typmod;
use crate::datatype::vecf32::{Vecf32, Vecf32Input, Vecf32Output};
use service::prelude::*;
#[pgrx::pg_extern(immutable, parallel_safe, strict)]
fn vecf32_cast_array_to_vector(
array: pgrx::Array<f32>,
typmod: i32,
_explicit: bool,
) -> Vecf32Output {
assert!(!array.is_empty());
assert!(array.len() <= 65535);
assert!(!array.contains_nulls());
let typmod = Typmod::parse_from_i32(typmod).unwrap();
let len = typmod.dims().unwrap_or(array.len() as u16);
let mut data = vec![F32::zero(); len as usize];
for (i, x) in array.iter().enumerate() {
data[i] = F32(x.unwrap_or(f32::NAN));
}
Vecf32::new_in_postgres(&data)
}
#[pgrx::pg_extern(immutable, parallel_safe, strict)]
fn vecf32_cast_vector_to_array(vector: Vecf32Input<'_>, _typmod: i32, _explicit: bool) -> Vec<f32> {
vector.data().iter().map(|x| x.to_f32()).collect()
}

6
src/datatype/mod.rs Normal file
View File

@@ -0,0 +1,6 @@
pub mod casts_f32;
pub mod operators_f16;
pub mod operators_f32;
pub mod typmod;
pub mod vecf16;
pub mod vecf32;

View File

@@ -1,53 +1,53 @@
use crate::postgres::datatype::{Vector, VectorInput, VectorOutput}; use crate::datatype::vecf16::{Vecf16, Vecf16Input, Vecf16Output};
use crate::prelude::*; use service::prelude::*;
use std::ops::Deref; use std::ops::Deref;
#[pgrx::pg_operator(immutable, parallel_safe, requires = ["vector"])] #[pgrx::pg_operator(immutable, parallel_safe, requires = ["vecf16"])]
#[pgrx::opname(+)] #[pgrx::opname(+)]
#[pgrx::commutator(+)] #[pgrx::commutator(+)]
fn operator_add(lhs: VectorInput<'_>, rhs: VectorInput<'_>) -> VectorOutput { fn vecf16_operator_add(lhs: Vecf16Input<'_>, rhs: Vecf16Input<'_>) -> Vecf16Output {
if lhs.len() != rhs.len() { if lhs.len() != rhs.len() {
FriendlyError::DifferentVectorDims { FriendlyError::Unmatched {
left_dimensions: lhs.len() as _, left_dimensions: lhs.len() as _,
right_dimensions: rhs.len() as _, right_dimensions: rhs.len() as _,
} }
.friendly(); .friendly();
} }
let n = lhs.len(); let n = lhs.len();
let mut v = Vector::new_zeroed(n); let mut v = vec![F16::zero(); n];
for i in 0..n { for i in 0..n {
v[i] = lhs[i] + rhs[i]; v[i] = lhs[i] + rhs[i];
} }
v.copy_into_postgres() Vecf16::new_in_postgres(&v)
} }
#[pgrx::pg_operator(immutable, parallel_safe, requires = ["vector"])] #[pgrx::pg_operator(immutable, parallel_safe, requires = ["vecf16"])]
#[pgrx::opname(-)] #[pgrx::opname(-)]
fn operator_minus(lhs: VectorInput<'_>, rhs: VectorInput<'_>) -> VectorOutput { fn vecf16_operator_minus(lhs: Vecf16Input<'_>, rhs: Vecf16Input<'_>) -> Vecf16Output {
if lhs.len() != rhs.len() { if lhs.len() != rhs.len() {
FriendlyError::DifferentVectorDims { FriendlyError::Unmatched {
left_dimensions: lhs.len() as _, left_dimensions: lhs.len() as _,
right_dimensions: rhs.len() as _, right_dimensions: rhs.len() as _,
} }
.friendly(); .friendly();
} }
let n = lhs.len(); let n = lhs.len();
let mut v = Vector::new_zeroed(n); let mut v = vec![F16::zero(); n];
for i in 0..n { for i in 0..n {
v[i] = lhs[i] - rhs[i]; v[i] = lhs[i] - rhs[i];
} }
v.copy_into_postgres() Vecf16::new_in_postgres(&v)
} }
#[pgrx::pg_operator(immutable, parallel_safe, requires = ["vector"])] #[pgrx::pg_operator(immutable, parallel_safe, requires = ["vecf16"])]
#[pgrx::opname(<)] #[pgrx::opname(<)]
#[pgrx::negator(>=)] #[pgrx::negator(>=)]
#[pgrx::commutator(>)] #[pgrx::commutator(>)]
#[pgrx::restrict(scalarltsel)] #[pgrx::restrict(scalarltsel)]
#[pgrx::join(scalarltjoinsel)] #[pgrx::join(scalarltjoinsel)]
fn operator_lt(lhs: VectorInput<'_>, rhs: VectorInput<'_>) -> bool { fn vecf16_operator_lt(lhs: Vecf16Input<'_>, rhs: Vecf16Input<'_>) -> bool {
if lhs.len() != rhs.len() { if lhs.len() != rhs.len() {
FriendlyError::DifferentVectorDims { FriendlyError::Unmatched {
left_dimensions: lhs.len() as _, left_dimensions: lhs.len() as _,
right_dimensions: rhs.len() as _, right_dimensions: rhs.len() as _,
} }
@@ -56,15 +56,15 @@ fn operator_lt(lhs: VectorInput<'_>, rhs: VectorInput<'_>) -> bool {
lhs.deref() < rhs.deref() lhs.deref() < rhs.deref()
} }
#[pgrx::pg_operator(immutable, parallel_safe, requires = ["vector"])] #[pgrx::pg_operator(immutable, parallel_safe, requires = ["vecf16"])]
#[pgrx::opname(<=)] #[pgrx::opname(<=)]
#[pgrx::negator(>)] #[pgrx::negator(>)]
#[pgrx::commutator(>=)] #[pgrx::commutator(>=)]
#[pgrx::restrict(scalarltsel)] #[pgrx::restrict(scalarltsel)]
#[pgrx::join(scalarltjoinsel)] #[pgrx::join(scalarltjoinsel)]
fn operator_lte(lhs: VectorInput<'_>, rhs: VectorInput<'_>) -> bool { fn vecf16_operator_lte(lhs: Vecf16Input<'_>, rhs: Vecf16Input<'_>) -> bool {
if lhs.len() != rhs.len() { if lhs.len() != rhs.len() {
FriendlyError::DifferentVectorDims { FriendlyError::Unmatched {
left_dimensions: lhs.len() as _, left_dimensions: lhs.len() as _,
right_dimensions: rhs.len() as _, right_dimensions: rhs.len() as _,
} }
@@ -73,15 +73,15 @@ fn operator_lte(lhs: VectorInput<'_>, rhs: VectorInput<'_>) -> bool {
lhs.deref() <= rhs.deref() lhs.deref() <= rhs.deref()
} }
#[pgrx::pg_operator(immutable, parallel_safe, requires = ["vector"])] #[pgrx::pg_operator(immutable, parallel_safe, requires = ["vecf16"])]
#[pgrx::opname(>)] #[pgrx::opname(>)]
#[pgrx::negator(<=)] #[pgrx::negator(<=)]
#[pgrx::commutator(<)] #[pgrx::commutator(<)]
#[pgrx::restrict(scalargtsel)] #[pgrx::restrict(scalargtsel)]
#[pgrx::join(scalargtjoinsel)] #[pgrx::join(scalargtjoinsel)]
fn operator_gt(lhs: VectorInput<'_>, rhs: VectorInput<'_>) -> bool { fn vecf16_operator_gt(lhs: Vecf16Input<'_>, rhs: Vecf16Input<'_>) -> bool {
if lhs.len() != rhs.len() { if lhs.len() != rhs.len() {
FriendlyError::DifferentVectorDims { FriendlyError::Unmatched {
left_dimensions: lhs.len() as _, left_dimensions: lhs.len() as _,
right_dimensions: rhs.len() as _, right_dimensions: rhs.len() as _,
} }
@@ -90,15 +90,15 @@ fn operator_gt(lhs: VectorInput<'_>, rhs: VectorInput<'_>) -> bool {
lhs.deref() > rhs.deref() lhs.deref() > rhs.deref()
} }
#[pgrx::pg_operator(immutable, parallel_safe, requires = ["vector"])] #[pgrx::pg_operator(immutable, parallel_safe, requires = ["vecf16"])]
#[pgrx::opname(>=)] #[pgrx::opname(>=)]
#[pgrx::negator(<)] #[pgrx::negator(<)]
#[pgrx::commutator(<=)] #[pgrx::commutator(<=)]
#[pgrx::restrict(scalargtsel)] #[pgrx::restrict(scalargtsel)]
#[pgrx::join(scalargtjoinsel)] #[pgrx::join(scalargtjoinsel)]
fn operator_gte(lhs: VectorInput<'_>, rhs: VectorInput<'_>) -> bool { fn vecf16_operator_gte(lhs: Vecf16Input<'_>, rhs: Vecf16Input<'_>) -> bool {
if lhs.len() != rhs.len() { if lhs.len() != rhs.len() {
FriendlyError::DifferentVectorDims { FriendlyError::Unmatched {
left_dimensions: lhs.len() as _, left_dimensions: lhs.len() as _,
right_dimensions: rhs.len() as _, right_dimensions: rhs.len() as _,
} }
@@ -107,15 +107,15 @@ fn operator_gte(lhs: VectorInput<'_>, rhs: VectorInput<'_>) -> bool {
lhs.deref() >= rhs.deref() lhs.deref() >= rhs.deref()
} }
#[pgrx::pg_operator(immutable, parallel_safe, requires = ["vector"])] #[pgrx::pg_operator(immutable, parallel_safe, requires = ["vecf16"])]
#[pgrx::opname(=)] #[pgrx::opname(=)]
#[pgrx::negator(<>)] #[pgrx::negator(<>)]
#[pgrx::commutator(=)] #[pgrx::commutator(=)]
#[pgrx::restrict(eqsel)] #[pgrx::restrict(eqsel)]
#[pgrx::join(eqjoinsel)] #[pgrx::join(eqjoinsel)]
fn operator_eq(lhs: VectorInput<'_>, rhs: VectorInput<'_>) -> bool { fn vecf16_operator_eq(lhs: Vecf16Input<'_>, rhs: Vecf16Input<'_>) -> bool {
if lhs.len() != rhs.len() { if lhs.len() != rhs.len() {
FriendlyError::DifferentVectorDims { FriendlyError::Unmatched {
left_dimensions: lhs.len() as _, left_dimensions: lhs.len() as _,
right_dimensions: rhs.len() as _, right_dimensions: rhs.len() as _,
} }
@@ -124,15 +124,15 @@ fn operator_eq(lhs: VectorInput<'_>, rhs: VectorInput<'_>) -> bool {
lhs.deref() == rhs.deref() lhs.deref() == rhs.deref()
} }
#[pgrx::pg_operator(immutable, parallel_safe, requires = ["vector"])] #[pgrx::pg_operator(immutable, parallel_safe, requires = ["vecf16"])]
#[pgrx::opname(<>)] #[pgrx::opname(<>)]
#[pgrx::negator(=)] #[pgrx::negator(=)]
#[pgrx::commutator(<>)] #[pgrx::commutator(<>)]
#[pgrx::restrict(eqsel)] #[pgrx::restrict(eqsel)]
#[pgrx::join(eqjoinsel)] #[pgrx::join(eqjoinsel)]
fn operator_neq(lhs: VectorInput<'_>, rhs: VectorInput<'_>) -> bool { fn vecf16_operator_neq(lhs: Vecf16Input<'_>, rhs: Vecf16Input<'_>) -> bool {
if lhs.len() != rhs.len() { if lhs.len() != rhs.len() {
FriendlyError::DifferentVectorDims { FriendlyError::Unmatched {
left_dimensions: lhs.len() as _, left_dimensions: lhs.len() as _,
right_dimensions: rhs.len() as _, right_dimensions: rhs.len() as _,
} }
@@ -141,44 +141,44 @@ fn operator_neq(lhs: VectorInput<'_>, rhs: VectorInput<'_>) -> bool {
lhs.deref() != rhs.deref() lhs.deref() != rhs.deref()
} }
#[pgrx::pg_operator(immutable, parallel_safe, requires = ["vector"])] #[pgrx::pg_operator(immutable, parallel_safe, requires = ["vecf16"])]
#[pgrx::opname(<=>)] #[pgrx::opname(<=>)]
#[pgrx::commutator(<=>)] #[pgrx::commutator(<=>)]
fn operator_cosine(lhs: VectorInput<'_>, rhs: VectorInput<'_>) -> Scalar { fn vecf16_operator_cosine(lhs: Vecf16Input<'_>, rhs: Vecf16Input<'_>) -> f32 {
if lhs.len() != rhs.len() { if lhs.len() != rhs.len() {
FriendlyError::DifferentVectorDims { FriendlyError::Unmatched {
left_dimensions: lhs.len() as _, left_dimensions: lhs.len() as _,
right_dimensions: rhs.len() as _, right_dimensions: rhs.len() as _,
} }
.friendly(); .friendly();
} }
Distance::Cosine.distance(&lhs, &rhs) F16Cos::distance(&lhs, &rhs).to_f32()
} }
#[pgrx::pg_operator(immutable, parallel_safe, requires = ["vector"])] #[pgrx::pg_operator(immutable, parallel_safe, requires = ["vecf16"])]
#[pgrx::opname(<#>)] #[pgrx::opname(<#>)]
#[pgrx::commutator(<#>)] #[pgrx::commutator(<#>)]
fn operator_dot(lhs: VectorInput<'_>, rhs: VectorInput<'_>) -> Scalar { fn vecf16_operator_dot(lhs: Vecf16Input<'_>, rhs: Vecf16Input<'_>) -> f32 {
if lhs.len() != rhs.len() { if lhs.len() != rhs.len() {
FriendlyError::DifferentVectorDims { FriendlyError::Unmatched {
left_dimensions: lhs.len() as _, left_dimensions: lhs.len() as _,
right_dimensions: rhs.len() as _, right_dimensions: rhs.len() as _,
} }
.friendly(); .friendly();
} }
Distance::Dot.distance(&lhs, &rhs) F16Dot::distance(&lhs, &rhs).to_f32()
} }
#[pgrx::pg_operator(immutable, parallel_safe, requires = ["vector"])] #[pgrx::pg_operator(immutable, parallel_safe, requires = ["vecf16"])]
#[pgrx::opname(<->)] #[pgrx::opname(<->)]
#[pgrx::commutator(<->)] #[pgrx::commutator(<->)]
fn operator_l2(lhs: VectorInput<'_>, rhs: VectorInput<'_>) -> Scalar { fn vecf16_operator_l2(lhs: Vecf16Input<'_>, rhs: Vecf16Input<'_>) -> f32 {
if lhs.len() != rhs.len() { if lhs.len() != rhs.len() {
FriendlyError::DifferentVectorDims { FriendlyError::Unmatched {
left_dimensions: lhs.len() as _, left_dimensions: lhs.len() as _,
right_dimensions: rhs.len() as _, right_dimensions: rhs.len() as _,
} }
.friendly(); .friendly();
} }
Distance::L2.distance(&lhs, &rhs) F16L2::distance(&lhs, &rhs).to_f32()
} }

View File

@@ -0,0 +1,184 @@
use crate::datatype::vecf32::{Vecf32, Vecf32Input, Vecf32Output};
use service::prelude::*;
use std::ops::Deref;
#[pgrx::pg_operator(immutable, parallel_safe, requires = ["vecf32"])]
#[pgrx::opname(+)]
#[pgrx::commutator(+)]
fn vecf32_operator_add(lhs: Vecf32Input<'_>, rhs: Vecf32Input<'_>) -> Vecf32Output {
if lhs.len() != rhs.len() {
FriendlyError::Unmatched {
left_dimensions: lhs.len() as _,
right_dimensions: rhs.len() as _,
}
.friendly();
}
let n = lhs.len();
let mut v = vec![F32::zero(); n];
for i in 0..n {
v[i] = lhs[i] + rhs[i];
}
Vecf32::new_in_postgres(&v)
}
#[pgrx::pg_operator(immutable, parallel_safe, requires = ["vecf32"])]
#[pgrx::opname(-)]
fn vecf32_operator_minus(lhs: Vecf32Input<'_>, rhs: Vecf32Input<'_>) -> Vecf32Output {
if lhs.len() != rhs.len() {
FriendlyError::Unmatched {
left_dimensions: lhs.len() as _,
right_dimensions: rhs.len() as _,
}
.friendly();
}
let n = lhs.len();
let mut v = vec![F32::zero(); n];
for i in 0..n {
v[i] = lhs[i] - rhs[i];
}
Vecf32::new_in_postgres(&v)
}
#[pgrx::pg_operator(immutable, parallel_safe, requires = ["vecf32"])]
#[pgrx::opname(<)]
#[pgrx::negator(>=)]
#[pgrx::commutator(>)]
#[pgrx::restrict(scalarltsel)]
#[pgrx::join(scalarltjoinsel)]
fn vecf32_operator_lt(lhs: Vecf32Input<'_>, rhs: Vecf32Input<'_>) -> bool {
if lhs.len() != rhs.len() {
FriendlyError::Unmatched {
left_dimensions: lhs.len() as _,
right_dimensions: rhs.len() as _,
}
.friendly();
}
lhs.deref() < rhs.deref()
}
#[pgrx::pg_operator(immutable, parallel_safe, requires = ["vecf32"])]
#[pgrx::opname(<=)]
#[pgrx::negator(>)]
#[pgrx::commutator(>=)]
#[pgrx::restrict(scalarltsel)]
#[pgrx::join(scalarltjoinsel)]
fn vecf32_operator_lte(lhs: Vecf32Input<'_>, rhs: Vecf32Input<'_>) -> bool {
if lhs.len() != rhs.len() {
FriendlyError::Unmatched {
left_dimensions: lhs.len() as _,
right_dimensions: rhs.len() as _,
}
.friendly();
}
lhs.deref() <= rhs.deref()
}
#[pgrx::pg_operator(immutable, parallel_safe, requires = ["vecf32"])]
#[pgrx::opname(>)]
#[pgrx::negator(<=)]
#[pgrx::commutator(<)]
#[pgrx::restrict(scalargtsel)]
#[pgrx::join(scalargtjoinsel)]
fn vecf32_operator_gt(lhs: Vecf32Input<'_>, rhs: Vecf32Input<'_>) -> bool {
if lhs.len() != rhs.len() {
FriendlyError::Unmatched {
left_dimensions: lhs.len() as _,
right_dimensions: rhs.len() as _,
}
.friendly();
}
lhs.deref() > rhs.deref()
}
#[pgrx::pg_operator(immutable, parallel_safe, requires = ["vecf32"])]
#[pgrx::opname(>=)]
#[pgrx::negator(<)]
#[pgrx::commutator(<=)]
#[pgrx::restrict(scalargtsel)]
#[pgrx::join(scalargtjoinsel)]
fn vecf32_operator_gte(lhs: Vecf32Input<'_>, rhs: Vecf32Input<'_>) -> bool {
if lhs.len() != rhs.len() {
FriendlyError::Unmatched {
left_dimensions: lhs.len() as _,
right_dimensions: rhs.len() as _,
}
.friendly();
}
lhs.deref() >= rhs.deref()
}
#[pgrx::pg_operator(immutable, parallel_safe, requires = ["vecf32"])]
#[pgrx::opname(=)]
#[pgrx::negator(<>)]
#[pgrx::commutator(=)]
#[pgrx::restrict(eqsel)]
#[pgrx::join(eqjoinsel)]
fn vecf32_operator_eq(lhs: Vecf32Input<'_>, rhs: Vecf32Input<'_>) -> bool {
if lhs.len() != rhs.len() {
FriendlyError::Unmatched {
left_dimensions: lhs.len() as _,
right_dimensions: rhs.len() as _,
}
.friendly();
}
lhs.deref() == rhs.deref()
}
#[pgrx::pg_operator(immutable, parallel_safe, requires = ["vecf32"])]
#[pgrx::opname(<>)]
#[pgrx::negator(=)]
#[pgrx::commutator(<>)]
#[pgrx::restrict(eqsel)]
#[pgrx::join(eqjoinsel)]
fn vecf32_operator_neq(lhs: Vecf32Input<'_>, rhs: Vecf32Input<'_>) -> bool {
if lhs.len() != rhs.len() {
FriendlyError::Unmatched {
left_dimensions: lhs.len() as _,
right_dimensions: rhs.len() as _,
}
.friendly();
}
lhs.deref() != rhs.deref()
}
#[pgrx::pg_operator(immutable, parallel_safe, requires = ["vecf32"])]
#[pgrx::opname(<=>)]
#[pgrx::commutator(<=>)]
fn vecf32_operator_cosine(lhs: Vecf32Input<'_>, rhs: Vecf32Input<'_>) -> f32 {
if lhs.len() != rhs.len() {
FriendlyError::Unmatched {
left_dimensions: lhs.len() as _,
right_dimensions: rhs.len() as _,
}
.friendly();
}
F32Cos::distance(&lhs, &rhs).to_f32()
}
#[pgrx::pg_operator(immutable, parallel_safe, requires = ["vecf32"])]
#[pgrx::opname(<#>)]
#[pgrx::commutator(<#>)]
fn vecf32_operator_dot(lhs: Vecf32Input<'_>, rhs: Vecf32Input<'_>) -> f32 {
if lhs.len() != rhs.len() {
FriendlyError::Unmatched {
left_dimensions: lhs.len() as _,
right_dimensions: rhs.len() as _,
}
.friendly();
}
F32Dot::distance(&lhs, &rhs).to_f32()
}
#[pgrx::pg_operator(immutable, parallel_safe, requires = ["vecf32"])]
#[pgrx::opname(<->)]
#[pgrx::commutator(<->)]
fn vecf32_operator_l2(lhs: Vecf32Input<'_>, rhs: Vecf32Input<'_>) -> f32 {
if lhs.len() != rhs.len() {
FriendlyError::Unmatched {
left_dimensions: lhs.len() as _,
right_dimensions: rhs.len() as _,
}
.friendly();
}
F32L2::distance(&lhs, &rhs).to_f32()
}

77
src/datatype/typmod.rs Normal file
View File

@@ -0,0 +1,77 @@
use pgrx::Array;
use serde::{Deserialize, Serialize};
use service::prelude::*;
use std::ffi::{CStr, CString};
use std::num::NonZeroU16;
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub enum Typmod {
Any,
Dims(NonZeroU16),
}
impl Typmod {
pub fn parse_from_str(s: &str) -> Option<Self> {
use Typmod::*;
if let Ok(x) = s.parse::<NonZeroU16>() {
Some(Dims(x))
} else {
None
}
}
pub fn parse_from_i32(x: i32) -> Option<Self> {
use Typmod::*;
if x == -1 {
Some(Any)
} else if 1 <= x && x <= u16::MAX as i32 {
Some(Dims(NonZeroU16::new(x as u16).unwrap()))
} else {
None
}
}
pub fn into_option_string(self) -> Option<String> {
use Typmod::*;
match self {
Any => None,
Dims(x) => Some(i32::from(x.get()).to_string()),
}
}
pub fn into_i32(self) -> i32 {
use Typmod::*;
match self {
Any => -1,
Dims(x) => i32::from(x.get()),
}
}
pub fn dims(self) -> Option<u16> {
use Typmod::*;
match self {
Any => None,
Dims(dims) => Some(dims.get()),
}
}
}
#[pgrx::pg_extern(immutable, parallel_safe, strict)]
fn typmod_in(list: Array<&CStr>) -> i32 {
if list.is_empty() {
-1
} else if list.len() == 1 {
let s = list.get(0).unwrap().unwrap().to_str().unwrap();
let typmod = Typmod::parse_from_str(s)
.ok_or(FriendlyError::BadTypeDimensions)
.friendly();
typmod.into_i32()
} else {
FriendlyError::BadTypeDimensions.friendly();
}
}
#[pgrx::pg_extern(immutable, parallel_safe, strict)]
fn typmod_out(typmod: i32) -> CString {
let typmod = Typmod::parse_from_i32(typmod).unwrap();
match typmod.into_option_string() {
Some(s) => CString::new(format!("({})", s)).unwrap(),
None => CString::new("()").unwrap(),
}
}

343
src/datatype/vecf16.rs Normal file
View File

@@ -0,0 +1,343 @@
use crate::datatype::typmod::Typmod;
use pgrx::pg_sys::Datum;
use pgrx::pg_sys::Oid;
use pgrx::pgrx_sql_entity_graph::metadata::ArgumentError;
use pgrx::pgrx_sql_entity_graph::metadata::Returns;
use pgrx::pgrx_sql_entity_graph::metadata::ReturnsError;
use pgrx::pgrx_sql_entity_graph::metadata::SqlMapping;
use pgrx::pgrx_sql_entity_graph::metadata::SqlTranslatable;
use pgrx::FromDatum;
use pgrx::IntoDatum;
use service::prelude::*;
use std::alloc::Layout;
use std::cmp::Ordering;
use std::ffi::CStr;
use std::ffi::CString;
use std::ops::Deref;
use std::ops::DerefMut;
use std::ops::Index;
use std::ops::IndexMut;
use std::ptr::NonNull;
pgrx::extension_sql!(
r#"
CREATE TYPE vecf16 (
INPUT = vecf16_in,
OUTPUT = vecf16_out,
TYPMOD_IN = typmod_in,
TYPMOD_OUT = typmod_out,
STORAGE = EXTENDED,
INTERNALLENGTH = VARIABLE,
ALIGNMENT = double
);
"#,
name = "vecf16",
creates = [Type(Vecf16)],
requires = [vecf16_in, vecf16_out, typmod_in, typmod_out],
);
#[repr(C, align(8))]
pub struct Vecf16 {
varlena: u32,
len: u16,
kind: u8,
reserved: u8,
phantom: [F16; 0],
}
impl Vecf16 {
fn varlena(size: usize) -> u32 {
(size << 2) as u32
}
fn layout(len: usize) -> Layout {
u16::try_from(len).expect("Vector is too large.");
let layout_alpha = Layout::new::<Vecf16>();
let layout_beta = Layout::array::<F16>(len).unwrap();
let layout = layout_alpha.extend(layout_beta).unwrap().0;
layout.pad_to_align()
}
pub fn new_in_postgres(slice: &[F16]) -> Vecf16Output {
unsafe {
assert!(u16::try_from(slice.len()).is_ok());
let layout = Vecf16::layout(slice.len());
let ptr = pgrx::pg_sys::palloc(layout.size()) as *mut Vecf16;
ptr.cast::<u8>().add(layout.size() - 8).write_bytes(0, 8);
std::ptr::addr_of_mut!((*ptr).varlena).write(Vecf16::varlena(layout.size()));
std::ptr::addr_of_mut!((*ptr).kind).write(1);
std::ptr::addr_of_mut!((*ptr).reserved).write(0);
std::ptr::addr_of_mut!((*ptr).len).write(slice.len() as u16);
std::ptr::copy_nonoverlapping(slice.as_ptr(), (*ptr).phantom.as_mut_ptr(), slice.len());
Vecf16Output(NonNull::new(ptr).unwrap())
}
}
pub fn len(&self) -> usize {
self.len as usize
}
pub fn data(&self) -> &[F16] {
debug_assert_eq!(self.varlena & 3, 0);
debug_assert_eq!(self.kind, 1);
unsafe { std::slice::from_raw_parts(self.phantom.as_ptr(), self.len as usize) }
}
pub fn data_mut(&mut self) -> &mut [F16] {
debug_assert_eq!(self.varlena & 3, 0);
debug_assert_eq!(self.kind, 1);
unsafe { std::slice::from_raw_parts_mut(self.phantom.as_mut_ptr(), self.len as usize) }
}
}
impl Deref for Vecf16 {
type Target = [F16];
fn deref(&self) -> &Self::Target {
self.data()
}
}
impl DerefMut for Vecf16 {
fn deref_mut(&mut self) -> &mut Self::Target {
self.data_mut()
}
}
impl AsRef<[F16]> for Vecf16 {
fn as_ref(&self) -> &[F16] {
self.data()
}
}
impl AsMut<[F16]> for Vecf16 {
fn as_mut(&mut self) -> &mut [F16] {
self.data_mut()
}
}
impl Index<usize> for Vecf16 {
type Output = F16;
fn index(&self, index: usize) -> &Self::Output {
self.data().index(index)
}
}
impl IndexMut<usize> for Vecf16 {
fn index_mut(&mut self, index: usize) -> &mut Self::Output {
self.data_mut().index_mut(index)
}
}
impl PartialEq for Vecf16 {
fn eq(&self, other: &Self) -> bool {
if self.len() != other.len() {
return false;
}
let n = self.len();
for i in 0..n {
if self[i] != other[i] {
return false;
}
}
true
}
}
impl Eq for Vecf16 {}
impl PartialOrd for Vecf16 {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for Vecf16 {
fn cmp(&self, other: &Self) -> Ordering {
use Ordering::*;
if let x @ Less | x @ Greater = self.len().cmp(&other.len()) {
return x;
}
let n = self.len();
for i in 0..n {
if let x @ Less | x @ Greater = self[i].cmp(&other[i]) {
return x;
}
}
Equal
}
}
pub enum Vecf16Input<'a> {
Owned(Vecf16Output),
Borrowed(&'a Vecf16),
}
impl<'a> Vecf16Input<'a> {
pub unsafe fn new(p: NonNull<Vecf16>) -> Self {
let q = unsafe {
NonNull::new(pgrx::pg_sys::pg_detoast_datum(p.cast().as_ptr()).cast()).unwrap()
};
if p != q {
Vecf16Input::Owned(Vecf16Output(q))
} else {
unsafe { Vecf16Input::Borrowed(p.as_ref()) }
}
}
}
impl Deref for Vecf16Input<'_> {
type Target = Vecf16;
fn deref(&self) -> &Self::Target {
match self {
Vecf16Input::Owned(x) => x,
Vecf16Input::Borrowed(x) => x,
}
}
}
pub struct Vecf16Output(NonNull<Vecf16>);
impl Vecf16Output {
pub fn into_raw(self) -> *mut Vecf16 {
let result = self.0.as_ptr();
std::mem::forget(self);
result
}
}
impl Deref for Vecf16Output {
type Target = Vecf16;
fn deref(&self) -> &Self::Target {
unsafe { self.0.as_ref() }
}
}
impl DerefMut for Vecf16Output {
fn deref_mut(&mut self) -> &mut Self::Target {
unsafe { self.0.as_mut() }
}
}
impl Drop for Vecf16Output {
fn drop(&mut self) {
unsafe {
pgrx::pg_sys::pfree(self.0.as_ptr() as _);
}
}
}
impl<'a> FromDatum for Vecf16Input<'a> {
unsafe fn from_polymorphic_datum(datum: Datum, is_null: bool, _typoid: Oid) -> Option<Self> {
if is_null {
None
} else {
let ptr = NonNull::new(datum.cast_mut_ptr::<Vecf16>()).unwrap();
unsafe { Some(Vecf16Input::new(ptr)) }
}
}
}
impl IntoDatum for Vecf16Output {
fn into_datum(self) -> Option<Datum> {
Some(Datum::from(self.into_raw() as *mut ()))
}
fn type_oid() -> Oid {
pgrx::wrappers::regtypein("vecf16")
}
}
unsafe impl SqlTranslatable for Vecf16Input<'_> {
fn argument_sql() -> Result<SqlMapping, ArgumentError> {
Ok(SqlMapping::As(String::from("vecf16")))
}
fn return_sql() -> Result<Returns, ReturnsError> {
Ok(Returns::One(SqlMapping::As(String::from("vecf16"))))
}
}
unsafe impl SqlTranslatable for Vecf16Output {
fn argument_sql() -> Result<SqlMapping, ArgumentError> {
Ok(SqlMapping::As(String::from("vecf16")))
}
fn return_sql() -> Result<Returns, ReturnsError> {
Ok(Returns::One(SqlMapping::As(String::from("vecf16"))))
}
}
#[pgrx::pg_extern(immutable, parallel_safe, strict)]
fn vecf16_in(input: &CStr, _oid: Oid, typmod: i32) -> Vecf16Output {
fn solve<T>(option: Option<T>, hint: &str) -> T {
if let Some(x) = option {
x
} else {
FriendlyError::BadLiteral {
hint: hint.to_string(),
}
.friendly()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum State {
MatchingLeft,
Reading,
MatchedRight,
}
use State::*;
let input = input.to_bytes();
let typmod = Typmod::parse_from_i32(typmod).unwrap();
let mut vector = Vec::<F16>::with_capacity(typmod.dims().unwrap_or(0) as usize);
let mut state = MatchingLeft;
let mut token: Option<String> = None;
for &c in input {
match (state, c) {
(MatchingLeft, b'[') => {
state = Reading;
}
(Reading, b'0'..=b'9' | b'.' | b'e' | b'+' | b'-') => {
let token = token.get_or_insert(String::new());
token.push(char::from_u32(c as u32).unwrap());
}
(Reading, b',') => {
let token = solve(token.take(), "Expect a number.");
vector.push(solve(token.parse().ok(), "Bad number."));
}
(Reading, b']') => {
if let Some(token) = token.take() {
vector.push(solve(token.parse().ok(), "Bad number."));
}
state = MatchedRight;
}
(_, b' ') => {}
_ => {
FriendlyError::BadLiteral {
hint: format!("Bad charactor with ascii {:#x}.", c),
}
.friendly();
}
}
}
if state != MatchedRight {
FriendlyError::BadLiteral {
hint: "Bad sequence.".to_string(),
}
.friendly();
}
if vector.is_empty() || vector.len() > 65535 {
FriendlyError::BadValueDimensions.friendly();
}
Vecf16::new_in_postgres(&vector)
}
#[pgrx::pg_extern(immutable, parallel_safe, strict)]
fn vecf16_out(vector: Vecf16Input<'_>) -> CString {
let mut buffer = String::new();
buffer.push('[');
if let Some(&x) = vector.data().first() {
buffer.push_str(format!("{}", x).as_str());
}
for &x in vector.data().iter().skip(1) {
buffer.push_str(format!(", {}", x).as_str());
}
buffer.push(']');
CString::new(buffer).unwrap()
}

343
src/datatype/vecf32.rs Normal file
View File

@@ -0,0 +1,343 @@
use crate::datatype::typmod::Typmod;
use pgrx::pg_sys::Datum;
use pgrx::pg_sys::Oid;
use pgrx::pgrx_sql_entity_graph::metadata::ArgumentError;
use pgrx::pgrx_sql_entity_graph::metadata::Returns;
use pgrx::pgrx_sql_entity_graph::metadata::ReturnsError;
use pgrx::pgrx_sql_entity_graph::metadata::SqlMapping;
use pgrx::pgrx_sql_entity_graph::metadata::SqlTranslatable;
use pgrx::FromDatum;
use pgrx::IntoDatum;
use service::prelude::*;
use std::alloc::Layout;
use std::cmp::Ordering;
use std::ffi::CStr;
use std::ffi::CString;
use std::ops::Deref;
use std::ops::DerefMut;
use std::ops::Index;
use std::ops::IndexMut;
use std::ptr::NonNull;
pgrx::extension_sql!(
r#"
CREATE TYPE vector (
INPUT = vecf32_in,
OUTPUT = vecf32_out,
TYPMOD_IN = typmod_in,
TYPMOD_OUT = typmod_out,
STORAGE = EXTENDED,
INTERNALLENGTH = VARIABLE,
ALIGNMENT = double
);
"#,
name = "vecf32",
creates = [Type(Vecf32)],
requires = [vecf32_in, vecf32_out, typmod_in, typmod_out],
);
#[repr(C, align(8))]
pub struct Vecf32 {
varlena: u32,
len: u16,
kind: u8,
reserved: u8,
phantom: [F32; 0],
}
impl Vecf32 {
fn varlena(size: usize) -> u32 {
(size << 2) as u32
}
fn layout(len: usize) -> Layout {
u16::try_from(len).expect("Vector is too large.");
let layout_alpha = Layout::new::<Vecf32>();
let layout_beta = Layout::array::<F32>(len).unwrap();
let layout = layout_alpha.extend(layout_beta).unwrap().0;
layout.pad_to_align()
}
pub fn new_in_postgres(slice: &[F32]) -> Vecf32Output {
unsafe {
assert!(u16::try_from(slice.len()).is_ok());
let layout = Vecf32::layout(slice.len());
let ptr = pgrx::pg_sys::palloc(layout.size()) as *mut Vecf32;
ptr.cast::<u8>().add(layout.size() - 8).write_bytes(0, 8);
std::ptr::addr_of_mut!((*ptr).varlena).write(Vecf32::varlena(layout.size()));
std::ptr::addr_of_mut!((*ptr).kind).write(0);
std::ptr::addr_of_mut!((*ptr).reserved).write(0);
std::ptr::addr_of_mut!((*ptr).len).write(slice.len() as u16);
std::ptr::copy_nonoverlapping(slice.as_ptr(), (*ptr).phantom.as_mut_ptr(), slice.len());
Vecf32Output(NonNull::new(ptr).unwrap())
}
}
pub fn len(&self) -> usize {
self.len as usize
}
pub fn data(&self) -> &[F32] {
debug_assert_eq!(self.varlena & 3, 0);
debug_assert_eq!(self.kind, 0);
unsafe { std::slice::from_raw_parts(self.phantom.as_ptr(), self.len as usize) }
}
pub fn data_mut(&mut self) -> &mut [F32] {
debug_assert_eq!(self.varlena & 3, 0);
debug_assert_eq!(self.kind, 0);
unsafe { std::slice::from_raw_parts_mut(self.phantom.as_mut_ptr(), self.len as usize) }
}
}
impl Deref for Vecf32 {
type Target = [F32];
fn deref(&self) -> &Self::Target {
self.data()
}
}
impl DerefMut for Vecf32 {
fn deref_mut(&mut self) -> &mut Self::Target {
self.data_mut()
}
}
impl AsRef<[F32]> for Vecf32 {
fn as_ref(&self) -> &[F32] {
self.data()
}
}
impl AsMut<[F32]> for Vecf32 {
fn as_mut(&mut self) -> &mut [F32] {
self.data_mut()
}
}
impl Index<usize> for Vecf32 {
type Output = F32;
fn index(&self, index: usize) -> &Self::Output {
self.data().index(index)
}
}
impl IndexMut<usize> for Vecf32 {
fn index_mut(&mut self, index: usize) -> &mut Self::Output {
self.data_mut().index_mut(index)
}
}
impl PartialEq for Vecf32 {
fn eq(&self, other: &Self) -> bool {
if self.len() != other.len() {
return false;
}
let n = self.len();
for i in 0..n {
if self[i] != other[i] {
return false;
}
}
true
}
}
impl Eq for Vecf32 {}
impl PartialOrd for Vecf32 {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for Vecf32 {
fn cmp(&self, other: &Self) -> Ordering {
use Ordering::*;
if let x @ Less | x @ Greater = self.len().cmp(&other.len()) {
return x;
}
let n = self.len();
for i in 0..n {
if let x @ Less | x @ Greater = self[i].cmp(&other[i]) {
return x;
}
}
Equal
}
}
pub enum Vecf32Input<'a> {
Owned(Vecf32Output),
Borrowed(&'a Vecf32),
}
impl<'a> Vecf32Input<'a> {
pub unsafe fn new(p: NonNull<Vecf32>) -> Self {
let q = unsafe {
NonNull::new(pgrx::pg_sys::pg_detoast_datum(p.cast().as_ptr()).cast()).unwrap()
};
if p != q {
Vecf32Input::Owned(Vecf32Output(q))
} else {
unsafe { Vecf32Input::Borrowed(p.as_ref()) }
}
}
}
impl Deref for Vecf32Input<'_> {
type Target = Vecf32;
fn deref(&self) -> &Self::Target {
match self {
Vecf32Input::Owned(x) => x,
Vecf32Input::Borrowed(x) => x,
}
}
}
pub struct Vecf32Output(NonNull<Vecf32>);
impl Vecf32Output {
pub fn into_raw(self) -> *mut Vecf32 {
let result = self.0.as_ptr();
std::mem::forget(self);
result
}
}
impl Deref for Vecf32Output {
type Target = Vecf32;
fn deref(&self) -> &Self::Target {
unsafe { self.0.as_ref() }
}
}
impl DerefMut for Vecf32Output {
fn deref_mut(&mut self) -> &mut Self::Target {
unsafe { self.0.as_mut() }
}
}
impl Drop for Vecf32Output {
fn drop(&mut self) {
unsafe {
pgrx::pg_sys::pfree(self.0.as_ptr() as _);
}
}
}
impl<'a> FromDatum for Vecf32Input<'a> {
unsafe fn from_polymorphic_datum(datum: Datum, is_null: bool, _typoid: Oid) -> Option<Self> {
if is_null {
None
} else {
let ptr = NonNull::new(datum.cast_mut_ptr::<Vecf32>()).unwrap();
unsafe { Some(Vecf32Input::new(ptr)) }
}
}
}
impl IntoDatum for Vecf32Output {
fn into_datum(self) -> Option<Datum> {
Some(Datum::from(self.into_raw() as *mut ()))
}
fn type_oid() -> Oid {
pgrx::wrappers::regtypein("vector")
}
}
unsafe impl SqlTranslatable for Vecf32Input<'_> {
fn argument_sql() -> Result<SqlMapping, ArgumentError> {
Ok(SqlMapping::As(String::from("vector")))
}
fn return_sql() -> Result<Returns, ReturnsError> {
Ok(Returns::One(SqlMapping::As(String::from("vector"))))
}
}
unsafe impl SqlTranslatable for Vecf32Output {
fn argument_sql() -> Result<SqlMapping, ArgumentError> {
Ok(SqlMapping::As(String::from("vector")))
}
fn return_sql() -> Result<Returns, ReturnsError> {
Ok(Returns::One(SqlMapping::As(String::from("vector"))))
}
}
#[pgrx::pg_extern(immutable, parallel_safe, strict)]
fn vecf32_in(input: &CStr, _oid: Oid, typmod: i32) -> Vecf32Output {
fn solve<T>(option: Option<T>, hint: &str) -> T {
if let Some(x) = option {
x
} else {
FriendlyError::BadLiteral {
hint: hint.to_string(),
}
.friendly()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum State {
MatchingLeft,
Reading,
MatchedRight,
}
use State::*;
let input = input.to_bytes();
let typmod = Typmod::parse_from_i32(typmod).unwrap();
let mut vector = Vec::<F32>::with_capacity(typmod.dims().unwrap_or(0) as usize);
let mut state = MatchingLeft;
let mut token: Option<String> = None;
for &c in input {
match (state, c) {
(MatchingLeft, b'[') => {
state = Reading;
}
(Reading, b'0'..=b'9' | b'.' | b'e' | b'+' | b'-') => {
let token = token.get_or_insert(String::new());
token.push(char::from_u32(c as u32).unwrap());
}
(Reading, b',') => {
let token = solve(token.take(), "Expect a number.");
vector.push(solve(token.parse().ok(), "Bad number."));
}
(Reading, b']') => {
if let Some(token) = token.take() {
vector.push(solve(token.parse().ok(), "Bad number."));
}
state = MatchedRight;
}
(_, b' ') => {}
_ => {
FriendlyError::BadLiteral {
hint: format!("Bad charactor with ascii {:#x}.", c),
}
.friendly();
}
}
}
if state != MatchedRight {
FriendlyError::BadLiteral {
hint: "Bad sequence.".to_string(),
}
.friendly();
}
if vector.is_empty() || vector.len() > 65535 {
FriendlyError::BadValueDimensions.friendly();
}
Vecf32::new_in_postgres(&vector)
}
#[pgrx::pg_extern(immutable, parallel_safe, strict)]
fn vecf32_out(vector: Vecf32Input<'_>) -> CString {
let mut buffer = String::new();
buffer.push('[');
if let Some(&x) = vector.data().first() {
buffer.push_str(format!("{}", x).as_str());
}
for &x in vector.data().iter().skip(1) {
buffer.push_str(format!(", {}", x).as_str());
}
buffer.push(']');
CString::new(buffer).unwrap()
}

View File

@@ -1,14 +1,12 @@
use super::openai::{EmbeddingCreator, OpenAIEmbedding}; use super::openai::{EmbeddingCreator, OpenAIEmbedding};
use super::Embedding; use super::Embedding;
use crate::postgres::datatype::Vector; use crate::datatype::vecf32::{Vecf32, Vecf32Output};
use crate::postgres::datatype::VectorOutput; use crate::gucs::OPENAI_API_KEY_GUC;
use crate::postgres::gucs::OPENAI_API_KEY_GUC;
use crate::prelude::Float;
use crate::prelude::Scalar;
use pgrx::prelude::*; use pgrx::prelude::*;
use service::prelude::F32;
#[pg_extern] #[pg_extern]
fn ai_embedding_vector(input: String) -> VectorOutput { fn ai_embedding_vector(input: String) -> Vecf32Output {
let api_key = match OPENAI_API_KEY_GUC.get() { let api_key = match OPENAI_API_KEY_GUC.get() {
Some(key) => key Some(key) => key
.to_str() .to_str()
@@ -26,9 +24,9 @@ fn ai_embedding_vector(input: String) -> VectorOutput {
Ok(embedding) => { Ok(embedding) => {
let embedding = embedding let embedding = embedding
.into_iter() .into_iter()
.map(|x| Scalar(x as Float)) .map(|x| F32(x as f32))
.collect::<Vec<_>>(); .collect::<Vec<_>>();
Vector::new_in_postgres(&embedding) Vecf32::new_in_postgres(&embedding)
} }
Err(e) => { Err(e) => {
error!("{}", e) error!("{}", e)

View File

@@ -23,7 +23,7 @@ pub static ENABLE_VECTOR_INDEX: GucSetting<bool> = GucSetting::<bool>::new(true)
pub static ENABLE_PREFILTER: GucSetting<bool> = GucSetting::<bool>::new(false); pub static ENABLE_PREFILTER: GucSetting<bool> = GucSetting::<bool>::new(false);
pub static VBASE_RANGE: GucSetting<i32> = GucSetting::<i32>::new(0); pub static ENABLE_VBASE: GucSetting<bool> = GucSetting::<bool>::new(false);
pub static TRANSPORT: GucSetting<Transport> = GucSetting::<Transport>::new(Transport::default()); pub static TRANSPORT: GucSetting<Transport> = GucSetting::<Transport>::new(Transport::default());
@@ -62,13 +62,11 @@ pub unsafe fn init() {
GucContext::Userset, GucContext::Userset,
GucFlags::default(), GucFlags::default(),
); );
GucRegistry::define_int_guc( GucRegistry::define_bool_guc(
"vectors.vbase_range", "vectors.enable_vbase",
"The range of vbase.", "Whether to enable vbase.",
"The range of vbase.", "When enabled, it will use vbase for filtering.",
&VBASE_RANGE, &ENABLE_VBASE,
0,
u16::MAX as _,
GucContext::Userset, GucContext::Userset,
GucFlags::default(), GucFlags::default(),
); );

View File

@@ -1,11 +1,12 @@
use super::index_build; use super::am_build;
use super::index_scan; use super::am_scan;
use super::index_setup; use super::am_setup;
use super::index_update; use super::am_update;
use crate::postgres::datatype::VectorInput; use crate::gucs::ENABLE_VECTOR_INDEX;
use crate::postgres::gucs::ENABLE_VECTOR_INDEX; use crate::index::utils::from_datum;
use crate::prelude::*; use crate::prelude::*;
use crate::utils::cells::PgCell; use crate::utils::cells::PgCell;
use service::prelude::*;
static RELOPT_KIND: PgCell<pgrx::pg_sys::relopt_kind> = unsafe { PgCell::new(0) }; static RELOPT_KIND: PgCell<pgrx::pg_sys::relopt_kind> = unsafe { PgCell::new(0) };
@@ -28,9 +29,7 @@ pub unsafe fn init() {
#[pgrx::pg_extern(sql = " #[pgrx::pg_extern(sql = "
CREATE OR REPLACE FUNCTION vectors_amhandler(internal) RETURNS index_am_handler CREATE OR REPLACE FUNCTION vectors_amhandler(internal) RETURNS index_am_handler
PARALLEL SAFE IMMUTABLE STRICT LANGUAGE c AS 'MODULE_PATHNAME', '@FUNCTION_NAME@'; PARALLEL SAFE IMMUTABLE STRICT LANGUAGE c AS 'MODULE_PATHNAME', '@FUNCTION_NAME@';
CREATE ACCESS METHOD vectors TYPE INDEX HANDLER vectors_amhandler; ", requires = ["vecf32"])]
COMMENT ON ACCESS METHOD vectors IS 'pgvecto.rs index access method';
", requires = ["vector"])]
fn vectors_amhandler( fn vectors_amhandler(
_fcinfo: pgrx::pg_sys::FunctionCallInfo, _fcinfo: pgrx::pg_sys::FunctionCallInfo,
) -> pgrx::PgBox<pgrx::pg_sys::IndexAmRoutine> { ) -> pgrx::PgBox<pgrx::pg_sys::IndexAmRoutine> {
@@ -85,7 +84,7 @@ const AM_HANDLER: pgrx::pg_sys::IndexAmRoutine = {
#[pgrx::pg_guard] #[pgrx::pg_guard]
pub unsafe extern "C" fn amvalidate(opclass_oid: pgrx::pg_sys::Oid) -> bool { pub unsafe extern "C" fn amvalidate(opclass_oid: pgrx::pg_sys::Oid) -> bool {
index_setup::convert_opclass_to_distance(opclass_oid); am_setup::convert_opclass_to_distance(opclass_oid);
true true
} }
@@ -99,7 +98,7 @@ pub unsafe extern "C" fn amoptions(
let tab: &[pgrx::pg_sys::relopt_parse_elt] = &[pgrx::pg_sys::relopt_parse_elt { let tab: &[pgrx::pg_sys::relopt_parse_elt] = &[pgrx::pg_sys::relopt_parse_elt {
optname: "options".as_pg_cstr(), optname: "options".as_pg_cstr(),
opttype: pgrx::pg_sys::relopt_type_RELOPT_TYPE_STRING, opttype: pgrx::pg_sys::relopt_type_RELOPT_TYPE_STRING,
offset: index_setup::helper_offset() as i32, offset: am_setup::helper_offset() as i32,
}]; }];
let mut noptions = 0; let mut noptions = 0;
let options = let options =
@@ -111,10 +110,10 @@ pub unsafe extern "C" fn amoptions(
relopt.gen.as_mut().unwrap().lockmode = relopt.gen.as_mut().unwrap().lockmode =
pgrx::pg_sys::AccessExclusiveLock as pgrx::pg_sys::LOCKMODE; pgrx::pg_sys::AccessExclusiveLock as pgrx::pg_sys::LOCKMODE;
} }
let rdopts = pgrx::pg_sys::allocateReloptStruct(index_setup::helper_size(), options, noptions); let rdopts = pgrx::pg_sys::allocateReloptStruct(am_setup::helper_size(), options, noptions);
pgrx::pg_sys::fillRelOptions( pgrx::pg_sys::fillRelOptions(
rdopts, rdopts,
index_setup::helper_size(), am_setup::helper_size(),
options, options,
noptions, noptions,
validate, validate,
@@ -136,13 +135,13 @@ pub unsafe extern "C" fn amoptions(
let tab: &[pgrx::pg_sys::relopt_parse_elt] = &[pgrx::pg_sys::relopt_parse_elt { let tab: &[pgrx::pg_sys::relopt_parse_elt] = &[pgrx::pg_sys::relopt_parse_elt {
optname: "options".as_pg_cstr(), optname: "options".as_pg_cstr(),
opttype: pgrx::pg_sys::relopt_type_RELOPT_TYPE_STRING, opttype: pgrx::pg_sys::relopt_type_RELOPT_TYPE_STRING,
offset: index_setup::helper_offset() as i32, offset: am_setup::helper_offset() as i32,
}]; }];
let rdopts = pgrx::pg_sys::build_reloptions( let rdopts = pgrx::pg_sys::build_reloptions(
reloptions, reloptions,
validate, validate,
RELOPT_KIND.get(), RELOPT_KIND.get(),
index_setup::helper_size(), am_setup::helper_size(),
tab.as_ptr(), tab.as_ptr(),
tab.len() as _, tab.len() as _,
); );
@@ -182,7 +181,7 @@ pub unsafe extern "C" fn ambuild(
index_info: *mut pgrx::pg_sys::IndexInfo, index_info: *mut pgrx::pg_sys::IndexInfo,
) -> *mut pgrx::pg_sys::IndexBuildResult { ) -> *mut pgrx::pg_sys::IndexBuildResult {
let result = pgrx::PgBox::<pgrx::pg_sys::IndexBuildResult>::alloc0(); let result = pgrx::PgBox::<pgrx::pg_sys::IndexBuildResult>::alloc0();
index_build::build( am_build::build(
index_relation, index_relation,
Some((heap_relation, index_info, result.as_ptr())), Some((heap_relation, index_info, result.as_ptr())),
); );
@@ -191,7 +190,7 @@ pub unsafe extern "C" fn ambuild(
#[pgrx::pg_guard] #[pgrx::pg_guard]
pub unsafe extern "C" fn ambuildempty(index_relation: pgrx::pg_sys::Relation) { pub unsafe extern "C" fn ambuildempty(index_relation: pgrx::pg_sys::Relation) {
index_build::build(index_relation, None); am_build::build(index_relation, None);
} }
#[cfg(any(feature = "pg12", feature = "pg13"))] #[cfg(any(feature = "pg12", feature = "pg13"))]
@@ -199,18 +198,16 @@ pub unsafe extern "C" fn ambuildempty(index_relation: pgrx::pg_sys::Relation) {
pub unsafe extern "C" fn aminsert( pub unsafe extern "C" fn aminsert(
index_relation: pgrx::pg_sys::Relation, index_relation: pgrx::pg_sys::Relation,
values: *mut pgrx::pg_sys::Datum, values: *mut pgrx::pg_sys::Datum,
is_null: *mut bool, _is_null: *mut bool,
heap_tid: pgrx::pg_sys::ItemPointer, heap_tid: pgrx::pg_sys::ItemPointer,
_heap_relation: pgrx::pg_sys::Relation, _heap_relation: pgrx::pg_sys::Relation,
_check_unique: pgrx::pg_sys::IndexUniqueCheck, _check_unique: pgrx::pg_sys::IndexUniqueCheck,
_index_info: *mut pgrx::pg_sys::IndexInfo, _index_info: *mut pgrx::pg_sys::IndexInfo,
) -> bool { ) -> bool {
use pgrx::FromDatum;
let oid = (*index_relation).rd_node.relNode; let oid = (*index_relation).rd_node.relNode;
let id = Id::from_sys(oid); let id = Id::from_sys(oid);
let vector = VectorInput::from_datum(*values.add(0), *is_null.add(0)).unwrap(); let vector = from_datum(*values.add(0));
let vector = vector.data().to_vec(); am_update::update_insert(id, vector, *heap_tid);
index_update::update_insert(id, vector, *heap_tid);
true true
} }
@@ -219,22 +216,20 @@ pub unsafe extern "C" fn aminsert(
pub unsafe extern "C" fn aminsert( pub unsafe extern "C" fn aminsert(
index_relation: pgrx::pg_sys::Relation, index_relation: pgrx::pg_sys::Relation,
values: *mut pgrx::pg_sys::Datum, values: *mut pgrx::pg_sys::Datum,
is_null: *mut bool, _is_null: *mut bool,
heap_tid: pgrx::pg_sys::ItemPointer, heap_tid: pgrx::pg_sys::ItemPointer,
_heap_relation: pgrx::pg_sys::Relation, _heap_relation: pgrx::pg_sys::Relation,
_check_unique: pgrx::pg_sys::IndexUniqueCheck, _check_unique: pgrx::pg_sys::IndexUniqueCheck,
_index_unchanged: bool, _index_unchanged: bool,
_index_info: *mut pgrx::pg_sys::IndexInfo, _index_info: *mut pgrx::pg_sys::IndexInfo,
) -> bool { ) -> bool {
use pgrx::FromDatum;
#[cfg(any(feature = "pg14", feature = "pg15"))] #[cfg(any(feature = "pg14", feature = "pg15"))]
let oid = (*index_relation).rd_node.relNode; let oid = (*index_relation).rd_node.relNode;
#[cfg(feature = "pg16")] #[cfg(feature = "pg16")]
let oid = (*index_relation).rd_locator.relNumber; let oid = (*index_relation).rd_locator.relNumber;
let id = Id::from_sys(oid); let id = Id::from_sys(oid);
let vector = VectorInput::from_datum(*values.add(0), *is_null.add(0)).unwrap(); let vector = from_datum(*values.add(0));
let vector = vector.data().to_vec(); am_update::update_insert(id, vector, *heap_tid);
index_update::update_insert(id, vector, *heap_tid);
true true
} }
@@ -242,20 +237,26 @@ pub unsafe extern "C" fn aminsert(
pub unsafe extern "C" fn ambeginscan( pub unsafe extern "C" fn ambeginscan(
index_relation: pgrx::pg_sys::Relation, index_relation: pgrx::pg_sys::Relation,
n_keys: std::os::raw::c_int, n_keys: std::os::raw::c_int,
n_order_bys: std::os::raw::c_int, n_orderbys: std::os::raw::c_int,
) -> pgrx::pg_sys::IndexScanDesc { ) -> pgrx::pg_sys::IndexScanDesc {
index_scan::make_scan(index_relation, n_keys, n_order_bys) assert!(n_keys == 0);
assert!(n_orderbys == 1);
am_scan::make_scan(index_relation)
} }
#[pgrx::pg_guard] #[pgrx::pg_guard]
pub unsafe extern "C" fn amrescan( pub unsafe extern "C" fn amrescan(
scan: pgrx::pg_sys::IndexScanDesc, scan: pgrx::pg_sys::IndexScanDesc,
keys: pgrx::pg_sys::ScanKey, _keys: pgrx::pg_sys::ScanKey,
n_keys: std::os::raw::c_int, n_keys: std::os::raw::c_int,
orderbys: pgrx::pg_sys::ScanKey, orderbys: pgrx::pg_sys::ScanKey,
n_orderbys: std::os::raw::c_int, n_orderbys: std::os::raw::c_int,
) { ) {
index_scan::start_scan(scan, keys, n_keys, orderbys, n_orderbys); assert!((*scan).numberOfKeys == n_keys);
assert!((*scan).numberOfOrderBys == n_orderbys);
assert!(n_keys == 0);
assert!(n_orderbys == 1);
am_scan::start_scan(scan, orderbys);
} }
#[pgrx::pg_guard] #[pgrx::pg_guard]
@@ -264,12 +265,12 @@ pub unsafe extern "C" fn amgettuple(
direction: pgrx::pg_sys::ScanDirection, direction: pgrx::pg_sys::ScanDirection,
) -> bool { ) -> bool {
assert!(direction == pgrx::pg_sys::ScanDirection_ForwardScanDirection); assert!(direction == pgrx::pg_sys::ScanDirection_ForwardScanDirection);
index_scan::next_scan(scan) am_scan::next_scan(scan)
} }
#[pgrx::pg_guard] #[pgrx::pg_guard]
pub unsafe extern "C" fn amendscan(scan: pgrx::pg_sys::IndexScanDesc) { pub unsafe extern "C" fn amendscan(scan: pgrx::pg_sys::IndexScanDesc) {
index_scan::end_scan(scan); am_scan::end_scan(scan);
} }
#[pgrx::pg_guard] #[pgrx::pg_guard]
@@ -285,7 +286,7 @@ pub unsafe extern "C" fn ambulkdelete(
let oid = (*(*info).index).rd_locator.relNumber; let oid = (*(*info).index).rd_locator.relNumber;
let id = Id::from_sys(oid); let id = Id::from_sys(oid);
if let Some(callback) = callback { if let Some(callback) = callback {
index_update::update_delete(id, |pointer| { am_update::update_delete(id, |pointer| {
callback( callback(
&mut pointer.into_sys() as *mut pgrx::pg_sys::ItemPointerData, &mut pointer.into_sys() as *mut pgrx::pg_sys::ItemPointerData,
callback_state, callback_state,

View File

@@ -1,11 +1,13 @@
use super::hook_transaction::{client, flush_if_commit}; use super::hook_transaction::flush_if_commit;
use crate::ipc::client::Rpc; use crate::index::utils::from_datum;
use crate::postgres::index_setup::options; use crate::ipc::client::ClientGuard;
use crate::prelude::*; use crate::prelude::*;
use crate::{index::am_setup::options, ipc::client::Rpc};
use pgrx::pg_sys::{IndexBuildResult, IndexInfo, RelationData}; use pgrx::pg_sys::{IndexBuildResult, IndexInfo, RelationData};
use service::prelude::*;
pub struct Builder { pub struct Builder {
pub rpc: Rpc, pub rpc: ClientGuard<Rpc>,
pub heap_relation: *mut RelationData, pub heap_relation: *mut RelationData,
pub index_info: *mut IndexInfo, pub index_info: *mut IndexInfo,
pub result: *mut IndexBuildResult, pub result: *mut IndexBuildResult,
@@ -22,27 +24,22 @@ pub unsafe fn build(
let id = Id::from_sys(oid); let id = Id::from_sys(oid);
flush_if_commit(id); flush_if_commit(id);
let options = options(index); let options = options(index);
client(|mut rpc| { let mut rpc = crate::ipc::client::borrow_mut();
rpc.create(id, options).friendly(); rpc.create(id, options);
rpc
});
if let Some((heap_relation, index_info, result)) = data { if let Some((heap_relation, index_info, result)) = data {
client(|rpc| { let mut builder = Builder {
let mut builder = Builder { rpc,
rpc, heap_relation,
heap_relation, index_info,
index_info, result,
result, };
}; pgrx::pg_sys::IndexBuildHeapScan(
pgrx::pg_sys::IndexBuildHeapScan( heap_relation,
heap_relation, index,
index, index_info,
index_info, Some(callback),
Some(callback), &mut builder,
&mut builder, );
);
builder.rpc
});
} }
} }
@@ -52,21 +49,17 @@ unsafe extern "C" fn callback(
index_relation: pgrx::pg_sys::Relation, index_relation: pgrx::pg_sys::Relation,
htup: pgrx::pg_sys::HeapTuple, htup: pgrx::pg_sys::HeapTuple,
values: *mut pgrx::pg_sys::Datum, values: *mut pgrx::pg_sys::Datum,
is_null: *mut bool, _is_null: *mut bool,
_tuple_is_alive: bool, _tuple_is_alive: bool,
state: *mut std::os::raw::c_void, state: *mut std::os::raw::c_void,
) { ) {
use super::datatype::VectorInput;
use pgrx::FromDatum;
let ctid = &(*htup).t_self; let ctid = &(*htup).t_self;
let oid = (*index_relation).rd_node.relNode; let oid = (*index_relation).rd_node.relNode;
let id = Id::from_sys(oid); let id = Id::from_sys(oid);
let state = &mut *(state as *mut Builder); let state = &mut *(state as *mut Builder);
let pgvector = VectorInput::from_datum(*values.add(0), *is_null.add(0)).unwrap(); let vector = from_datum(*values.add(0));
let data = (pgvector.to_vec(), Pointer::from_sys(*ctid)); let data = (vector, Pointer::from_sys(*ctid));
state.rpc.insert(id, data).friendly().friendly(); state.rpc.insert(id, data);
(*state.result).heap_tuples += 1.0; (*state.result).heap_tuples += 1.0;
(*state.result).index_tuples += 1.0; (*state.result).index_tuples += 1.0;
} }
@@ -77,22 +70,19 @@ unsafe extern "C" fn callback(
index_relation: pgrx::pg_sys::Relation, index_relation: pgrx::pg_sys::Relation,
ctid: pgrx::pg_sys::ItemPointer, ctid: pgrx::pg_sys::ItemPointer,
values: *mut pgrx::pg_sys::Datum, values: *mut pgrx::pg_sys::Datum,
is_null: *mut bool, _is_null: *mut bool,
_tuple_is_alive: bool, _tuple_is_alive: bool,
state: *mut std::os::raw::c_void, state: *mut std::os::raw::c_void,
) { ) {
use super::datatype::VectorInput;
use pgrx::FromDatum;
#[cfg(any(feature = "pg13", feature = "pg14", feature = "pg15"))] #[cfg(any(feature = "pg13", feature = "pg14", feature = "pg15"))]
let oid = (*index_relation).rd_node.relNode; let oid = (*index_relation).rd_node.relNode;
#[cfg(feature = "pg16")] #[cfg(feature = "pg16")]
let oid = (*index_relation).rd_locator.relNumber; let oid = (*index_relation).rd_locator.relNumber;
let id = Id::from_sys(oid); let id = Id::from_sys(oid);
let state = &mut *(state as *mut Builder); let state = &mut *(state as *mut Builder);
let pgvector = VectorInput::from_datum(*values.add(0), *is_null.add(0)).unwrap(); let vector = from_datum(*values.add(0));
let data = (pgvector.to_vec(), Pointer::from_sys(*ctid)); let data = (vector, Pointer::from_sys(*ctid));
state.rpc.insert(id, data).friendly().friendly(); state.rpc.insert(id, data);
(*state.result).heap_tuples += 1.0; (*state.result).heap_tuples += 1.0;
(*state.result).index_tuples += 1.0; (*state.result).index_tuples += 1.0;
} }

234
src/index/am_scan.rs Normal file
View File

@@ -0,0 +1,234 @@
use crate::gucs::ENABLE_PREFILTER;
use crate::gucs::ENABLE_VBASE;
use crate::gucs::K;
use crate::index::utils::from_datum;
use crate::ipc::client::ClientGuard;
use crate::ipc::client::Vbase;
use crate::prelude::*;
use pgrx::FromDatum;
use service::prelude::*;
pub enum Scanner {
Initial {
node: Option<*mut pgrx::pg_sys::IndexScanState>,
vector: Option<DynamicVector>,
},
Search {
node: *mut pgrx::pg_sys::IndexScanState,
data: Vec<Pointer>,
},
Vbase {
node: *mut pgrx::pg_sys::IndexScanState,
vbase: ClientGuard<Vbase>,
},
}
impl Scanner {
fn node(&self) -> Option<*mut pgrx::pg_sys::IndexScanState> {
match self {
Scanner::Initial { node, .. } => *node,
Scanner::Search { node, .. } => Some(*node),
Scanner::Vbase { node, .. } => Some(*node),
}
}
}
pub unsafe fn make_scan(index_relation: pgrx::pg_sys::Relation) -> pgrx::pg_sys::IndexScanDesc {
use pgrx::PgMemoryContexts;
let scan = pgrx::pg_sys::RelationGetIndexScan(index_relation, 0, 1);
(*scan).xs_recheck = false;
(*scan).xs_recheckorderby = false;
(*scan).opaque =
PgMemoryContexts::CurrentMemoryContext.leak_and_drop_on_delete(Scanner::Initial {
vector: None,
node: None,
}) as _;
(*scan).xs_orderbyvals = pgrx::pg_sys::palloc0(std::mem::size_of::<pgrx::pg_sys::Datum>()) as _;
(*scan).xs_orderbynulls = {
let data = pgrx::pg_sys::palloc(std::mem::size_of::<bool>()) as *mut bool;
data.write_bytes(1, 1);
data
};
scan
}
pub unsafe fn start_scan(scan: pgrx::pg_sys::IndexScanDesc, orderbys: pgrx::pg_sys::ScanKey) {
std::ptr::copy(orderbys, (*scan).orderByData, 1);
let vector = from_datum((*orderbys.add(0)).sk_argument);
let scanner = &mut *((*scan).opaque as *mut Scanner);
let scanner = std::mem::replace(
scanner,
Scanner::Initial {
node: scanner.node(),
vector: Some(vector),
},
);
match scanner {
Scanner::Initial { .. } => {}
Scanner::Search { .. } => {}
Scanner::Vbase { vbase, .. } => {
vbase.leave();
}
}
}
pub unsafe fn next_scan(scan: pgrx::pg_sys::IndexScanDesc) -> bool {
let scanner = &mut *((*scan).opaque as *mut Scanner);
if let Scanner::Initial { node, vector } = scanner {
let node = node.expect("Hook failed.");
let vector = vector.as_ref().expect("Scan failed.");
#[cfg(any(feature = "pg12", feature = "pg13", feature = "pg14", feature = "pg15"))]
let oid = (*(*scan).indexRelation).rd_node.relNode;
#[cfg(feature = "pg16")]
let oid = (*(*scan).indexRelation).rd_locator.relNumber;
let id = Id::from_sys(oid);
let mut rpc = crate::ipc::client::borrow_mut();
if ENABLE_VBASE.get() {
let vbase = rpc.vbase(id, vector.clone());
*scanner = Scanner::Vbase { node, vbase };
} else {
let k = K.get() as _;
struct Search {
node: *mut pgrx::pg_sys::IndexScanState,
}
impl crate::ipc::client::Search for Search {
fn check(&mut self, p: Pointer) -> bool {
unsafe { check(self.node, p) }
}
}
let search = Search { node };
let mut data = rpc.search(id, (vector.clone(), k), ENABLE_PREFILTER.get(), search);
data.reverse();
*scanner = Scanner::Search { node, data };
}
}
match scanner {
Scanner::Initial { .. } => unreachable!(),
Scanner::Search { data, .. } => {
if let Some(p) = data.pop() {
(*scan).xs_heaptid = p.into_sys();
true
} else {
false
}
}
Scanner::Vbase { vbase, .. } => {
if let Some(p) = vbase.next() {
(*scan).xs_heaptid = p.into_sys();
true
} else {
false
}
}
}
}
pub unsafe fn end_scan(scan: pgrx::pg_sys::IndexScanDesc) {
let scanner = &mut *((*scan).opaque as *mut Scanner);
let scanner = std::mem::replace(
scanner,
Scanner::Initial {
node: scanner.node(),
vector: None,
},
);
match scanner {
Scanner::Initial { .. } => {}
Scanner::Search { .. } => {}
Scanner::Vbase { vbase, .. } => {
vbase.leave();
}
}
}
unsafe fn execute_boolean_qual(
state: *mut pgrx::pg_sys::ExprState,
econtext: *mut pgrx::pg_sys::ExprContext,
) -> bool {
use pgrx::PgMemoryContexts;
if state.is_null() {
return true;
}
assert!((*state).flags & pgrx::pg_sys::EEO_FLAG_IS_QUAL as u8 != 0);
let mut is_null = true;
pgrx::pg_sys::MemoryContextReset((*econtext).ecxt_per_tuple_memory);
let ret = PgMemoryContexts::For((*econtext).ecxt_per_tuple_memory)
.switch_to(|_| (*state).evalfunc.unwrap()(state, econtext, &mut is_null));
assert!(!is_null);
bool::from_datum(ret, is_null).unwrap()
}
unsafe fn check_quals(node: *mut pgrx::pg_sys::IndexScanState) -> bool {
let slot = (*node).ss.ss_ScanTupleSlot;
let econtext = (*node).ss.ps.ps_ExprContext;
(*econtext).ecxt_scantuple = slot;
if (*node).ss.ps.qual.is_null() {
return true;
}
let state = (*node).ss.ps.qual;
let econtext = (*node).ss.ps.ps_ExprContext;
execute_boolean_qual(state, econtext)
}
unsafe fn check_mvcc(node: *mut pgrx::pg_sys::IndexScanState, p: Pointer) -> bool {
let scan_desc = (*node).iss_ScanDesc;
let heap_fetch = (*scan_desc).xs_heapfetch;
let index_relation = (*heap_fetch).rel;
let rd_tableam = (*index_relation).rd_tableam;
let snapshot = (*scan_desc).xs_snapshot;
let index_fetch_tuple = (*rd_tableam).index_fetch_tuple.unwrap();
let mut all_dead = false;
let slot = (*node).ss.ss_ScanTupleSlot;
let mut heap_continue = false;
let found = index_fetch_tuple(
heap_fetch,
&mut p.into_sys(),
snapshot,
slot,
&mut heap_continue,
&mut all_dead,
);
if found {
return true;
}
while heap_continue {
let found = index_fetch_tuple(
heap_fetch,
&mut p.into_sys(),
snapshot,
slot,
&mut heap_continue,
&mut all_dead,
);
if found {
return true;
}
}
false
}
unsafe fn check(node: *mut pgrx::pg_sys::IndexScanState, p: Pointer) -> bool {
if !check_mvcc(node, p) {
return false;
}
if !check_quals(node) {
return false;
}
true
}

View File

@@ -1,22 +1,22 @@
use crate::index::indexing::IndexingOptions; use crate::datatype::typmod::Typmod;
use crate::index::optimizing::OptimizingOptions;
use crate::index::segments::SegmentsOptions;
use crate::index::{IndexOptions, VectorOptions};
use crate::postgres::datatype::VectorTypmod;
use crate::prelude::*;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use service::index::indexing::IndexingOptions;
use service::index::optimizing::OptimizingOptions;
use service::index::segments::SegmentsOptions;
use service::index::{IndexOptions, VectorOptions};
use service::prelude::*;
use std::ffi::CStr; use std::ffi::CStr;
use validator::Validate; use validator::Validate;
pub fn helper_offset() -> usize { pub fn helper_offset() -> usize {
memoffset::offset_of!(Helper, offset) std::mem::offset_of!(Helper, offset)
} }
pub fn helper_size() -> usize { pub fn helper_size() -> usize {
std::mem::size_of::<Helper>() std::mem::size_of::<Helper>()
} }
pub unsafe fn convert_opclass_to_distance(opclass: pgrx::pg_sys::Oid) -> Distance { pub unsafe fn convert_opclass_to_distance(opclass: pgrx::pg_sys::Oid) -> (Distance, Kind) {
let opclass_cache_id = pgrx::pg_sys::SysCacheIdentifier_CLAOID as _; let opclass_cache_id = pgrx::pg_sys::SysCacheIdentifier_CLAOID as _;
let tuple = pgrx::pg_sys::SearchSysCache1(opclass_cache_id, opclass.into()); let tuple = pgrx::pg_sys::SearchSysCache1(opclass_cache_id, opclass.into());
assert!( assert!(
@@ -25,12 +25,12 @@ pub unsafe fn convert_opclass_to_distance(opclass: pgrx::pg_sys::Oid) -> Distanc
); );
let classform = pgrx::pg_sys::GETSTRUCT(tuple).cast::<pgrx::pg_sys::FormData_pg_opclass>(); let classform = pgrx::pg_sys::GETSTRUCT(tuple).cast::<pgrx::pg_sys::FormData_pg_opclass>();
let opfamily = (*classform).opcfamily; let opfamily = (*classform).opcfamily;
let distance = convert_opfamily_to_distance(opfamily); let result = convert_opfamily_to_distance(opfamily);
pgrx::pg_sys::ReleaseSysCache(tuple); pgrx::pg_sys::ReleaseSysCache(tuple);
distance result
} }
pub unsafe fn convert_opfamily_to_distance(opfamily: pgrx::pg_sys::Oid) -> Distance { pub unsafe fn convert_opfamily_to_distance(opfamily: pgrx::pg_sys::Oid) -> (Distance, Kind) {
let opfamily_cache_id = pgrx::pg_sys::SysCacheIdentifier_OPFAMILYOID as _; let opfamily_cache_id = pgrx::pg_sys::SysCacheIdentifier_OPFAMILYOID as _;
let opstrategy_cache_id = pgrx::pg_sys::SysCacheIdentifier_AMOPSTRATEGY as _; let opstrategy_cache_id = pgrx::pg_sys::SysCacheIdentifier_AMOPSTRATEGY as _;
let tuple = pgrx::pg_sys::SearchSysCache1(opfamily_cache_id, opfamily.into()); let tuple = pgrx::pg_sys::SearchSysCache1(opfamily_cache_id, opfamily.into());
@@ -52,19 +52,25 @@ pub unsafe fn convert_opfamily_to_distance(opfamily: pgrx::pg_sys::Oid) -> Dista
assert!((*amop).amopstrategy == 1); assert!((*amop).amopstrategy == 1);
assert!((*amop).amoppurpose == pgrx::pg_sys::AMOP_ORDER as libc::c_char); assert!((*amop).amoppurpose == pgrx::pg_sys::AMOP_ORDER as libc::c_char);
let operator = (*amop).amopopr; let operator = (*amop).amopopr;
let distance; let result;
if operator == regoperatorin("<->(vector,vector)") { if operator == regoperatorin("<->(vector,vector)") {
distance = Distance::L2; result = (Distance::L2, Kind::F32);
} else if operator == regoperatorin("<#>(vector,vector)") { } else if operator == regoperatorin("<#>(vector,vector)") {
distance = Distance::Dot; result = (Distance::Dot, Kind::F32);
} else if operator == regoperatorin("<=>(vector,vector)") { } else if operator == regoperatorin("<=>(vector,vector)") {
distance = Distance::Cosine; result = (Distance::Cos, Kind::F32);
} else if operator == regoperatorin("<->(vecf16,vecf16)") {
result = (Distance::L2, Kind::F16);
} else if operator == regoperatorin("<#>(vecf16,vecf16)") {
result = (Distance::Dot, Kind::F16);
} else if operator == regoperatorin("<=>(vecf16,vecf16)") {
result = (Distance::Cos, Kind::F16);
} else { } else {
FriendlyError::UnsupportedOperator.friendly(); FriendlyError::BadOptions3.friendly();
}; };
pgrx::pg_sys::ReleaseCatCacheList(list); pgrx::pg_sys::ReleaseCatCacheList(list);
pgrx::pg_sys::ReleaseSysCache(tuple); pgrx::pg_sys::ReleaseSysCache(tuple);
distance result
} }
pub unsafe fn options(index_relation: pgrx::pg_sys::Relation) -> IndexOptions { pub unsafe fn options(index_relation: pgrx::pg_sys::Relation) -> IndexOptions {
@@ -72,22 +78,25 @@ pub unsafe fn options(index_relation: pgrx::pg_sys::Relation) -> IndexOptions {
assert!(nkeysatts == 1, "Can not be built on multicolumns."); assert!(nkeysatts == 1, "Can not be built on multicolumns.");
// get distance // get distance
let opfamily = (*index_relation).rd_opfamily.read(); let opfamily = (*index_relation).rd_opfamily.read();
let d = convert_opfamily_to_distance(opfamily); let (d, k) = convert_opfamily_to_distance(opfamily);
// get dims // get dims
let attrs = (*(*index_relation).rd_att).attrs.as_slice(1); let attrs = (*(*index_relation).rd_att).attrs.as_slice(1);
let attr = &attrs[0]; let attr = &attrs[0];
let typmod = VectorTypmod::parse_from_i32(attr.type_mod()).unwrap(); let typmod = Typmod::parse_from_i32(attr.type_mod()).unwrap();
let dims = typmod.dims().ok_or(FriendlyError::DimsIsNeeded).friendly(); let dims = typmod.dims().ok_or(FriendlyError::BadOption2).friendly();
// get other options // get other options
let parsed = get_parsed_from_varlena((*index_relation).rd_options); let parsed = get_parsed_from_varlena((*index_relation).rd_options);
let options = IndexOptions { let options = IndexOptions {
vector: VectorOptions { dims, d }, vector: VectorOptions { dims, d, k },
segment: parsed.segment, segment: parsed.segment,
optimizing: parsed.optimizing, optimizing: parsed.optimizing,
indexing: parsed.indexing, indexing: parsed.indexing,
}; };
if let Err(errors) = options.validate() { if let Err(errors) = options.validate() {
FriendlyError::BadOption(errors.to_string()).friendly(); FriendlyError::BadOption {
validation: errors.to_string(),
}
.friendly();
} }
options options
} }

31
src/index/am_update.rs Normal file
View File

@@ -0,0 +1,31 @@
use crate::index::hook_transaction::flush_if_commit;
use crate::prelude::*;
use service::prelude::*;
pub fn update_insert(id: Id, vector: DynamicVector, tid: pgrx::pg_sys::ItemPointerData) {
flush_if_commit(id);
let p = Pointer::from_sys(tid);
let mut rpc = crate::ipc::client::borrow_mut();
rpc.insert(id, (vector, p));
}
pub fn update_delete(id: Id, hook: impl Fn(Pointer) -> bool) {
struct Delete<H> {
hook: H,
}
impl<H> crate::ipc::client::Delete for Delete<H>
where
H: Fn(Pointer) -> bool,
{
fn test(&mut self, p: Pointer) -> bool {
(self.hook)(p)
}
}
let client_delete = Delete { hook };
flush_if_commit(id);
let mut rpc = crate::ipc::client::borrow_mut();
rpc.delete(id, client_delete);
}

View File

@@ -1,5 +1,4 @@
use crate::postgres::index_scan::Scanner; use crate::index::am_scan::Scanner;
use crate::postgres::index_scan::ScannerState;
use std::ptr::null_mut; use std::ptr::null_mut;
pub unsafe fn post_executor_start(query_desc: *mut pgrx::pg_sys::QueryDesc) { pub unsafe fn post_executor_start(query_desc: *mut pgrx::pg_sys::QueryDesc) {
@@ -21,7 +20,7 @@ unsafe extern "C" fn rewrite_plan_state(
if index_relation if index_relation
.as_ref() .as_ref()
.and_then(|p| p.rd_indam.as_ref()) .and_then(|p| p.rd_indam.as_ref())
.map(|p| p.amvalidate == Some(super::index::amvalidate)) .map(|p| p.amvalidate == Some(super::am::amvalidate))
.unwrap_or(false) .unwrap_or(false)
{ {
// The logic is copied from Postgres source code. // The logic is copied from Postgres source code.
@@ -33,6 +32,13 @@ unsafe extern "C" fn rewrite_plan_state(
(*node).iss_NumScanKeys, (*node).iss_NumScanKeys,
(*node).iss_NumOrderByKeys, (*node).iss_NumOrderByKeys,
); );
let scanner = &mut *((*(*node).iss_ScanDesc).opaque as *mut Scanner);
*scanner = Scanner::Initial {
node: Some(node),
vector: None,
};
if (*node).iss_NumRuntimeKeys == 0 || (*node).iss_RuntimeKeysReady { if (*node).iss_NumRuntimeKeys == 0 || (*node).iss_RuntimeKeysReady {
pgrx::pg_sys::index_rescan( pgrx::pg_sys::index_rescan(
(*node).iss_ScanDesc, (*node).iss_ScanDesc,
@@ -42,10 +48,6 @@ unsafe extern "C" fn rewrite_plan_state(
(*node).iss_NumOrderByKeys, (*node).iss_NumOrderByKeys,
); );
} }
// inject
let scanner = &mut *((*(*node).iss_ScanDesc).opaque as *mut Scanner);
scanner.index_scan_state = node;
assert!(matches!(scanner.state, ScannerState::Initial { .. }));
} }
} }
} }

View File

@@ -0,0 +1,26 @@
use crate::utils::cells::PgRefCell;
use service::prelude::*;
use std::collections::BTreeSet;
static FLUSH_IF_COMMIT: PgRefCell<BTreeSet<Id>> = unsafe { PgRefCell::new(BTreeSet::new()) };
pub fn aborting() {
*FLUSH_IF_COMMIT.borrow_mut() = BTreeSet::new();
}
pub fn committing() {
{
let flush_if_commit = FLUSH_IF_COMMIT.borrow();
if flush_if_commit.len() != 0 {
let mut rpc = crate::ipc::client::borrow_mut();
for id in flush_if_commit.iter().copied() {
rpc.flush(id);
}
}
}
*FLUSH_IF_COMMIT.borrow_mut() = BTreeSet::new();
}
pub fn flush_if_commit(id: Id) {
FLUSH_IF_COMMIT.borrow_mut().insert(id);
}

View File

@@ -1,5 +1,5 @@
use crate::postgres::hook_transaction::client;
use crate::prelude::*; use crate::prelude::*;
use service::prelude::*;
static mut PREV_EXECUTOR_START: pgrx::pg_sys::ExecutorStart_hook_type = None; static mut PREV_EXECUTOR_START: pgrx::pg_sys::ExecutorStart_hook_type = None;
@@ -46,10 +46,8 @@ unsafe fn xact_delete() {
.iter() .iter()
.map(|node| Id::from_sys(node.relNode)) .map(|node| Id::from_sys(node.relNode))
.collect::<Vec<_>>(); .collect::<Vec<_>>();
client(|mut rpc| { let mut rpc = crate::ipc::client::borrow_mut();
rpc.destory(ids).friendly(); rpc.destory(ids);
rpc
});
} }
} }
@@ -63,9 +61,7 @@ unsafe fn xact_delete() {
.iter() .iter()
.map(|node| Id::from_sys(node.relNumber)) .map(|node| Id::from_sys(node.relNumber))
.collect::<Vec<_>>(); .collect::<Vec<_>>();
client(|mut rpc| { let mut rpc = crate::ipc::client::borrow_mut();
rpc.destory(ids).friendly(); rpc.destory(ids);
rpc
});
} }
} }

View File

@@ -1,552 +1,17 @@
pub mod delete; #![allow(unsafe_op_in_unsafe_fn)]
pub mod indexing;
pub mod optimizing;
pub mod segments;
use self::delete::Delete; mod am;
use self::indexing::IndexingOptions; mod am_build;
use self::optimizing::OptimizingOptions; mod am_scan;
use self::segments::growing::GrowingSegment; mod am_setup;
use self::segments::growing::GrowingSegmentInsertError; mod am_update;
use self::segments::sealed::SealedSegment; mod hook_executor;
use self::segments::SegmentsOptions; mod hook_transaction;
use crate::index::indexing::DynamicIndexIter; mod hooks;
use crate::prelude::*; mod utils;
use crate::utils::clean::clean; mod views;
use crate::utils::dir_ops::sync_dir;
use crate::utils::file_atomic::FileAtomic;
use arc_swap::ArcSwap;
use crossbeam::sync::Parker;
use crossbeam::sync::Unparker;
use parking_lot::Mutex;
use serde::{Deserialize, Serialize};
use std::cmp::Reverse;
use std::collections::BinaryHeap;
use std::collections::HashMap;
use std::collections::HashSet;
use std::path::PathBuf;
use std::sync::{Arc, Weak};
use thiserror::Error;
use uuid::Uuid;
use validator::Validate;
#[derive(Debug, Error)] pub unsafe fn init() {
pub enum IndexInsertError { self::hooks::init();
#[error("The vector is invalid.")] self::am::init();
InvalidVector(Vec<Scalar>),
#[error("The index view is outdated.")]
OutdatedView(#[from] Option<GrowingSegmentInsertError>),
}
#[derive(Debug, Error)]
pub enum IndexSearchError {
#[error("The vector is invalid.")]
InvalidVector(Vec<Scalar>),
}
#[derive(Debug, Clone, Serialize, Deserialize, Validate)]
pub struct VectorOptions {
#[validate(range(min = 1, max = 65535))]
#[serde(rename = "dimensions")]
pub dims: u16,
#[serde(rename = "distance")]
pub d: Distance,
}
#[derive(Debug, Clone, Serialize, Deserialize, Validate)]
pub struct IndexOptions {
#[validate]
pub vector: VectorOptions,
#[validate]
pub segment: SegmentsOptions,
#[validate]
pub optimizing: OptimizingOptions,
#[validate]
pub indexing: IndexingOptions,
}
pub struct Index {
path: PathBuf,
options: IndexOptions,
delete: Arc<Delete>,
protect: Mutex<IndexProtect>,
view: ArcSwap<IndexView>,
optimize_unparker: Unparker,
indexing: Mutex<bool>,
_tracker: Arc<IndexTracker>,
}
impl Index {
pub fn create(path: PathBuf, options: IndexOptions) -> Arc<Self> {
assert!(options.validate().is_ok());
std::fs::create_dir(&path).unwrap();
std::fs::create_dir(path.join("segments")).unwrap();
let startup = FileAtomic::create(
path.join("startup"),
IndexStartup {
sealeds: HashSet::new(),
growings: HashSet::new(),
},
);
let delete = Delete::create(path.join("delete"));
sync_dir(&path);
let parker = Parker::new();
let index = Arc::new(Index {
path: path.clone(),
options: options.clone(),
delete: delete.clone(),
protect: Mutex::new(IndexProtect {
startup,
sealed: HashMap::new(),
growing: HashMap::new(),
write: None,
}),
view: ArcSwap::new(Arc::new(IndexView {
options: options.clone(),
sealed: HashMap::new(),
growing: HashMap::new(),
delete: delete.clone(),
write: None,
})),
optimize_unparker: parker.unparker().clone(),
indexing: Mutex::new(true),
_tracker: Arc::new(IndexTracker { path }),
});
IndexBackground {
index: Arc::downgrade(&index),
parker,
}
.spawn();
index
}
pub fn open(path: PathBuf, options: IndexOptions) -> Arc<Self> {
let tracker = Arc::new(IndexTracker { path: path.clone() });
let startup = FileAtomic::<IndexStartup>::open(path.join("startup"));
clean(
path.join("segments"),
startup
.get()
.sealeds
.iter()
.map(|s| s.to_string())
.chain(startup.get().growings.iter().map(|s| s.to_string())),
);
let sealed = startup
.get()
.sealeds
.iter()
.map(|&uuid| {
(
uuid,
SealedSegment::open(
tracker.clone(),
path.join("segments").join(uuid.to_string()),
uuid,
options.clone(),
),
)
})
.collect::<HashMap<_, _>>();
let growing = startup
.get()
.growings
.iter()
.map(|&uuid| {
(
uuid,
GrowingSegment::open(
tracker.clone(),
path.join("segments").join(uuid.to_string()),
uuid,
options.clone(),
),
)
})
.collect::<HashMap<_, _>>();
let delete = Delete::open(path.join("delete"));
let parker = Parker::new();
let index = Arc::new(Index {
path: path.clone(),
options: options.clone(),
delete: delete.clone(),
protect: Mutex::new(IndexProtect {
startup,
sealed: sealed.clone(),
growing: growing.clone(),
write: None,
}),
view: ArcSwap::new(Arc::new(IndexView {
options: options.clone(),
delete: delete.clone(),
sealed,
growing,
write: None,
})),
optimize_unparker: parker.unparker().clone(),
indexing: Mutex::new(true),
_tracker: tracker,
});
IndexBackground {
index: Arc::downgrade(&index),
parker,
}
.spawn();
index
}
pub fn options(&self) -> &IndexOptions {
&self.options
}
pub fn view(&self) -> Arc<IndexView> {
self.view.load_full()
}
pub fn refresh(&self) {
let mut protect = self.protect.lock();
if let Some((uuid, write)) = protect.write.clone() {
write.seal();
protect.growing.insert(uuid, write);
}
let write_segment_uuid = Uuid::new_v4();
let write_segment = GrowingSegment::create(
self._tracker.clone(),
self.path
.join("segments")
.join(write_segment_uuid.to_string()),
write_segment_uuid,
self.options.clone(),
);
protect.write = Some((write_segment_uuid, write_segment));
protect.maintain(self.options.clone(), self.delete.clone(), &self.view);
self.optimize_unparker.unpark();
}
pub fn indexing(&self) -> bool {
*self.indexing.lock()
}
}
impl Drop for Index {
fn drop(&mut self) {
self.optimize_unparker.unpark();
}
}
#[derive(Debug, Clone)]
pub struct IndexTracker {
path: PathBuf,
}
impl Drop for IndexTracker {
fn drop(&mut self) {
std::fs::remove_dir_all(&self.path).unwrap();
}
}
pub struct IndexView {
options: IndexOptions,
delete: Arc<Delete>,
sealed: HashMap<Uuid, Arc<SealedSegment>>,
growing: HashMap<Uuid, Arc<GrowingSegment>>,
write: Option<(Uuid, Arc<GrowingSegment>)>,
}
impl IndexView {
pub fn sealed_len(&self) -> u32 {
self.sealed.values().map(|x| x.len()).sum::<u32>()
}
pub fn growing_len(&self) -> u32 {
self.growing.values().map(|x| x.len()).sum::<u32>()
}
pub fn write_len(&self) -> u32 {
self.write.as_ref().map(|x| x.1.len()).unwrap_or(0)
}
pub fn sealed_len_vec(&self) -> Vec<u32> {
self.sealed.values().map(|x| x.len()).collect()
}
pub fn growing_len_vec(&self) -> Vec<u32> {
self.growing.values().map(|x| x.len()).collect()
}
pub fn search<F: FnMut(Pointer) -> bool>(
&self,
k: usize,
vector: &[Scalar],
mut filter: F,
) -> Result<Vec<Pointer>, IndexSearchError> {
if self.options.vector.dims as usize != vector.len() {
return Err(IndexSearchError::InvalidVector(vector.to_vec()));
}
struct Comparer(BinaryHeap<Reverse<HeapElement>>);
impl PartialEq for Comparer {
fn eq(&self, other: &Self) -> bool {
self.cmp(other).is_eq()
}
}
impl Eq for Comparer {}
impl PartialOrd for Comparer {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl Ord for Comparer {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.0.peek().cmp(&other.0.peek()).reverse()
}
}
let mut filter = |payload| {
if let Some(p) = self.delete.check(payload) {
filter(p)
} else {
false
}
};
let n = self.sealed.len() + self.growing.len() + 1;
let mut result = Heap::new(k);
let mut heaps = BinaryHeap::with_capacity(1 + n);
for (_, sealed) in self.sealed.iter() {
let p = sealed.search(k, vector, &mut filter).into_reversed_heap();
heaps.push(Comparer(p));
}
for (_, growing) in self.growing.iter() {
let p = growing.search(k, vector, &mut filter).into_reversed_heap();
heaps.push(Comparer(p));
}
if let Some((_, write)) = &self.write {
let p = write.search(k, vector, &mut filter).into_reversed_heap();
heaps.push(Comparer(p));
}
while let Some(Comparer(mut heap)) = heaps.pop() {
if let Some(Reverse(x)) = heap.pop() {
result.push(x);
heaps.push(Comparer(heap));
}
}
Ok(result
.into_sorted_vec()
.iter()
.map(|x| Pointer::from_u48(x.payload >> 16))
.collect())
}
pub fn search_vbase<F>(
&self,
range: usize,
vector: &[Scalar],
mut next: F,
) -> Result<(), IndexSearchError>
where
F: FnMut(Pointer) -> bool,
{
if self.options.vector.dims as usize != vector.len() {
return Err(IndexSearchError::InvalidVector(vector.to_vec()));
}
struct Comparer<'index, 'vector> {
iter: ComparerIter<'index, 'vector>,
item: Option<HeapElement>,
}
enum ComparerIter<'index, 'vector> {
Sealed(DynamicIndexIter<'index, 'vector>),
Growing(std::vec::IntoIter<HeapElement>),
}
impl PartialEq for Comparer<'_, '_> {
fn eq(&self, other: &Self) -> bool {
self.cmp(other).is_eq()
}
}
impl Eq for Comparer<'_, '_> {}
impl PartialOrd for Comparer<'_, '_> {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl Ord for Comparer<'_, '_> {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.item.cmp(&other.item).reverse()
}
}
impl Iterator for ComparerIter<'_, '_> {
type Item = HeapElement;
fn next(&mut self) -> Option<Self::Item> {
match self {
Self::Sealed(iter) => iter.next(),
Self::Growing(iter) => iter.next(),
}
}
}
impl Iterator for Comparer<'_, '_> {
type Item = HeapElement;
fn next(&mut self) -> Option<Self::Item> {
let item = self.item.take();
self.item = self.iter.next();
item
}
}
fn from_iter<'index, 'vector>(
mut iter: ComparerIter<'index, 'vector>,
) -> Comparer<'index, 'vector> {
let item = iter.next();
Comparer { iter, item }
}
use ComparerIter::*;
let filter = |payload| self.delete.check(payload).is_some();
let n = self.sealed.len() + self.growing.len() + 1;
let mut heaps: BinaryHeap<Comparer> = BinaryHeap::with_capacity(1 + n);
for (_, sealed) in self.sealed.iter() {
let res = sealed.search_vbase(range, vector);
heaps.push(from_iter(Sealed(res)));
}
for (_, growing) in self.growing.iter() {
let mut res = growing.search_all(vector);
res.sort_unstable();
heaps.push(from_iter(Growing(res.into_iter())));
}
if let Some((_, write)) = &self.write {
let mut res = write.search_all(vector);
res.sort_unstable();
heaps.push(from_iter(Growing(res.into_iter())));
}
while let Some(mut iter) = heaps.pop() {
if let Some(x) = iter.next() {
if !filter(x.payload) {
continue;
}
let stop = next(Pointer::from_u48(x.payload >> 16));
if stop {
break;
}
heaps.push(iter);
}
}
Ok(())
}
pub fn insert(&self, vector: Vec<Scalar>, pointer: Pointer) -> Result<(), IndexInsertError> {
if self.options.vector.dims as usize != vector.len() {
return Err(IndexInsertError::InvalidVector(vector));
}
let payload = (pointer.as_u48() << 16) | self.delete.version(pointer) as Payload;
if let Some((_, growing)) = self.write.as_ref() {
Ok(growing.insert(vector, payload)?)
} else {
Err(IndexInsertError::OutdatedView(None))
}
}
pub fn delete<F: FnMut(Pointer) -> bool>(&self, mut f: F) {
for (_, sealed) in self.sealed.iter() {
let n = sealed.len();
for i in 0..n {
if let Some(p) = self.delete.check(sealed.payload(i)) {
if f(p) {
self.delete.delete(p);
}
}
}
}
for (_, growing) in self.growing.iter() {
let n = growing.len();
for i in 0..n {
if let Some(p) = self.delete.check(growing.payload(i)) {
if f(p) {
self.delete.delete(p);
}
}
}
}
if let Some((_, write)) = &self.write {
let n = write.len();
for i in 0..n {
if let Some(p) = self.delete.check(write.payload(i)) {
if f(p) {
self.delete.delete(p);
}
}
}
}
}
pub fn flush(&self) -> Result<(), IndexInsertError> {
self.delete.flush();
if let Some((_, write)) = &self.write {
write.flush();
}
Ok(())
}
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
struct IndexStartup {
sealeds: HashSet<Uuid>,
growings: HashSet<Uuid>,
}
struct IndexProtect {
startup: FileAtomic<IndexStartup>,
sealed: HashMap<Uuid, Arc<SealedSegment>>,
growing: HashMap<Uuid, Arc<GrowingSegment>>,
write: Option<(Uuid, Arc<GrowingSegment>)>,
}
impl IndexProtect {
fn maintain(&mut self, options: IndexOptions, delete: Arc<Delete>, swap: &ArcSwap<IndexView>) {
let view: Arc<IndexView> = Arc::new(IndexView {
options,
delete,
sealed: self.sealed.clone(),
growing: self.growing.clone(),
write: self.write.clone(),
});
let startup_write = self.write.as_ref().map(|(uuid, _)| *uuid);
let startup_sealeds = self.sealed.keys().copied().collect();
let startup_growings = self.growing.keys().copied().chain(startup_write).collect();
self.startup.set(IndexStartup {
sealeds: startup_sealeds,
growings: startup_growings,
});
swap.swap(view);
}
}
pub struct IndexBackground {
index: Weak<Index>,
parker: Parker,
}
impl IndexBackground {
pub fn main(self) {
let pool;
if let Some(index) = self.index.upgrade() {
pool = rayon::ThreadPoolBuilder::new()
.num_threads(index.options.optimizing.optimizing_threads)
.build()
.unwrap();
} else {
return;
}
while let Some(index) = self.index.upgrade() {
let done = pool.install(|| optimizing::indexing::optimizing_indexing(index.clone()));
if done {
*index.indexing.lock() = false;
drop(index);
self.parker.park();
if let Some(index) = self.index.upgrade() {
*index.indexing.lock() = true;
}
}
}
}
pub fn spawn(self) {
std::thread::spawn(move || {
self.main();
});
}
} }

Some files were not shown because too many files have changed in this diff Show More