mirror of
https://github.com/tensorchord/pgvecto.rs.git
synced 2025-04-18 21:44:00 +03:00
feat: synchronous remote index dropping (#414)
* feat: synchronous remote index dropping Signed-off-by: usamoi <usamoi@outlook.com> * fix: comments Signed-off-by: usamoi <usamoi@outlook.com> --------- Signed-off-by: usamoi <usamoi@outlook.com>
This commit is contained in:
parent
a1e79bba57
commit
afe2b65d9a
12
.github/workflows/check.yml
vendored
12
.github/workflows/check.yml
vendored
@ -136,19 +136,17 @@ jobs:
|
||||
run: |
|
||||
cargo install cargo-pgrx@$(grep 'pgrx = {' Cargo.toml | cut -d '"' -f 2 | head -n 1) --debug
|
||||
cargo pgrx init --pg$VERSION=$(which pg_config)
|
||||
- name: Format check
|
||||
run: cargo fmt --check
|
||||
- name: Semantic check
|
||||
run: |
|
||||
cargo clippy --no-default-features --features "pg${{ matrix.version }} pg_test" --target x86_64-unknown-linux-gnu
|
||||
cargo clippy --no-default-features --features "pg${{ matrix.version }} pg_test" --target aarch64-unknown-linux-gnu
|
||||
cargo clippy --no-default-features --features "pg${{ matrix.version }}" --target x86_64-unknown-linux-gnu
|
||||
cargo clippy --no-default-features --features "pg${{ matrix.version }}" --target aarch64-unknown-linux-gnu
|
||||
- name: Debug build
|
||||
run: |
|
||||
cargo build --no-default-features --features "pg${{ matrix.version }} pg_test" --target x86_64-unknown-linux-gnu
|
||||
cargo build --no-default-features --features "pg${{ matrix.version }} pg_test" --target aarch64-unknown-linux-gnu
|
||||
cargo build --no-default-features --features "pg${{ matrix.version }}" --target x86_64-unknown-linux-gnu
|
||||
cargo build --no-default-features --features "pg${{ matrix.version }}" --target aarch64-unknown-linux-gnu
|
||||
- name: Test
|
||||
run: |
|
||||
cargo test --all --no-fail-fast --no-default-features --features "pg${{ matrix.version }} pg_test" --target x86_64-unknown-linux-gnu -- --nocapture
|
||||
cargo test --all --no-fail-fast --no-default-features --features "pg${{ matrix.version }}" --target x86_64-unknown-linux-gnu -- --nocapture
|
||||
- name: Cache
|
||||
uses: actions/cache/save@v4
|
||||
if: ${{ !steps.cache.outputs.cache-hit }}
|
||||
|
28
.github/workflows/style.yml
vendored
Normal file
28
.github/workflows/style.yml
vendored
Normal file
@ -0,0 +1,28 @@
|
||||
name: Style check
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: ["main"]
|
||||
pull_request:
|
||||
branches: ["main"]
|
||||
merge_group:
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
run:
|
||||
name: check
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout Actions Repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Typos check
|
||||
uses: crate-ci/typos@master
|
||||
|
||||
- name: Rust format check
|
||||
run: cargo fmt --check
|
||||
|
||||
- name: Toml format check
|
||||
run: |
|
||||
curl -fsSL https://github.com/tamasfe/taplo/releases/download/0.8.1/taplo-full-linux-x86_64.gz | gzip -d - | install -m 755 /dev/stdin /usr/local/bin/taplo
|
||||
taplo fmt --check
|
22
.github/workflows/typos.yml
vendored
22
.github/workflows/typos.yml
vendored
@ -1,22 +0,0 @@
|
||||
name: Typos check
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: ["main"]
|
||||
pull_request:
|
||||
branches: ["main"]
|
||||
merge_group:
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
run:
|
||||
name: Spell Check with Typos
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout Actions Repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Check spelling of file.txt
|
||||
uses: crate-ci/typos@master
|
||||
with:
|
||||
config: .typos.toml
|
14
.taplo.toml
Normal file
14
.taplo.toml
Normal file
@ -0,0 +1,14 @@
|
||||
formatting.indent_string = " "
|
||||
|
||||
[[rule]]
|
||||
include = ["**/Cargo.toml"]
|
||||
keys = [
|
||||
"dependencies",
|
||||
"dev-dependencies",
|
||||
"lints",
|
||||
"target.*.dependencies",
|
||||
"workspace.dependencies",
|
||||
"workspace.lints",
|
||||
]
|
||||
formatting.reorder_arrays = true
|
||||
formatting.reorder_keys = true
|
615
Cargo.lock
generated
615
Cargo.lock
generated
@ -167,7 +167,7 @@ checksum = "05b1b633a2115cd122d73b955eadd9916c18c8f510ec9cd1686404c60ad1c29c"
|
||||
dependencies = [
|
||||
"async-channel 2.2.0",
|
||||
"async-executor",
|
||||
"async-io 2.3.1",
|
||||
"async-io 2.3.2",
|
||||
"async-lock 3.3.0",
|
||||
"blocking",
|
||||
"futures-lite 2.2.0",
|
||||
@ -196,9 +196,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "async-io"
|
||||
version = "2.3.1"
|
||||
version = "2.3.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8f97ab0c5b00a7cdbe5a371b9a782ee7be1316095885c8a4ea1daf490eb0ef65"
|
||||
checksum = "dcccb0f599cfa2f8ace422d3555572f47424da5648a4382a9dd0310ff8210884"
|
||||
dependencies = [
|
||||
"async-lock 3.3.0",
|
||||
"cfg-if",
|
||||
@ -265,7 +265,7 @@ version = "0.2.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9e47d90f65a225c4527103a8d747001fc56e375203592b25ad103e1ca13124c5"
|
||||
dependencies = [
|
||||
"async-io 2.3.1",
|
||||
"async-io 2.3.2",
|
||||
"async-lock 2.8.0",
|
||||
"atomic-waker",
|
||||
"cfg-if",
|
||||
@ -462,15 +462,6 @@ dependencies = [
|
||||
"wyz",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "block-buffer"
|
||||
version = "0.10.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3078c7629b62d3f0439517fa394996acacc5cbc91c5a20d8c658e77abd503a71"
|
||||
dependencies = [
|
||||
"generic-array",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "blocking"
|
||||
version = "1.5.1"
|
||||
@ -489,9 +480,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "bumpalo"
|
||||
version = "3.15.3"
|
||||
version = "3.15.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8ea184aa71bb362a1157c896979544cc23974e08fd265f29ea96b59f0b4a555b"
|
||||
checksum = "7ff69b9dd49fd426c69a0db9fc04dd934cdb6645ff000864d98f7e2af8830eaa"
|
||||
|
||||
[[package]]
|
||||
name = "bytemuck"
|
||||
@ -504,9 +495,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "bytemuck_derive"
|
||||
version = "1.5.0"
|
||||
version = "1.6.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "965ab7eb5f8f97d2a083c799f3a1b994fc397b2fe2da5d1da1626ce15a39f2b1"
|
||||
checksum = "4da9a32f3fed317401fa3c862968128267c3106685286e15d5aaa3d7389c2f60"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
@ -535,38 +526,6 @@ dependencies = [
|
||||
"rand",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "camino"
|
||||
version = "1.1.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c59e92b5a388f549b863a7bea62612c09f24c8393560709a54558a9abdfb3b9c"
|
||||
dependencies = [
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cargo-platform"
|
||||
version = "0.1.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "694c8807f2ae16faecc43dc17d74b3eb042482789fd0eb64b39a2e04e087053f"
|
||||
dependencies = [
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cargo_metadata"
|
||||
version = "0.18.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2d886547e41f740c616ae73108f6eb70afe6d940c7bc697cb30f13daec073037"
|
||||
dependencies = [
|
||||
"camino",
|
||||
"cargo-platform",
|
||||
"semver 1.0.22",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"thiserror",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cargo_toml"
|
||||
version = "0.19.2"
|
||||
@ -579,9 +538,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "cc"
|
||||
version = "1.0.89"
|
||||
version = "1.0.90"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a0ba8f7aaa012f30d5b2861462f6708eccd49c3c39863fe083a308035f63d723"
|
||||
checksum = "8cd6604a82acf3039f1144f54b8eb34e91ffba622051189e71b781822d5ee1f5"
|
||||
|
||||
[[package]]
|
||||
name = "cexpr"
|
||||
@ -609,61 +568,24 @@ dependencies = [
|
||||
"libloading",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "clap"
|
||||
version = "4.5.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c918d541ef2913577a0f9566e9ce27cb35b6df072075769e0b26cb5a554520da"
|
||||
dependencies = [
|
||||
"clap_builder",
|
||||
"clap_derive",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "clap-cargo"
|
||||
version = "0.14.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f6e2fd20c8f8c7cc395f69a86a61eb9d93e1de8fadc00338508cde2ffc656388"
|
||||
dependencies = [
|
||||
"anstyle",
|
||||
"cargo_metadata",
|
||||
"clap",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "clap_builder"
|
||||
version = "4.5.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9f3e7391dad68afb0c2ede1bf619f579a3dc9c2ec67f089baa397123a2f3d1eb"
|
||||
dependencies = [
|
||||
"anstyle",
|
||||
"clap_lex",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "clap_derive"
|
||||
version = "4.5.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "307bc0538d5f0f83b8248db3087aa92fe504e4691294d0c96c0eabc33f47ba47"
|
||||
dependencies = [
|
||||
"heck",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.52",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "clap_lex"
|
||||
version = "0.7.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "98cc8fbded0c607b7ba9dd60cd98df59af97e84d24e49c8557331cfc26d301ce"
|
||||
|
||||
[[package]]
|
||||
name = "colorchoice"
|
||||
version = "1.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "acbf1af155f9b9ef647e42cdc158db4b64a1b61f743629225fde6f3e0be2a7c7"
|
||||
|
||||
[[package]]
|
||||
name = "common"
|
||||
version = "0.0.0"
|
||||
dependencies = [
|
||||
"bytemuck",
|
||||
"log",
|
||||
"memmap2",
|
||||
"rustix 0.38.31",
|
||||
"serde",
|
||||
"serde_json",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "concurrent-queue"
|
||||
version = "2.4.0"
|
||||
@ -698,15 +620,6 @@ version = "0.8.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "06ea2b9bc92be3c2baa9334a323ebca2d6f074ff852cd1d7b11064035cd3868f"
|
||||
|
||||
[[package]]
|
||||
name = "cpufeatures"
|
||||
version = "0.2.12"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "53fe5e26ff1b7aef8bca9c6080520cfb8d9333c7568e1829cef191a9723e5504"
|
||||
dependencies = [
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "crc32fast"
|
||||
version = "1.4.0"
|
||||
@ -778,16 +691,6 @@ version = "0.2.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7a81dae078cea95a014a339291cec439d2f232ebe854a9d672b796c6afafa9b7"
|
||||
|
||||
[[package]]
|
||||
name = "crypto-common"
|
||||
version = "0.1.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1bfb12502f3fc46cca1bb51ac28df9d618d813cdc3d2f25b9fe775a34af26bb3"
|
||||
dependencies = [
|
||||
"generic-array",
|
||||
"typenum",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cty"
|
||||
version = "0.2.2"
|
||||
@ -850,17 +753,6 @@ dependencies = [
|
||||
"std_detect",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "digest"
|
||||
version = "0.10.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292"
|
||||
dependencies = [
|
||||
"block-buffer",
|
||||
"crypto-common",
|
||||
"subtle",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "dirs-next"
|
||||
version = "2.0.0"
|
||||
@ -888,6 +780,18 @@ version = "1.10.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "11157ac094ffbdde99aa67b23417ebdd801842852b500e395a45a9c0aac03e4a"
|
||||
|
||||
[[package]]
|
||||
name = "elkan_k_means"
|
||||
version = "0.0.0"
|
||||
dependencies = [
|
||||
"base",
|
||||
"bytemuck",
|
||||
"common",
|
||||
"num-traits",
|
||||
"rand",
|
||||
"rayon 0.0.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "embedding"
|
||||
version = "0.0.0"
|
||||
@ -949,9 +853,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "env_logger"
|
||||
version = "0.11.2"
|
||||
version = "0.11.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6c012a26a7f605efc424dd53697843a72be7dc86ad2d01f7814337794a12231d"
|
||||
checksum = "38b35839ba51819680ba087cd351788c9a3c476841207e0b8cee0b04722343b9"
|
||||
dependencies = [
|
||||
"anstream",
|
||||
"anstyle",
|
||||
@ -1045,12 +949,6 @@ dependencies = [
|
||||
"once_cell",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "fallible-iterator"
|
||||
version = "0.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4443176a9f2c162692bd3d352d745ef9413eec5782a80d8fd6f8a1ac692a07f7"
|
||||
|
||||
[[package]]
|
||||
name = "fastrand"
|
||||
version = "1.9.0"
|
||||
@ -1066,18 +964,23 @@ version = "2.0.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "25cbce373ec4653f1a01a31e8a5e5ec0c622dc27ff9c4e6606eefef5cbbed4a5"
|
||||
|
||||
[[package]]
|
||||
name = "finl_unicode"
|
||||
version = "1.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8fcfdc7a0362c9f4444381a9e697c79d435fe65b52a37466fc2c1184cee9edc6"
|
||||
|
||||
[[package]]
|
||||
name = "fixedbitset"
|
||||
version = "0.4.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0ce7134b9999ecaf8bcd65542e436736ef32ddca1b3e06094cb6ec5755203b80"
|
||||
|
||||
[[package]]
|
||||
name = "flat"
|
||||
version = "0.0.0"
|
||||
dependencies = [
|
||||
"base",
|
||||
"common",
|
||||
"quantization",
|
||||
"rayon 0.0.0",
|
||||
"storage",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "fnv"
|
||||
version = "1.0.7"
|
||||
@ -1106,7 +1009,6 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "eac8f7d7865dcb88bd4373ab671c8cf4508703796caa2b1985a9ca867b3fcb78"
|
||||
dependencies = [
|
||||
"futures-core",
|
||||
"futures-sink",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -1181,7 +1083,6 @@ dependencies = [
|
||||
"futures-core",
|
||||
"futures-io",
|
||||
"futures-macro",
|
||||
"futures-sink",
|
||||
"futures-task",
|
||||
"memchr",
|
||||
"pin-project-lite",
|
||||
@ -1189,16 +1090,6 @@ dependencies = [
|
||||
"slab",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "generic-array"
|
||||
version = "0.14.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "85649ca51fd72272d7821adaf274ad91c288277713d9c18820d8499a7ff69e9a"
|
||||
dependencies = [
|
||||
"typenum",
|
||||
"version_check",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "getrandom"
|
||||
version = "0.2.12"
|
||||
@ -1299,12 +1190,6 @@ dependencies = [
|
||||
"stable_deref_trait",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "heck"
|
||||
version = "0.4.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8"
|
||||
|
||||
[[package]]
|
||||
name = "hermit-abi"
|
||||
version = "0.3.9"
|
||||
@ -1312,12 +1197,16 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d231dfb89cfffdbc30e7fc41579ed6066ad03abda9e567ccafae602b97ec5024"
|
||||
|
||||
[[package]]
|
||||
name = "hmac"
|
||||
version = "0.12.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6c49c37c09c17a53d937dfbb742eb3a961d65a994e6bcdcf37e7399d0cc8ab5e"
|
||||
name = "hnsw"
|
||||
version = "0.0.0"
|
||||
dependencies = [
|
||||
"digest",
|
||||
"base",
|
||||
"bytemuck",
|
||||
"common",
|
||||
"parking_lot",
|
||||
"quantization",
|
||||
"rayon 0.0.0",
|
||||
"storage",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -1457,6 +1346,35 @@ version = "0.3.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ce23b50ad8242c51a442f3ff322d56b02f08852c77e4c0b4d3fd684abc89c683"
|
||||
|
||||
[[package]]
|
||||
name = "index"
|
||||
version = "0.0.0"
|
||||
dependencies = [
|
||||
"arc-swap",
|
||||
"base",
|
||||
"bincode",
|
||||
"byteorder",
|
||||
"common",
|
||||
"crc32fast",
|
||||
"crossbeam",
|
||||
"dashmap",
|
||||
"elkan_k_means",
|
||||
"flat",
|
||||
"hnsw",
|
||||
"ivf",
|
||||
"log",
|
||||
"parking_lot",
|
||||
"quantization",
|
||||
"rand",
|
||||
"rayon 0.0.0",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"storage",
|
||||
"thiserror",
|
||||
"uuid",
|
||||
"validator",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "indexmap"
|
||||
version = "2.2.5"
|
||||
@ -1525,6 +1443,21 @@ version = "1.0.10"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b1a46d1a171d865aa5f83f92695765caa047a9b4cbae2cbf37dbd613a793fd4c"
|
||||
|
||||
[[package]]
|
||||
name = "ivf"
|
||||
version = "0.0.0"
|
||||
dependencies = [
|
||||
"base",
|
||||
"common",
|
||||
"elkan_k_means",
|
||||
"num-traits",
|
||||
"quantization",
|
||||
"rand",
|
||||
"rayon 0.0.0",
|
||||
"serde_json",
|
||||
"storage",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "js-sys"
|
||||
version = "0.3.69"
|
||||
@ -1600,9 +1533,9 @@ checksum = "9c198f91728a82281a64e1f4f9eeb25d82cb32a5de251c6bd1b5154d63a8e7bd"
|
||||
|
||||
[[package]]
|
||||
name = "libloading"
|
||||
version = "0.8.2"
|
||||
version = "0.8.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2caa5afb8bf9f3a2652760ce7d4f62d21c4d5a423e68466fca30df82f2330164"
|
||||
checksum = "0c2a198fb6b0eada2a8df47933734e6d35d350665a33a3593d7164fa52c75c19"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"windows-targets 0.52.4",
|
||||
@ -1656,16 +1589,6 @@ dependencies = [
|
||||
"value-bag",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "md-5"
|
||||
version = "0.10.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d89e7ee0cfbedfc4da3340218492196241d89eefb6dab27de5df917a6d2e78cf"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"digest",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "memchr"
|
||||
version = "2.7.1"
|
||||
@ -1769,15 +1692,6 @@ dependencies = [
|
||||
"minimal-lexical",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ntapi"
|
||||
version = "0.4.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e8a3895c6391c39d7fe7ebc444a87eb2991b2a0bc718fdabd071eec617fc68e4"
|
||||
dependencies = [
|
||||
"winapi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num-traits"
|
||||
version = "0.2.18"
|
||||
@ -1982,37 +1896,6 @@ dependencies = [
|
||||
"unescape",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pgrx-tests"
|
||||
version = "0.12.0-alpha.1"
|
||||
source = "git+https://github.com/tensorchord/pgrx.git?branch=v0.12.0-alpha.1-patch#1a3459f597396a8d3dad0947a1d646f4cbe8e1ae"
|
||||
dependencies = [
|
||||
"clap-cargo",
|
||||
"eyre",
|
||||
"libc",
|
||||
"owo-colors",
|
||||
"pgrx",
|
||||
"pgrx-macros",
|
||||
"pgrx-pg-config",
|
||||
"postgres",
|
||||
"proptest",
|
||||
"rand",
|
||||
"regex",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"sysinfo",
|
||||
"thiserror",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "phf"
|
||||
version = "0.11.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ade2d8b8f33c7333b51bcf0428d37e217e9f32192ae4772156f65063b8ce03dc"
|
||||
dependencies = [
|
||||
"phf_shared 0.11.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "phf_shared"
|
||||
version = "0.10.0"
|
||||
@ -2022,15 +1905,6 @@ dependencies = [
|
||||
"siphasher",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "phf_shared"
|
||||
version = "0.11.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "90fcb95eef784c2ac79119d1dd819e162b5da872ce6f3c3abe1e8ca1c082f72b"
|
||||
dependencies = [
|
||||
"siphasher",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pico-args"
|
||||
version = "0.5.0"
|
||||
@ -2090,49 +1964,6 @@ dependencies = [
|
||||
"windows-sys 0.52.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "postgres"
|
||||
version = "0.19.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7915b33ed60abc46040cbcaa25ffa1c7ec240668e0477c4f3070786f5916d451"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"fallible-iterator",
|
||||
"futures-util",
|
||||
"log",
|
||||
"tokio",
|
||||
"tokio-postgres",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "postgres-protocol"
|
||||
version = "0.6.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "49b6c5ef183cd3ab4ba005f1ca64c21e8bd97ce4699cfea9e8d9a2c4958ca520"
|
||||
dependencies = [
|
||||
"base64",
|
||||
"byteorder",
|
||||
"bytes",
|
||||
"fallible-iterator",
|
||||
"hmac",
|
||||
"md-5",
|
||||
"memchr",
|
||||
"rand",
|
||||
"sha2",
|
||||
"stringprep",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "postgres-types"
|
||||
version = "0.2.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8d2234cdee9408b523530a9b6d2d6b373d1db34f6a8e51dc03ded1828d7fb67c"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"fallible-iterator",
|
||||
"postgres-protocol",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ppv-lite86"
|
||||
version = "0.2.17"
|
||||
@ -2179,31 +2010,18 @@ dependencies = [
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "proptest"
|
||||
version = "1.4.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "31b476131c3c86cb68032fdc5cb6d5a1045e3e42d96b69fa599fd77701e1f5bf"
|
||||
name = "quantization"
|
||||
version = "0.0.0"
|
||||
dependencies = [
|
||||
"bit-set",
|
||||
"bit-vec",
|
||||
"bitflags 2.4.2",
|
||||
"lazy_static",
|
||||
"base",
|
||||
"common",
|
||||
"elkan_k_means",
|
||||
"multiversion",
|
||||
"num-traits",
|
||||
"rand",
|
||||
"rand_chacha",
|
||||
"rand_xorshift",
|
||||
"regex-syntax",
|
||||
"rusty-fork",
|
||||
"tempfile",
|
||||
"unarray",
|
||||
"serde_json",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "quick-error"
|
||||
version = "1.2.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a1d01941d82fa2ab50be1e79e6714289dd7cde78eba4c074bc5a4374f650dfe0"
|
||||
|
||||
[[package]]
|
||||
name = "quote"
|
||||
version = "1.0.35"
|
||||
@ -2260,12 +2078,11 @@ dependencies = [
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rand_xorshift"
|
||||
version = "0.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d25bf25ec5ae4a3f1b92f929810509a2f53d7dca2f50b794ff57e3face536c8f"
|
||||
name = "rayon"
|
||||
version = "0.0.0"
|
||||
dependencies = [
|
||||
"rand_core",
|
||||
"log",
|
||||
"rayon 1.9.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -2339,9 +2156,9 @@ checksum = "c08c74e62047bb2de4ff487b251e4a92e24f48745648451635cec7d591162d9f"
|
||||
|
||||
[[package]]
|
||||
name = "reqwest"
|
||||
version = "0.11.24"
|
||||
version = "0.11.25"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c6920094eb85afde5e4a138be3f2de8bbdf28000f0029e72c45025a56b042251"
|
||||
checksum = "0eea5a9eb898d3783f17c6407670e3592fd174cb81a10e51d4c37f49450b9946"
|
||||
dependencies = [
|
||||
"base64",
|
||||
"bytes",
|
||||
@ -2411,7 +2228,7 @@ version = "0.3.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f0dfe2087c51c460008730de8b57e6a320782fbfb312e1f4d520e6c6fae155ee"
|
||||
dependencies = [
|
||||
"semver 0.11.0",
|
||||
"semver",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -2478,18 +2295,6 @@ version = "1.0.14"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7ffc183a10b4478d04cbbbfc96d0873219d962dd5accaff2ffbd4ceb7df837f4"
|
||||
|
||||
[[package]]
|
||||
name = "rusty-fork"
|
||||
version = "0.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "cb3dcc6e454c328bb824492db107ab7c0ae8fcffe4ad210136ef014458c1bc4f"
|
||||
dependencies = [
|
||||
"fnv",
|
||||
"quick-error",
|
||||
"tempfile",
|
||||
"wait-timeout",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ryu"
|
||||
version = "1.0.17"
|
||||
@ -2536,15 +2341,6 @@ dependencies = [
|
||||
"semver-parser",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "semver"
|
||||
version = "1.0.22"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "92d43fe69e652f3df9bdc2b85b2854a0825b86e4fb76bc44d945137d053639ca"
|
||||
dependencies = [
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "semver-parser"
|
||||
version = "0.10.2"
|
||||
@ -2647,35 +2443,12 @@ version = "0.0.0"
|
||||
dependencies = [
|
||||
"arc-swap",
|
||||
"base",
|
||||
"bincode",
|
||||
"bytemuck",
|
||||
"byteorder",
|
||||
"crc32fast",
|
||||
"crossbeam",
|
||||
"dashmap",
|
||||
"log",
|
||||
"memmap2",
|
||||
"num-traits",
|
||||
"common",
|
||||
"index",
|
||||
"parking_lot",
|
||||
"rand",
|
||||
"rayon",
|
||||
"rustix 0.38.31",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"thiserror",
|
||||
"uuid",
|
||||
"validator",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "sha2"
|
||||
version = "0.10.8"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "793db75ad2bcafc3ffa7c68b215fee268f537982cd901d132f89c6343f3a3dc8"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"cpufeatures",
|
||||
"digest",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -2767,6 +2540,14 @@ dependencies = [
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "storage"
|
||||
version = "0.0.0"
|
||||
dependencies = [
|
||||
"base",
|
||||
"common",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "string_cache"
|
||||
version = "0.8.7"
|
||||
@ -2776,33 +2557,16 @@ dependencies = [
|
||||
"new_debug_unreachable",
|
||||
"once_cell",
|
||||
"parking_lot",
|
||||
"phf_shared 0.10.0",
|
||||
"phf_shared",
|
||||
"precomputed-hash",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "stringprep"
|
||||
version = "0.1.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "bb41d74e231a107a1b4ee36bd1214b11285b77768d2e3824aedafa988fd36ee6"
|
||||
dependencies = [
|
||||
"finl_unicode",
|
||||
"unicode-bidi",
|
||||
"unicode-normalization",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "strsim"
|
||||
version = "0.10.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "73473c0e59e6d5812c5dfe2a064a6444949f089e20eec9a2e5506596494e4623"
|
||||
|
||||
[[package]]
|
||||
name = "subtle"
|
||||
version = "2.5.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "81cdd64d312baedb58e21336b31bc043b77e01cc99033ce76ef539f78e965ebc"
|
||||
|
||||
[[package]]
|
||||
name = "syn"
|
||||
version = "1.0.109"
|
||||
@ -2831,37 +2595,22 @@ version = "0.1.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2047c6ded9c721764247e62cd3b03c09ffc529b2ba5b10ec482ae507a4a70160"
|
||||
|
||||
[[package]]
|
||||
name = "sysinfo"
|
||||
version = "0.29.11"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "cd727fc423c2060f6c92d9534cef765c65a6ed3f428a03d7def74a8c4348e666"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"core-foundation-sys",
|
||||
"libc",
|
||||
"ntapi",
|
||||
"once_cell",
|
||||
"rayon",
|
||||
"winapi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "system-configuration"
|
||||
version = "0.5.1"
|
||||
version = "0.6.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ba3a3adc5c275d719af8cb4272ea1c4a6d668a777f37e115f6d11ddbc1c8e0e7"
|
||||
checksum = "658bc6ee10a9b4fcf576e9b0819d95ec16f4d2c02d39fd83ac1c8789785c4a42"
|
||||
dependencies = [
|
||||
"bitflags 1.3.2",
|
||||
"bitflags 2.4.2",
|
||||
"core-foundation",
|
||||
"system-configuration-sys",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "system-configuration-sys"
|
||||
version = "0.5.0"
|
||||
version = "0.6.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a75fb188eb626b924683e3b95e3a48e63551fcfb51949de2f06a9d91dbee93c9"
|
||||
checksum = "8e1d1b10ced5ca923a1fcb8d03e96b8d3268065d724548c0211415ff6ac6bac4"
|
||||
dependencies = [
|
||||
"core-foundation-sys",
|
||||
"libc",
|
||||
@ -2879,18 +2628,6 @@ version = "0.1.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "cfb5fa503293557c5158bd215fdc225695e567a77e453f5d4452a50a193969bd"
|
||||
|
||||
[[package]]
|
||||
name = "tempfile"
|
||||
version = "3.10.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "85b77fafb263dd9d05cbeac119526425676db3784113aa9295c88498cbf8bff1"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"fastrand 2.0.1",
|
||||
"rustix 0.38.31",
|
||||
"windows-sys 0.52.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "term"
|
||||
version = "0.7.0"
|
||||
@ -2922,6 +2659,26 @@ dependencies = [
|
||||
"syn 2.0.52",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tikv-jemalloc-sys"
|
||||
version = "0.5.4+5.3.0-patched"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9402443cb8fd499b6f327e40565234ff34dbda27460c5b47db0db77443dd85d1"
|
||||
dependencies = [
|
||||
"cc",
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tikv-jemallocator"
|
||||
version = "0.5.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "965fe0c26be5c56c94e38ba547249074803efd52adfb66de62107d95aab3eaca"
|
||||
dependencies = [
|
||||
"libc",
|
||||
"tikv-jemalloc-sys",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tiny-keccak"
|
||||
version = "2.0.2"
|
||||
@ -2975,32 +2732,6 @@ dependencies = [
|
||||
"syn 2.0.52",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tokio-postgres"
|
||||
version = "0.7.10"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d340244b32d920260ae7448cb72b6e238bddc3d4f7603394e7dd46ed8e48f5b8"
|
||||
dependencies = [
|
||||
"async-trait",
|
||||
"byteorder",
|
||||
"bytes",
|
||||
"fallible-iterator",
|
||||
"futures-channel",
|
||||
"futures-util",
|
||||
"log",
|
||||
"parking_lot",
|
||||
"percent-encoding",
|
||||
"phf",
|
||||
"pin-project-lite",
|
||||
"postgres-protocol",
|
||||
"postgres-types",
|
||||
"rand",
|
||||
"socket2 0.5.6",
|
||||
"tokio",
|
||||
"tokio-util",
|
||||
"whoami",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tokio-rustls"
|
||||
version = "0.24.1"
|
||||
@ -3090,12 +2821,6 @@ version = "0.2.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b"
|
||||
|
||||
[[package]]
|
||||
name = "typenum"
|
||||
version = "1.17.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825"
|
||||
|
||||
[[package]]
|
||||
name = "ucd-trie"
|
||||
version = "0.1.6"
|
||||
@ -3111,12 +2836,6 @@ dependencies = [
|
||||
"cty",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "unarray"
|
||||
version = "0.1.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "eaea85b334db583fe3274d12b4cd1880032beab409c0d774be044d4480ab9a94"
|
||||
|
||||
[[package]]
|
||||
name = "unescape"
|
||||
version = "0.1.0"
|
||||
@ -3237,7 +2956,6 @@ dependencies = [
|
||||
"detect",
|
||||
"embedding",
|
||||
"env_logger",
|
||||
"httpmock",
|
||||
"interprocess_atomic_wait",
|
||||
"libc",
|
||||
"log",
|
||||
@ -3246,14 +2964,15 @@ dependencies = [
|
||||
"num-traits",
|
||||
"paste",
|
||||
"pgrx",
|
||||
"pgrx-tests",
|
||||
"rand",
|
||||
"rustix 0.38.31",
|
||||
"scopeguard",
|
||||
"send_fd",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"service",
|
||||
"thiserror",
|
||||
"tikv-jemallocator",
|
||||
"toml",
|
||||
"validator",
|
||||
]
|
||||
@ -3264,15 +2983,6 @@ version = "0.9.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f"
|
||||
|
||||
[[package]]
|
||||
name = "wait-timeout"
|
||||
version = "0.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9f200f5b12eb75f8c1ed65abd4b2db8a6e1b138a20de009dacee265a2498f3f6"
|
||||
dependencies = [
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "waker-fn"
|
||||
version = "1.1.1"
|
||||
@ -3304,12 +3014,6 @@ version = "0.11.0+wasi-snapshot-preview1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423"
|
||||
|
||||
[[package]]
|
||||
name = "wasite"
|
||||
version = "0.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b8dad83b4f25e74f184f64c43b150b91efe7647395b42289f38e50566d82855b"
|
||||
|
||||
[[package]]
|
||||
name = "wasm-bindgen"
|
||||
version = "0.2.92"
|
||||
@ -3392,17 +3096,6 @@ version = "0.25.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5f20c57d8d7db6d3b86154206ae5d8fba62dd39573114de97c2cb0578251f8e1"
|
||||
|
||||
[[package]]
|
||||
name = "whoami"
|
||||
version = "1.5.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0fec781d48b41f8163426ed18e8fc2864c12937df9ce54c88ede7bd47270893e"
|
||||
dependencies = [
|
||||
"redox_syscall",
|
||||
"wasite",
|
||||
"web-sys",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "winapi"
|
||||
version = "0.3.9"
|
||||
|
62
Cargo.toml
62
Cargo.toml
@ -12,45 +12,44 @@ path = "./src/bin/pgrx_embed.rs"
|
||||
|
||||
[features]
|
||||
default = ["pg15"]
|
||||
pg14 = ["pgrx/pg14", "pgrx-tests/pg14"]
|
||||
pg15 = ["pgrx/pg15", "pgrx-tests/pg15"]
|
||||
pg16 = ["pgrx/pg16", "pgrx-tests/pg16"]
|
||||
pg_test = []
|
||||
pg14 = ["pgrx/pg14"]
|
||||
pg15 = ["pgrx/pg15"]
|
||||
pg16 = ["pgrx/pg16"]
|
||||
|
||||
[dependencies]
|
||||
arrayvec.workspace = true
|
||||
bincode.workspace = true
|
||||
bytemuck.workspace = true
|
||||
byteorder.workspace = true
|
||||
env_logger = "0.11.2"
|
||||
libc.workspace = true
|
||||
log.workspace = true
|
||||
memmap2.workspace = true
|
||||
num-traits.workspace = true
|
||||
paste.workspace = true
|
||||
pgrx = { version = "0.12.0-alpha.1", default-features = false, features = [] }
|
||||
rand.workspace = true
|
||||
rustix.workspace = true
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
thiserror.workspace = true
|
||||
tikv-jemallocator = { version = "0.5.4", features = [
|
||||
"disable_initial_exec_tls",
|
||||
] }
|
||||
toml = "0.8.10"
|
||||
validator.workspace = true
|
||||
|
||||
base = { path = "crates/base" }
|
||||
detect = { path = "crates/detect" }
|
||||
send_fd = { path = "crates/send_fd" }
|
||||
service = { path = "crates/service" }
|
||||
embedding = { path = "crates/embedding" }
|
||||
interprocess_atomic_wait = { path = "crates/interprocess-atomic-wait" }
|
||||
memfd = { path = "crates/memfd" }
|
||||
pgrx = { version = "0.12.0-alpha.1", default-features = false, features = [] }
|
||||
env_logger = "0.11.2"
|
||||
toml = "0.8.10"
|
||||
|
||||
[dev-dependencies]
|
||||
pgrx-tests = "0.12.0-alpha.1"
|
||||
httpmock = "0.7"
|
||||
scopeguard = "1.2.0"
|
||||
send_fd = { path = "crates/send_fd" }
|
||||
service = { path = "crates/service" }
|
||||
|
||||
[patch.crates-io]
|
||||
pgrx = { git = "https://github.com/tensorchord/pgrx.git", branch = "v0.12.0-alpha.1-patch" }
|
||||
pgrx-tests = { git = "https://github.com/tensorchord/pgrx.git", branch = "v0.12.0-alpha.1-patch" }
|
||||
|
||||
[lints]
|
||||
rust.unsafe_op_in_unsafe_fn = "deny"
|
||||
@ -66,29 +65,32 @@ version = "0.0.0"
|
||||
edition = "2021"
|
||||
|
||||
[workspace.dependencies]
|
||||
arrayvec = "~0.7"
|
||||
bincode = "~1.3"
|
||||
bytemuck = { version = "~1.14", features = ["extern_crate_alloc"] }
|
||||
byteorder = "~1.5"
|
||||
half = { version = "~2.4", features = [
|
||||
arc-swap = "1.7.0"
|
||||
arrayvec = "0.7.4"
|
||||
bincode = "1.3.3"
|
||||
bytemuck = { version = "1.14.3", features = ["extern_crate_alloc"] }
|
||||
byteorder = "1.5.0"
|
||||
half = { version = "2.4.0", features = [
|
||||
"bytemuck",
|
||||
"num-traits",
|
||||
"rand_distr",
|
||||
"serde",
|
||||
"use-intrinsics",
|
||||
"rand_distr",
|
||||
] }
|
||||
libc = "~0.2"
|
||||
log = "~0.4"
|
||||
libc = "0.2.153"
|
||||
log = "0.4.21"
|
||||
memmap2 = "0.9.4"
|
||||
num-traits = "~0.2"
|
||||
paste = "~1.0"
|
||||
multiversion = "0.7.3"
|
||||
num-traits = "0.2.18"
|
||||
parking_lot = "0.12.1"
|
||||
paste = "1.0.14"
|
||||
rand = "0.8.5"
|
||||
rustix = { version = "~0.38", features = ["fs", "net", "mm"] }
|
||||
serde = "~1.0"
|
||||
serde_json = "~1.0"
|
||||
thiserror = "~1.0"
|
||||
uuid = { version = "1.7.0", features = ["v4", "serde"] }
|
||||
validator = { version = "~0.17", features = ["derive"] }
|
||||
rustix = { version = "0.38.31", features = ["fs", "mm", "net"] }
|
||||
serde = "1"
|
||||
serde_json = "1"
|
||||
thiserror = "1"
|
||||
uuid = { version = "1.7.0", features = ["serde", "v4"] }
|
||||
validator = { version = "0.17.0", features = ["derive"] }
|
||||
|
||||
[workspace.lints]
|
||||
rust.unsafe_op_in_unsafe_fn = "forbid"
|
||||
|
@ -3,42 +3,29 @@ name = "pgvecto-rs"
|
||||
version = "0.1.4"
|
||||
description = "Python binding for pgvecto.rs"
|
||||
authors = [
|
||||
{ name = "TensorChord", email = "envd-maintainers@tensorchord.ai" },
|
||||
{ name = "盐粒 Yanli", email = "mail@yanli.one" },
|
||||
]
|
||||
dependencies = [
|
||||
"numpy>=1.23",
|
||||
"toml>=0.10",
|
||||
{ name = "TensorChord", email = "envd-maintainers@tensorchord.ai" },
|
||||
{ name = "盐粒 Yanli", email = "mail@yanli.one" },
|
||||
]
|
||||
dependencies = ["numpy>=1.23", "toml>=0.10"]
|
||||
requires-python = ">=3.8"
|
||||
readme = "README.md"
|
||||
license = { text = "Apache-2.0" }
|
||||
classifiers = [
|
||||
"Programming Language :: Python :: 3 :: Only",
|
||||
"Programming Language :: Python :: 3.8",
|
||||
"Programming Language :: Python :: 3.9",
|
||||
"Programming Language :: Python :: 3.10",
|
||||
"Programming Language :: Python :: 3.11",
|
||||
"Programming Language :: Python :: 3 :: Only",
|
||||
"Programming Language :: Python :: 3.8",
|
||||
"Programming Language :: Python :: 3.9",
|
||||
"Programming Language :: Python :: 3.10",
|
||||
"Programming Language :: Python :: 3.11",
|
||||
]
|
||||
|
||||
[build-system]
|
||||
build-backend = "pdm.backend"
|
||||
requires = [
|
||||
"pdm-backend",
|
||||
]
|
||||
requires = ["pdm-backend"]
|
||||
|
||||
[project.optional-dependencies]
|
||||
psycopg3 = [
|
||||
"psycopg[binary]>=3.1.12",
|
||||
]
|
||||
sdk = [
|
||||
"openai>=1.2.2",
|
||||
"pgvecto_rs[sqlalchemy]",
|
||||
]
|
||||
sqlalchemy = [
|
||||
"pgvecto_rs[psycopg3]",
|
||||
"SQLAlchemy>=2.0.23",
|
||||
]
|
||||
psycopg3 = ["psycopg[binary]>=3.1.12"]
|
||||
sdk = ["openai>=1.2.2", "pgvecto_rs[sqlalchemy]"]
|
||||
sqlalchemy = ["pgvecto_rs[psycopg3]", "SQLAlchemy>=2.0.23"]
|
||||
[tool.pdm.dev-dependencies]
|
||||
lint = ["ruff>=0.1.5"]
|
||||
test = ["pytest>=7.4.3"]
|
||||
@ -51,20 +38,20 @@ check = { composite = ["ruff format . --check", "ruff ."] }
|
||||
|
||||
[tool.ruff]
|
||||
select = [
|
||||
"E", #https://docs.astral.sh/ruff/rules/#error-e
|
||||
"F", #https://docs.astral.sh/ruff/rules/#pyflakes-f
|
||||
"I", #https://docs.astral.sh/ruff/rules/#isort-i
|
||||
"TID", #https://docs.astral.sh/ruff/rules/#flake8-tidy-imports-tid
|
||||
"S", #https://docs.astral.sh/ruff/rules/#flake8-bandit-s
|
||||
"B", #https://docs.astral.sh/ruff/rules/#flake8-bugbear-b
|
||||
"SIM", #https://docs.astral.sh/ruff/rules/#flake8-simplify-sim
|
||||
"N", #https://docs.astral.sh/ruff/rules/#pep8-naming-n
|
||||
"PT", #https://docs.astral.sh/ruff/rules/#flake8-pytest-style-pt
|
||||
"TRY", #https://docs.astral.sh/ruff/rules/#tryceratops-try
|
||||
"FLY", #https://docs.astral.sh/ruff/rules/#flynt-fly
|
||||
"PL", #https://docs.astral.sh/ruff/rules/#pylint-pl
|
||||
"NPY", #https://docs.astral.sh/ruff/rules/#numpy-specific-rules-npy
|
||||
"RUF", #https://docs.astral.sh/ruff/rules/#ruff-specific-rules-ruf
|
||||
"E", #https://docs.astral.sh/ruff/rules/#error-e
|
||||
"F", #https://docs.astral.sh/ruff/rules/#pyflakes-f
|
||||
"I", #https://docs.astral.sh/ruff/rules/#isort-i
|
||||
"TID", #https://docs.astral.sh/ruff/rules/#flake8-tidy-imports-tid
|
||||
"S", #https://docs.astral.sh/ruff/rules/#flake8-bandit-s
|
||||
"B", #https://docs.astral.sh/ruff/rules/#flake8-bugbear-b
|
||||
"SIM", #https://docs.astral.sh/ruff/rules/#flake8-simplify-sim
|
||||
"N", #https://docs.astral.sh/ruff/rules/#pep8-naming-n
|
||||
"PT", #https://docs.astral.sh/ruff/rules/#flake8-pytest-style-pt
|
||||
"TRY", #https://docs.astral.sh/ruff/rules/#tryceratops-try
|
||||
"FLY", #https://docs.astral.sh/ruff/rules/#flynt-fly
|
||||
"PL", #https://docs.astral.sh/ruff/rules/#pylint-pl
|
||||
"NPY", #https://docs.astral.sh/ruff/rules/#numpy-specific-rules-npy
|
||||
"RUF", #https://docs.astral.sh/ruff/rules/#ruff-specific-rules-ruf
|
||||
]
|
||||
ignore = ["S101", "E731", "E501"]
|
||||
src = ["src"]
|
||||
|
@ -7,15 +7,16 @@ edition.workspace = true
|
||||
bytemuck.workspace = true
|
||||
half.workspace = true
|
||||
libc.workspace = true
|
||||
multiversion.workspace = true
|
||||
num-traits.workspace = true
|
||||
rand.workspace = true
|
||||
serde.workspace = true
|
||||
thiserror.workspace = true
|
||||
uuid.workspace = true
|
||||
validator.workspace = true
|
||||
|
||||
c = { path = "../c" }
|
||||
detect = { path = "../detect" }
|
||||
multiversion = "0.7.3"
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
@ -1,95 +0,0 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use thiserror::Error;
|
||||
|
||||
// control plane
|
||||
|
||||
#[must_use]
|
||||
#[derive(Debug, Clone, Error, Serialize, Deserialize)]
|
||||
pub enum CreateError {
|
||||
#[error("Index of given name already exists.")]
|
||||
Exist,
|
||||
#[error("Invalid index options.")]
|
||||
InvalidIndexOptions { reason: String },
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
#[derive(Debug, Clone, Error, Serialize, Deserialize)]
|
||||
pub enum DropError {
|
||||
#[error("Index not found.")]
|
||||
NotExist,
|
||||
}
|
||||
|
||||
// data plane
|
||||
|
||||
#[must_use]
|
||||
#[derive(Debug, Clone, Error, Serialize, Deserialize)]
|
||||
pub enum FlushError {
|
||||
#[error("Index not found.")]
|
||||
NotExist,
|
||||
#[error("Maintenance should be done.")]
|
||||
Upgrade,
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
#[derive(Debug, Clone, Error, Serialize, Deserialize)]
|
||||
pub enum InsertError {
|
||||
#[error("Index not found.")]
|
||||
NotExist,
|
||||
#[error("Maintenance should be done.")]
|
||||
Upgrade,
|
||||
#[error("Invalid vector.")]
|
||||
InvalidVector,
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
#[derive(Debug, Clone, Error, Serialize, Deserialize)]
|
||||
pub enum DeleteError {
|
||||
#[error("Index not found.")]
|
||||
NotExist,
|
||||
#[error("Maintenance should be done.")]
|
||||
Upgrade,
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
#[derive(Debug, Clone, Error, Serialize, Deserialize)]
|
||||
pub enum BasicError {
|
||||
#[error("Index not found.")]
|
||||
NotExist,
|
||||
#[error("Maintenance should be done.")]
|
||||
Upgrade,
|
||||
#[error("Invalid vector.")]
|
||||
InvalidVector,
|
||||
#[error("Invalid search options.")]
|
||||
InvalidSearchOptions { reason: String },
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
#[derive(Debug, Clone, Error, Serialize, Deserialize)]
|
||||
pub enum VbaseError {
|
||||
#[error("Index not found.")]
|
||||
NotExist,
|
||||
#[error("Maintenance should be done.")]
|
||||
Upgrade,
|
||||
#[error("Invalid vector.")]
|
||||
InvalidVector,
|
||||
#[error("Invalid search options.")]
|
||||
InvalidSearchOptions { reason: String },
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
#[derive(Debug, Clone, Error, Serialize, Deserialize)]
|
||||
pub enum ListError {
|
||||
#[error("Index not found.")]
|
||||
NotExist,
|
||||
#[error("Maintenance should be done.")]
|
||||
Upgrade,
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
#[derive(Debug, Clone, Error, Serialize, Deserialize)]
|
||||
pub enum StatError {
|
||||
#[error("Index not found.")]
|
||||
NotExist,
|
||||
#[error("Maintenance should be done.")]
|
||||
Upgrade,
|
||||
}
|
@ -1,404 +0,0 @@
|
||||
use crate::scalar::*;
|
||||
use crate::vector::*;
|
||||
use num_traits::Float;
|
||||
|
||||
#[inline(always)]
|
||||
pub fn cosine<'a>(lhs: BVecf32Borrowed<'a>, rhs: BVecf32Borrowed<'a>) -> F32 {
|
||||
let lhs = lhs.data();
|
||||
let rhs = rhs.data();
|
||||
assert!(lhs.len() == rhs.len());
|
||||
|
||||
#[inline(always)]
|
||||
#[multiversion::multiversion(targets(
|
||||
"x86_64/x86-64-v4",
|
||||
"x86_64/x86-64-v3",
|
||||
"x86_64/x86-64-v2",
|
||||
"aarch64+neon"
|
||||
))]
|
||||
pub fn cosine(lhs: &[usize], rhs: &[usize]) -> F32 {
|
||||
let mut xy = 0;
|
||||
let mut xx = 0;
|
||||
let mut yy = 0;
|
||||
for i in 0..lhs.len() {
|
||||
xy += (lhs[i] & rhs[i]).count_ones();
|
||||
xx += lhs[i].count_ones();
|
||||
yy += rhs[i].count_ones();
|
||||
}
|
||||
let rxy = xy as f32;
|
||||
let rxx = xx as f32;
|
||||
let ryy = yy as f32;
|
||||
F32(rxy / (rxx * ryy).sqrt())
|
||||
}
|
||||
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
#[target_feature(enable = "avx512vpopcntdq,avx512bw,avx512f,bmi2")]
|
||||
unsafe fn cosine_avx512vpopcntdq(lhs: &[usize], rhs: &[usize]) -> F32 {
|
||||
use std::arch::x86_64::*;
|
||||
#[inline]
|
||||
#[target_feature(enable = "avx512vpopcntdq,avx512bw,avx512f,bmi2")]
|
||||
pub unsafe fn _mm512_maskz_loadu_epi64(k: __mmask8, mem_addr: *const i8) -> __m512i {
|
||||
let mut dst: __m512i;
|
||||
unsafe {
|
||||
std::arch::asm!(
|
||||
"vmovdqu64 {dst}{{{k}}} {{z}}, [{p}]",
|
||||
p = in(reg) mem_addr,
|
||||
k = in(kreg) k,
|
||||
dst = out(zmm_reg) dst,
|
||||
options(pure, readonly, nostack)
|
||||
);
|
||||
}
|
||||
dst
|
||||
}
|
||||
assert_eq!(lhs.len(), rhs.len());
|
||||
unsafe {
|
||||
const WIDTH: usize = 512 / 8 / std::mem::size_of::<usize>();
|
||||
let mut xy = _mm512_setzero_si512();
|
||||
let mut xx = _mm512_setzero_si512();
|
||||
let mut yy = _mm512_setzero_si512();
|
||||
let mut a = lhs.as_ptr();
|
||||
let mut b = rhs.as_ptr();
|
||||
let mut n = lhs.len();
|
||||
while n >= WIDTH {
|
||||
let x = _mm512_loadu_si512(a.cast());
|
||||
let y = _mm512_loadu_si512(b.cast());
|
||||
a = a.add(WIDTH);
|
||||
b = b.add(WIDTH);
|
||||
n -= WIDTH;
|
||||
xy = _mm512_add_epi64(xy, _mm512_popcnt_epi64(_mm512_and_si512(x, y)));
|
||||
xx = _mm512_add_epi64(xx, _mm512_popcnt_epi64(x));
|
||||
yy = _mm512_add_epi64(yy, _mm512_popcnt_epi64(y));
|
||||
}
|
||||
if n > 0 {
|
||||
let mask = _bzhi_u32(0xFFFF, n as u32) as u8;
|
||||
let x = _mm512_maskz_loadu_epi64(mask, a.cast());
|
||||
let y = _mm512_maskz_loadu_epi64(mask, b.cast());
|
||||
xy = _mm512_add_epi64(xy, _mm512_popcnt_epi64(_mm512_and_si512(x, y)));
|
||||
xx = _mm512_add_epi64(xx, _mm512_popcnt_epi64(x));
|
||||
yy = _mm512_add_epi64(yy, _mm512_popcnt_epi64(y));
|
||||
}
|
||||
let rxy = _mm512_reduce_add_epi64(xy) as f32;
|
||||
let rxx = _mm512_reduce_add_epi64(xx) as f32;
|
||||
let ryy = _mm512_reduce_add_epi64(yy) as f32;
|
||||
F32(rxy / (rxx * ryy).sqrt())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
if detect::x86_64::detect_avx512vpopcntdq() {
|
||||
unsafe {
|
||||
return cosine_avx512vpopcntdq(lhs, rhs);
|
||||
}
|
||||
}
|
||||
cosine(lhs, rhs)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn dot<'a>(lhs: BVecf32Borrowed<'a>, rhs: BVecf32Borrowed<'a>) -> F32 {
|
||||
let lhs = lhs.data();
|
||||
let rhs = rhs.data();
|
||||
assert!(lhs.len() == rhs.len());
|
||||
|
||||
#[inline(always)]
|
||||
#[multiversion::multiversion(targets(
|
||||
"x86_64/x86-64-v4",
|
||||
"x86_64/x86-64-v3",
|
||||
"x86_64/x86-64-v2",
|
||||
"aarch64+neon"
|
||||
))]
|
||||
pub fn dot(lhs: &[usize], rhs: &[usize]) -> F32 {
|
||||
let mut xy = 0;
|
||||
for i in 0..lhs.len() {
|
||||
xy += (lhs[i] & rhs[i]).count_ones();
|
||||
}
|
||||
F32(xy as f32)
|
||||
}
|
||||
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
#[target_feature(enable = "avx512vpopcntdq,avx512bw,avx512f,bmi2")]
|
||||
unsafe fn dot_avx512vpopcntdq(lhs: &[usize], rhs: &[usize]) -> F32 {
|
||||
use std::arch::x86_64::*;
|
||||
#[inline]
|
||||
#[target_feature(enable = "avx512vpopcntdq,avx512bw,avx512f,bmi2")]
|
||||
pub unsafe fn _mm512_maskz_loadu_epi64(k: __mmask8, mem_addr: *const i8) -> __m512i {
|
||||
let mut dst: __m512i;
|
||||
unsafe {
|
||||
std::arch::asm!(
|
||||
"vmovdqu64 {dst}{{{k}}} {{z}}, [{p}]",
|
||||
p = in(reg) mem_addr,
|
||||
k = in(kreg) k,
|
||||
dst = out(zmm_reg) dst,
|
||||
options(pure, readonly, nostack)
|
||||
);
|
||||
}
|
||||
dst
|
||||
}
|
||||
assert_eq!(lhs.len(), rhs.len());
|
||||
unsafe {
|
||||
const WIDTH: usize = 512 / 8 / std::mem::size_of::<usize>();
|
||||
let mut xy = _mm512_setzero_si512();
|
||||
let mut a = lhs.as_ptr();
|
||||
let mut b = rhs.as_ptr();
|
||||
let mut n = lhs.len();
|
||||
while n >= WIDTH {
|
||||
let x = _mm512_loadu_si512(a.cast());
|
||||
let y = _mm512_loadu_si512(b.cast());
|
||||
a = a.add(WIDTH);
|
||||
b = b.add(WIDTH);
|
||||
n -= WIDTH;
|
||||
xy = _mm512_add_epi64(xy, _mm512_popcnt_epi64(_mm512_and_si512(x, y)));
|
||||
}
|
||||
if n > 0 {
|
||||
let mask = _bzhi_u32(0xFFFF, n as u32) as u8;
|
||||
let x = _mm512_maskz_loadu_epi64(mask, a.cast());
|
||||
let y = _mm512_maskz_loadu_epi64(mask, b.cast());
|
||||
xy = _mm512_add_epi64(xy, _mm512_popcnt_epi64(_mm512_and_si512(x, y)));
|
||||
}
|
||||
let rxy = _mm512_reduce_add_epi64(xy) as f32;
|
||||
F32(rxy)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
if detect::x86_64::detect_avx512vpopcntdq() {
|
||||
unsafe {
|
||||
return dot_avx512vpopcntdq(lhs, rhs);
|
||||
}
|
||||
}
|
||||
dot(lhs, rhs)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn sl2<'a>(lhs: BVecf32Borrowed<'a>, rhs: BVecf32Borrowed<'a>) -> F32 {
|
||||
let lhs = lhs.data();
|
||||
let rhs = rhs.data();
|
||||
assert!(lhs.len() == rhs.len());
|
||||
|
||||
#[inline(always)]
|
||||
#[multiversion::multiversion(targets(
|
||||
"x86_64/x86-64-v4",
|
||||
"x86_64/x86-64-v3",
|
||||
"x86_64/x86-64-v2",
|
||||
"aarch64+neon"
|
||||
))]
|
||||
pub fn sl2(lhs: &[usize], rhs: &[usize]) -> F32 {
|
||||
let mut dd = 0;
|
||||
for i in 0..lhs.len() {
|
||||
dd += (lhs[i] ^ rhs[i]).count_ones();
|
||||
}
|
||||
F32(dd as f32)
|
||||
}
|
||||
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
#[target_feature(enable = "avx512vpopcntdq,avx512bw,avx512f,bmi2")]
|
||||
unsafe fn sl2_avx512vpopcntdq(lhs: &[usize], rhs: &[usize]) -> F32 {
|
||||
use std::arch::x86_64::*;
|
||||
#[inline]
|
||||
#[target_feature(enable = "avx512vpopcntdq,avx512bw,avx512f,bmi2")]
|
||||
pub unsafe fn _mm512_maskz_loadu_epi64(k: __mmask8, mem_addr: *const i8) -> __m512i {
|
||||
let mut dst: __m512i;
|
||||
unsafe {
|
||||
std::arch::asm!(
|
||||
"vmovdqu64 {dst}{{{k}}} {{z}}, [{p}]",
|
||||
p = in(reg) mem_addr,
|
||||
k = in(kreg) k,
|
||||
dst = out(zmm_reg) dst,
|
||||
options(pure, readonly, nostack)
|
||||
);
|
||||
}
|
||||
dst
|
||||
}
|
||||
assert_eq!(lhs.len(), rhs.len());
|
||||
unsafe {
|
||||
const WIDTH: usize = 512 / 8 / std::mem::size_of::<usize>();
|
||||
let mut dd = _mm512_setzero_si512();
|
||||
let mut a = lhs.as_ptr();
|
||||
let mut b = rhs.as_ptr();
|
||||
let mut n = lhs.len();
|
||||
while n >= WIDTH {
|
||||
let x = _mm512_loadu_si512(a.cast());
|
||||
let y = _mm512_loadu_si512(b.cast());
|
||||
a = a.add(WIDTH);
|
||||
b = b.add(WIDTH);
|
||||
n -= WIDTH;
|
||||
dd = _mm512_add_epi64(dd, _mm512_popcnt_epi64(_mm512_xor_si512(x, y)));
|
||||
}
|
||||
if n > 0 {
|
||||
let mask = _bzhi_u32(0xFFFF, n as u32) as u8;
|
||||
let x = _mm512_maskz_loadu_epi64(mask, a.cast());
|
||||
let y = _mm512_maskz_loadu_epi64(mask, b.cast());
|
||||
dd = _mm512_add_epi64(dd, _mm512_popcnt_epi64(_mm512_xor_si512(x, y)));
|
||||
}
|
||||
let rdd = _mm512_reduce_add_epi64(dd) as f32;
|
||||
F32(rdd)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
if detect::x86_64::detect_avx512vpopcntdq() {
|
||||
unsafe {
|
||||
return sl2_avx512vpopcntdq(lhs, rhs);
|
||||
}
|
||||
}
|
||||
sl2(lhs, rhs)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn jaccard<'a>(lhs: BVecf32Borrowed<'a>, rhs: BVecf32Borrowed<'a>) -> F32 {
|
||||
let lhs = lhs.data();
|
||||
let rhs = rhs.data();
|
||||
assert!(lhs.len() == rhs.len());
|
||||
|
||||
#[inline(always)]
|
||||
#[multiversion::multiversion(targets(
|
||||
"x86_64/x86-64-v4",
|
||||
"x86_64/x86-64-v3",
|
||||
"x86_64/x86-64-v2",
|
||||
"aarch64+neon"
|
||||
))]
|
||||
pub fn jaccard(lhs: &[usize], rhs: &[usize]) -> F32 {
|
||||
let mut inter = 0;
|
||||
let mut union = 0;
|
||||
for i in 0..lhs.len() {
|
||||
inter += (lhs[i] & rhs[i]).count_ones();
|
||||
union += (lhs[i] | rhs[i]).count_ones();
|
||||
}
|
||||
F32(inter as f32 / union as f32)
|
||||
}
|
||||
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
#[target_feature(enable = "avx512vpopcntdq,avx512bw,avx512f,bmi2")]
|
||||
unsafe fn jaccard_avx512vpopcntdq(lhs: &[usize], rhs: &[usize]) -> F32 {
|
||||
use std::arch::x86_64::*;
|
||||
#[inline]
|
||||
#[target_feature(enable = "avx512vpopcntdq,avx512bw,avx512f,bmi2")]
|
||||
pub unsafe fn _mm512_maskz_loadu_epi64(k: __mmask8, mem_addr: *const i8) -> __m512i {
|
||||
let mut dst: __m512i;
|
||||
unsafe {
|
||||
std::arch::asm!(
|
||||
"vmovdqu64 {dst}{{{k}}} {{z}}, [{p}]",
|
||||
p = in(reg) mem_addr,
|
||||
k = in(kreg) k,
|
||||
dst = out(zmm_reg) dst,
|
||||
options(pure, readonly, nostack)
|
||||
);
|
||||
}
|
||||
dst
|
||||
}
|
||||
assert_eq!(lhs.len(), rhs.len());
|
||||
unsafe {
|
||||
const WIDTH: usize = 512 / 8 / std::mem::size_of::<usize>();
|
||||
let mut inter = _mm512_setzero_si512();
|
||||
let mut union = _mm512_setzero_si512();
|
||||
let mut a = lhs.as_ptr();
|
||||
let mut b = rhs.as_ptr();
|
||||
let mut n = lhs.len();
|
||||
while n >= WIDTH {
|
||||
let x = _mm512_loadu_si512(a.cast());
|
||||
let y = _mm512_loadu_si512(b.cast());
|
||||
a = a.add(WIDTH);
|
||||
b = b.add(WIDTH);
|
||||
n -= WIDTH;
|
||||
inter = _mm512_add_epi64(inter, _mm512_popcnt_epi64(_mm512_and_si512(x, y)));
|
||||
union = _mm512_add_epi64(union, _mm512_popcnt_epi64(_mm512_or_si512(x, y)));
|
||||
}
|
||||
if n > 0 {
|
||||
let mask = _bzhi_u32(0xFFFF, n as u32) as u8;
|
||||
let x = _mm512_maskz_loadu_epi64(mask, a.cast());
|
||||
let y = _mm512_maskz_loadu_epi64(mask, b.cast());
|
||||
inter = _mm512_add_epi64(inter, _mm512_popcnt_epi64(_mm512_and_si512(x, y)));
|
||||
union = _mm512_add_epi64(union, _mm512_popcnt_epi64(_mm512_or_si512(x, y)));
|
||||
}
|
||||
let rinter = _mm512_reduce_add_epi64(inter) as f32;
|
||||
let runion = _mm512_reduce_add_epi64(union) as f32;
|
||||
F32(rinter / runion)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
if detect::x86_64::detect_avx512vpopcntdq() {
|
||||
unsafe {
|
||||
return jaccard_avx512vpopcntdq(lhs, rhs);
|
||||
}
|
||||
}
|
||||
jaccard(lhs, rhs)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn length(vector: BVecf32Borrowed<'_>) -> F32 {
|
||||
let vector = vector.data();
|
||||
|
||||
#[inline(always)]
|
||||
#[multiversion::multiversion(targets(
|
||||
"x86_64/x86-64-v4",
|
||||
"x86_64/x86-64-v3",
|
||||
"x86_64/x86-64-v2",
|
||||
"aarch64+neon"
|
||||
))]
|
||||
pub fn length(vector: &[usize]) -> F32 {
|
||||
let mut l = 0;
|
||||
for i in 0..vector.len() {
|
||||
l += vector[i].count_ones();
|
||||
}
|
||||
F32(l as f32).sqrt()
|
||||
}
|
||||
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
#[target_feature(enable = "avx512vpopcntdq,avx512bw,avx512f,bmi2")]
|
||||
unsafe fn length_avx512vpopcntdq(lhs: &[usize]) -> F32 {
|
||||
use std::arch::x86_64::*;
|
||||
#[inline]
|
||||
#[target_feature(enable = "avx512vpopcntdq,avx512bw,avx512f,bmi2")]
|
||||
pub unsafe fn _mm512_maskz_loadu_epi64(k: __mmask8, mem_addr: *const i8) -> __m512i {
|
||||
let mut dst: __m512i;
|
||||
unsafe {
|
||||
std::arch::asm!(
|
||||
"vmovdqu64 {dst}{{{k}}} {{z}}, [{p}]",
|
||||
p = in(reg) mem_addr,
|
||||
k = in(kreg) k,
|
||||
dst = out(zmm_reg) dst,
|
||||
options(pure, readonly, nostack)
|
||||
);
|
||||
}
|
||||
dst
|
||||
}
|
||||
unsafe {
|
||||
const WIDTH: usize = 512 / 8 / std::mem::size_of::<usize>();
|
||||
let mut cnt = _mm512_setzero_si512();
|
||||
let mut a = lhs.as_ptr();
|
||||
let mut n = lhs.len();
|
||||
while n >= WIDTH {
|
||||
let x = _mm512_loadu_si512(a.cast());
|
||||
a = a.add(WIDTH);
|
||||
n -= WIDTH;
|
||||
cnt = _mm512_add_epi64(cnt, _mm512_popcnt_epi64(x));
|
||||
}
|
||||
if n > 0 {
|
||||
let mask = _bzhi_u32(0xFFFF, n as u32) as u8;
|
||||
let x = _mm512_maskz_loadu_epi64(mask, a.cast());
|
||||
cnt = _mm512_add_epi64(cnt, _mm512_popcnt_epi64(x));
|
||||
}
|
||||
let rcnt = _mm512_reduce_add_epi64(cnt) as f32;
|
||||
F32(rcnt.sqrt())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
if detect::x86_64::detect_avx512vpopcntdq() {
|
||||
unsafe {
|
||||
return length_avx512vpopcntdq(vector);
|
||||
}
|
||||
}
|
||||
length(vector)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
#[multiversion::multiversion(targets(
|
||||
"x86_64/x86-64-v4",
|
||||
"x86_64/x86-64-v3",
|
||||
"x86_64/x86-64-v2",
|
||||
"aarch64+neon"
|
||||
))]
|
||||
pub fn l2_normalize<'a>(vector: BVecf32Borrowed<'a>) -> Vecf32Owned {
|
||||
let l = length(vector);
|
||||
Vecf32Owned::new(vector.iter().map(|i| F32(i as u32 as f32) / l).collect())
|
||||
}
|
@ -1,104 +0,0 @@
|
||||
use super::*;
|
||||
use crate::distance::*;
|
||||
use crate::scalar::*;
|
||||
use crate::vector::*;
|
||||
use num_traits::Float;
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub enum BVecf32Cos {}
|
||||
|
||||
impl Global for BVecf32Cos {
|
||||
type VectorOwned = BVecf32Owned;
|
||||
|
||||
const VECTOR_KIND: VectorKind = VectorKind::BVecf32;
|
||||
const DISTANCE_KIND: DistanceKind = DistanceKind::Cos;
|
||||
|
||||
fn distance(lhs: Borrowed<'_, Self>, rhs: Borrowed<'_, Self>) -> F32 {
|
||||
F32(1.0) - super::bvecf32::cosine(lhs, rhs)
|
||||
}
|
||||
}
|
||||
|
||||
impl GlobalElkanKMeans for BVecf32Cos {
|
||||
type VectorNormalized = Vecf32Owned;
|
||||
|
||||
fn elkan_k_means_normalize(vector: &mut [Scalar<Self>]) {
|
||||
super::vecf32::l2_normalize(vector)
|
||||
}
|
||||
|
||||
fn elkan_k_means_normalize2(vector: Borrowed<'_, Self>) -> Vecf32Owned {
|
||||
super::bvecf32::l2_normalize(vector)
|
||||
}
|
||||
|
||||
fn elkan_k_means_distance(lhs: &[Scalar<Self>], rhs: &[Scalar<Self>]) -> F32 {
|
||||
super::vecf32::dot(lhs, rhs).acos()
|
||||
}
|
||||
|
||||
fn elkan_k_means_distance2(lhs: Vecf32Borrowed<'_>, rhs: &[Scalar<Self>]) -> F32 {
|
||||
super::vecf32::dot(lhs.slice(), rhs).acos()
|
||||
}
|
||||
}
|
||||
|
||||
impl GlobalScalarQuantization for BVecf32Cos {
|
||||
fn scalar_quantization_distance(
|
||||
_dims: u16,
|
||||
_max: &[F32],
|
||||
_min: &[F32],
|
||||
_lhs: Borrowed<'_, Self>,
|
||||
_rhs: &[u8],
|
||||
) -> F32 {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn scalar_quantization_distance2(
|
||||
_dims: u16,
|
||||
_max: &[Scalar<Self>],
|
||||
_min: &[Scalar<Self>],
|
||||
_lhs: &[u8],
|
||||
_rhs: &[u8],
|
||||
) -> F32 {
|
||||
unimplemented!()
|
||||
}
|
||||
}
|
||||
|
||||
impl GlobalProductQuantization for BVecf32Cos {
|
||||
type ProductQuantizationL2 = BVecf32L2;
|
||||
|
||||
fn product_quantization_distance(
|
||||
_dims: u32,
|
||||
_ratio: u32,
|
||||
_centroids: &[Scalar<Self>],
|
||||
_lhs: Borrowed<'_, Self>,
|
||||
_rhs: &[u8],
|
||||
) -> F32 {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn product_quantization_distance2(
|
||||
_dims: u32,
|
||||
_ratio: u32,
|
||||
_centroids: &[Scalar<Self>],
|
||||
_lhs: &[u8],
|
||||
_rhs: &[u8],
|
||||
) -> F32 {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn product_quantization_distance_with_delta(
|
||||
_dims: u32,
|
||||
_ratio: u32,
|
||||
_centroids: &[Scalar<Self>],
|
||||
_lhs: Borrowed<'_, Self>,
|
||||
_rhs: &[u8],
|
||||
_delta: &[Scalar<Self>],
|
||||
) -> F32 {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn product_quantization_l2_distance(lhs: &[Scalar<Self>], rhs: &[Scalar<Self>]) -> F32 {
|
||||
super::vecf32::sl2(lhs, rhs)
|
||||
}
|
||||
|
||||
fn product_quantization_dense_distance(_: &[Scalar<Self>], _: &[Scalar<Self>]) -> F32 {
|
||||
unimplemented!()
|
||||
}
|
||||
}
|
@ -1,104 +0,0 @@
|
||||
use super::*;
|
||||
use crate::distance::*;
|
||||
use crate::scalar::*;
|
||||
use crate::vector::*;
|
||||
use num_traits::Float;
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub enum BVecf32Dot {}
|
||||
|
||||
impl Global for BVecf32Dot {
|
||||
type VectorOwned = BVecf32Owned;
|
||||
|
||||
const VECTOR_KIND: VectorKind = VectorKind::BVecf32;
|
||||
const DISTANCE_KIND: DistanceKind = DistanceKind::Dot;
|
||||
|
||||
fn distance(lhs: Borrowed<'_, Self>, rhs: Borrowed<'_, Self>) -> F32 {
|
||||
super::bvecf32::dot(lhs, rhs) * (-1.0)
|
||||
}
|
||||
}
|
||||
|
||||
impl GlobalElkanKMeans for BVecf32Dot {
|
||||
type VectorNormalized = Vecf32Owned;
|
||||
|
||||
fn elkan_k_means_normalize(vector: &mut [Scalar<Self>]) {
|
||||
super::vecf32::l2_normalize(vector)
|
||||
}
|
||||
|
||||
fn elkan_k_means_normalize2(vector: Borrowed<'_, Self>) -> Vecf32Owned {
|
||||
super::bvecf32::l2_normalize(vector)
|
||||
}
|
||||
|
||||
fn elkan_k_means_distance(lhs: &[Scalar<Self>], rhs: &[Scalar<Self>]) -> F32 {
|
||||
super::vecf32::dot(lhs, rhs).acos()
|
||||
}
|
||||
|
||||
fn elkan_k_means_distance2(lhs: Vecf32Borrowed<'_>, rhs: &[Scalar<Self>]) -> F32 {
|
||||
super::vecf32::dot(lhs.slice(), rhs).acos()
|
||||
}
|
||||
}
|
||||
|
||||
impl GlobalScalarQuantization for BVecf32Dot {
|
||||
fn scalar_quantization_distance(
|
||||
_dims: u16,
|
||||
_max: &[F32],
|
||||
_min: &[F32],
|
||||
_lhs: Borrowed<'_, Self>,
|
||||
_rhs: &[u8],
|
||||
) -> F32 {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn scalar_quantization_distance2(
|
||||
_dims: u16,
|
||||
_max: &[Scalar<Self>],
|
||||
_min: &[Scalar<Self>],
|
||||
_lhs: &[u8],
|
||||
_rhs: &[u8],
|
||||
) -> F32 {
|
||||
unimplemented!()
|
||||
}
|
||||
}
|
||||
|
||||
impl GlobalProductQuantization for BVecf32Dot {
|
||||
type ProductQuantizationL2 = BVecf32L2;
|
||||
|
||||
fn product_quantization_distance(
|
||||
_dims: u32,
|
||||
_ratio: u32,
|
||||
_centroids: &[Scalar<Self>],
|
||||
_lhs: Borrowed<'_, Self>,
|
||||
_rhs: &[u8],
|
||||
) -> F32 {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn product_quantization_distance2(
|
||||
_dims: u32,
|
||||
_ratio: u32,
|
||||
_centroids: &[Scalar<Self>],
|
||||
_lhs: &[u8],
|
||||
_rhs: &[u8],
|
||||
) -> F32 {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn product_quantization_distance_with_delta(
|
||||
_dims: u32,
|
||||
_ratio: u32,
|
||||
_centroids: &[Scalar<Self>],
|
||||
_lhs: Borrowed<'_, Self>,
|
||||
_rhs: &[u8],
|
||||
_delta: &[Scalar<Self>],
|
||||
) -> F32 {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn product_quantization_l2_distance(lhs: &[Scalar<Self>], rhs: &[Scalar<Self>]) -> F32 {
|
||||
super::vecf32::sl2(lhs, rhs)
|
||||
}
|
||||
|
||||
fn product_quantization_dense_distance(_: &[Scalar<Self>], _: &[Scalar<Self>]) -> F32 {
|
||||
unimplemented!()
|
||||
}
|
||||
}
|
@ -1,104 +0,0 @@
|
||||
use super::*;
|
||||
use crate::distance::*;
|
||||
use crate::scalar::*;
|
||||
use crate::vector::*;
|
||||
use num_traits::Float;
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub enum BVecf32Jaccard {}
|
||||
|
||||
impl Global for BVecf32Jaccard {
|
||||
type VectorOwned = BVecf32Owned;
|
||||
|
||||
const VECTOR_KIND: VectorKind = VectorKind::BVecf32;
|
||||
const DISTANCE_KIND: DistanceKind = DistanceKind::Jaccard;
|
||||
|
||||
fn distance(lhs: Borrowed<'_, Self>, rhs: Borrowed<'_, Self>) -> F32 {
|
||||
F32(1.) - super::bvecf32::jaccard(lhs, rhs)
|
||||
}
|
||||
}
|
||||
|
||||
impl GlobalElkanKMeans for BVecf32Jaccard {
|
||||
type VectorNormalized = Vecf32Owned;
|
||||
|
||||
fn elkan_k_means_normalize(vector: &mut [Scalar<Self>]) {
|
||||
super::vecf32::l2_normalize(vector)
|
||||
}
|
||||
|
||||
fn elkan_k_means_normalize2(vector: Borrowed<'_, Self>) -> Vecf32Owned {
|
||||
Vecf32Owned::new(vector.to_vec())
|
||||
}
|
||||
|
||||
fn elkan_k_means_distance(lhs: &[Scalar<Self>], rhs: &[Scalar<Self>]) -> F32 {
|
||||
super::vecf32::sl2(lhs, rhs).sqrt()
|
||||
}
|
||||
|
||||
fn elkan_k_means_distance2(lhs: Vecf32Borrowed<'_>, rhs: &[Scalar<Self>]) -> F32 {
|
||||
super::vecf32::sl2(lhs.slice(), rhs).sqrt()
|
||||
}
|
||||
}
|
||||
|
||||
impl GlobalScalarQuantization for BVecf32Jaccard {
|
||||
fn scalar_quantization_distance(
|
||||
_dims: u16,
|
||||
_max: &[F32],
|
||||
_min: &[F32],
|
||||
_lhs: Borrowed<'_, Self>,
|
||||
_rhs: &[u8],
|
||||
) -> F32 {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn scalar_quantization_distance2(
|
||||
_dims: u16,
|
||||
_max: &[Scalar<Self>],
|
||||
_min: &[Scalar<Self>],
|
||||
_lhs: &[u8],
|
||||
_rhs: &[u8],
|
||||
) -> F32 {
|
||||
unimplemented!()
|
||||
}
|
||||
}
|
||||
|
||||
impl GlobalProductQuantization for BVecf32Jaccard {
|
||||
type ProductQuantizationL2 = BVecf32L2;
|
||||
|
||||
fn product_quantization_distance(
|
||||
_dims: u32,
|
||||
_ratio: u32,
|
||||
_centroids: &[Scalar<Self>],
|
||||
_lhs: Borrowed<'_, Self>,
|
||||
_rhs: &[u8],
|
||||
) -> F32 {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn product_quantization_distance2(
|
||||
_dims: u32,
|
||||
_ratio: u32,
|
||||
_centroids: &[Scalar<Self>],
|
||||
_lhs: &[u8],
|
||||
_rhs: &[u8],
|
||||
) -> F32 {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn product_quantization_distance_with_delta(
|
||||
_dims: u32,
|
||||
_ratio: u32,
|
||||
_centroids: &[Scalar<Self>],
|
||||
_lhs: Borrowed<'_, Self>,
|
||||
_rhs: &[u8],
|
||||
_delta: &[Scalar<Self>],
|
||||
) -> F32 {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn product_quantization_l2_distance(lhs: &[Scalar<Self>], rhs: &[Scalar<Self>]) -> F32 {
|
||||
super::vecf32::sl2(lhs, rhs)
|
||||
}
|
||||
|
||||
fn product_quantization_dense_distance(_: &[Scalar<Self>], _: &[Scalar<Self>]) -> F32 {
|
||||
unimplemented!()
|
||||
}
|
||||
}
|
@ -1,104 +0,0 @@
|
||||
use super::*;
|
||||
use crate::distance::*;
|
||||
use crate::scalar::*;
|
||||
use crate::vector::*;
|
||||
use num_traits::Float;
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub enum BVecf32L2 {}
|
||||
|
||||
impl Global for BVecf32L2 {
|
||||
type VectorOwned = BVecf32Owned;
|
||||
|
||||
const VECTOR_KIND: VectorKind = VectorKind::BVecf32;
|
||||
const DISTANCE_KIND: DistanceKind = DistanceKind::L2;
|
||||
|
||||
fn distance(lhs: Borrowed<'_, Self>, rhs: Borrowed<'_, Self>) -> F32 {
|
||||
super::bvecf32::sl2(lhs, rhs)
|
||||
}
|
||||
}
|
||||
|
||||
impl GlobalElkanKMeans for BVecf32L2 {
|
||||
type VectorNormalized = Vecf32Owned;
|
||||
|
||||
fn elkan_k_means_normalize(vector: &mut [Scalar<Self>]) {
|
||||
super::vecf32::l2_normalize(vector)
|
||||
}
|
||||
|
||||
fn elkan_k_means_normalize2(vector: Borrowed<'_, Self>) -> Vecf32Owned {
|
||||
Vecf32Owned::new(vector.to_vec())
|
||||
}
|
||||
|
||||
fn elkan_k_means_distance(lhs: &[Scalar<Self>], rhs: &[Scalar<Self>]) -> F32 {
|
||||
super::vecf32::sl2(lhs, rhs).sqrt()
|
||||
}
|
||||
|
||||
fn elkan_k_means_distance2(lhs: Vecf32Borrowed<'_>, rhs: &[Scalar<Self>]) -> F32 {
|
||||
super::vecf32::sl2(lhs.slice(), rhs).sqrt()
|
||||
}
|
||||
}
|
||||
|
||||
impl GlobalScalarQuantization for BVecf32L2 {
|
||||
fn scalar_quantization_distance(
|
||||
_dims: u16,
|
||||
_max: &[F32],
|
||||
_min: &[F32],
|
||||
_lhs: Borrowed<'_, Self>,
|
||||
_rhs: &[u8],
|
||||
) -> F32 {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn scalar_quantization_distance2(
|
||||
_dims: u16,
|
||||
_max: &[Scalar<Self>],
|
||||
_min: &[Scalar<Self>],
|
||||
_lhs: &[u8],
|
||||
_rhs: &[u8],
|
||||
) -> F32 {
|
||||
unimplemented!()
|
||||
}
|
||||
}
|
||||
|
||||
impl GlobalProductQuantization for BVecf32L2 {
|
||||
type ProductQuantizationL2 = BVecf32L2;
|
||||
|
||||
fn product_quantization_distance(
|
||||
_dims: u32,
|
||||
_ratio: u32,
|
||||
_centroids: &[Scalar<Self>],
|
||||
_lhs: Borrowed<'_, Self>,
|
||||
_rhs: &[u8],
|
||||
) -> F32 {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn product_quantization_distance2(
|
||||
_dims: u32,
|
||||
_ratio: u32,
|
||||
_centroids: &[Scalar<Self>],
|
||||
_lhs: &[u8],
|
||||
_rhs: &[u8],
|
||||
) -> F32 {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn product_quantization_distance_with_delta(
|
||||
_dims: u32,
|
||||
_ratio: u32,
|
||||
_centroids: &[Scalar<Self>],
|
||||
_lhs: Borrowed<'_, Self>,
|
||||
_rhs: &[u8],
|
||||
_delta: &[Scalar<Self>],
|
||||
) -> F32 {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn product_quantization_l2_distance(lhs: &[Scalar<Self>], rhs: &[Scalar<Self>]) -> F32 {
|
||||
super::vecf32::sl2(lhs, rhs)
|
||||
}
|
||||
|
||||
fn product_quantization_dense_distance(_: &[Scalar<Self>], _: &[Scalar<Self>]) -> F32 {
|
||||
unimplemented!()
|
||||
}
|
||||
}
|
@ -1,114 +0,0 @@
|
||||
mod bvecf32;
|
||||
mod bvecf32_cos;
|
||||
mod bvecf32_dot;
|
||||
mod bvecf32_jaccard;
|
||||
mod bvecf32_l2;
|
||||
mod svecf32;
|
||||
mod svecf32_cos;
|
||||
mod svecf32_dot;
|
||||
mod svecf32_l2;
|
||||
mod vecf16;
|
||||
mod vecf16_cos;
|
||||
mod vecf16_dot;
|
||||
mod vecf16_l2;
|
||||
mod vecf32;
|
||||
mod vecf32_cos;
|
||||
mod vecf32_dot;
|
||||
mod vecf32_l2;
|
||||
mod veci8;
|
||||
mod veci8_cos;
|
||||
mod veci8_dot;
|
||||
mod veci8_l2;
|
||||
|
||||
pub use bvecf32_cos::BVecf32Cos;
|
||||
pub use bvecf32_dot::BVecf32Dot;
|
||||
pub use bvecf32_jaccard::BVecf32Jaccard;
|
||||
pub use bvecf32_l2::BVecf32L2;
|
||||
pub use svecf32_cos::SVecf32Cos;
|
||||
pub use svecf32_dot::SVecf32Dot;
|
||||
pub use svecf32_l2::SVecf32L2;
|
||||
pub use vecf16_cos::Vecf16Cos;
|
||||
pub use vecf16_dot::Vecf16Dot;
|
||||
pub use vecf16_l2::Vecf16L2;
|
||||
pub use vecf32_cos::Vecf32Cos;
|
||||
pub use vecf32_dot::Vecf32Dot;
|
||||
pub use vecf32_l2::Vecf32L2;
|
||||
pub use veci8_cos::Veci8Cos;
|
||||
pub use veci8_dot::Veci8Dot;
|
||||
pub use veci8_l2::Veci8L2;
|
||||
|
||||
use crate::distance::*;
|
||||
use crate::scalar::*;
|
||||
use crate::vector::*;
|
||||
|
||||
pub trait GlobalElkanKMeans: Global {
|
||||
type VectorNormalized: VectorOwned;
|
||||
|
||||
fn elkan_k_means_normalize(vector: &mut [Scalar<Self>]);
|
||||
fn elkan_k_means_normalize2(vector: Borrowed<'_, Self>) -> Self::VectorNormalized;
|
||||
fn elkan_k_means_distance(lhs: &[Scalar<Self>], rhs: &[Scalar<Self>]) -> F32;
|
||||
fn elkan_k_means_distance2(
|
||||
lhs: <Self::VectorNormalized as VectorOwned>::Borrowed<'_>,
|
||||
rhs: &[Scalar<Self>],
|
||||
) -> F32;
|
||||
}
|
||||
|
||||
pub trait GlobalScalarQuantization: Global {
|
||||
fn scalar_quantization_distance(
|
||||
dims: u16,
|
||||
max: &[Scalar<Self>],
|
||||
min: &[Scalar<Self>],
|
||||
lhs: Borrowed<'_, Self>,
|
||||
rhs: &[u8],
|
||||
) -> F32;
|
||||
fn scalar_quantization_distance2(
|
||||
dims: u16,
|
||||
max: &[Scalar<Self>],
|
||||
min: &[Scalar<Self>],
|
||||
lhs: &[u8],
|
||||
rhs: &[u8],
|
||||
) -> F32;
|
||||
}
|
||||
|
||||
pub trait GlobalProductQuantization: Global {
|
||||
type ProductQuantizationL2: Global<VectorOwned = Self::VectorOwned>
|
||||
+ GlobalElkanKMeans
|
||||
+ GlobalProductQuantization;
|
||||
fn product_quantization_distance(
|
||||
dims: u32,
|
||||
ratio: u32,
|
||||
centroids: &[Scalar<Self>],
|
||||
lhs: Borrowed<'_, Self>,
|
||||
rhs: &[u8],
|
||||
) -> F32;
|
||||
fn product_quantization_distance2(
|
||||
dims: u32,
|
||||
ratio: u32,
|
||||
centroids: &[Scalar<Self>],
|
||||
lhs: &[u8],
|
||||
rhs: &[u8],
|
||||
) -> F32;
|
||||
fn product_quantization_distance_with_delta(
|
||||
dims: u32,
|
||||
ratio: u32,
|
||||
centroids: &[Scalar<Self>],
|
||||
lhs: Borrowed<'_, Self>,
|
||||
rhs: &[u8],
|
||||
delta: &[Scalar<Self>],
|
||||
) -> F32;
|
||||
fn product_quantization_l2_distance(lhs: &[Scalar<Self>], rhs: &[Scalar<Self>]) -> F32;
|
||||
fn product_quantization_dense_distance(lhs: &[Scalar<Self>], rhs: &[Scalar<Self>]) -> F32;
|
||||
}
|
||||
|
||||
pub trait Global: Copy + 'static {
|
||||
type VectorOwned: VectorOwned;
|
||||
|
||||
const VECTOR_KIND: VectorKind;
|
||||
const DISTANCE_KIND: DistanceKind;
|
||||
|
||||
fn distance(lhs: Borrowed<'_, Self>, rhs: Borrowed<'_, Self>) -> F32;
|
||||
}
|
||||
|
||||
pub type Owned<T> = <T as Global>::VectorOwned;
|
||||
pub type Borrowed<'a, T> = <<T as Global>::VectorOwned as VectorOwned>::Borrowed<'a>;
|
||||
pub type Scalar<T> = <<T as Global>::VectorOwned as VectorOwned>::Scalar;
|
@ -1,169 +0,0 @@
|
||||
use crate::scalar::*;
|
||||
use crate::vector::*;
|
||||
use num_traits::{Float, Zero};
|
||||
|
||||
#[inline(always)]
|
||||
#[multiversion::multiversion(targets(
|
||||
"x86_64/x86-64-v4",
|
||||
"x86_64/x86-64-v3",
|
||||
"x86_64/x86-64-v2",
|
||||
"aarch64+neon"
|
||||
))]
|
||||
pub fn cosine<'a>(lhs: SVecf32Borrowed<'a>, rhs: SVecf32Borrowed<'a>) -> F32 {
|
||||
let mut lhs_pos = 0;
|
||||
let mut rhs_pos = 0;
|
||||
let size1 = lhs.len() as usize;
|
||||
let size2 = rhs.len() as usize;
|
||||
let mut xy = F32::zero();
|
||||
let mut x2 = F32::zero();
|
||||
let mut y2 = F32::zero();
|
||||
while lhs_pos < size1 && rhs_pos < size2 {
|
||||
let lhs_index = lhs.indexes()[lhs_pos];
|
||||
let rhs_index = rhs.indexes()[rhs_pos];
|
||||
let lhs_value = lhs.values()[lhs_pos];
|
||||
let rhs_value = rhs.values()[rhs_pos];
|
||||
xy += F32((lhs_index == rhs_index) as u32 as f32) * lhs_value * rhs_value;
|
||||
x2 += F32((lhs_index <= rhs_index) as u32 as f32) * lhs_value * lhs_value;
|
||||
y2 += F32((lhs_index >= rhs_index) as u32 as f32) * rhs_value * rhs_value;
|
||||
lhs_pos += (lhs_index <= rhs_index) as usize;
|
||||
rhs_pos += (lhs_index >= rhs_index) as usize;
|
||||
}
|
||||
for i in lhs_pos..size1 {
|
||||
x2 += lhs.values()[i] * lhs.values()[i];
|
||||
}
|
||||
for i in rhs_pos..size2 {
|
||||
y2 += rhs.values()[i] * rhs.values()[i];
|
||||
}
|
||||
xy / (x2 * y2).sqrt()
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
#[multiversion::multiversion(targets(
|
||||
"x86_64/x86-64-v4",
|
||||
"x86_64/x86-64-v3",
|
||||
"x86_64/x86-64-v2",
|
||||
"aarch64+neon"
|
||||
))]
|
||||
pub fn dot<'a>(lhs: SVecf32Borrowed<'a>, rhs: SVecf32Borrowed<'a>) -> F32 {
|
||||
let mut lhs_pos = 0;
|
||||
let mut rhs_pos = 0;
|
||||
let size1 = lhs.len() as usize;
|
||||
let size2 = rhs.len() as usize;
|
||||
let mut xy = F32::zero();
|
||||
while lhs_pos < size1 && rhs_pos < size2 {
|
||||
let lhs_index = lhs.indexes()[lhs_pos];
|
||||
let rhs_index = rhs.indexes()[rhs_pos];
|
||||
let lhs_value = lhs.values()[lhs_pos];
|
||||
let rhs_value = rhs.values()[rhs_pos];
|
||||
xy += F32((lhs_index == rhs_index) as u32 as f32) * lhs_value * rhs_value;
|
||||
lhs_pos += (lhs_index <= rhs_index) as usize;
|
||||
rhs_pos += (lhs_index >= rhs_index) as usize;
|
||||
}
|
||||
xy
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
#[multiversion::multiversion(targets(
|
||||
"x86_64/x86-64-v4",
|
||||
"x86_64/x86-64-v3",
|
||||
"x86_64/x86-64-v2",
|
||||
"aarch64+neon"
|
||||
))]
|
||||
pub fn dot_2<'a>(lhs: SVecf32Borrowed<'a>, rhs: &[F32]) -> F32 {
|
||||
let mut xy = F32::zero();
|
||||
for i in 0..lhs.len() as usize {
|
||||
xy += lhs.values()[i] * rhs[lhs.indexes()[i] as usize];
|
||||
}
|
||||
xy
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
#[multiversion::multiversion(targets(
|
||||
"x86_64/x86-64-v4",
|
||||
"x86_64/x86-64-v3",
|
||||
"x86_64/x86-64-v2",
|
||||
"aarch64+neon"
|
||||
))]
|
||||
pub fn sl2<'a>(lhs: SVecf32Borrowed<'a>, rhs: SVecf32Borrowed<'a>) -> F32 {
|
||||
let mut lhs_pos = 0;
|
||||
let mut rhs_pos = 0;
|
||||
let size1 = lhs.len() as usize;
|
||||
let size2 = rhs.len() as usize;
|
||||
let mut d2 = F32::zero();
|
||||
while lhs_pos < size1 && rhs_pos < size2 {
|
||||
let lhs_index = lhs.indexes()[lhs_pos];
|
||||
let rhs_index = rhs.indexes()[rhs_pos];
|
||||
let lhs_value = lhs.values()[lhs_pos];
|
||||
let rhs_value = rhs.values()[rhs_pos];
|
||||
let d = F32((lhs_index <= rhs_index) as u32 as f32) * lhs_value
|
||||
- F32((lhs_index >= rhs_index) as u32 as f32) * rhs_value;
|
||||
d2 += d * d;
|
||||
lhs_pos += (lhs_index <= rhs_index) as usize;
|
||||
rhs_pos += (lhs_index >= rhs_index) as usize;
|
||||
}
|
||||
for i in lhs_pos..size1 {
|
||||
d2 += lhs.values()[i] * lhs.values()[i];
|
||||
}
|
||||
for i in rhs_pos..size2 {
|
||||
d2 += rhs.values()[i] * rhs.values()[i];
|
||||
}
|
||||
d2
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
#[multiversion::multiversion(targets(
|
||||
"x86_64/x86-64-v4",
|
||||
"x86_64/x86-64-v3",
|
||||
"x86_64/x86-64-v2",
|
||||
"aarch64+neon"
|
||||
))]
|
||||
pub fn sl2_2<'a>(lhs: SVecf32Borrowed<'a>, rhs: &[F32]) -> F32 {
|
||||
let mut d2 = F32::zero();
|
||||
let mut lhs_pos = 0;
|
||||
let mut rhs_pos = 0;
|
||||
while lhs_pos < lhs.len() {
|
||||
let index_eq = lhs.indexes()[lhs_pos as usize] == rhs_pos;
|
||||
let d =
|
||||
F32(index_eq as u32 as f32) * lhs.values()[lhs_pos as usize] - rhs[rhs_pos as usize];
|
||||
d2 += d * d;
|
||||
lhs_pos += index_eq as u32;
|
||||
rhs_pos += 1;
|
||||
}
|
||||
for i in rhs_pos..rhs.len() as u32 {
|
||||
d2 += rhs[i as usize] * rhs[i as usize];
|
||||
}
|
||||
d2
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
#[multiversion::multiversion(targets(
|
||||
"x86_64/x86-64-v4",
|
||||
"x86_64/x86-64-v3",
|
||||
"x86_64/x86-64-v2",
|
||||
"aarch64+neon"
|
||||
))]
|
||||
pub fn length<'a>(vector: SVecf32Borrowed<'a>) -> F32 {
|
||||
let mut dot = F32::zero();
|
||||
for &i in vector.values() {
|
||||
dot += i * i;
|
||||
}
|
||||
dot.sqrt()
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
#[multiversion::multiversion(targets(
|
||||
"x86_64/x86-64-v4",
|
||||
"x86_64/x86-64-v3",
|
||||
"x86_64/x86-64-v2",
|
||||
"aarch64+neon"
|
||||
))]
|
||||
pub fn l2_normalize(vector: &mut SVecf32Owned) {
|
||||
let l = length(vector.for_borrow());
|
||||
let dims = vector.dims();
|
||||
let indexes = vector.indexes().to_vec();
|
||||
let mut values = vector.values().to_vec();
|
||||
for i in values.iter_mut() {
|
||||
*i /= l;
|
||||
}
|
||||
*vector = SVecf32Owned::new(dims, indexes, values);
|
||||
}
|
@ -1,106 +0,0 @@
|
||||
use super::*;
|
||||
use crate::distance::*;
|
||||
use crate::scalar::*;
|
||||
use crate::vector::*;
|
||||
use num_traits::Float;
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub enum SVecf32Cos {}
|
||||
|
||||
impl Global for SVecf32Cos {
|
||||
type VectorOwned = SVecf32Owned;
|
||||
|
||||
const VECTOR_KIND: VectorKind = VectorKind::SVecf32;
|
||||
const DISTANCE_KIND: DistanceKind = DistanceKind::Cos;
|
||||
|
||||
fn distance(lhs: Borrowed<'_, Self>, rhs: Borrowed<'_, Self>) -> F32 {
|
||||
F32(1.0) - super::svecf32::cosine(lhs, rhs)
|
||||
}
|
||||
}
|
||||
|
||||
impl GlobalElkanKMeans for SVecf32Cos {
|
||||
type VectorNormalized = Self::VectorOwned;
|
||||
|
||||
fn elkan_k_means_normalize(vector: &mut [Scalar<Self>]) {
|
||||
super::vecf32::l2_normalize(vector)
|
||||
}
|
||||
|
||||
fn elkan_k_means_normalize2(vector: Borrowed<'_, Self>) -> SVecf32Owned {
|
||||
let mut vector = vector.for_own();
|
||||
super::svecf32::l2_normalize(&mut vector);
|
||||
vector
|
||||
}
|
||||
|
||||
fn elkan_k_means_distance(lhs: &[Scalar<Self>], rhs: &[Scalar<Self>]) -> F32 {
|
||||
super::vecf32::dot(lhs, rhs).acos()
|
||||
}
|
||||
|
||||
fn elkan_k_means_distance2(lhs: Borrowed<'_, Self>, rhs: &[Scalar<Self>]) -> F32 {
|
||||
super::svecf32::dot_2(lhs, rhs).acos()
|
||||
}
|
||||
}
|
||||
|
||||
impl GlobalScalarQuantization for SVecf32Cos {
|
||||
fn scalar_quantization_distance(
|
||||
_dims: u16,
|
||||
_max: &[F32],
|
||||
_min: &[F32],
|
||||
_lhs: Borrowed<'_, Self>,
|
||||
_rhs: &[u8],
|
||||
) -> F32 {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn scalar_quantization_distance2(
|
||||
_dims: u16,
|
||||
_max: &[Scalar<Self>],
|
||||
_min: &[Scalar<Self>],
|
||||
_lhs: &[u8],
|
||||
_rhs: &[u8],
|
||||
) -> F32 {
|
||||
unimplemented!()
|
||||
}
|
||||
}
|
||||
|
||||
impl GlobalProductQuantization for SVecf32Cos {
|
||||
type ProductQuantizationL2 = SVecf32L2;
|
||||
|
||||
fn product_quantization_distance(
|
||||
_dims: u32,
|
||||
_ratio: u32,
|
||||
_centroids: &[Scalar<Self>],
|
||||
_lhs: Borrowed<'_, Self>,
|
||||
_rhs: &[u8],
|
||||
) -> F32 {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn product_quantization_distance2(
|
||||
_dims: u32,
|
||||
_ratio: u32,
|
||||
_centroids: &[Scalar<Self>],
|
||||
_lhs: &[u8],
|
||||
_rhs: &[u8],
|
||||
) -> F32 {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn product_quantization_distance_with_delta(
|
||||
_dims: u32,
|
||||
_ratio: u32,
|
||||
_centroids: &[Scalar<Self>],
|
||||
_lhs: Borrowed<'_, Self>,
|
||||
_rhs: &[u8],
|
||||
_delta: &[Scalar<Self>],
|
||||
) -> F32 {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn product_quantization_l2_distance(lhs: &[Scalar<Self>], rhs: &[Scalar<Self>]) -> F32 {
|
||||
super::vecf32::sl2(lhs, rhs)
|
||||
}
|
||||
|
||||
fn product_quantization_dense_distance(_: &[Scalar<Self>], _: &[Scalar<Self>]) -> F32 {
|
||||
unimplemented!()
|
||||
}
|
||||
}
|
@ -1,106 +0,0 @@
|
||||
use super::*;
|
||||
use crate::distance::*;
|
||||
use crate::scalar::*;
|
||||
use crate::vector::*;
|
||||
use num_traits::Float;
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub enum SVecf32Dot {}
|
||||
|
||||
impl Global for SVecf32Dot {
|
||||
type VectorOwned = SVecf32Owned;
|
||||
|
||||
const VECTOR_KIND: VectorKind = VectorKind::SVecf32;
|
||||
const DISTANCE_KIND: DistanceKind = DistanceKind::Dot;
|
||||
|
||||
fn distance(lhs: Borrowed<'_, Self>, rhs: Borrowed<'_, Self>) -> F32 {
|
||||
super::svecf32::dot(lhs, rhs) * (-1.0)
|
||||
}
|
||||
}
|
||||
|
||||
impl GlobalElkanKMeans for SVecf32Dot {
|
||||
type VectorNormalized = Self::VectorOwned;
|
||||
|
||||
fn elkan_k_means_normalize(vector: &mut [Scalar<Self>]) {
|
||||
super::vecf32::l2_normalize(vector)
|
||||
}
|
||||
|
||||
fn elkan_k_means_normalize2(vector: Borrowed<'_, Self>) -> SVecf32Owned {
|
||||
let mut vector = vector.for_own();
|
||||
super::svecf32::l2_normalize(&mut vector);
|
||||
vector
|
||||
}
|
||||
|
||||
fn elkan_k_means_distance(lhs: &[Scalar<Self>], rhs: &[Scalar<Self>]) -> F32 {
|
||||
super::vecf32::dot(lhs, rhs).acos()
|
||||
}
|
||||
|
||||
fn elkan_k_means_distance2(lhs: Borrowed<'_, Self>, rhs: &[Scalar<Self>]) -> F32 {
|
||||
super::svecf32::dot_2(lhs, rhs).acos()
|
||||
}
|
||||
}
|
||||
|
||||
impl GlobalScalarQuantization for SVecf32Dot {
|
||||
fn scalar_quantization_distance(
|
||||
_dims: u16,
|
||||
_max: &[Scalar<Self>],
|
||||
_min: &[Scalar<Self>],
|
||||
_lhs: Borrowed<'_, Self>,
|
||||
_rhs: &[u8],
|
||||
) -> F32 {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn scalar_quantization_distance2(
|
||||
_dims: u16,
|
||||
_max: &[Scalar<Self>],
|
||||
_min: &[Scalar<Self>],
|
||||
_lhs: &[u8],
|
||||
_rhs: &[u8],
|
||||
) -> F32 {
|
||||
unimplemented!()
|
||||
}
|
||||
}
|
||||
|
||||
impl GlobalProductQuantization for SVecf32Dot {
|
||||
type ProductQuantizationL2 = SVecf32L2;
|
||||
|
||||
fn product_quantization_distance(
|
||||
_dims: u32,
|
||||
_ratio: u32,
|
||||
_centroids: &[Scalar<Self>],
|
||||
_lhs: Borrowed<'_, Self>,
|
||||
_rhs: &[u8],
|
||||
) -> F32 {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn product_quantization_distance2(
|
||||
_dims: u32,
|
||||
_ratio: u32,
|
||||
_centroids: &[Scalar<Self>],
|
||||
_lhs: &[u8],
|
||||
_rhs: &[u8],
|
||||
) -> F32 {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn product_quantization_distance_with_delta(
|
||||
_dims: u32,
|
||||
_ratio: u32,
|
||||
_centroids: &[Scalar<Self>],
|
||||
_lhs: Borrowed<'_, Self>,
|
||||
_rhs: &[u8],
|
||||
_delta: &[Scalar<Self>],
|
||||
) -> F32 {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn product_quantization_l2_distance(lhs: &[Scalar<Self>], rhs: &[Scalar<Self>]) -> F32 {
|
||||
super::vecf32::sl2(lhs, rhs)
|
||||
}
|
||||
|
||||
fn product_quantization_dense_distance(_: &[Scalar<Self>], _: &[Scalar<Self>]) -> F32 {
|
||||
unimplemented!()
|
||||
}
|
||||
}
|
@ -1,102 +0,0 @@
|
||||
use super::*;
|
||||
use crate::distance::*;
|
||||
use crate::scalar::*;
|
||||
use crate::vector::*;
|
||||
use num_traits::Float;
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub enum SVecf32L2 {}
|
||||
|
||||
impl Global for SVecf32L2 {
|
||||
type VectorOwned = SVecf32Owned;
|
||||
|
||||
const VECTOR_KIND: VectorKind = VectorKind::SVecf32;
|
||||
const DISTANCE_KIND: DistanceKind = DistanceKind::L2;
|
||||
|
||||
fn distance(lhs: SVecf32Borrowed<'_>, rhs: SVecf32Borrowed<'_>) -> F32 {
|
||||
super::svecf32::sl2(lhs, rhs)
|
||||
}
|
||||
}
|
||||
|
||||
impl GlobalElkanKMeans for SVecf32L2 {
|
||||
type VectorNormalized = Self::VectorOwned;
|
||||
|
||||
fn elkan_k_means_normalize(_: &mut [Scalar<Self>]) {}
|
||||
|
||||
fn elkan_k_means_normalize2(vector: SVecf32Borrowed<'_>) -> SVecf32Owned {
|
||||
vector.for_own()
|
||||
}
|
||||
|
||||
fn elkan_k_means_distance(lhs: &[Scalar<Self>], rhs: &[Scalar<Self>]) -> F32 {
|
||||
super::vecf32::sl2(lhs, rhs).sqrt()
|
||||
}
|
||||
|
||||
fn elkan_k_means_distance2(lhs: SVecf32Borrowed<'_>, rhs: &[Scalar<Self>]) -> F32 {
|
||||
super::svecf32::sl2_2(lhs, rhs).sqrt()
|
||||
}
|
||||
}
|
||||
|
||||
impl GlobalScalarQuantization for SVecf32L2 {
|
||||
fn scalar_quantization_distance(
|
||||
_dims: u16,
|
||||
_max: &[Scalar<Self>],
|
||||
_min: &[Scalar<Self>],
|
||||
_lhs: SVecf32Borrowed<'_>,
|
||||
_rhs: &[u8],
|
||||
) -> F32 {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn scalar_quantization_distance2(
|
||||
_dims: u16,
|
||||
_max: &[Scalar<Self>],
|
||||
_min: &[Scalar<Self>],
|
||||
_lhs: &[u8],
|
||||
_rhs: &[u8],
|
||||
) -> F32 {
|
||||
unimplemented!()
|
||||
}
|
||||
}
|
||||
|
||||
impl GlobalProductQuantization for SVecf32L2 {
|
||||
type ProductQuantizationL2 = SVecf32L2;
|
||||
|
||||
fn product_quantization_distance(
|
||||
_dims: u32,
|
||||
_ratio: u32,
|
||||
_centroids: &[Scalar<Self>],
|
||||
_lhs: SVecf32Borrowed<'_>,
|
||||
_rhs: &[u8],
|
||||
) -> F32 {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn product_quantization_distance2(
|
||||
_dims: u32,
|
||||
_ratio: u32,
|
||||
_centroids: &[Scalar<Self>],
|
||||
_lhs: &[u8],
|
||||
_rhs: &[u8],
|
||||
) -> F32 {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn product_quantization_distance_with_delta(
|
||||
_dims: u32,
|
||||
_ratio: u32,
|
||||
_centroids: &[Scalar<Self>],
|
||||
_lhs: SVecf32Borrowed<'_>,
|
||||
_rhs: &[u8],
|
||||
_delta: &[Scalar<Self>],
|
||||
) -> F32 {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn product_quantization_l2_distance(lhs: &[Scalar<Self>], rhs: &[Scalar<Self>]) -> F32 {
|
||||
super::vecf32::sl2(lhs, rhs)
|
||||
}
|
||||
|
||||
fn product_quantization_dense_distance(_: &[Scalar<Self>], _: &[Scalar<Self>]) -> F32 {
|
||||
unimplemented!()
|
||||
}
|
||||
}
|
@ -1,170 +0,0 @@
|
||||
use crate::scalar::*;
|
||||
use num_traits::{Float, Zero};
|
||||
|
||||
pub fn cosine(lhs: &[F16], rhs: &[F16]) -> F32 {
|
||||
#[inline(always)]
|
||||
#[multiversion::multiversion(targets(
|
||||
"x86_64/x86-64-v4",
|
||||
"x86_64/x86-64-v3",
|
||||
"x86_64/x86-64-v2",
|
||||
"aarch64+neon"
|
||||
))]
|
||||
pub fn cosine(lhs: &[F16], rhs: &[F16]) -> F32 {
|
||||
assert!(lhs.len() == rhs.len());
|
||||
let n = lhs.len();
|
||||
let mut xy = F32::zero();
|
||||
let mut x2 = F32::zero();
|
||||
let mut y2 = F32::zero();
|
||||
for i in 0..n {
|
||||
xy += lhs[i].to_f() * rhs[i].to_f();
|
||||
x2 += lhs[i].to_f() * lhs[i].to_f();
|
||||
y2 += rhs[i].to_f() * rhs[i].to_f();
|
||||
}
|
||||
xy / (x2 * y2).sqrt()
|
||||
}
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
if detect::x86_64::detect_avx512fp16() {
|
||||
assert!(lhs.len() == rhs.len());
|
||||
let n = lhs.len();
|
||||
unsafe {
|
||||
return c::v_f16_cosine_avx512fp16(lhs.as_ptr().cast(), rhs.as_ptr().cast(), n).into();
|
||||
}
|
||||
}
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
if detect::x86_64::detect_v4() {
|
||||
assert!(lhs.len() == rhs.len());
|
||||
let n = lhs.len();
|
||||
unsafe {
|
||||
return c::v_f16_cosine_v4(lhs.as_ptr().cast(), rhs.as_ptr().cast(), n).into();
|
||||
}
|
||||
}
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
if detect::x86_64::detect_v3() {
|
||||
assert!(lhs.len() == rhs.len());
|
||||
let n = lhs.len();
|
||||
unsafe {
|
||||
return c::v_f16_cosine_v3(lhs.as_ptr().cast(), rhs.as_ptr().cast(), n).into();
|
||||
}
|
||||
}
|
||||
cosine(lhs, rhs)
|
||||
}
|
||||
|
||||
pub fn dot(lhs: &[F16], rhs: &[F16]) -> F32 {
|
||||
#[inline(always)]
|
||||
#[multiversion::multiversion(targets(
|
||||
"x86_64/x86-64-v4",
|
||||
"x86_64/x86-64-v3",
|
||||
"x86_64/x86-64-v2",
|
||||
"aarch64+neon"
|
||||
))]
|
||||
pub fn dot(lhs: &[F16], rhs: &[F16]) -> F32 {
|
||||
assert!(lhs.len() == rhs.len());
|
||||
let n = lhs.len();
|
||||
let mut xy = F32::zero();
|
||||
for i in 0..n {
|
||||
xy += lhs[i].to_f() * rhs[i].to_f();
|
||||
}
|
||||
xy
|
||||
}
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
if detect::x86_64::detect_avx512fp16() {
|
||||
assert!(lhs.len() == rhs.len());
|
||||
let n = lhs.len();
|
||||
unsafe {
|
||||
return c::v_f16_dot_avx512fp16(lhs.as_ptr().cast(), rhs.as_ptr().cast(), n).into();
|
||||
}
|
||||
}
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
if detect::x86_64::detect_v4() {
|
||||
assert!(lhs.len() == rhs.len());
|
||||
let n = lhs.len();
|
||||
unsafe {
|
||||
return c::v_f16_dot_v4(lhs.as_ptr().cast(), rhs.as_ptr().cast(), n).into();
|
||||
}
|
||||
}
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
if detect::x86_64::detect_v3() {
|
||||
assert!(lhs.len() == rhs.len());
|
||||
let n = lhs.len();
|
||||
unsafe {
|
||||
return c::v_f16_dot_v3(lhs.as_ptr().cast(), rhs.as_ptr().cast(), n).into();
|
||||
}
|
||||
}
|
||||
dot(lhs, rhs)
|
||||
}
|
||||
|
||||
pub fn sl2(lhs: &[F16], rhs: &[F16]) -> F32 {
|
||||
#[inline(always)]
|
||||
#[multiversion::multiversion(targets(
|
||||
"x86_64/x86-64-v4",
|
||||
"x86_64/x86-64-v3",
|
||||
"x86_64/x86-64-v2",
|
||||
"aarch64+neon"
|
||||
))]
|
||||
pub fn sl2(lhs: &[F16], rhs: &[F16]) -> F32 {
|
||||
assert!(lhs.len() == rhs.len());
|
||||
let n = lhs.len();
|
||||
let mut d2 = F32::zero();
|
||||
for i in 0..n {
|
||||
let d = lhs[i].to_f() - rhs[i].to_f();
|
||||
d2 += d * d;
|
||||
}
|
||||
d2
|
||||
}
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
if detect::x86_64::detect_avx512fp16() {
|
||||
assert!(lhs.len() == rhs.len());
|
||||
let n = lhs.len();
|
||||
unsafe {
|
||||
return c::v_f16_sl2_avx512fp16(lhs.as_ptr().cast(), rhs.as_ptr().cast(), n).into();
|
||||
}
|
||||
}
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
if detect::x86_64::detect_v4() {
|
||||
assert!(lhs.len() == rhs.len());
|
||||
let n = lhs.len();
|
||||
unsafe {
|
||||
return c::v_f16_sl2_v4(lhs.as_ptr().cast(), rhs.as_ptr().cast(), n).into();
|
||||
}
|
||||
}
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
if detect::x86_64::detect_v3() {
|
||||
assert!(lhs.len() == rhs.len());
|
||||
let n = lhs.len();
|
||||
unsafe {
|
||||
return c::v_f16_sl2_v3(lhs.as_ptr().cast(), rhs.as_ptr().cast(), n).into();
|
||||
}
|
||||
}
|
||||
sl2(lhs, rhs)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
#[multiversion::multiversion(targets(
|
||||
"x86_64/x86-64-v4",
|
||||
"x86_64/x86-64-v3",
|
||||
"x86_64/x86-64-v2",
|
||||
"aarch64+neon"
|
||||
))]
|
||||
fn length(vector: &[F16]) -> F16 {
|
||||
let n = vector.len();
|
||||
let mut dot = F16::zero();
|
||||
for i in 0..n {
|
||||
dot += vector[i] * vector[i];
|
||||
}
|
||||
dot.sqrt()
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
#[multiversion::multiversion(targets(
|
||||
"x86_64/x86-64-v4",
|
||||
"x86_64/x86-64-v3",
|
||||
"x86_64/x86-64-v2",
|
||||
"aarch64+neon"
|
||||
))]
|
||||
pub fn l2_normalize(vector: &mut [F16]) {
|
||||
let n = vector.len();
|
||||
let l = length(vector);
|
||||
for i in 0..n {
|
||||
vector[i] /= l;
|
||||
}
|
||||
}
|
@ -1,245 +0,0 @@
|
||||
use super::*;
|
||||
use crate::distance::*;
|
||||
use crate::scalar::*;
|
||||
use crate::vector::*;
|
||||
use num_traits::{Float, Zero};
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub enum Vecf16Cos {}
|
||||
|
||||
impl Global for Vecf16Cos {
|
||||
type VectorOwned = Vecf16Owned;
|
||||
|
||||
const VECTOR_KIND: VectorKind = VectorKind::Vecf16;
|
||||
const DISTANCE_KIND: DistanceKind = DistanceKind::Cos;
|
||||
|
||||
fn distance(lhs: Vecf16Borrowed<'_>, rhs: Vecf16Borrowed<'_>) -> F32 {
|
||||
F32(1.0) - super::vecf16::cosine(lhs.slice(), rhs.slice())
|
||||
}
|
||||
}
|
||||
|
||||
impl GlobalElkanKMeans for Vecf16Cos {
|
||||
type VectorNormalized = Self::VectorOwned;
|
||||
|
||||
fn elkan_k_means_normalize(vector: &mut [F16]) {
|
||||
super::vecf16::l2_normalize(vector)
|
||||
}
|
||||
|
||||
fn elkan_k_means_normalize2(vector: Vecf16Borrowed<'_>) -> Vecf16Owned {
|
||||
let mut vector = vector.for_own();
|
||||
super::vecf16::l2_normalize(vector.slice_mut());
|
||||
vector
|
||||
}
|
||||
|
||||
fn elkan_k_means_distance(lhs: &[F16], rhs: &[F16]) -> F32 {
|
||||
super::vecf16::dot(lhs, rhs).acos()
|
||||
}
|
||||
|
||||
fn elkan_k_means_distance2(lhs: Vecf16Borrowed<'_>, rhs: &[F16]) -> F32 {
|
||||
super::vecf16::dot(lhs.slice(), rhs).acos()
|
||||
}
|
||||
}
|
||||
|
||||
impl GlobalScalarQuantization for Vecf16Cos {
|
||||
#[multiversion::multiversion(targets(
|
||||
"x86_64/x86-64-v4",
|
||||
"x86_64/x86-64-v3",
|
||||
"x86_64/x86-64-v2",
|
||||
"aarch64+neon"
|
||||
))]
|
||||
fn scalar_quantization_distance<'a>(
|
||||
dims: u16,
|
||||
max: &[F16],
|
||||
min: &[F16],
|
||||
lhs: Vecf16Borrowed<'a>,
|
||||
rhs: &[u8],
|
||||
) -> F32 {
|
||||
let lhs = lhs.slice();
|
||||
let mut xy = F32::zero();
|
||||
let mut x2 = F32::zero();
|
||||
let mut y2 = F32::zero();
|
||||
for i in 0..dims as usize {
|
||||
let _x = lhs[i].to_f();
|
||||
let _y = F32(rhs[i] as f32 / 256.0) * (max[i].to_f() - min[i].to_f()) + min[i].to_f();
|
||||
xy += _x * _y;
|
||||
x2 += _x * _x;
|
||||
y2 += _y * _y;
|
||||
}
|
||||
F32(1.0) - xy / (x2 * y2).sqrt()
|
||||
}
|
||||
|
||||
#[multiversion::multiversion(targets(
|
||||
"x86_64/x86-64-v4",
|
||||
"x86_64/x86-64-v3",
|
||||
"x86_64/x86-64-v2",
|
||||
"aarch64+neon"
|
||||
))]
|
||||
fn scalar_quantization_distance2(
|
||||
dims: u16,
|
||||
max: &[F16],
|
||||
min: &[F16],
|
||||
lhs: &[u8],
|
||||
rhs: &[u8],
|
||||
) -> F32 {
|
||||
let mut xy = F32::zero();
|
||||
let mut x2 = F32::zero();
|
||||
let mut y2 = F32::zero();
|
||||
for i in 0..dims as usize {
|
||||
let _x = F32(lhs[i] as f32 / 256.0) * (max[i].to_f() - min[i].to_f()) + min[i].to_f();
|
||||
let _y = F32(rhs[i] as f32 / 256.0) * (max[i].to_f() - min[i].to_f()) + min[i].to_f();
|
||||
xy += _x * _y;
|
||||
x2 += _x * _x;
|
||||
y2 += _y * _y;
|
||||
}
|
||||
F32(1.0) - xy / (x2 * y2).sqrt()
|
||||
}
|
||||
}
|
||||
|
||||
impl GlobalProductQuantization for Vecf16Cos {
|
||||
type ProductQuantizationL2 = Vecf16L2;
|
||||
|
||||
#[multiversion::multiversion(targets(
|
||||
"x86_64/x86-64-v4",
|
||||
"x86_64/x86-64-v3",
|
||||
"x86_64/x86-64-v2",
|
||||
"aarch64+neon"
|
||||
))]
|
||||
fn product_quantization_distance<'a>(
|
||||
dims: u32,
|
||||
ratio: u32,
|
||||
centroids: &[F16],
|
||||
lhs: Vecf16Borrowed<'a>,
|
||||
rhs: &[u8],
|
||||
) -> F32 {
|
||||
let lhs = lhs.slice();
|
||||
let width = dims.div_ceil(ratio);
|
||||
let mut xy = F32::zero();
|
||||
let mut x2 = F32::zero();
|
||||
let mut y2 = F32::zero();
|
||||
for i in 0..width {
|
||||
let k = std::cmp::min(ratio, dims - ratio * i);
|
||||
let lhs = &lhs[(i * ratio) as usize..][..k as usize];
|
||||
let rhsp = rhs[i as usize] as usize * dims as usize;
|
||||
let rhs = ¢roids[rhsp..][(i * ratio) as usize..][..k as usize];
|
||||
let (_xy, _x2, _y2) = xy_x2_y2(lhs, rhs);
|
||||
xy += _xy;
|
||||
x2 += _x2;
|
||||
y2 += _y2;
|
||||
}
|
||||
F32(1.0) - xy / (x2 * y2).sqrt()
|
||||
}
|
||||
|
||||
#[multiversion::multiversion(targets(
|
||||
"x86_64/x86-64-v4",
|
||||
"x86_64/x86-64-v3",
|
||||
"x86_64/x86-64-v2",
|
||||
"aarch64+neon"
|
||||
))]
|
||||
fn product_quantization_distance2(
|
||||
dims: u32,
|
||||
ratio: u32,
|
||||
centroids: &[F16],
|
||||
lhs: &[u8],
|
||||
rhs: &[u8],
|
||||
) -> F32 {
|
||||
let width = dims.div_ceil(ratio);
|
||||
let mut xy = F32::zero();
|
||||
let mut x2 = F32::zero();
|
||||
let mut y2 = F32::zero();
|
||||
for i in 0..width {
|
||||
let k = std::cmp::min(ratio, dims - ratio * i);
|
||||
let lhsp = lhs[i as usize] as usize * dims as usize;
|
||||
let lhs = ¢roids[lhsp..][(i * ratio) as usize..][..k as usize];
|
||||
let rhsp = rhs[i as usize] as usize * dims as usize;
|
||||
let rhs = ¢roids[rhsp..][(i * ratio) as usize..][..k as usize];
|
||||
let (_xy, _x2, _y2) = xy_x2_y2(lhs, rhs);
|
||||
xy += _xy;
|
||||
x2 += _x2;
|
||||
y2 += _y2;
|
||||
}
|
||||
F32(1.0) - xy / (x2 * y2).sqrt()
|
||||
}
|
||||
|
||||
#[multiversion::multiversion(targets(
|
||||
"x86_64/x86-64-v4",
|
||||
"x86_64/x86-64-v3",
|
||||
"x86_64/x86-64-v2",
|
||||
"aarch64+neon"
|
||||
))]
|
||||
fn product_quantization_distance_with_delta<'a>(
|
||||
dims: u32,
|
||||
ratio: u32,
|
||||
centroids: &[F16],
|
||||
lhs: Vecf16Borrowed<'a>,
|
||||
rhs: &[u8],
|
||||
delta: &[F16],
|
||||
) -> F32 {
|
||||
let lhs = lhs.slice();
|
||||
let width = dims.div_ceil(ratio);
|
||||
let mut xy = F32::zero();
|
||||
let mut x2 = F32::zero();
|
||||
let mut y2 = F32::zero();
|
||||
for i in 0..width {
|
||||
let k = std::cmp::min(ratio, dims - ratio * i);
|
||||
let lhs = &lhs[(i * ratio) as usize..][..k as usize];
|
||||
let rhsp = rhs[i as usize] as usize * dims as usize;
|
||||
let rhs = ¢roids[rhsp..][(i * ratio) as usize..][..k as usize];
|
||||
let del = &delta[(i * ratio) as usize..][..k as usize];
|
||||
let (_xy, _x2, _y2) = xy_x2_y2_delta(lhs, rhs, del);
|
||||
xy += _xy;
|
||||
x2 += _x2;
|
||||
y2 += _y2;
|
||||
}
|
||||
F32(1.0) - xy / (x2 * y2).sqrt()
|
||||
}
|
||||
|
||||
fn product_quantization_l2_distance(lhs: &[Scalar<Self>], rhs: &[Scalar<Self>]) -> F32 {
|
||||
super::vecf16::sl2(lhs, rhs)
|
||||
}
|
||||
|
||||
fn product_quantization_dense_distance(lhs: &[Scalar<Self>], rhs: &[Scalar<Self>]) -> F32 {
|
||||
F32(1.0) - super::vecf16::cosine(lhs, rhs)
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
#[multiversion::multiversion(targets(
|
||||
"x86_64/x86-64-v4",
|
||||
"x86_64/x86-64-v3",
|
||||
"x86_64/x86-64-v2",
|
||||
"aarch64+neon"
|
||||
))]
|
||||
fn xy_x2_y2(lhs: &[F16], rhs: &[F16]) -> (F32, F32, F32) {
|
||||
assert!(lhs.len() == rhs.len());
|
||||
let n = lhs.len();
|
||||
let mut xy = F32::zero();
|
||||
let mut x2 = F32::zero();
|
||||
let mut y2 = F32::zero();
|
||||
for i in 0..n {
|
||||
xy += lhs[i].to_f() * rhs[i].to_f();
|
||||
x2 += lhs[i].to_f() * lhs[i].to_f();
|
||||
y2 += rhs[i].to_f() * rhs[i].to_f();
|
||||
}
|
||||
(xy, x2, y2)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
#[multiversion::multiversion(targets(
|
||||
"x86_64/x86-64-v4",
|
||||
"x86_64/x86-64-v3",
|
||||
"x86_64/x86-64-v2",
|
||||
"aarch64+neon"
|
||||
))]
|
||||
fn xy_x2_y2_delta(lhs: &[F16], rhs: &[F16], del: &[F16]) -> (F32, F32, F32) {
|
||||
assert!(lhs.len() == rhs.len());
|
||||
let n = lhs.len();
|
||||
let mut xy = F32::zero();
|
||||
let mut x2 = F32::zero();
|
||||
let mut y2 = F32::zero();
|
||||
for i in 0..n {
|
||||
xy += lhs[i].to_f() * (rhs[i].to_f() + del[i].to_f());
|
||||
x2 += lhs[i].to_f() * lhs[i].to_f();
|
||||
y2 += (rhs[i].to_f() + del[i].to_f()) * (rhs[i].to_f() + del[i].to_f());
|
||||
}
|
||||
(xy, x2, y2)
|
||||
}
|
@ -1,200 +0,0 @@
|
||||
use super::*;
|
||||
use crate::distance::*;
|
||||
use crate::scalar::*;
|
||||
use crate::vector::*;
|
||||
use num_traits::{Float, Zero};
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub enum Vecf16Dot {}
|
||||
|
||||
impl Global for Vecf16Dot {
|
||||
type VectorOwned = Vecf16Owned;
|
||||
|
||||
const VECTOR_KIND: VectorKind = VectorKind::Vecf16;
|
||||
const DISTANCE_KIND: DistanceKind = DistanceKind::Dot;
|
||||
|
||||
fn distance(lhs: Vecf16Borrowed<'_>, rhs: Vecf16Borrowed<'_>) -> F32 {
|
||||
super::vecf16::dot(lhs.slice(), rhs.slice()) * (-1.0)
|
||||
}
|
||||
}
|
||||
|
||||
impl GlobalElkanKMeans for Vecf16Dot {
|
||||
type VectorNormalized = Self::VectorOwned;
|
||||
|
||||
fn elkan_k_means_normalize(vector: &mut [F16]) {
|
||||
super::vecf16::l2_normalize(vector)
|
||||
}
|
||||
|
||||
fn elkan_k_means_normalize2(vector: Vecf16Borrowed<'_>) -> Vecf16Owned {
|
||||
let mut vector = vector.for_own();
|
||||
super::vecf16::l2_normalize(vector.slice_mut());
|
||||
vector
|
||||
}
|
||||
|
||||
fn elkan_k_means_distance(lhs: &[F16], rhs: &[F16]) -> F32 {
|
||||
super::vecf16::dot(lhs, rhs).acos()
|
||||
}
|
||||
|
||||
fn elkan_k_means_distance2(lhs: Vecf16Borrowed<'_>, rhs: &[F16]) -> F32 {
|
||||
super::vecf16::dot(lhs.slice(), rhs).acos()
|
||||
}
|
||||
}
|
||||
|
||||
impl GlobalScalarQuantization for Vecf16Dot {
|
||||
#[multiversion::multiversion(targets(
|
||||
"x86_64/x86-64-v4",
|
||||
"x86_64/x86-64-v3",
|
||||
"x86_64/x86-64-v2",
|
||||
"aarch64+neon"
|
||||
))]
|
||||
fn scalar_quantization_distance<'a>(
|
||||
dims: u16,
|
||||
max: &[F16],
|
||||
min: &[F16],
|
||||
lhs: Vecf16Borrowed<'a>,
|
||||
rhs: &[u8],
|
||||
) -> F32 {
|
||||
let lhs = lhs.slice();
|
||||
let mut xy = F32::zero();
|
||||
for i in 0..dims as usize {
|
||||
let _x = lhs[i].to_f();
|
||||
let _y = F32(rhs[i] as f32 / 256.0) * (max[i].to_f() - min[i].to_f()) + min[i].to_f();
|
||||
xy += _x * _y;
|
||||
}
|
||||
xy * (-1.0)
|
||||
}
|
||||
|
||||
#[multiversion::multiversion(targets(
|
||||
"x86_64/x86-64-v4",
|
||||
"x86_64/x86-64-v3",
|
||||
"x86_64/x86-64-v2",
|
||||
"aarch64+neon"
|
||||
))]
|
||||
fn scalar_quantization_distance2(
|
||||
dims: u16,
|
||||
max: &[F16],
|
||||
min: &[F16],
|
||||
lhs: &[u8],
|
||||
rhs: &[u8],
|
||||
) -> F32 {
|
||||
let mut xy = F32::zero();
|
||||
for i in 0..dims as usize {
|
||||
let _x = F32(lhs[i] as f32 / 256.0) * (max[i].to_f() - min[i].to_f()) + min[i].to_f();
|
||||
let _y = F32(rhs[i] as f32 / 256.0) * (max[i].to_f() - min[i].to_f()) + min[i].to_f();
|
||||
xy += _x * _y;
|
||||
}
|
||||
xy * (-1.0)
|
||||
}
|
||||
}
|
||||
|
||||
impl GlobalProductQuantization for Vecf16Dot {
|
||||
type ProductQuantizationL2 = Vecf16L2;
|
||||
|
||||
#[multiversion::multiversion(targets(
|
||||
"x86_64/x86-64-v4",
|
||||
"x86_64/x86-64-v3",
|
||||
"x86_64/x86-64-v2",
|
||||
"aarch64+neon"
|
||||
))]
|
||||
fn product_quantization_distance<'a>(
|
||||
dims: u32,
|
||||
ratio: u32,
|
||||
centroids: &[F16],
|
||||
lhs: Vecf16Borrowed<'a>,
|
||||
rhs: &[u8],
|
||||
) -> F32 {
|
||||
let lhs = lhs.slice();
|
||||
let width = dims.div_ceil(ratio);
|
||||
let mut xy = F32::zero();
|
||||
for i in 0..width {
|
||||
let k = std::cmp::min(ratio, dims - ratio * i);
|
||||
let lhs = &lhs[(i * ratio) as usize..][..k as usize];
|
||||
let rhsp = rhs[i as usize] as usize * dims as usize;
|
||||
let rhs = ¢roids[rhsp..][(i * ratio) as usize..][..k as usize];
|
||||
let _xy = super::vecf16::dot(lhs, rhs);
|
||||
xy += _xy;
|
||||
}
|
||||
xy * (-1.0)
|
||||
}
|
||||
|
||||
#[multiversion::multiversion(targets(
|
||||
"x86_64/x86-64-v4",
|
||||
"x86_64/x86-64-v3",
|
||||
"x86_64/x86-64-v2",
|
||||
"aarch64+neon"
|
||||
))]
|
||||
fn product_quantization_distance2(
|
||||
dims: u32,
|
||||
ratio: u32,
|
||||
centroids: &[F16],
|
||||
lhs: &[u8],
|
||||
rhs: &[u8],
|
||||
) -> F32 {
|
||||
let width = dims.div_ceil(ratio);
|
||||
let mut xy = F32::zero();
|
||||
for i in 0..width {
|
||||
let k = std::cmp::min(ratio, dims - ratio * i);
|
||||
let lhsp = lhs[i as usize] as usize * dims as usize;
|
||||
let lhs = ¢roids[lhsp..][(i * ratio) as usize..][..k as usize];
|
||||
let rhsp = rhs[i as usize] as usize * dims as usize;
|
||||
let rhs = ¢roids[rhsp..][(i * ratio) as usize..][..k as usize];
|
||||
let _xy = super::vecf16::dot(lhs, rhs);
|
||||
xy += _xy;
|
||||
}
|
||||
xy * (-1.0)
|
||||
}
|
||||
|
||||
#[multiversion::multiversion(targets(
|
||||
"x86_64/x86-64-v4",
|
||||
"x86_64/x86-64-v3",
|
||||
"x86_64/x86-64-v2",
|
||||
"aarch64+neon"
|
||||
))]
|
||||
fn product_quantization_distance_with_delta<'a>(
|
||||
dims: u32,
|
||||
ratio: u32,
|
||||
centroids: &[F16],
|
||||
lhs: Vecf16Borrowed<'a>,
|
||||
rhs: &[u8],
|
||||
delta: &[F16],
|
||||
) -> F32 {
|
||||
let lhs = lhs.slice();
|
||||
let width = dims.div_ceil(ratio);
|
||||
let mut xy = F32::zero();
|
||||
for i in 0..width {
|
||||
let k = std::cmp::min(ratio, dims - ratio * i);
|
||||
let lhs = &lhs[(i * ratio) as usize..][..k as usize];
|
||||
let rhsp = rhs[i as usize] as usize * dims as usize;
|
||||
let rhs = ¢roids[rhsp..][(i * ratio) as usize..][..k as usize];
|
||||
let del = &delta[(i * ratio) as usize..][..k as usize];
|
||||
let _xy = dot_delta(lhs, rhs, del);
|
||||
xy += _xy;
|
||||
}
|
||||
xy * (-1.0)
|
||||
}
|
||||
|
||||
fn product_quantization_l2_distance(lhs: &[Scalar<Self>], rhs: &[Scalar<Self>]) -> F32 {
|
||||
super::vecf16::sl2(lhs, rhs)
|
||||
}
|
||||
|
||||
fn product_quantization_dense_distance(lhs: &[Scalar<Self>], rhs: &[Scalar<Self>]) -> F32 {
|
||||
super::vecf16::dot(lhs, rhs) * (-1.0)
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
#[multiversion::multiversion(targets(
|
||||
"x86_64/x86-64-v4",
|
||||
"x86_64/x86-64-v3",
|
||||
"x86_64/x86-64-v2",
|
||||
"aarch64+neon"
|
||||
))]
|
||||
fn dot_delta(lhs: &[F16], rhs: &[F16], del: &[F16]) -> F32 {
|
||||
assert!(lhs.len() == rhs.len());
|
||||
let n: usize = lhs.len();
|
||||
let mut xy = F32::zero();
|
||||
for i in 0..n {
|
||||
xy += lhs[i].to_f() * (rhs[i].to_f() + del[i].to_f());
|
||||
}
|
||||
xy
|
||||
}
|
@ -1,194 +0,0 @@
|
||||
use super::*;
|
||||
use crate::distance::*;
|
||||
use crate::scalar::*;
|
||||
use crate::vector::*;
|
||||
use num_traits::{Float, Zero};
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub enum Vecf16L2 {}
|
||||
|
||||
impl Global for Vecf16L2 {
|
||||
type VectorOwned = Vecf16Owned;
|
||||
|
||||
const VECTOR_KIND: VectorKind = VectorKind::Vecf16;
|
||||
const DISTANCE_KIND: DistanceKind = DistanceKind::L2;
|
||||
|
||||
fn distance(lhs: Vecf16Borrowed<'_>, rhs: Vecf16Borrowed<'_>) -> F32 {
|
||||
super::vecf16::sl2(lhs.slice(), rhs.slice())
|
||||
}
|
||||
}
|
||||
|
||||
impl GlobalElkanKMeans for Vecf16L2 {
|
||||
type VectorNormalized = Self::VectorOwned;
|
||||
|
||||
fn elkan_k_means_normalize(_: &mut [F16]) {}
|
||||
|
||||
fn elkan_k_means_normalize2(vector: Vecf16Borrowed<'_>) -> Vecf16Owned {
|
||||
vector.for_own()
|
||||
}
|
||||
|
||||
fn elkan_k_means_distance(lhs: &[F16], rhs: &[F16]) -> F32 {
|
||||
super::vecf16::sl2(lhs, rhs).sqrt()
|
||||
}
|
||||
|
||||
fn elkan_k_means_distance2(lhs: Vecf16Borrowed<'_>, rhs: &[F16]) -> F32 {
|
||||
super::vecf16::sl2(lhs.slice(), rhs).sqrt()
|
||||
}
|
||||
}
|
||||
|
||||
impl GlobalScalarQuantization for Vecf16L2 {
|
||||
#[multiversion::multiversion(targets(
|
||||
"x86_64/x86-64-v4",
|
||||
"x86_64/x86-64-v3",
|
||||
"x86_64/x86-64-v2",
|
||||
"aarch64+neon"
|
||||
))]
|
||||
fn scalar_quantization_distance<'a>(
|
||||
dims: u16,
|
||||
max: &[F16],
|
||||
min: &[F16],
|
||||
lhs: Vecf16Borrowed<'a>,
|
||||
rhs: &[u8],
|
||||
) -> F32 {
|
||||
let lhs = lhs.slice();
|
||||
let mut result = F32::zero();
|
||||
for i in 0..dims as usize {
|
||||
let _x = lhs[i].to_f();
|
||||
let _y = (F32(rhs[i] as f32) / 256.0) * (max[i].to_f() - min[i].to_f()) + min[i].to_f();
|
||||
result += (_x - _y) * (_x - _y);
|
||||
}
|
||||
result
|
||||
}
|
||||
|
||||
#[multiversion::multiversion(targets(
|
||||
"x86_64/x86-64-v4",
|
||||
"x86_64/x86-64-v3",
|
||||
"x86_64/x86-64-v2",
|
||||
"aarch64+neon"
|
||||
))]
|
||||
fn scalar_quantization_distance2(
|
||||
dims: u16,
|
||||
max: &[F16],
|
||||
min: &[F16],
|
||||
lhs: &[u8],
|
||||
rhs: &[u8],
|
||||
) -> F32 {
|
||||
let mut result = F32::zero();
|
||||
for i in 0..dims as usize {
|
||||
let _x = F32(lhs[i] as f32 / 256.0) * (max[i].to_f() - min[i].to_f()) + min[i].to_f();
|
||||
let _y = F32(rhs[i] as f32 / 256.0) * (max[i].to_f() - min[i].to_f()) + min[i].to_f();
|
||||
result += (_x - _y) * (_x - _y);
|
||||
}
|
||||
result
|
||||
}
|
||||
}
|
||||
|
||||
impl GlobalProductQuantization for Vecf16L2 {
|
||||
type ProductQuantizationL2 = Vecf16L2;
|
||||
|
||||
#[multiversion::multiversion(targets(
|
||||
"x86_64/x86-64-v4",
|
||||
"x86_64/x86-64-v3",
|
||||
"x86_64/x86-64-v2",
|
||||
"aarch64+neon"
|
||||
))]
|
||||
fn product_quantization_distance<'a>(
|
||||
dims: u32,
|
||||
ratio: u32,
|
||||
centroids: &[F16],
|
||||
lhs: Vecf16Borrowed<'a>,
|
||||
rhs: &[u8],
|
||||
) -> F32 {
|
||||
let lhs = lhs.slice();
|
||||
let width = dims.div_ceil(ratio);
|
||||
let mut result = F32::zero();
|
||||
for i in 0..width {
|
||||
let k = std::cmp::min(ratio, dims - ratio * i);
|
||||
let lhs = &lhs[(i * ratio) as usize..][..k as usize];
|
||||
let rhsp = rhs[i as usize] as usize * dims as usize;
|
||||
let rhs = ¢roids[rhsp..][(i * ratio) as usize..][..k as usize];
|
||||
result += super::vecf16::sl2(lhs, rhs);
|
||||
}
|
||||
result
|
||||
}
|
||||
|
||||
#[multiversion::multiversion(targets(
|
||||
"x86_64/x86-64-v4",
|
||||
"x86_64/x86-64-v3",
|
||||
"x86_64/x86-64-v2",
|
||||
"aarch64+neon"
|
||||
))]
|
||||
fn product_quantization_distance2(
|
||||
dims: u32,
|
||||
ratio: u32,
|
||||
centroids: &[F16],
|
||||
lhs: &[u8],
|
||||
rhs: &[u8],
|
||||
) -> F32 {
|
||||
let width = dims.div_ceil(ratio);
|
||||
let mut result = F32::zero();
|
||||
for i in 0..width {
|
||||
let k = std::cmp::min(ratio, dims - ratio * i);
|
||||
let lhsp = lhs[i as usize] as usize * dims as usize;
|
||||
let lhs = ¢roids[lhsp..][(i * ratio) as usize..][..k as usize];
|
||||
let rhsp = rhs[i as usize] as usize * dims as usize;
|
||||
let rhs = ¢roids[rhsp..][(i * ratio) as usize..][..k as usize];
|
||||
result += super::vecf16::sl2(lhs, rhs);
|
||||
}
|
||||
result
|
||||
}
|
||||
|
||||
#[multiversion::multiversion(targets(
|
||||
"x86_64/x86-64-v4",
|
||||
"x86_64/x86-64-v3",
|
||||
"x86_64/x86-64-v2",
|
||||
"aarch64+neon"
|
||||
))]
|
||||
fn product_quantization_distance_with_delta<'a>(
|
||||
dims: u32,
|
||||
ratio: u32,
|
||||
centroids: &[F16],
|
||||
lhs: Vecf16Borrowed<'a>,
|
||||
rhs: &[u8],
|
||||
delta: &[F16],
|
||||
) -> F32 {
|
||||
let lhs = lhs.slice();
|
||||
let width = dims.div_ceil(ratio);
|
||||
let mut result = F32::zero();
|
||||
for i in 0..width {
|
||||
let k = std::cmp::min(ratio, dims - ratio * i);
|
||||
let lhs = &lhs[(i * ratio) as usize..][..k as usize];
|
||||
let rhsp = rhs[i as usize] as usize * dims as usize;
|
||||
let rhs = ¢roids[rhsp..][(i * ratio) as usize..][..k as usize];
|
||||
let del = &delta[(i * ratio) as usize..][..k as usize];
|
||||
result += distance_squared_l2_delta(lhs, rhs, del);
|
||||
}
|
||||
result
|
||||
}
|
||||
|
||||
fn product_quantization_l2_distance(lhs: &[Scalar<Self>], rhs: &[Scalar<Self>]) -> F32 {
|
||||
super::vecf16::sl2(lhs, rhs)
|
||||
}
|
||||
|
||||
fn product_quantization_dense_distance(lhs: &[Scalar<Self>], rhs: &[Scalar<Self>]) -> F32 {
|
||||
super::vecf16::sl2(lhs, rhs)
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
#[multiversion::multiversion(targets(
|
||||
"x86_64/x86-64-v4",
|
||||
"x86_64/x86-64-v3",
|
||||
"x86_64/x86-64-v2",
|
||||
"aarch64+neon"
|
||||
))]
|
||||
fn distance_squared_l2_delta(lhs: &[F16], rhs: &[F16], del: &[F16]) -> F32 {
|
||||
assert!(lhs.len() == rhs.len());
|
||||
let n = lhs.len();
|
||||
let mut d2 = F32::zero();
|
||||
for i in 0..n {
|
||||
let d = lhs[i].to_f() - (rhs[i].to_f() + del[i].to_f());
|
||||
d2 += d * d;
|
||||
}
|
||||
d2
|
||||
}
|
@ -1,89 +0,0 @@
|
||||
use crate::scalar::*;
|
||||
use num_traits::{Float, Zero};
|
||||
|
||||
#[inline(always)]
|
||||
#[multiversion::multiversion(targets(
|
||||
"x86_64/x86-64-v4",
|
||||
"x86_64/x86-64-v3",
|
||||
"x86_64/x86-64-v2",
|
||||
"aarch64+neon"
|
||||
))]
|
||||
pub fn cosine(lhs: &[F32], rhs: &[F32]) -> F32 {
|
||||
assert!(lhs.len() == rhs.len());
|
||||
let n = lhs.len();
|
||||
let mut xy = F32::zero();
|
||||
let mut x2 = F32::zero();
|
||||
let mut y2 = F32::zero();
|
||||
for i in 0..n {
|
||||
xy += lhs[i] * rhs[i];
|
||||
x2 += lhs[i] * lhs[i];
|
||||
y2 += rhs[i] * rhs[i];
|
||||
}
|
||||
xy / (x2 * y2).sqrt()
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
#[multiversion::multiversion(targets(
|
||||
"x86_64/x86-64-v4",
|
||||
"x86_64/x86-64-v3",
|
||||
"x86_64/x86-64-v2",
|
||||
"aarch64+neon"
|
||||
))]
|
||||
pub fn dot(lhs: &[F32], rhs: &[F32]) -> F32 {
|
||||
assert!(lhs.len() == rhs.len());
|
||||
let n = lhs.len();
|
||||
let mut xy = F32::zero();
|
||||
for i in 0..n {
|
||||
xy += lhs[i] * rhs[i];
|
||||
}
|
||||
xy
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
#[multiversion::multiversion(targets(
|
||||
"x86_64/x86-64-v4",
|
||||
"x86_64/x86-64-v3",
|
||||
"x86_64/x86-64-v2",
|
||||
"aarch64+neon"
|
||||
))]
|
||||
pub fn sl2(lhs: &[F32], rhs: &[F32]) -> F32 {
|
||||
assert!(lhs.len() == rhs.len());
|
||||
let n = lhs.len();
|
||||
let mut d2 = F32::zero();
|
||||
for i in 0..n {
|
||||
let d = lhs[i] - rhs[i];
|
||||
d2 += d * d;
|
||||
}
|
||||
d2
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
#[multiversion::multiversion(targets(
|
||||
"x86_64/x86-64-v4",
|
||||
"x86_64/x86-64-v3",
|
||||
"x86_64/x86-64-v2",
|
||||
"aarch64+neon"
|
||||
))]
|
||||
pub fn length(vector: &[F32]) -> F32 {
|
||||
let n = vector.len();
|
||||
let mut dot = F32::zero();
|
||||
for i in 0..n {
|
||||
dot += vector[i] * vector[i];
|
||||
}
|
||||
dot.sqrt()
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
#[multiversion::multiversion(targets(
|
||||
"x86_64/x86-64-v4",
|
||||
"x86_64/x86-64-v3",
|
||||
"x86_64/x86-64-v2",
|
||||
"aarch64+neon"
|
||||
))]
|
||||
pub fn l2_normalize(vector: &mut [F32]) {
|
||||
let n = vector.len();
|
||||
let l = length(vector);
|
||||
for i in 0..n {
|
||||
vector[i] /= l;
|
||||
}
|
||||
}
|
@ -1,245 +0,0 @@
|
||||
use super::*;
|
||||
use crate::distance::*;
|
||||
use crate::scalar::*;
|
||||
use crate::vector::*;
|
||||
use num_traits::{Float, Zero};
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub enum Vecf32Cos {}
|
||||
|
||||
impl Global for Vecf32Cos {
|
||||
type VectorOwned = Vecf32Owned;
|
||||
|
||||
const VECTOR_KIND: VectorKind = VectorKind::Vecf32;
|
||||
const DISTANCE_KIND: DistanceKind = DistanceKind::Cos;
|
||||
|
||||
fn distance(lhs: Vecf32Borrowed<'_>, rhs: Vecf32Borrowed<'_>) -> F32 {
|
||||
F32(1.0) - super::vecf32::cosine(lhs.slice(), rhs.slice())
|
||||
}
|
||||
}
|
||||
|
||||
impl GlobalElkanKMeans for Vecf32Cos {
|
||||
type VectorNormalized = Self::VectorOwned;
|
||||
|
||||
fn elkan_k_means_normalize(vector: &mut [F32]) {
|
||||
super::vecf32::l2_normalize(vector)
|
||||
}
|
||||
|
||||
fn elkan_k_means_normalize2(vector: Vecf32Borrowed<'_>) -> Vecf32Owned {
|
||||
let mut vector = vector.for_own();
|
||||
super::vecf32::l2_normalize(vector.slice_mut());
|
||||
vector
|
||||
}
|
||||
|
||||
fn elkan_k_means_distance(lhs: &[F32], rhs: &[F32]) -> F32 {
|
||||
super::vecf32::dot(lhs, rhs).acos()
|
||||
}
|
||||
|
||||
fn elkan_k_means_distance2(lhs: Vecf32Borrowed<'_>, rhs: &[F32]) -> F32 {
|
||||
super::vecf32::dot(lhs.slice(), rhs).acos()
|
||||
}
|
||||
}
|
||||
|
||||
impl GlobalScalarQuantization for Vecf32Cos {
|
||||
#[multiversion::multiversion(targets(
|
||||
"x86_64/x86-64-v4",
|
||||
"x86_64/x86-64-v3",
|
||||
"x86_64/x86-64-v2",
|
||||
"aarch64+neon"
|
||||
))]
|
||||
fn scalar_quantization_distance<'a>(
|
||||
dims: u16,
|
||||
max: &[F32],
|
||||
min: &[F32],
|
||||
lhs: Vecf32Borrowed<'a>,
|
||||
rhs: &[u8],
|
||||
) -> F32 {
|
||||
let lhs = lhs.slice();
|
||||
let mut xy = F32::zero();
|
||||
let mut x2 = F32::zero();
|
||||
let mut y2 = F32::zero();
|
||||
for i in 0..dims as usize {
|
||||
let _x = lhs[i];
|
||||
let _y = F32(rhs[i] as f32 / 256.0) * (max[i] - min[i]) + min[i];
|
||||
xy += _x * _y;
|
||||
x2 += _x * _x;
|
||||
y2 += _y * _y;
|
||||
}
|
||||
F32(1.0) - xy / (x2 * y2).sqrt()
|
||||
}
|
||||
|
||||
#[multiversion::multiversion(targets(
|
||||
"x86_64/x86-64-v4",
|
||||
"x86_64/x86-64-v3",
|
||||
"x86_64/x86-64-v2",
|
||||
"aarch64+neon"
|
||||
))]
|
||||
fn scalar_quantization_distance2(
|
||||
dims: u16,
|
||||
max: &[F32],
|
||||
min: &[F32],
|
||||
lhs: &[u8],
|
||||
rhs: &[u8],
|
||||
) -> F32 {
|
||||
let mut xy = F32::zero();
|
||||
let mut x2 = F32::zero();
|
||||
let mut y2 = F32::zero();
|
||||
for i in 0..dims as usize {
|
||||
let _x = F32(lhs[i] as f32 / 256.0) * (max[i] - min[i]) + min[i];
|
||||
let _y = F32(rhs[i] as f32 / 256.0) * (max[i] - min[i]) + min[i];
|
||||
xy += _x * _y;
|
||||
x2 += _x * _x;
|
||||
y2 += _y * _y;
|
||||
}
|
||||
F32(1.0) - xy / (x2 * y2).sqrt()
|
||||
}
|
||||
}
|
||||
|
||||
impl GlobalProductQuantization for Vecf32Cos {
|
||||
type ProductQuantizationL2 = Vecf32L2;
|
||||
|
||||
#[multiversion::multiversion(targets(
|
||||
"x86_64/x86-64-v4",
|
||||
"x86_64/x86-64-v3",
|
||||
"x86_64/x86-64-v2",
|
||||
"aarch64+neon"
|
||||
))]
|
||||
fn product_quantization_distance<'a>(
|
||||
dims: u32,
|
||||
ratio: u32,
|
||||
centroids: &[F32],
|
||||
lhs: Vecf32Borrowed<'a>,
|
||||
rhs: &[u8],
|
||||
) -> F32 {
|
||||
let lhs = lhs.slice();
|
||||
let width = dims.div_ceil(ratio);
|
||||
let mut xy = F32::zero();
|
||||
let mut x2 = F32::zero();
|
||||
let mut y2 = F32::zero();
|
||||
for i in 0..width {
|
||||
let k = std::cmp::min(ratio, dims - ratio * i);
|
||||
let lhs = &lhs[(i * ratio) as usize..][..k as usize];
|
||||
let rhsp = rhs[i as usize] as usize * dims as usize;
|
||||
let rhs = ¢roids[rhsp..][(i * ratio) as usize..][..k as usize];
|
||||
let (_xy, _x2, _y2) = xy_x2_y2(lhs, rhs);
|
||||
xy += _xy;
|
||||
x2 += _x2;
|
||||
y2 += _y2;
|
||||
}
|
||||
F32(1.0) - xy / (x2 * y2).sqrt()
|
||||
}
|
||||
|
||||
#[multiversion::multiversion(targets(
|
||||
"x86_64/x86-64-v4",
|
||||
"x86_64/x86-64-v3",
|
||||
"x86_64/x86-64-v2",
|
||||
"aarch64+neon"
|
||||
))]
|
||||
fn product_quantization_distance2(
|
||||
dims: u32,
|
||||
ratio: u32,
|
||||
centroids: &[F32],
|
||||
lhs: &[u8],
|
||||
rhs: &[u8],
|
||||
) -> F32 {
|
||||
let width = dims.div_ceil(ratio);
|
||||
let mut xy = F32::zero();
|
||||
let mut x2 = F32::zero();
|
||||
let mut y2 = F32::zero();
|
||||
for i in 0..width {
|
||||
let k = std::cmp::min(ratio, dims - ratio * i);
|
||||
let lhsp = lhs[i as usize] as usize * dims as usize;
|
||||
let lhs = ¢roids[lhsp..][(i * ratio) as usize..][..k as usize];
|
||||
let rhsp = rhs[i as usize] as usize * dims as usize;
|
||||
let rhs = ¢roids[rhsp..][(i * ratio) as usize..][..k as usize];
|
||||
let (_xy, _x2, _y2) = xy_x2_y2(lhs, rhs);
|
||||
xy += _xy;
|
||||
x2 += _x2;
|
||||
y2 += _y2;
|
||||
}
|
||||
F32(1.0) - xy / (x2 * y2).sqrt()
|
||||
}
|
||||
|
||||
#[multiversion::multiversion(targets(
|
||||
"x86_64/x86-64-v4",
|
||||
"x86_64/x86-64-v3",
|
||||
"x86_64/x86-64-v2",
|
||||
"aarch64+neon"
|
||||
))]
|
||||
fn product_quantization_distance_with_delta<'a>(
|
||||
dims: u32,
|
||||
ratio: u32,
|
||||
centroids: &[F32],
|
||||
lhs: Vecf32Borrowed<'a>,
|
||||
rhs: &[u8],
|
||||
delta: &[F32],
|
||||
) -> F32 {
|
||||
let lhs = lhs.slice();
|
||||
let width = dims.div_ceil(ratio);
|
||||
let mut xy = F32::zero();
|
||||
let mut x2 = F32::zero();
|
||||
let mut y2 = F32::zero();
|
||||
for i in 0..width {
|
||||
let k = std::cmp::min(ratio, dims - ratio * i);
|
||||
let lhs = &lhs[(i * ratio) as usize..][..k as usize];
|
||||
let rhsp = rhs[i as usize] as usize * dims as usize;
|
||||
let rhs = ¢roids[rhsp..][(i * ratio) as usize..][..k as usize];
|
||||
let del = &delta[(i * ratio) as usize..][..k as usize];
|
||||
let (_xy, _x2, _y2) = xy_x2_y2_delta(lhs, rhs, del);
|
||||
xy += _xy;
|
||||
x2 += _x2;
|
||||
y2 += _y2;
|
||||
}
|
||||
F32(1.0) - xy / (x2 * y2).sqrt()
|
||||
}
|
||||
|
||||
fn product_quantization_l2_distance(lhs: &[Scalar<Self>], rhs: &[Scalar<Self>]) -> F32 {
|
||||
super::vecf32::sl2(lhs, rhs)
|
||||
}
|
||||
|
||||
fn product_quantization_dense_distance(lhs: &[Scalar<Self>], rhs: &[Scalar<Self>]) -> F32 {
|
||||
F32(1.0) - super::vecf32::cosine(lhs, rhs)
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
#[multiversion::multiversion(targets(
|
||||
"x86_64/x86-64-v4",
|
||||
"x86_64/x86-64-v3",
|
||||
"x86_64/x86-64-v2",
|
||||
"aarch64+neon"
|
||||
))]
|
||||
fn xy_x2_y2(lhs: &[F32], rhs: &[F32]) -> (F32, F32, F32) {
|
||||
assert!(lhs.len() == rhs.len());
|
||||
let n = lhs.len();
|
||||
let mut xy = F32::zero();
|
||||
let mut x2 = F32::zero();
|
||||
let mut y2 = F32::zero();
|
||||
for i in 0..n {
|
||||
xy += lhs[i] * rhs[i];
|
||||
x2 += lhs[i] * lhs[i];
|
||||
y2 += rhs[i] * rhs[i];
|
||||
}
|
||||
(xy, x2, y2)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
#[multiversion::multiversion(targets(
|
||||
"x86_64/x86-64-v4",
|
||||
"x86_64/x86-64-v3",
|
||||
"x86_64/x86-64-v2",
|
||||
"aarch64+neon"
|
||||
))]
|
||||
fn xy_x2_y2_delta(lhs: &[F32], rhs: &[F32], del: &[F32]) -> (F32, F32, F32) {
|
||||
assert!(lhs.len() == rhs.len());
|
||||
let n = lhs.len();
|
||||
let mut xy = F32::zero();
|
||||
let mut x2 = F32::zero();
|
||||
let mut y2 = F32::zero();
|
||||
for i in 0..n {
|
||||
xy += lhs[i] * (rhs[i] + del[i]);
|
||||
x2 += lhs[i] * lhs[i];
|
||||
y2 += (rhs[i] + del[i]) * (rhs[i] + del[i]);
|
||||
}
|
||||
(xy, x2, y2)
|
||||
}
|
@ -1,200 +0,0 @@
|
||||
use super::*;
|
||||
use crate::distance::*;
|
||||
use crate::scalar::*;
|
||||
use crate::vector::*;
|
||||
use num_traits::{Float, Zero};
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub enum Vecf32Dot {}
|
||||
|
||||
impl Global for Vecf32Dot {
|
||||
type VectorOwned = Vecf32Owned;
|
||||
|
||||
const DISTANCE_KIND: DistanceKind = DistanceKind::Dot;
|
||||
const VECTOR_KIND: VectorKind = VectorKind::Vecf32;
|
||||
|
||||
fn distance(lhs: Vecf32Borrowed<'_>, rhs: Vecf32Borrowed<'_>) -> F32 {
|
||||
super::vecf32::dot(lhs.slice(), rhs.slice()) * (-1.0)
|
||||
}
|
||||
}
|
||||
|
||||
impl GlobalElkanKMeans for Vecf32Dot {
|
||||
type VectorNormalized = Self::VectorOwned;
|
||||
|
||||
fn elkan_k_means_normalize(vector: &mut [F32]) {
|
||||
super::vecf32::l2_normalize(vector)
|
||||
}
|
||||
|
||||
fn elkan_k_means_normalize2(vector: Vecf32Borrowed<'_>) -> Vecf32Owned {
|
||||
let mut vector = vector.for_own();
|
||||
super::vecf32::l2_normalize(vector.slice_mut());
|
||||
vector
|
||||
}
|
||||
|
||||
fn elkan_k_means_distance(lhs: &[F32], rhs: &[F32]) -> F32 {
|
||||
super::vecf32::dot(lhs, rhs).acos()
|
||||
}
|
||||
|
||||
fn elkan_k_means_distance2(lhs: Vecf32Borrowed<'_>, rhs: &[F32]) -> F32 {
|
||||
super::vecf32::dot(lhs.slice(), rhs).acos()
|
||||
}
|
||||
}
|
||||
|
||||
impl GlobalScalarQuantization for Vecf32Dot {
|
||||
#[multiversion::multiversion(targets(
|
||||
"x86_64/x86-64-v4",
|
||||
"x86_64/x86-64-v3",
|
||||
"x86_64/x86-64-v2",
|
||||
"aarch64+neon"
|
||||
))]
|
||||
fn scalar_quantization_distance<'a>(
|
||||
dims: u16,
|
||||
max: &[F32],
|
||||
min: &[F32],
|
||||
lhs: Vecf32Borrowed<'a>,
|
||||
rhs: &[u8],
|
||||
) -> F32 {
|
||||
let lhs = lhs.slice();
|
||||
let mut xy = F32::zero();
|
||||
for i in 0..dims as usize {
|
||||
let _x = lhs[i];
|
||||
let _y = F32(rhs[i] as f32 / 256.0) * (max[i] - min[i]) + min[i];
|
||||
xy += _x * _y;
|
||||
}
|
||||
xy * (-1.0)
|
||||
}
|
||||
|
||||
#[multiversion::multiversion(targets(
|
||||
"x86_64/x86-64-v4",
|
||||
"x86_64/x86-64-v3",
|
||||
"x86_64/x86-64-v2",
|
||||
"aarch64+neon"
|
||||
))]
|
||||
fn scalar_quantization_distance2(
|
||||
dims: u16,
|
||||
max: &[F32],
|
||||
min: &[F32],
|
||||
lhs: &[u8],
|
||||
rhs: &[u8],
|
||||
) -> F32 {
|
||||
let mut xy = F32::zero();
|
||||
for i in 0..dims as usize {
|
||||
let _x = F32(lhs[i] as f32 / 256.0) * (max[i] - min[i]) + min[i];
|
||||
let _y = F32(rhs[i] as f32 / 256.0) * (max[i] - min[i]) + min[i];
|
||||
xy += _x * _y;
|
||||
}
|
||||
xy * (-1.0)
|
||||
}
|
||||
}
|
||||
|
||||
impl GlobalProductQuantization for Vecf32Dot {
|
||||
type ProductQuantizationL2 = Vecf32L2;
|
||||
|
||||
#[multiversion::multiversion(targets(
|
||||
"x86_64/x86-64-v4",
|
||||
"x86_64/x86-64-v3",
|
||||
"x86_64/x86-64-v2",
|
||||
"aarch64+neon"
|
||||
))]
|
||||
fn product_quantization_distance<'a>(
|
||||
dims: u32,
|
||||
ratio: u32,
|
||||
centroids: &[F32],
|
||||
lhs: Vecf32Borrowed<'a>,
|
||||
rhs: &[u8],
|
||||
) -> F32 {
|
||||
let lhs = lhs.slice();
|
||||
let width = dims.div_ceil(ratio);
|
||||
let mut xy = F32::zero();
|
||||
for i in 0..width {
|
||||
let k = std::cmp::min(ratio, dims - ratio * i);
|
||||
let lhs = &lhs[(i * ratio) as usize..][..k as usize];
|
||||
let rhsp = rhs[i as usize] as usize * dims as usize;
|
||||
let rhs = ¢roids[rhsp..][(i * ratio) as usize..][..k as usize];
|
||||
let _xy = super::vecf32::dot(lhs, rhs);
|
||||
xy += _xy;
|
||||
}
|
||||
xy * (-1.0)
|
||||
}
|
||||
|
||||
#[multiversion::multiversion(targets(
|
||||
"x86_64/x86-64-v4",
|
||||
"x86_64/x86-64-v3",
|
||||
"x86_64/x86-64-v2",
|
||||
"aarch64+neon"
|
||||
))]
|
||||
fn product_quantization_distance2(
|
||||
dims: u32,
|
||||
ratio: u32,
|
||||
centroids: &[F32],
|
||||
lhs: &[u8],
|
||||
rhs: &[u8],
|
||||
) -> F32 {
|
||||
let width = dims.div_ceil(ratio);
|
||||
let mut xy = F32::zero();
|
||||
for i in 0..width {
|
||||
let k = std::cmp::min(ratio, dims - ratio * i);
|
||||
let lhsp = lhs[i as usize] as usize * dims as usize;
|
||||
let lhs = ¢roids[lhsp..][(i * ratio) as usize..][..k as usize];
|
||||
let rhsp = rhs[i as usize] as usize * dims as usize;
|
||||
let rhs = ¢roids[rhsp..][(i * ratio) as usize..][..k as usize];
|
||||
let _xy = super::vecf32::dot(lhs, rhs);
|
||||
xy += _xy;
|
||||
}
|
||||
xy * (-1.0)
|
||||
}
|
||||
|
||||
#[multiversion::multiversion(targets(
|
||||
"x86_64/x86-64-v4",
|
||||
"x86_64/x86-64-v3",
|
||||
"x86_64/x86-64-v2",
|
||||
"aarch64+neon"
|
||||
))]
|
||||
fn product_quantization_distance_with_delta<'a>(
|
||||
dims: u32,
|
||||
ratio: u32,
|
||||
centroids: &[F32],
|
||||
lhs: Vecf32Borrowed<'a>,
|
||||
rhs: &[u8],
|
||||
delta: &[F32],
|
||||
) -> F32 {
|
||||
let lhs = lhs.slice();
|
||||
let width = dims.div_ceil(ratio);
|
||||
let mut xy = F32::zero();
|
||||
for i in 0..width {
|
||||
let k = std::cmp::min(ratio, dims - ratio * i);
|
||||
let lhs = &lhs[(i * ratio) as usize..][..k as usize];
|
||||
let rhsp = rhs[i as usize] as usize * dims as usize;
|
||||
let rhs = ¢roids[rhsp..][(i * ratio) as usize..][..k as usize];
|
||||
let del = &delta[(i * ratio) as usize..][..k as usize];
|
||||
let _xy = dot_delta(lhs, rhs, del);
|
||||
xy += _xy;
|
||||
}
|
||||
xy * (-1.0)
|
||||
}
|
||||
|
||||
fn product_quantization_l2_distance(lhs: &[Scalar<Self>], rhs: &[Scalar<Self>]) -> F32 {
|
||||
super::vecf32::sl2(lhs, rhs)
|
||||
}
|
||||
|
||||
fn product_quantization_dense_distance(lhs: &[Scalar<Self>], rhs: &[Scalar<Self>]) -> F32 {
|
||||
super::vecf32::dot(lhs, rhs) * (-1.0)
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
#[multiversion::multiversion(targets(
|
||||
"x86_64/x86-64-v4",
|
||||
"x86_64/x86-64-v3",
|
||||
"x86_64/x86-64-v2",
|
||||
"aarch64+neon"
|
||||
))]
|
||||
fn dot_delta(lhs: &[F32], rhs: &[F32], del: &[F32]) -> F32 {
|
||||
assert!(lhs.len() == rhs.len());
|
||||
let n: usize = lhs.len();
|
||||
let mut xy = F32::zero();
|
||||
for i in 0..n {
|
||||
xy += lhs[i] * (rhs[i] + del[i]);
|
||||
}
|
||||
xy
|
||||
}
|
@ -1,194 +0,0 @@
|
||||
use super::*;
|
||||
use crate::distance::*;
|
||||
use crate::scalar::*;
|
||||
use crate::vector::*;
|
||||
use num_traits::{Float, Zero};
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub enum Vecf32L2 {}
|
||||
|
||||
impl Global for Vecf32L2 {
|
||||
type VectorOwned = Vecf32Owned;
|
||||
|
||||
const VECTOR_KIND: VectorKind = VectorKind::Vecf32;
|
||||
const DISTANCE_KIND: DistanceKind = DistanceKind::L2;
|
||||
|
||||
fn distance(lhs: Vecf32Borrowed<'_>, rhs: Vecf32Borrowed<'_>) -> F32 {
|
||||
super::vecf32::sl2(lhs.slice(), rhs.slice())
|
||||
}
|
||||
}
|
||||
|
||||
impl GlobalElkanKMeans for Vecf32L2 {
|
||||
type VectorNormalized = Self::VectorOwned;
|
||||
|
||||
fn elkan_k_means_normalize(_: &mut [F32]) {}
|
||||
|
||||
fn elkan_k_means_normalize2(vector: Vecf32Borrowed<'_>) -> Vecf32Owned {
|
||||
vector.for_own()
|
||||
}
|
||||
|
||||
fn elkan_k_means_distance(lhs: &[Scalar<Self>], rhs: &[Scalar<Self>]) -> F32 {
|
||||
super::vecf32::sl2(lhs, rhs).sqrt()
|
||||
}
|
||||
|
||||
fn elkan_k_means_distance2(lhs: Vecf32Borrowed<'_>, rhs: &[Scalar<Self>]) -> F32 {
|
||||
super::vecf32::sl2(lhs.slice(), rhs).sqrt()
|
||||
}
|
||||
}
|
||||
|
||||
impl GlobalScalarQuantization for Vecf32L2 {
|
||||
#[multiversion::multiversion(targets(
|
||||
"x86_64/x86-64-v4",
|
||||
"x86_64/x86-64-v3",
|
||||
"x86_64/x86-64-v2",
|
||||
"aarch64+neon"
|
||||
))]
|
||||
fn scalar_quantization_distance<'a>(
|
||||
dims: u16,
|
||||
max: &[F32],
|
||||
min: &[F32],
|
||||
lhs: Vecf32Borrowed<'a>,
|
||||
rhs: &[u8],
|
||||
) -> F32 {
|
||||
let lhs = lhs.slice();
|
||||
let mut result = F32::zero();
|
||||
for i in 0..dims as usize {
|
||||
let _x = lhs[i];
|
||||
let _y = F32(rhs[i] as f32 / 256.0) * (max[i] - min[i]) + min[i];
|
||||
result += (_x - _y) * (_x - _y);
|
||||
}
|
||||
result
|
||||
}
|
||||
|
||||
#[multiversion::multiversion(targets(
|
||||
"x86_64/x86-64-v4",
|
||||
"x86_64/x86-64-v3",
|
||||
"x86_64/x86-64-v2",
|
||||
"aarch64+neon"
|
||||
))]
|
||||
fn scalar_quantization_distance2(
|
||||
dims: u16,
|
||||
max: &[F32],
|
||||
min: &[F32],
|
||||
lhs: &[u8],
|
||||
rhs: &[u8],
|
||||
) -> F32 {
|
||||
let mut result = F32::zero();
|
||||
for i in 0..dims as usize {
|
||||
let _x = F32(lhs[i] as f32 / 256.0) * (max[i] - min[i]) + min[i];
|
||||
let _y = F32(rhs[i] as f32 / 256.0) * (max[i] - min[i]) + min[i];
|
||||
result += (_x - _y) * (_x - _y);
|
||||
}
|
||||
result
|
||||
}
|
||||
}
|
||||
|
||||
impl GlobalProductQuantization for Vecf32L2 {
|
||||
type ProductQuantizationL2 = Vecf32L2;
|
||||
|
||||
#[multiversion::multiversion(targets(
|
||||
"x86_64/x86-64-v4",
|
||||
"x86_64/x86-64-v3",
|
||||
"x86_64/x86-64-v2",
|
||||
"aarch64+neon"
|
||||
))]
|
||||
fn product_quantization_distance<'a>(
|
||||
dims: u32,
|
||||
ratio: u32,
|
||||
centroids: &[F32],
|
||||
lhs: Vecf32Borrowed<'a>,
|
||||
rhs: &[u8],
|
||||
) -> F32 {
|
||||
let lhs = lhs.slice();
|
||||
let width = dims.div_ceil(ratio);
|
||||
let mut result = F32::zero();
|
||||
for i in 0..width {
|
||||
let k = std::cmp::min(ratio, dims - ratio * i);
|
||||
let lhs = &lhs[(i * ratio) as usize..][..k as usize];
|
||||
let rhsp = rhs[i as usize] as usize * dims as usize;
|
||||
let rhs = ¢roids[rhsp..][(i * ratio) as usize..][..k as usize];
|
||||
result += super::vecf32::sl2(lhs, rhs);
|
||||
}
|
||||
result
|
||||
}
|
||||
|
||||
#[multiversion::multiversion(targets(
|
||||
"x86_64/x86-64-v4",
|
||||
"x86_64/x86-64-v3",
|
||||
"x86_64/x86-64-v2",
|
||||
"aarch64+neon"
|
||||
))]
|
||||
fn product_quantization_distance2(
|
||||
dims: u32,
|
||||
ratio: u32,
|
||||
centroids: &[F32],
|
||||
lhs: &[u8],
|
||||
rhs: &[u8],
|
||||
) -> F32 {
|
||||
let width = dims.div_ceil(ratio);
|
||||
let mut result = F32::zero();
|
||||
for i in 0..width {
|
||||
let k = std::cmp::min(ratio, dims - ratio * i);
|
||||
let lhsp = lhs[i as usize] as usize * dims as usize;
|
||||
let lhs = ¢roids[lhsp..][(i * ratio) as usize..][..k as usize];
|
||||
let rhsp = rhs[i as usize] as usize * dims as usize;
|
||||
let rhs = ¢roids[rhsp..][(i * ratio) as usize..][..k as usize];
|
||||
result += super::vecf32::sl2(lhs, rhs);
|
||||
}
|
||||
result
|
||||
}
|
||||
|
||||
#[multiversion::multiversion(targets(
|
||||
"x86_64/x86-64-v4",
|
||||
"x86_64/x86-64-v3",
|
||||
"x86_64/x86-64-v2",
|
||||
"aarch64+neon"
|
||||
))]
|
||||
fn product_quantization_distance_with_delta<'a>(
|
||||
dims: u32,
|
||||
ratio: u32,
|
||||
centroids: &[F32],
|
||||
lhs: Vecf32Borrowed<'a>,
|
||||
rhs: &[u8],
|
||||
delta: &[F32],
|
||||
) -> F32 {
|
||||
let lhs = lhs.slice();
|
||||
let width = dims.div_ceil(ratio);
|
||||
let mut result = F32::zero();
|
||||
for i in 0..width {
|
||||
let k = std::cmp::min(ratio, dims - ratio * i);
|
||||
let lhs = &lhs[(i * ratio) as usize..][..k as usize];
|
||||
let rhsp = rhs[i as usize] as usize * dims as usize;
|
||||
let rhs = ¢roids[rhsp..][(i * ratio) as usize..][..k as usize];
|
||||
let del = &delta[(i * ratio) as usize..][..k as usize];
|
||||
result += distance_squared_l2_delta(lhs, rhs, del);
|
||||
}
|
||||
result
|
||||
}
|
||||
|
||||
fn product_quantization_l2_distance(lhs: &[Scalar<Self>], rhs: &[Scalar<Self>]) -> F32 {
|
||||
super::vecf32::sl2(lhs, rhs)
|
||||
}
|
||||
|
||||
fn product_quantization_dense_distance(lhs: &[Scalar<Self>], rhs: &[Scalar<Self>]) -> F32 {
|
||||
super::vecf32::sl2(lhs, rhs)
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
#[multiversion::multiversion(targets(
|
||||
"x86_64/x86-64-v4",
|
||||
"x86_64/x86-64-v3",
|
||||
"x86_64/x86-64-v2",
|
||||
"aarch64+neon"
|
||||
))]
|
||||
fn distance_squared_l2_delta(lhs: &[F32], rhs: &[F32], del: &[F32]) -> F32 {
|
||||
assert!(lhs.len() == rhs.len());
|
||||
let n = lhs.len();
|
||||
let mut d2 = F32::zero();
|
||||
for i in 0..n {
|
||||
let d = lhs[i] - (rhs[i] + del[i]);
|
||||
d2 += d * d;
|
||||
}
|
||||
d2
|
||||
}
|
@ -1,232 +0,0 @@
|
||||
use crate::scalar::{F32, I8};
|
||||
|
||||
use super::Veci8Borrowed;
|
||||
|
||||
pub fn dot(x: &[I8], y: &[I8]) -> F32 {
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
{
|
||||
if detect::x86_64::test_avx512vnni() {
|
||||
return unsafe { dot_i8_avx512vnni(x, y) };
|
||||
}
|
||||
}
|
||||
dot_i8_fallback(x, y)
|
||||
}
|
||||
|
||||
#[multiversion::multiversion(targets(
|
||||
"x86_64/x86-64-v4",
|
||||
"x86_64/x86-64-v3",
|
||||
"x86_64/x86-64-v2",
|
||||
"aarch64+neon"
|
||||
))]
|
||||
fn dot_i8_fallback(x: &[I8], y: &[I8]) -> F32 {
|
||||
// i8 * i8 fall in range of i16. Since our length is less than (2^16 - 1), the result won't overflow.
|
||||
let mut sum = 0;
|
||||
assert_eq!(x.len(), y.len());
|
||||
let length = x.len();
|
||||
// according to https://godbolt.org/z/ff48vW4es, this loop will be autovectorized
|
||||
for i in 0..length {
|
||||
sum += (x[i].0 as i16 * y[i].0 as i16) as i32;
|
||||
}
|
||||
F32(sum as f32)
|
||||
}
|
||||
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
#[target_feature(enable = "avx512vnni,avx512bw,avx512f,bmi2")]
|
||||
unsafe fn dot_i8_avx512vnni(x: &[I8], y: &[I8]) -> F32 {
|
||||
use std::arch::x86_64::*;
|
||||
#[inline]
|
||||
#[target_feature(enable = "avx512vnni,avx512bw,avx512f,bmi2")]
|
||||
pub unsafe fn _mm512_maskz_loadu_epi8(k: __mmask64, mem_addr: *const i8) -> __m512i {
|
||||
let mut dst: __m512i;
|
||||
unsafe {
|
||||
std::arch::asm!(
|
||||
"vmovdqu8 {dst}{{{k}}} {{z}}, [{p}]",
|
||||
p = in(reg) mem_addr,
|
||||
k = in(kreg) k,
|
||||
dst = out(zmm_reg) dst,
|
||||
options(pure, readonly, nostack)
|
||||
);
|
||||
}
|
||||
dst
|
||||
}
|
||||
assert_eq!(x.len(), y.len());
|
||||
let mut sum = 0;
|
||||
let mut i = x.len();
|
||||
let mut p_x = x.as_ptr() as *const i8;
|
||||
let mut p_y = y.as_ptr() as *const i8;
|
||||
let mut vec_x;
|
||||
let mut vec_y;
|
||||
unsafe {
|
||||
let mut result = _mm512_setzero_si512();
|
||||
let zero = _mm512_setzero_si512();
|
||||
while i > 0 {
|
||||
if i < 64 {
|
||||
let mask = _bzhi_u64(0xFFFF_FFFF_FFFF_FFFF, i as u32);
|
||||
vec_x = _mm512_maskz_loadu_epi8(mask, p_x);
|
||||
vec_y = _mm512_maskz_loadu_epi8(mask, p_y);
|
||||
i = 0;
|
||||
} else {
|
||||
vec_x = _mm512_loadu_epi8(p_x);
|
||||
vec_y = _mm512_loadu_epi8(p_y);
|
||||
i -= 64;
|
||||
p_x = p_x.add(64);
|
||||
p_y = p_y.add(64);
|
||||
}
|
||||
// There are only _mm512_dpbusd_epi32 support, dpbusd will zeroextend a[i] and signextend b[i] first, so we need to convert a[i] positive and change corresponding b[i] to get right result.
|
||||
// And because we use -b[i] here, the range of quantization should be [-127, 127] instead of [-128, 127] to avoid overflow.
|
||||
let neg_mask = _mm512_movepi8_mask(vec_x);
|
||||
vec_x = _mm512_mask_abs_epi8(vec_x, neg_mask, vec_x);
|
||||
// Get -b[i] here, use saturating sub to avoid overflow. There are some precision loss here.
|
||||
vec_y = _mm512_mask_subs_epi8(vec_y, neg_mask, zero, vec_y);
|
||||
result = _mm512_dpbusd_epi32(result, vec_x, vec_y);
|
||||
}
|
||||
sum += _mm512_reduce_add_epi32(result);
|
||||
}
|
||||
F32(sum as f32)
|
||||
}
|
||||
|
||||
pub fn dot_distance(x: &Veci8Borrowed<'_>, y: &Veci8Borrowed<'_>) -> F32 {
|
||||
// (alpha_x * x[i] + offset_x) * (alpha_y * y[i] + offset_y)
|
||||
// = alpha_x * alpha_y * x[i] * y[i] + alpha_x * offset_y * x[i] + alpha_y * offset_x * y[i] + offset_x * offset_y
|
||||
// Sum(dot(origin_x[i] , origin_y[i])) = alpha_x * alpha_y * Sum(dot(x[i], y[i])) + offset_y * Sum(alpha_x * x[i]) + offset_x * Sum(alpha_y * y[i]) + offset_x * offset_y * dims
|
||||
let dot_xy = dot(x.data(), y.data());
|
||||
x.alpha() * y.alpha() * dot_xy
|
||||
+ x.offset() * y.sum()
|
||||
+ y.offset() * x.sum()
|
||||
+ x.offset() * y.offset() * F32(x.dims() as f32)
|
||||
}
|
||||
|
||||
pub fn l2_distance(x: &Veci8Borrowed<'_>, y: &Veci8Borrowed<'_>) -> F32 {
|
||||
// Sum(l2(origin_x[i] - origin_y[i])) = sum(x[i] ^ 2 - 2 * x[i] * y[i] + y[i] ^ 2)
|
||||
// = dot(x, x) - 2 * dot(x, y) + dot(y, y)
|
||||
x.l2_norm() * x.l2_norm() - F32(2.0) * dot_distance(x, y) + y.l2_norm() * y.l2_norm()
|
||||
}
|
||||
|
||||
pub fn cosine_distance(x: &Veci8Borrowed<'_>, y: &Veci8Borrowed<'_>) -> F32 {
|
||||
// dot(x, y) / (l2(x) * l2(y))
|
||||
let dot_xy = dot_distance(x, y);
|
||||
let l2_x = x.l2_norm();
|
||||
let l2_y = y.l2_norm();
|
||||
dot_xy / (l2_x * l2_y)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
#[multiversion::multiversion(targets(
|
||||
"x86_64/x86-64-v4",
|
||||
"x86_64/x86-64-v3",
|
||||
"x86_64/x86-64-v2",
|
||||
"aarch64+neon"
|
||||
))]
|
||||
pub fn l2_2<'a>(lhs: Veci8Borrowed<'a>, rhs: &[F32]) -> F32 {
|
||||
let data = lhs.data();
|
||||
assert_eq!(data.len(), rhs.len());
|
||||
data.iter()
|
||||
.zip(rhs.iter())
|
||||
.map(|(&x, &y)| {
|
||||
(x.to_f32() * lhs.alpha() + lhs.offset() - y)
|
||||
* (x.to_f32() * lhs.alpha() + lhs.offset() - y)
|
||||
})
|
||||
.sum::<F32>()
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
#[multiversion::multiversion(targets(
|
||||
"x86_64/x86-64-v4",
|
||||
"x86_64/x86-64-v3",
|
||||
"x86_64/x86-64-v2",
|
||||
"aarch64+neon"
|
||||
))]
|
||||
pub fn dot_2<'a>(lhs: Veci8Borrowed<'a>, rhs: &[F32]) -> F32 {
|
||||
let data = lhs.data();
|
||||
assert_eq!(data.len(), rhs.len());
|
||||
data.iter()
|
||||
.zip(rhs.iter())
|
||||
.map(|(&x, &y)| (x.to_f32() * lhs.alpha() + lhs.offset()) * y)
|
||||
.sum::<F32>()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::{
|
||||
global::{Veci8Owned, VectorOwned},
|
||||
vector::i8_quantization,
|
||||
};
|
||||
|
||||
fn new_random_vec_f32(size: usize) -> Vec<F32> {
|
||||
use rand::Rng;
|
||||
let mut rng = rand::thread_rng();
|
||||
(0..size)
|
||||
.map(|_| F32(rng.gen_range(-100000.0..100000.0)))
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn vec_to_owned(vec: Vec<F32>) -> Veci8Owned {
|
||||
let (v, alpha, offset) = i8_quantization(&vec);
|
||||
Veci8Owned::new(v.len() as u32, v, alpha, offset)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_dot_i8() {
|
||||
let x = vec![F32(1.0), F32(2.0), F32(3.0)];
|
||||
let y = vec![F32(3.0), F32(2.0), F32(1.0)];
|
||||
let x_owned = vec_to_owned(x);
|
||||
let ref_x = x_owned.for_borrow();
|
||||
let y_owned = vec_to_owned(y);
|
||||
let ref_y = y_owned.for_borrow();
|
||||
let result = dot_distance(&ref_x, &ref_y);
|
||||
assert!((result.0 - 10.0).abs() < 0.1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cos_i8() {
|
||||
let x = vec![F32(1.0), F32(2.0), F32(3.0)];
|
||||
let y = vec![F32(3.0), F32(2.0), F32(1.0)];
|
||||
let x_owned = vec_to_owned(x);
|
||||
let ref_x = x_owned.for_borrow();
|
||||
let y_owned = vec_to_owned(y);
|
||||
let ref_y = y_owned.for_borrow();
|
||||
let result = cosine_distance(&ref_x, &ref_y);
|
||||
assert!((result.0 - (10.0 / 14.0)).abs() < 0.1);
|
||||
// test cos_i8 using random generated data, check the precision
|
||||
let x = new_random_vec_f32(1000);
|
||||
let y = new_random_vec_f32(1000);
|
||||
let xy = x.iter().zip(y.iter()).map(|(&x, &y)| x * y).sum::<F32>().0;
|
||||
let l2_x = x.iter().map(|&x| x * x).sum::<F32>().0.sqrt();
|
||||
let l2_y = y.iter().map(|&y| y * y).sum::<F32>().0.sqrt();
|
||||
let result_expected = xy / (l2_x * l2_y);
|
||||
let x_owned = vec_to_owned(x);
|
||||
let ref_x = x_owned.for_borrow();
|
||||
let y_owned = vec_to_owned(y);
|
||||
let ref_y = y_owned.for_borrow();
|
||||
let result = cosine_distance(&ref_x, &ref_y);
|
||||
assert!((result.0 - result_expected).abs() / result_expected < 0.25);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_l2_i8() {
|
||||
let x = vec![F32(1.0), F32(2.0), F32(3.0)];
|
||||
let y = vec![F32(3.0), F32(2.0), F32(1.0)];
|
||||
let x_owned = vec_to_owned(x);
|
||||
let ref_x = x_owned.for_borrow();
|
||||
let y_owned = vec_to_owned(y);
|
||||
let ref_y = y_owned.for_borrow();
|
||||
let result = l2_distance(&ref_x, &ref_y);
|
||||
assert!((result.0 - 8.0).abs() < 0.1);
|
||||
// test l2_i8 using random generated data, check the precision
|
||||
let x = new_random_vec_f32(1000);
|
||||
let y = new_random_vec_f32(1000);
|
||||
let result_expected = x
|
||||
.iter()
|
||||
.zip(y.iter())
|
||||
.map(|(&x, &y)| (x - y) * (x - y))
|
||||
.sum::<F32>()
|
||||
.0;
|
||||
let x_owned = vec_to_owned(x);
|
||||
let ref_x = x_owned.for_borrow();
|
||||
let y_owned = vec_to_owned(y);
|
||||
let ref_y = y_owned.for_borrow();
|
||||
let result = l2_distance(&ref_x, &ref_y);
|
||||
assert!((result.0 - result_expected).abs() / result_expected < 0.05);
|
||||
}
|
||||
}
|
@ -1,99 +0,0 @@
|
||||
use super::*;
|
||||
use crate::distance::*;
|
||||
use crate::scalar::*;
|
||||
use crate::vector::*;
|
||||
use num_traits::Float;
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub enum Veci8Cos {}
|
||||
|
||||
impl Global for Veci8Cos {
|
||||
type VectorOwned = Veci8Owned;
|
||||
|
||||
const VECTOR_KIND: VectorKind = VectorKind::Veci8;
|
||||
const DISTANCE_KIND: DistanceKind = DistanceKind::Cos;
|
||||
|
||||
fn distance(lhs: Borrowed<'_, Self>, rhs: Borrowed<'_, Self>) -> F32 {
|
||||
F32(1.0) - super::veci8::cosine_distance(&lhs, &rhs)
|
||||
}
|
||||
}
|
||||
|
||||
impl GlobalElkanKMeans for Veci8Cos {
|
||||
type VectorNormalized = Self::VectorOwned;
|
||||
|
||||
fn elkan_k_means_normalize(vector: &mut [Scalar<Self>]) {
|
||||
super::vecf32::l2_normalize(vector)
|
||||
}
|
||||
|
||||
fn elkan_k_means_normalize2(vector: Borrowed<'_, Self>) -> Veci8Owned {
|
||||
vector.normalize()
|
||||
}
|
||||
|
||||
fn elkan_k_means_distance(lhs: &[Scalar<Self>], rhs: &[Scalar<Self>]) -> F32 {
|
||||
super::vecf32::dot(lhs, rhs).acos()
|
||||
}
|
||||
|
||||
fn elkan_k_means_distance2(lhs: Borrowed<'_, Self>, rhs: &[Scalar<Self>]) -> F32 {
|
||||
super::veci8::dot_2(lhs, rhs).acos()
|
||||
}
|
||||
}
|
||||
|
||||
impl GlobalScalarQuantization for Veci8Cos {
|
||||
fn scalar_quantization_distance(
|
||||
_dims: u16,
|
||||
_max: &[Scalar<Self>],
|
||||
_min: &[Scalar<Self>],
|
||||
_lhs: Borrowed<'_, Self>,
|
||||
_rhs: &[u8],
|
||||
) -> F32 {
|
||||
unimplemented!()
|
||||
}
|
||||
fn scalar_quantization_distance2(
|
||||
_dims: u16,
|
||||
_max: &[Scalar<Self>],
|
||||
_min: &[Scalar<Self>],
|
||||
_lhs: &[u8],
|
||||
_rhs: &[u8],
|
||||
) -> F32 {
|
||||
unimplemented!()
|
||||
}
|
||||
}
|
||||
|
||||
impl GlobalProductQuantization for Veci8Cos {
|
||||
type ProductQuantizationL2 = Veci8Cos;
|
||||
|
||||
fn product_quantization_distance(
|
||||
_dims: u32,
|
||||
_ratio: u32,
|
||||
_centroids: &[Scalar<Self>],
|
||||
_lhs: Borrowed<'_, Self>,
|
||||
_rhs: &[u8],
|
||||
) -> F32 {
|
||||
unimplemented!()
|
||||
}
|
||||
fn product_quantization_distance2(
|
||||
_dims: u32,
|
||||
_ratio: u32,
|
||||
_centroids: &[Scalar<Self>],
|
||||
_lhs: &[u8],
|
||||
_rhs: &[u8],
|
||||
) -> F32 {
|
||||
unimplemented!()
|
||||
}
|
||||
fn product_quantization_distance_with_delta(
|
||||
_dims: u32,
|
||||
_ratio: u32,
|
||||
_centroids: &[Scalar<Self>],
|
||||
_lhs: Borrowed<'_, Self>,
|
||||
_rhs: &[u8],
|
||||
_delta: &[Scalar<Self>],
|
||||
) -> F32 {
|
||||
unimplemented!()
|
||||
}
|
||||
fn product_quantization_l2_distance(_lhs: &[Scalar<Self>], _rhs: &[Scalar<Self>]) -> F32 {
|
||||
unimplemented!()
|
||||
}
|
||||
fn product_quantization_dense_distance(_lhs: &[Scalar<Self>], _rhs: &[Scalar<Self>]) -> F32 {
|
||||
unimplemented!()
|
||||
}
|
||||
}
|
@ -1,99 +0,0 @@
|
||||
use super::*;
|
||||
use crate::distance::*;
|
||||
use crate::scalar::*;
|
||||
use crate::vector::*;
|
||||
use num_traits::Float;
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub enum Veci8Dot {}
|
||||
|
||||
impl Global for Veci8Dot {
|
||||
type VectorOwned = Veci8Owned;
|
||||
|
||||
const VECTOR_KIND: VectorKind = VectorKind::Veci8;
|
||||
const DISTANCE_KIND: DistanceKind = DistanceKind::Dot;
|
||||
|
||||
fn distance(lhs: Borrowed<'_, Self>, rhs: Borrowed<'_, Self>) -> F32 {
|
||||
super::veci8::dot_distance(&lhs, &rhs) * (-1.0)
|
||||
}
|
||||
}
|
||||
|
||||
impl GlobalElkanKMeans for Veci8Dot {
|
||||
type VectorNormalized = Self::VectorOwned;
|
||||
|
||||
fn elkan_k_means_normalize(vector: &mut [Scalar<Self>]) {
|
||||
super::vecf32::l2_normalize(vector)
|
||||
}
|
||||
|
||||
fn elkan_k_means_normalize2(vector: Borrowed<'_, Self>) -> Veci8Owned {
|
||||
vector.normalize()
|
||||
}
|
||||
|
||||
fn elkan_k_means_distance(lhs: &[Scalar<Self>], rhs: &[Scalar<Self>]) -> F32 {
|
||||
super::vecf32::dot(lhs, rhs).acos()
|
||||
}
|
||||
|
||||
fn elkan_k_means_distance2(lhs: Borrowed<'_, Self>, rhs: &[Scalar<Self>]) -> F32 {
|
||||
super::veci8::dot_2(lhs, rhs).acos()
|
||||
}
|
||||
}
|
||||
|
||||
impl GlobalScalarQuantization for Veci8Dot {
|
||||
fn scalar_quantization_distance(
|
||||
_dims: u16,
|
||||
_max: &[Scalar<Self>],
|
||||
_min: &[Scalar<Self>],
|
||||
_lhs: Borrowed<'_, Self>,
|
||||
_rhs: &[u8],
|
||||
) -> F32 {
|
||||
unimplemented!()
|
||||
}
|
||||
fn scalar_quantization_distance2(
|
||||
_dims: u16,
|
||||
_max: &[Scalar<Self>],
|
||||
_min: &[Scalar<Self>],
|
||||
_lhs: &[u8],
|
||||
_rhs: &[u8],
|
||||
) -> F32 {
|
||||
unimplemented!()
|
||||
}
|
||||
}
|
||||
|
||||
impl GlobalProductQuantization for Veci8Dot {
|
||||
type ProductQuantizationL2 = Veci8Dot;
|
||||
|
||||
fn product_quantization_distance(
|
||||
_dims: u32,
|
||||
_ratio: u32,
|
||||
_centroids: &[Scalar<Self>],
|
||||
_lhs: Borrowed<'_, Self>,
|
||||
_rhs: &[u8],
|
||||
) -> F32 {
|
||||
unimplemented!()
|
||||
}
|
||||
fn product_quantization_distance2(
|
||||
_dims: u32,
|
||||
_ratio: u32,
|
||||
_centroids: &[Scalar<Self>],
|
||||
_lhs: &[u8],
|
||||
_rhs: &[u8],
|
||||
) -> F32 {
|
||||
unimplemented!()
|
||||
}
|
||||
fn product_quantization_distance_with_delta(
|
||||
_dims: u32,
|
||||
_ratio: u32,
|
||||
_centroids: &[Scalar<Self>],
|
||||
_lhs: Borrowed<'_, Self>,
|
||||
_rhs: &[u8],
|
||||
_delta: &[Scalar<Self>],
|
||||
) -> F32 {
|
||||
unimplemented!()
|
||||
}
|
||||
fn product_quantization_l2_distance(_lhs: &[Scalar<Self>], _rhs: &[Scalar<Self>]) -> F32 {
|
||||
unimplemented!()
|
||||
}
|
||||
fn product_quantization_dense_distance(_lhs: &[Scalar<Self>], _rhs: &[Scalar<Self>]) -> F32 {
|
||||
unimplemented!()
|
||||
}
|
||||
}
|
@ -1,99 +0,0 @@
|
||||
use super::*;
|
||||
use crate::distance::*;
|
||||
use crate::scalar::*;
|
||||
use crate::vector::*;
|
||||
use num_traits::Float;
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub enum Veci8L2 {}
|
||||
|
||||
impl Global for Veci8L2 {
|
||||
type VectorOwned = Veci8Owned;
|
||||
|
||||
const VECTOR_KIND: VectorKind = VectorKind::Veci8;
|
||||
const DISTANCE_KIND: DistanceKind = DistanceKind::Dot;
|
||||
|
||||
fn distance(lhs: Borrowed<'_, Self>, rhs: Borrowed<'_, Self>) -> F32 {
|
||||
super::veci8::l2_distance(&lhs, &rhs)
|
||||
}
|
||||
}
|
||||
|
||||
impl GlobalElkanKMeans for Veci8L2 {
|
||||
type VectorNormalized = Self::VectorOwned;
|
||||
|
||||
fn elkan_k_means_normalize(vector: &mut [Scalar<Self>]) {
|
||||
super::vecf32::l2_normalize(vector)
|
||||
}
|
||||
|
||||
fn elkan_k_means_normalize2(vector: Borrowed<'_, Self>) -> Veci8Owned {
|
||||
vector.normalize()
|
||||
}
|
||||
|
||||
fn elkan_k_means_distance(lhs: &[Scalar<Self>], rhs: &[Scalar<Self>]) -> F32 {
|
||||
super::vecf32::sl2(lhs, rhs).sqrt()
|
||||
}
|
||||
|
||||
fn elkan_k_means_distance2(lhs: Borrowed<'_, Self>, rhs: &[Scalar<Self>]) -> F32 {
|
||||
super::veci8::l2_2(lhs, rhs).sqrt()
|
||||
}
|
||||
}
|
||||
|
||||
impl GlobalScalarQuantization for Veci8L2 {
|
||||
fn scalar_quantization_distance(
|
||||
_dims: u16,
|
||||
_max: &[Scalar<Self>],
|
||||
_min: &[Scalar<Self>],
|
||||
_lhs: Borrowed<'_, Self>,
|
||||
_rhs: &[u8],
|
||||
) -> F32 {
|
||||
unimplemented!()
|
||||
}
|
||||
fn scalar_quantization_distance2(
|
||||
_dims: u16,
|
||||
_max: &[Scalar<Self>],
|
||||
_min: &[Scalar<Self>],
|
||||
_lhs: &[u8],
|
||||
_rhs: &[u8],
|
||||
) -> F32 {
|
||||
unimplemented!()
|
||||
}
|
||||
}
|
||||
|
||||
impl GlobalProductQuantization for Veci8L2 {
|
||||
type ProductQuantizationL2 = Veci8L2;
|
||||
|
||||
fn product_quantization_distance(
|
||||
_dims: u32,
|
||||
_ratio: u32,
|
||||
_centroids: &[Scalar<Self>],
|
||||
_lhs: Borrowed<'_, Self>,
|
||||
_rhs: &[u8],
|
||||
) -> F32 {
|
||||
unimplemented!()
|
||||
}
|
||||
fn product_quantization_distance2(
|
||||
_dims: u32,
|
||||
_ratio: u32,
|
||||
_centroids: &[Scalar<Self>],
|
||||
_lhs: &[u8],
|
||||
_rhs: &[u8],
|
||||
) -> F32 {
|
||||
unimplemented!()
|
||||
}
|
||||
fn product_quantization_distance_with_delta(
|
||||
_dims: u32,
|
||||
_ratio: u32,
|
||||
_centroids: &[Scalar<Self>],
|
||||
_lhs: Borrowed<'_, Self>,
|
||||
_rhs: &[u8],
|
||||
_delta: &[Scalar<Self>],
|
||||
) -> F32 {
|
||||
unimplemented!()
|
||||
}
|
||||
fn product_quantization_l2_distance(_lhs: &[Scalar<Self>], _rhs: &[Scalar<Self>]) -> F32 {
|
||||
unimplemented!()
|
||||
}
|
||||
fn product_quantization_dense_distance(_lhs: &[Scalar<Self>], _rhs: &[Scalar<Self>]) -> F32 {
|
||||
unimplemented!()
|
||||
}
|
||||
}
|
@ -1,9 +1,83 @@
|
||||
use crate::distance::*;
|
||||
use crate::vector::*;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use thiserror::Error;
|
||||
use uuid::Uuid;
|
||||
use validator::{Validate, ValidationError};
|
||||
|
||||
#[must_use]
|
||||
#[derive(Debug, Clone, Error, Serialize, Deserialize)]
|
||||
pub enum CreateError {
|
||||
#[error("Invalid index options.")]
|
||||
InvalidIndexOptions { reason: String },
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
#[derive(Debug, Clone, Error, Serialize, Deserialize)]
|
||||
pub enum DropError {
|
||||
#[error("Index not found.")]
|
||||
NotExist,
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
#[derive(Debug, Clone, Error, Serialize, Deserialize)]
|
||||
pub enum FlushError {
|
||||
#[error("Index not found.")]
|
||||
NotExist,
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
#[derive(Debug, Clone, Error, Serialize, Deserialize)]
|
||||
pub enum InsertError {
|
||||
#[error("Index not found.")]
|
||||
NotExist,
|
||||
#[error("Invalid vector.")]
|
||||
InvalidVector,
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
#[derive(Debug, Clone, Error, Serialize, Deserialize)]
|
||||
pub enum DeleteError {
|
||||
#[error("Index not found.")]
|
||||
NotExist,
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
#[derive(Debug, Clone, Error, Serialize, Deserialize)]
|
||||
pub enum BasicError {
|
||||
#[error("Index not found.")]
|
||||
NotExist,
|
||||
#[error("Invalid vector.")]
|
||||
InvalidVector,
|
||||
#[error("Invalid search options.")]
|
||||
InvalidSearchOptions { reason: String },
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
#[derive(Debug, Clone, Error, Serialize, Deserialize)]
|
||||
pub enum VbaseError {
|
||||
#[error("Index not found.")]
|
||||
NotExist,
|
||||
#[error("Invalid vector.")]
|
||||
InvalidVector,
|
||||
#[error("Invalid search options.")]
|
||||
InvalidSearchOptions { reason: String },
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
#[derive(Debug, Clone, Error, Serialize, Deserialize)]
|
||||
pub enum ListError {
|
||||
#[error("Index not found.")]
|
||||
NotExist,
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
#[derive(Debug, Clone, Error, Serialize, Deserialize)]
|
||||
pub enum StatError {
|
||||
#[error("Index not found.")]
|
||||
NotExist,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Validate)]
|
||||
#[serde(deny_unknown_fields)]
|
||||
#[validate(schema(function = "IndexOptions::validate_index_options"))]
|
||||
|
@ -9,9 +9,8 @@
|
||||
#![allow(clippy::nonminimal_bool)]
|
||||
|
||||
pub mod distance;
|
||||
pub mod error;
|
||||
pub mod global;
|
||||
pub mod index;
|
||||
pub mod operator;
|
||||
pub mod scalar;
|
||||
pub mod search;
|
||||
pub mod vector;
|
||||
|
17
crates/base/src/operator/bvecf32_cos.rs
Normal file
17
crates/base/src/operator/bvecf32_cos.rs
Normal file
@ -0,0 +1,17 @@
|
||||
use crate::distance::*;
|
||||
use crate::operator::*;
|
||||
use crate::scalar::*;
|
||||
use crate::vector::*;
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub enum BVecf32Cos {}
|
||||
|
||||
impl Operator for BVecf32Cos {
|
||||
type VectorOwned = BVecf32Owned;
|
||||
|
||||
const DISTANCE_KIND: DistanceKind = DistanceKind::Cos;
|
||||
|
||||
fn distance(lhs: Borrowed<'_, Self>, rhs: Borrowed<'_, Self>) -> F32 {
|
||||
F32(1.0) - bvecf32::cosine(lhs, rhs)
|
||||
}
|
||||
}
|
17
crates/base/src/operator/bvecf32_dot.rs
Normal file
17
crates/base/src/operator/bvecf32_dot.rs
Normal file
@ -0,0 +1,17 @@
|
||||
use crate::distance::*;
|
||||
use crate::operator::*;
|
||||
use crate::scalar::*;
|
||||
use crate::vector::*;
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub enum BVecf32Dot {}
|
||||
|
||||
impl Operator for BVecf32Dot {
|
||||
type VectorOwned = BVecf32Owned;
|
||||
|
||||
const DISTANCE_KIND: DistanceKind = DistanceKind::Dot;
|
||||
|
||||
fn distance(lhs: Borrowed<'_, Self>, rhs: Borrowed<'_, Self>) -> F32 {
|
||||
bvecf32::dot(lhs, rhs) * (-1.0)
|
||||
}
|
||||
}
|
17
crates/base/src/operator/bvecf32_jaccard.rs
Normal file
17
crates/base/src/operator/bvecf32_jaccard.rs
Normal file
@ -0,0 +1,17 @@
|
||||
use crate::distance::*;
|
||||
use crate::operator::*;
|
||||
use crate::scalar::*;
|
||||
use crate::vector::*;
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub enum BVecf32Jaccard {}
|
||||
|
||||
impl Operator for BVecf32Jaccard {
|
||||
type VectorOwned = BVecf32Owned;
|
||||
|
||||
const DISTANCE_KIND: DistanceKind = DistanceKind::Jaccard;
|
||||
|
||||
fn distance(lhs: Borrowed<'_, Self>, rhs: Borrowed<'_, Self>) -> F32 {
|
||||
F32(1.) - bvecf32::jaccard(lhs, rhs)
|
||||
}
|
||||
}
|
17
crates/base/src/operator/bvecf32_l2.rs
Normal file
17
crates/base/src/operator/bvecf32_l2.rs
Normal file
@ -0,0 +1,17 @@
|
||||
use crate::distance::*;
|
||||
use crate::operator::*;
|
||||
use crate::scalar::*;
|
||||
use crate::vector::*;
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub enum BVecf32L2 {}
|
||||
|
||||
impl Operator for BVecf32L2 {
|
||||
type VectorOwned = BVecf32Owned;
|
||||
|
||||
const DISTANCE_KIND: DistanceKind = DistanceKind::L2;
|
||||
|
||||
fn distance(lhs: Borrowed<'_, Self>, rhs: Borrowed<'_, Self>) -> F32 {
|
||||
bvecf32::sl2(lhs, rhs)
|
||||
}
|
||||
}
|
49
crates/base/src/operator/mod.rs
Normal file
49
crates/base/src/operator/mod.rs
Normal file
@ -0,0 +1,49 @@
|
||||
mod bvecf32_cos;
|
||||
mod bvecf32_dot;
|
||||
mod bvecf32_jaccard;
|
||||
mod bvecf32_l2;
|
||||
mod svecf32_cos;
|
||||
mod svecf32_dot;
|
||||
mod svecf32_l2;
|
||||
mod vecf16_cos;
|
||||
mod vecf16_dot;
|
||||
mod vecf16_l2;
|
||||
mod vecf32_cos;
|
||||
mod vecf32_dot;
|
||||
mod vecf32_l2;
|
||||
mod veci8_cos;
|
||||
mod veci8_dot;
|
||||
mod veci8_l2;
|
||||
|
||||
pub use bvecf32_cos::BVecf32Cos;
|
||||
pub use bvecf32_dot::BVecf32Dot;
|
||||
pub use bvecf32_jaccard::BVecf32Jaccard;
|
||||
pub use bvecf32_l2::BVecf32L2;
|
||||
pub use svecf32_cos::SVecf32Cos;
|
||||
pub use svecf32_dot::SVecf32Dot;
|
||||
pub use svecf32_l2::SVecf32L2;
|
||||
pub use vecf16_cos::Vecf16Cos;
|
||||
pub use vecf16_dot::Vecf16Dot;
|
||||
pub use vecf16_l2::Vecf16L2;
|
||||
pub use vecf32_cos::Vecf32Cos;
|
||||
pub use vecf32_dot::Vecf32Dot;
|
||||
pub use vecf32_l2::Vecf32L2;
|
||||
pub use veci8_cos::Veci8Cos;
|
||||
pub use veci8_dot::Veci8Dot;
|
||||
pub use veci8_l2::Veci8L2;
|
||||
|
||||
use crate::distance::*;
|
||||
use crate::scalar::*;
|
||||
use crate::vector::*;
|
||||
|
||||
pub trait Operator: Copy + 'static {
|
||||
type VectorOwned: VectorOwned;
|
||||
|
||||
const DISTANCE_KIND: DistanceKind;
|
||||
|
||||
fn distance(lhs: Borrowed<'_, Self>, rhs: Borrowed<'_, Self>) -> F32;
|
||||
}
|
||||
|
||||
pub type Owned<T> = <T as Operator>::VectorOwned;
|
||||
pub type Borrowed<'a, T> = <<T as Operator>::VectorOwned as VectorOwned>::Borrowed<'a>;
|
||||
pub type Scalar<T> = <<T as Operator>::VectorOwned as VectorOwned>::Scalar;
|
17
crates/base/src/operator/svecf32_cos.rs
Normal file
17
crates/base/src/operator/svecf32_cos.rs
Normal file
@ -0,0 +1,17 @@
|
||||
use crate::distance::*;
|
||||
use crate::operator::*;
|
||||
use crate::scalar::*;
|
||||
use crate::vector::*;
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub enum SVecf32Cos {}
|
||||
|
||||
impl Operator for SVecf32Cos {
|
||||
type VectorOwned = SVecf32Owned;
|
||||
|
||||
const DISTANCE_KIND: DistanceKind = DistanceKind::Cos;
|
||||
|
||||
fn distance(lhs: Borrowed<'_, Self>, rhs: Borrowed<'_, Self>) -> F32 {
|
||||
F32(1.0) - svecf32::cosine(lhs, rhs)
|
||||
}
|
||||
}
|
17
crates/base/src/operator/svecf32_dot.rs
Normal file
17
crates/base/src/operator/svecf32_dot.rs
Normal file
@ -0,0 +1,17 @@
|
||||
use crate::distance::*;
|
||||
use crate::operator::*;
|
||||
use crate::scalar::*;
|
||||
use crate::vector::*;
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub enum SVecf32Dot {}
|
||||
|
||||
impl Operator for SVecf32Dot {
|
||||
type VectorOwned = SVecf32Owned;
|
||||
|
||||
const DISTANCE_KIND: DistanceKind = DistanceKind::Dot;
|
||||
|
||||
fn distance(lhs: Borrowed<'_, Self>, rhs: Borrowed<'_, Self>) -> F32 {
|
||||
svecf32::dot(lhs, rhs) * (-1.0)
|
||||
}
|
||||
}
|
17
crates/base/src/operator/svecf32_l2.rs
Normal file
17
crates/base/src/operator/svecf32_l2.rs
Normal file
@ -0,0 +1,17 @@
|
||||
use crate::distance::*;
|
||||
use crate::operator::*;
|
||||
use crate::scalar::*;
|
||||
use crate::vector::*;
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub enum SVecf32L2 {}
|
||||
|
||||
impl Operator for SVecf32L2 {
|
||||
type VectorOwned = SVecf32Owned;
|
||||
|
||||
const DISTANCE_KIND: DistanceKind = DistanceKind::L2;
|
||||
|
||||
fn distance(lhs: SVecf32Borrowed<'_>, rhs: SVecf32Borrowed<'_>) -> F32 {
|
||||
svecf32::sl2(lhs, rhs)
|
||||
}
|
||||
}
|
17
crates/base/src/operator/vecf16_cos.rs
Normal file
17
crates/base/src/operator/vecf16_cos.rs
Normal file
@ -0,0 +1,17 @@
|
||||
use crate::distance::*;
|
||||
use crate::operator::*;
|
||||
use crate::scalar::*;
|
||||
use crate::vector::*;
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub enum Vecf16Cos {}
|
||||
|
||||
impl Operator for Vecf16Cos {
|
||||
type VectorOwned = Vecf16Owned;
|
||||
|
||||
const DISTANCE_KIND: DistanceKind = DistanceKind::Cos;
|
||||
|
||||
fn distance(lhs: Vecf16Borrowed<'_>, rhs: Vecf16Borrowed<'_>) -> F32 {
|
||||
F32(1.0) - vecf16::cosine(lhs.slice(), rhs.slice())
|
||||
}
|
||||
}
|
17
crates/base/src/operator/vecf16_dot.rs
Normal file
17
crates/base/src/operator/vecf16_dot.rs
Normal file
@ -0,0 +1,17 @@
|
||||
use crate::distance::*;
|
||||
use crate::operator::*;
|
||||
use crate::scalar::*;
|
||||
use crate::vector::*;
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub enum Vecf16Dot {}
|
||||
|
||||
impl Operator for Vecf16Dot {
|
||||
type VectorOwned = Vecf16Owned;
|
||||
|
||||
const DISTANCE_KIND: DistanceKind = DistanceKind::Dot;
|
||||
|
||||
fn distance(lhs: Vecf16Borrowed<'_>, rhs: Vecf16Borrowed<'_>) -> F32 {
|
||||
vecf16::dot(lhs.slice(), rhs.slice()) * (-1.0)
|
||||
}
|
||||
}
|
17
crates/base/src/operator/vecf16_l2.rs
Normal file
17
crates/base/src/operator/vecf16_l2.rs
Normal file
@ -0,0 +1,17 @@
|
||||
use crate::distance::*;
|
||||
use crate::operator::*;
|
||||
use crate::scalar::*;
|
||||
use crate::vector::*;
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub enum Vecf16L2 {}
|
||||
|
||||
impl Operator for Vecf16L2 {
|
||||
type VectorOwned = Vecf16Owned;
|
||||
|
||||
const DISTANCE_KIND: DistanceKind = DistanceKind::L2;
|
||||
|
||||
fn distance(lhs: Vecf16Borrowed<'_>, rhs: Vecf16Borrowed<'_>) -> F32 {
|
||||
vecf16::sl2(lhs.slice(), rhs.slice())
|
||||
}
|
||||
}
|
17
crates/base/src/operator/vecf32_cos.rs
Normal file
17
crates/base/src/operator/vecf32_cos.rs
Normal file
@ -0,0 +1,17 @@
|
||||
use crate::distance::*;
|
||||
use crate::operator::*;
|
||||
use crate::scalar::*;
|
||||
use crate::vector::*;
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub enum Vecf32Cos {}
|
||||
|
||||
impl Operator for Vecf32Cos {
|
||||
type VectorOwned = Vecf32Owned;
|
||||
|
||||
const DISTANCE_KIND: DistanceKind = DistanceKind::Cos;
|
||||
|
||||
fn distance(lhs: Vecf32Borrowed<'_>, rhs: Vecf32Borrowed<'_>) -> F32 {
|
||||
F32(1.0) - vecf32::cosine(lhs.slice(), rhs.slice())
|
||||
}
|
||||
}
|
17
crates/base/src/operator/vecf32_dot.rs
Normal file
17
crates/base/src/operator/vecf32_dot.rs
Normal file
@ -0,0 +1,17 @@
|
||||
use crate::distance::*;
|
||||
use crate::operator::*;
|
||||
use crate::scalar::*;
|
||||
use crate::vector::*;
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub enum Vecf32Dot {}
|
||||
|
||||
impl Operator for Vecf32Dot {
|
||||
type VectorOwned = Vecf32Owned;
|
||||
|
||||
const DISTANCE_KIND: DistanceKind = DistanceKind::Dot;
|
||||
|
||||
fn distance(lhs: Vecf32Borrowed<'_>, rhs: Vecf32Borrowed<'_>) -> F32 {
|
||||
vecf32::dot(lhs.slice(), rhs.slice()) * (-1.0)
|
||||
}
|
||||
}
|
17
crates/base/src/operator/vecf32_l2.rs
Normal file
17
crates/base/src/operator/vecf32_l2.rs
Normal file
@ -0,0 +1,17 @@
|
||||
use crate::distance::*;
|
||||
use crate::operator::*;
|
||||
use crate::scalar::*;
|
||||
use crate::vector::*;
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub enum Vecf32L2 {}
|
||||
|
||||
impl Operator for Vecf32L2 {
|
||||
type VectorOwned = Vecf32Owned;
|
||||
|
||||
const DISTANCE_KIND: DistanceKind = DistanceKind::L2;
|
||||
|
||||
fn distance(lhs: Vecf32Borrowed<'_>, rhs: Vecf32Borrowed<'_>) -> F32 {
|
||||
vecf32::sl2(lhs.slice(), rhs.slice())
|
||||
}
|
||||
}
|
17
crates/base/src/operator/veci8_cos.rs
Normal file
17
crates/base/src/operator/veci8_cos.rs
Normal file
@ -0,0 +1,17 @@
|
||||
use crate::distance::*;
|
||||
use crate::operator::*;
|
||||
use crate::scalar::*;
|
||||
use crate::vector::*;
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub enum Veci8Cos {}
|
||||
|
||||
impl Operator for Veci8Cos {
|
||||
type VectorOwned = Veci8Owned;
|
||||
|
||||
const DISTANCE_KIND: DistanceKind = DistanceKind::Cos;
|
||||
|
||||
fn distance(lhs: Borrowed<'_, Self>, rhs: Borrowed<'_, Self>) -> F32 {
|
||||
F32(1.0) - veci8::cosine_distance(&lhs, &rhs)
|
||||
}
|
||||
}
|
17
crates/base/src/operator/veci8_dot.rs
Normal file
17
crates/base/src/operator/veci8_dot.rs
Normal file
@ -0,0 +1,17 @@
|
||||
use crate::distance::*;
|
||||
use crate::operator::*;
|
||||
use crate::scalar::*;
|
||||
use crate::vector::*;
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub enum Veci8Dot {}
|
||||
|
||||
impl Operator for Veci8Dot {
|
||||
type VectorOwned = Veci8Owned;
|
||||
|
||||
const DISTANCE_KIND: DistanceKind = DistanceKind::Dot;
|
||||
|
||||
fn distance(lhs: Borrowed<'_, Self>, rhs: Borrowed<'_, Self>) -> F32 {
|
||||
veci8::dot_distance(&lhs, &rhs) * (-1.0)
|
||||
}
|
||||
}
|
17
crates/base/src/operator/veci8_l2.rs
Normal file
17
crates/base/src/operator/veci8_l2.rs
Normal file
@ -0,0 +1,17 @@
|
||||
use crate::distance::*;
|
||||
use crate::operator::*;
|
||||
use crate::scalar::*;
|
||||
use crate::vector::*;
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub enum Veci8L2 {}
|
||||
|
||||
impl Operator for Veci8L2 {
|
||||
type VectorOwned = Veci8Owned;
|
||||
|
||||
const DISTANCE_KIND: DistanceKind = DistanceKind::Dot;
|
||||
|
||||
fn distance(lhs: Borrowed<'_, Self>, rhs: Borrowed<'_, Self>) -> F32 {
|
||||
veci8::l2_distance(&lhs, &rhs)
|
||||
}
|
||||
}
|
@ -2,6 +2,7 @@ use super::ScalarLike;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::cmp::Ordering;
|
||||
use std::fmt::{Debug, Display};
|
||||
use std::hash::{Hash, Hasher};
|
||||
use std::iter::Sum;
|
||||
use std::num::ParseFloatError;
|
||||
use std::ops::*;
|
||||
@ -12,6 +13,20 @@ use std::str::FromStr;
|
||||
#[serde(transparent)]
|
||||
pub struct F32(pub f32);
|
||||
|
||||
impl Hash for F32 {
|
||||
fn hash<H: Hasher>(&self, state: &mut H) {
|
||||
use num_traits::{Float, Zero};
|
||||
let bits = if self.is_nan() {
|
||||
f32::NAN.to_bits()
|
||||
} else if self.is_zero() {
|
||||
f32::zero().to_bits()
|
||||
} else {
|
||||
self.0.to_bits()
|
||||
};
|
||||
bits.hash(state)
|
||||
}
|
||||
}
|
||||
|
||||
impl Debug for F32 {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
Debug::fmt(&self.0, f)
|
||||
|
@ -1,57 +1,81 @@
|
||||
use crate::operator::{Borrowed, Operator};
|
||||
use crate::scalar::F32;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{fmt::Display, num::ParseIntError, str::FromStr};
|
||||
|
||||
pub type Payload = u64;
|
||||
|
||||
pub trait Filter: Clone {
|
||||
fn check(&mut self, payload: Payload) -> bool;
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
|
||||
pub struct Element {
|
||||
pub distance: F32,
|
||||
pub payload: Payload,
|
||||
}
|
||||
use std::fmt::Display;
|
||||
|
||||
#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
|
||||
pub struct Handle {
|
||||
pub newtype: u32,
|
||||
newtype: u128,
|
||||
}
|
||||
|
||||
impl Handle {
|
||||
pub fn as_u32(self) -> u32 {
|
||||
pub fn new(newtype: u128) -> Self {
|
||||
Self { newtype }
|
||||
}
|
||||
pub fn as_u128(self) -> u128 {
|
||||
self.newtype
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for Handle {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{}", self.as_u32())
|
||||
write!(f, "{:x}", self.as_u128())
|
||||
}
|
||||
}
|
||||
|
||||
impl FromStr for Handle {
|
||||
type Err = ParseIntError;
|
||||
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
Ok(Handle {
|
||||
newtype: u32::from_str(s)?,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
|
||||
pub struct Pointer {
|
||||
pub newtype: u64,
|
||||
newtype: u64,
|
||||
}
|
||||
|
||||
impl Pointer {
|
||||
pub fn from_u48(value: u64) -> Self {
|
||||
assert!(value < (1u64 << 48));
|
||||
pub fn new(value: u64) -> Self {
|
||||
Self { newtype: value }
|
||||
}
|
||||
pub fn as_u48(self) -> u64 {
|
||||
pub fn as_u64(self) -> u64 {
|
||||
self.newtype
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
|
||||
#[repr(C)]
|
||||
pub struct Payload {
|
||||
pointer: Pointer,
|
||||
time: u64,
|
||||
}
|
||||
|
||||
impl Payload {
|
||||
pub fn new(pointer: Pointer, time: u64) -> Self {
|
||||
Self { pointer, time }
|
||||
}
|
||||
pub fn pointer(&self) -> Pointer {
|
||||
self.pointer
|
||||
}
|
||||
pub fn time(&self) -> u64 {
|
||||
self.time
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl bytemuck::Zeroable for Payload {}
|
||||
unsafe impl bytemuck::Pod for Payload {}
|
||||
|
||||
#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
|
||||
pub struct Element {
|
||||
pub distance: F32,
|
||||
pub payload: Payload,
|
||||
}
|
||||
|
||||
pub trait Filter: Clone {
|
||||
fn check(&mut self, payload: Payload) -> bool;
|
||||
}
|
||||
|
||||
pub trait Collection<O: Operator> {
|
||||
fn dims(&self) -> u32;
|
||||
fn len(&self) -> u32;
|
||||
fn vector(&self, i: u32) -> Borrowed<'_, O>;
|
||||
fn payload(&self, i: u32) -> Payload;
|
||||
}
|
||||
|
||||
pub trait Source<O: Operator>: Collection<O> {
|
||||
// ..
|
||||
}
|
||||
|
@ -1,5 +1,6 @@
|
||||
use super::{VectorBorrowed, VectorOwned};
|
||||
use crate::scalar::F32;
|
||||
use crate::vector::{Vecf32Owned, VectorBorrowed, VectorKind, VectorOwned};
|
||||
use num_traits::Float;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
pub const BVEC_WIDTH: usize = usize::BITS as usize;
|
||||
@ -43,6 +44,8 @@ impl VectorOwned for BVecf32Owned {
|
||||
type Scalar = F32;
|
||||
type Borrowed<'a> = BVecf32Borrowed<'a>;
|
||||
|
||||
const VECTOR_KIND: VectorKind = VectorKind::BVecf32;
|
||||
|
||||
#[inline(always)]
|
||||
fn dims(&self) -> u32 {
|
||||
self.dims as u32
|
||||
@ -159,3 +162,404 @@ impl<'a> PartialOrd for BVecf32Borrowed<'a> {
|
||||
Some(self.cmp(other))
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn cosine<'a>(lhs: BVecf32Borrowed<'a>, rhs: BVecf32Borrowed<'a>) -> F32 {
|
||||
let lhs = lhs.data();
|
||||
let rhs = rhs.data();
|
||||
assert!(lhs.len() == rhs.len());
|
||||
|
||||
#[inline(always)]
|
||||
#[multiversion::multiversion(targets(
|
||||
"x86_64/x86-64-v4",
|
||||
"x86_64/x86-64-v3",
|
||||
"x86_64/x86-64-v2",
|
||||
"aarch64+neon"
|
||||
))]
|
||||
pub fn cosine(lhs: &[usize], rhs: &[usize]) -> F32 {
|
||||
let mut xy = 0;
|
||||
let mut xx = 0;
|
||||
let mut yy = 0;
|
||||
for i in 0..lhs.len() {
|
||||
xy += (lhs[i] & rhs[i]).count_ones();
|
||||
xx += lhs[i].count_ones();
|
||||
yy += rhs[i].count_ones();
|
||||
}
|
||||
let rxy = xy as f32;
|
||||
let rxx = xx as f32;
|
||||
let ryy = yy as f32;
|
||||
F32(rxy / (rxx * ryy).sqrt())
|
||||
}
|
||||
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
#[target_feature(enable = "avx512vpopcntdq,avx512bw,avx512f,bmi2")]
|
||||
unsafe fn cosine_avx512vpopcntdq(lhs: &[usize], rhs: &[usize]) -> F32 {
|
||||
use std::arch::x86_64::*;
|
||||
#[inline]
|
||||
#[target_feature(enable = "avx512vpopcntdq,avx512bw,avx512f,bmi2")]
|
||||
pub unsafe fn _mm512_maskz_loadu_epi64(k: __mmask8, mem_addr: *const i8) -> __m512i {
|
||||
let mut dst: __m512i;
|
||||
unsafe {
|
||||
std::arch::asm!(
|
||||
"vmovdqu64 {dst}{{{k}}} {{z}}, [{p}]",
|
||||
p = in(reg) mem_addr,
|
||||
k = in(kreg) k,
|
||||
dst = out(zmm_reg) dst,
|
||||
options(pure, readonly, nostack)
|
||||
);
|
||||
}
|
||||
dst
|
||||
}
|
||||
assert_eq!(lhs.len(), rhs.len());
|
||||
unsafe {
|
||||
const WIDTH: usize = 512 / 8 / std::mem::size_of::<usize>();
|
||||
let mut xy = _mm512_setzero_si512();
|
||||
let mut xx = _mm512_setzero_si512();
|
||||
let mut yy = _mm512_setzero_si512();
|
||||
let mut a = lhs.as_ptr();
|
||||
let mut b = rhs.as_ptr();
|
||||
let mut n = lhs.len();
|
||||
while n >= WIDTH {
|
||||
let x = _mm512_loadu_si512(a.cast());
|
||||
let y = _mm512_loadu_si512(b.cast());
|
||||
a = a.add(WIDTH);
|
||||
b = b.add(WIDTH);
|
||||
n -= WIDTH;
|
||||
xy = _mm512_add_epi64(xy, _mm512_popcnt_epi64(_mm512_and_si512(x, y)));
|
||||
xx = _mm512_add_epi64(xx, _mm512_popcnt_epi64(x));
|
||||
yy = _mm512_add_epi64(yy, _mm512_popcnt_epi64(y));
|
||||
}
|
||||
if n > 0 {
|
||||
let mask = _bzhi_u32(0xFFFF, n as u32) as u8;
|
||||
let x = _mm512_maskz_loadu_epi64(mask, a.cast());
|
||||
let y = _mm512_maskz_loadu_epi64(mask, b.cast());
|
||||
xy = _mm512_add_epi64(xy, _mm512_popcnt_epi64(_mm512_and_si512(x, y)));
|
||||
xx = _mm512_add_epi64(xx, _mm512_popcnt_epi64(x));
|
||||
yy = _mm512_add_epi64(yy, _mm512_popcnt_epi64(y));
|
||||
}
|
||||
let rxy = _mm512_reduce_add_epi64(xy) as f32;
|
||||
let rxx = _mm512_reduce_add_epi64(xx) as f32;
|
||||
let ryy = _mm512_reduce_add_epi64(yy) as f32;
|
||||
F32(rxy / (rxx * ryy).sqrt())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
if detect::x86_64::detect_avx512vpopcntdq() {
|
||||
unsafe {
|
||||
return cosine_avx512vpopcntdq(lhs, rhs);
|
||||
}
|
||||
}
|
||||
cosine(lhs, rhs)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn dot<'a>(lhs: BVecf32Borrowed<'a>, rhs: BVecf32Borrowed<'a>) -> F32 {
|
||||
let lhs = lhs.data();
|
||||
let rhs = rhs.data();
|
||||
assert!(lhs.len() == rhs.len());
|
||||
|
||||
#[inline(always)]
|
||||
#[multiversion::multiversion(targets(
|
||||
"x86_64/x86-64-v4",
|
||||
"x86_64/x86-64-v3",
|
||||
"x86_64/x86-64-v2",
|
||||
"aarch64+neon"
|
||||
))]
|
||||
pub fn dot(lhs: &[usize], rhs: &[usize]) -> F32 {
|
||||
let mut xy = 0;
|
||||
for i in 0..lhs.len() {
|
||||
xy += (lhs[i] & rhs[i]).count_ones();
|
||||
}
|
||||
F32(xy as f32)
|
||||
}
|
||||
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
#[target_feature(enable = "avx512vpopcntdq,avx512bw,avx512f,bmi2")]
|
||||
unsafe fn dot_avx512vpopcntdq(lhs: &[usize], rhs: &[usize]) -> F32 {
|
||||
use std::arch::x86_64::*;
|
||||
#[inline]
|
||||
#[target_feature(enable = "avx512vpopcntdq,avx512bw,avx512f,bmi2")]
|
||||
pub unsafe fn _mm512_maskz_loadu_epi64(k: __mmask8, mem_addr: *const i8) -> __m512i {
|
||||
let mut dst: __m512i;
|
||||
unsafe {
|
||||
std::arch::asm!(
|
||||
"vmovdqu64 {dst}{{{k}}} {{z}}, [{p}]",
|
||||
p = in(reg) mem_addr,
|
||||
k = in(kreg) k,
|
||||
dst = out(zmm_reg) dst,
|
||||
options(pure, readonly, nostack)
|
||||
);
|
||||
}
|
||||
dst
|
||||
}
|
||||
assert_eq!(lhs.len(), rhs.len());
|
||||
unsafe {
|
||||
const WIDTH: usize = 512 / 8 / std::mem::size_of::<usize>();
|
||||
let mut xy = _mm512_setzero_si512();
|
||||
let mut a = lhs.as_ptr();
|
||||
let mut b = rhs.as_ptr();
|
||||
let mut n = lhs.len();
|
||||
while n >= WIDTH {
|
||||
let x = _mm512_loadu_si512(a.cast());
|
||||
let y = _mm512_loadu_si512(b.cast());
|
||||
a = a.add(WIDTH);
|
||||
b = b.add(WIDTH);
|
||||
n -= WIDTH;
|
||||
xy = _mm512_add_epi64(xy, _mm512_popcnt_epi64(_mm512_and_si512(x, y)));
|
||||
}
|
||||
if n > 0 {
|
||||
let mask = _bzhi_u32(0xFFFF, n as u32) as u8;
|
||||
let x = _mm512_maskz_loadu_epi64(mask, a.cast());
|
||||
let y = _mm512_maskz_loadu_epi64(mask, b.cast());
|
||||
xy = _mm512_add_epi64(xy, _mm512_popcnt_epi64(_mm512_and_si512(x, y)));
|
||||
}
|
||||
let rxy = _mm512_reduce_add_epi64(xy) as f32;
|
||||
F32(rxy)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
if detect::x86_64::detect_avx512vpopcntdq() {
|
||||
unsafe {
|
||||
return dot_avx512vpopcntdq(lhs, rhs);
|
||||
}
|
||||
}
|
||||
dot(lhs, rhs)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn sl2<'a>(lhs: BVecf32Borrowed<'a>, rhs: BVecf32Borrowed<'a>) -> F32 {
|
||||
let lhs = lhs.data();
|
||||
let rhs = rhs.data();
|
||||
assert!(lhs.len() == rhs.len());
|
||||
|
||||
#[inline(always)]
|
||||
#[multiversion::multiversion(targets(
|
||||
"x86_64/x86-64-v4",
|
||||
"x86_64/x86-64-v3",
|
||||
"x86_64/x86-64-v2",
|
||||
"aarch64+neon"
|
||||
))]
|
||||
pub fn sl2(lhs: &[usize], rhs: &[usize]) -> F32 {
|
||||
let mut dd = 0;
|
||||
for i in 0..lhs.len() {
|
||||
dd += (lhs[i] ^ rhs[i]).count_ones();
|
||||
}
|
||||
F32(dd as f32)
|
||||
}
|
||||
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
#[target_feature(enable = "avx512vpopcntdq,avx512bw,avx512f,bmi2")]
|
||||
unsafe fn sl2_avx512vpopcntdq(lhs: &[usize], rhs: &[usize]) -> F32 {
|
||||
use std::arch::x86_64::*;
|
||||
#[inline]
|
||||
#[target_feature(enable = "avx512vpopcntdq,avx512bw,avx512f,bmi2")]
|
||||
pub unsafe fn _mm512_maskz_loadu_epi64(k: __mmask8, mem_addr: *const i8) -> __m512i {
|
||||
let mut dst: __m512i;
|
||||
unsafe {
|
||||
std::arch::asm!(
|
||||
"vmovdqu64 {dst}{{{k}}} {{z}}, [{p}]",
|
||||
p = in(reg) mem_addr,
|
||||
k = in(kreg) k,
|
||||
dst = out(zmm_reg) dst,
|
||||
options(pure, readonly, nostack)
|
||||
);
|
||||
}
|
||||
dst
|
||||
}
|
||||
assert_eq!(lhs.len(), rhs.len());
|
||||
unsafe {
|
||||
const WIDTH: usize = 512 / 8 / std::mem::size_of::<usize>();
|
||||
let mut dd = _mm512_setzero_si512();
|
||||
let mut a = lhs.as_ptr();
|
||||
let mut b = rhs.as_ptr();
|
||||
let mut n = lhs.len();
|
||||
while n >= WIDTH {
|
||||
let x = _mm512_loadu_si512(a.cast());
|
||||
let y = _mm512_loadu_si512(b.cast());
|
||||
a = a.add(WIDTH);
|
||||
b = b.add(WIDTH);
|
||||
n -= WIDTH;
|
||||
dd = _mm512_add_epi64(dd, _mm512_popcnt_epi64(_mm512_xor_si512(x, y)));
|
||||
}
|
||||
if n > 0 {
|
||||
let mask = _bzhi_u32(0xFFFF, n as u32) as u8;
|
||||
let x = _mm512_maskz_loadu_epi64(mask, a.cast());
|
||||
let y = _mm512_maskz_loadu_epi64(mask, b.cast());
|
||||
dd = _mm512_add_epi64(dd, _mm512_popcnt_epi64(_mm512_xor_si512(x, y)));
|
||||
}
|
||||
let rdd = _mm512_reduce_add_epi64(dd) as f32;
|
||||
F32(rdd)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
if detect::x86_64::detect_avx512vpopcntdq() {
|
||||
unsafe {
|
||||
return sl2_avx512vpopcntdq(lhs, rhs);
|
||||
}
|
||||
}
|
||||
sl2(lhs, rhs)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn jaccard<'a>(lhs: BVecf32Borrowed<'a>, rhs: BVecf32Borrowed<'a>) -> F32 {
|
||||
let lhs = lhs.data();
|
||||
let rhs = rhs.data();
|
||||
assert!(lhs.len() == rhs.len());
|
||||
|
||||
#[inline(always)]
|
||||
#[multiversion::multiversion(targets(
|
||||
"x86_64/x86-64-v4",
|
||||
"x86_64/x86-64-v3",
|
||||
"x86_64/x86-64-v2",
|
||||
"aarch64+neon"
|
||||
))]
|
||||
pub fn jaccard(lhs: &[usize], rhs: &[usize]) -> F32 {
|
||||
let mut inter = 0;
|
||||
let mut union = 0;
|
||||
for i in 0..lhs.len() {
|
||||
inter += (lhs[i] & rhs[i]).count_ones();
|
||||
union += (lhs[i] | rhs[i]).count_ones();
|
||||
}
|
||||
F32(inter as f32 / union as f32)
|
||||
}
|
||||
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
#[target_feature(enable = "avx512vpopcntdq,avx512bw,avx512f,bmi2")]
|
||||
unsafe fn jaccard_avx512vpopcntdq(lhs: &[usize], rhs: &[usize]) -> F32 {
|
||||
use std::arch::x86_64::*;
|
||||
#[inline]
|
||||
#[target_feature(enable = "avx512vpopcntdq,avx512bw,avx512f,bmi2")]
|
||||
pub unsafe fn _mm512_maskz_loadu_epi64(k: __mmask8, mem_addr: *const i8) -> __m512i {
|
||||
let mut dst: __m512i;
|
||||
unsafe {
|
||||
std::arch::asm!(
|
||||
"vmovdqu64 {dst}{{{k}}} {{z}}, [{p}]",
|
||||
p = in(reg) mem_addr,
|
||||
k = in(kreg) k,
|
||||
dst = out(zmm_reg) dst,
|
||||
options(pure, readonly, nostack)
|
||||
);
|
||||
}
|
||||
dst
|
||||
}
|
||||
assert_eq!(lhs.len(), rhs.len());
|
||||
unsafe {
|
||||
const WIDTH: usize = 512 / 8 / std::mem::size_of::<usize>();
|
||||
let mut inter = _mm512_setzero_si512();
|
||||
let mut union = _mm512_setzero_si512();
|
||||
let mut a = lhs.as_ptr();
|
||||
let mut b = rhs.as_ptr();
|
||||
let mut n = lhs.len();
|
||||
while n >= WIDTH {
|
||||
let x = _mm512_loadu_si512(a.cast());
|
||||
let y = _mm512_loadu_si512(b.cast());
|
||||
a = a.add(WIDTH);
|
||||
b = b.add(WIDTH);
|
||||
n -= WIDTH;
|
||||
inter = _mm512_add_epi64(inter, _mm512_popcnt_epi64(_mm512_and_si512(x, y)));
|
||||
union = _mm512_add_epi64(union, _mm512_popcnt_epi64(_mm512_or_si512(x, y)));
|
||||
}
|
||||
if n > 0 {
|
||||
let mask = _bzhi_u32(0xFFFF, n as u32) as u8;
|
||||
let x = _mm512_maskz_loadu_epi64(mask, a.cast());
|
||||
let y = _mm512_maskz_loadu_epi64(mask, b.cast());
|
||||
inter = _mm512_add_epi64(inter, _mm512_popcnt_epi64(_mm512_and_si512(x, y)));
|
||||
union = _mm512_add_epi64(union, _mm512_popcnt_epi64(_mm512_or_si512(x, y)));
|
||||
}
|
||||
let rinter = _mm512_reduce_add_epi64(inter) as f32;
|
||||
let runion = _mm512_reduce_add_epi64(union) as f32;
|
||||
F32(rinter / runion)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
if detect::x86_64::detect_avx512vpopcntdq() {
|
||||
unsafe {
|
||||
return jaccard_avx512vpopcntdq(lhs, rhs);
|
||||
}
|
||||
}
|
||||
jaccard(lhs, rhs)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn length(vector: BVecf32Borrowed<'_>) -> F32 {
|
||||
let vector = vector.data();
|
||||
|
||||
#[inline(always)]
|
||||
#[multiversion::multiversion(targets(
|
||||
"x86_64/x86-64-v4",
|
||||
"x86_64/x86-64-v3",
|
||||
"x86_64/x86-64-v2",
|
||||
"aarch64+neon"
|
||||
))]
|
||||
pub fn length(vector: &[usize]) -> F32 {
|
||||
let mut l = 0;
|
||||
for i in 0..vector.len() {
|
||||
l += vector[i].count_ones();
|
||||
}
|
||||
F32(l as f32).sqrt()
|
||||
}
|
||||
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
#[target_feature(enable = "avx512vpopcntdq,avx512bw,avx512f,bmi2")]
|
||||
unsafe fn length_avx512vpopcntdq(lhs: &[usize]) -> F32 {
|
||||
use std::arch::x86_64::*;
|
||||
#[inline]
|
||||
#[target_feature(enable = "avx512vpopcntdq,avx512bw,avx512f,bmi2")]
|
||||
pub unsafe fn _mm512_maskz_loadu_epi64(k: __mmask8, mem_addr: *const i8) -> __m512i {
|
||||
let mut dst: __m512i;
|
||||
unsafe {
|
||||
std::arch::asm!(
|
||||
"vmovdqu64 {dst}{{{k}}} {{z}}, [{p}]",
|
||||
p = in(reg) mem_addr,
|
||||
k = in(kreg) k,
|
||||
dst = out(zmm_reg) dst,
|
||||
options(pure, readonly, nostack)
|
||||
);
|
||||
}
|
||||
dst
|
||||
}
|
||||
unsafe {
|
||||
const WIDTH: usize = 512 / 8 / std::mem::size_of::<usize>();
|
||||
let mut cnt = _mm512_setzero_si512();
|
||||
let mut a = lhs.as_ptr();
|
||||
let mut n = lhs.len();
|
||||
while n >= WIDTH {
|
||||
let x = _mm512_loadu_si512(a.cast());
|
||||
a = a.add(WIDTH);
|
||||
n -= WIDTH;
|
||||
cnt = _mm512_add_epi64(cnt, _mm512_popcnt_epi64(x));
|
||||
}
|
||||
if n > 0 {
|
||||
let mask = _bzhi_u32(0xFFFF, n as u32) as u8;
|
||||
let x = _mm512_maskz_loadu_epi64(mask, a.cast());
|
||||
cnt = _mm512_add_epi64(cnt, _mm512_popcnt_epi64(x));
|
||||
}
|
||||
let rcnt = _mm512_reduce_add_epi64(cnt) as f32;
|
||||
F32(rcnt.sqrt())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
if detect::x86_64::detect_avx512vpopcntdq() {
|
||||
unsafe {
|
||||
return length_avx512vpopcntdq(vector);
|
||||
}
|
||||
}
|
||||
length(vector)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
#[multiversion::multiversion(targets(
|
||||
"x86_64/x86-64-v4",
|
||||
"x86_64/x86-64-v3",
|
||||
"x86_64/x86-64-v2",
|
||||
"aarch64+neon"
|
||||
))]
|
||||
pub fn l2_normalize<'a>(vector: BVecf32Borrowed<'a>) -> Vecf32Owned {
|
||||
let l = length(vector);
|
||||
Vecf32Owned::new(vector.iter().map(|i| F32(i as u32 as f32) / l).collect())
|
||||
}
|
||||
|
@ -1,14 +1,14 @@
|
||||
mod bvecf32;
|
||||
mod svecf32;
|
||||
mod vecf16;
|
||||
mod vecf32;
|
||||
mod veci8;
|
||||
pub mod bvecf32;
|
||||
pub mod svecf32;
|
||||
pub mod vecf16;
|
||||
pub mod vecf32;
|
||||
pub mod veci8;
|
||||
|
||||
pub use bvecf32::{BVecf32Borrowed, BVecf32Owned, BVEC_WIDTH};
|
||||
pub use svecf32::{SVecf32Borrowed, SVecf32Owned};
|
||||
pub use vecf16::{Vecf16Borrowed, Vecf16Owned};
|
||||
pub use vecf32::{Vecf32Borrowed, Vecf32Owned};
|
||||
pub use veci8::{i8_dequantization, i8_precompute, i8_quantization, Veci8Borrowed, Veci8Owned};
|
||||
pub use veci8::{Veci8Borrowed, Veci8Owned};
|
||||
|
||||
use crate::scalar::ScalarLike;
|
||||
use serde::{Deserialize, Serialize};
|
||||
@ -27,6 +27,8 @@ pub trait VectorOwned: Clone + Serialize + for<'a> Deserialize<'a> + 'static {
|
||||
type Scalar: ScalarLike;
|
||||
type Borrowed<'a>: VectorBorrowed<Scalar = Self::Scalar, Owned = Self>;
|
||||
|
||||
const VECTOR_KIND: VectorKind;
|
||||
|
||||
fn for_borrow(&self) -> Self::Borrowed<'_>;
|
||||
|
||||
fn dims(&self) -> u32;
|
||||
|
@ -1,6 +1,6 @@
|
||||
use super::{VectorBorrowed, VectorOwned};
|
||||
use crate::scalar::F32;
|
||||
use num_traits::Zero;
|
||||
use crate::vector::{VectorBorrowed, VectorKind, VectorOwned};
|
||||
use num_traits::{Float, Zero};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
@ -67,6 +67,8 @@ impl VectorOwned for SVecf32Owned {
|
||||
type Scalar = F32;
|
||||
type Borrowed<'a> = SVecf32Borrowed<'a>;
|
||||
|
||||
const VECTOR_KIND: VectorKind = VectorKind::SVecf32;
|
||||
|
||||
#[inline(always)]
|
||||
fn dims(&self) -> u32 {
|
||||
self.dims
|
||||
@ -181,3 +183,169 @@ impl<'a> SVecf32Borrowed<'a> {
|
||||
self.indexes.len().try_into().unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
#[multiversion::multiversion(targets(
|
||||
"x86_64/x86-64-v4",
|
||||
"x86_64/x86-64-v3",
|
||||
"x86_64/x86-64-v2",
|
||||
"aarch64+neon"
|
||||
))]
|
||||
pub fn cosine<'a>(lhs: SVecf32Borrowed<'a>, rhs: SVecf32Borrowed<'a>) -> F32 {
|
||||
let mut lhs_pos = 0;
|
||||
let mut rhs_pos = 0;
|
||||
let size1 = lhs.len() as usize;
|
||||
let size2 = rhs.len() as usize;
|
||||
let mut xy = F32::zero();
|
||||
let mut x2 = F32::zero();
|
||||
let mut y2 = F32::zero();
|
||||
while lhs_pos < size1 && rhs_pos < size2 {
|
||||
let lhs_index = lhs.indexes()[lhs_pos];
|
||||
let rhs_index = rhs.indexes()[rhs_pos];
|
||||
let lhs_value = lhs.values()[lhs_pos];
|
||||
let rhs_value = rhs.values()[rhs_pos];
|
||||
xy += F32((lhs_index == rhs_index) as u32 as f32) * lhs_value * rhs_value;
|
||||
x2 += F32((lhs_index <= rhs_index) as u32 as f32) * lhs_value * lhs_value;
|
||||
y2 += F32((lhs_index >= rhs_index) as u32 as f32) * rhs_value * rhs_value;
|
||||
lhs_pos += (lhs_index <= rhs_index) as usize;
|
||||
rhs_pos += (lhs_index >= rhs_index) as usize;
|
||||
}
|
||||
for i in lhs_pos..size1 {
|
||||
x2 += lhs.values()[i] * lhs.values()[i];
|
||||
}
|
||||
for i in rhs_pos..size2 {
|
||||
y2 += rhs.values()[i] * rhs.values()[i];
|
||||
}
|
||||
xy / (x2 * y2).sqrt()
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
#[multiversion::multiversion(targets(
|
||||
"x86_64/x86-64-v4",
|
||||
"x86_64/x86-64-v3",
|
||||
"x86_64/x86-64-v2",
|
||||
"aarch64+neon"
|
||||
))]
|
||||
pub fn dot<'a>(lhs: SVecf32Borrowed<'a>, rhs: SVecf32Borrowed<'a>) -> F32 {
|
||||
let mut lhs_pos = 0;
|
||||
let mut rhs_pos = 0;
|
||||
let size1 = lhs.len() as usize;
|
||||
let size2 = rhs.len() as usize;
|
||||
let mut xy = F32::zero();
|
||||
while lhs_pos < size1 && rhs_pos < size2 {
|
||||
let lhs_index = lhs.indexes()[lhs_pos];
|
||||
let rhs_index = rhs.indexes()[rhs_pos];
|
||||
let lhs_value = lhs.values()[lhs_pos];
|
||||
let rhs_value = rhs.values()[rhs_pos];
|
||||
xy += F32((lhs_index == rhs_index) as u32 as f32) * lhs_value * rhs_value;
|
||||
lhs_pos += (lhs_index <= rhs_index) as usize;
|
||||
rhs_pos += (lhs_index >= rhs_index) as usize;
|
||||
}
|
||||
xy
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
#[multiversion::multiversion(targets(
|
||||
"x86_64/x86-64-v4",
|
||||
"x86_64/x86-64-v3",
|
||||
"x86_64/x86-64-v2",
|
||||
"aarch64+neon"
|
||||
))]
|
||||
pub fn dot_2<'a>(lhs: SVecf32Borrowed<'a>, rhs: &[F32]) -> F32 {
|
||||
let mut xy = F32::zero();
|
||||
for i in 0..lhs.len() as usize {
|
||||
xy += lhs.values()[i] * rhs[lhs.indexes()[i] as usize];
|
||||
}
|
||||
xy
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
#[multiversion::multiversion(targets(
|
||||
"x86_64/x86-64-v4",
|
||||
"x86_64/x86-64-v3",
|
||||
"x86_64/x86-64-v2",
|
||||
"aarch64+neon"
|
||||
))]
|
||||
pub fn sl2<'a>(lhs: SVecf32Borrowed<'a>, rhs: SVecf32Borrowed<'a>) -> F32 {
|
||||
let mut lhs_pos = 0;
|
||||
let mut rhs_pos = 0;
|
||||
let size1 = lhs.len() as usize;
|
||||
let size2 = rhs.len() as usize;
|
||||
let mut d2 = F32::zero();
|
||||
while lhs_pos < size1 && rhs_pos < size2 {
|
||||
let lhs_index = lhs.indexes()[lhs_pos];
|
||||
let rhs_index = rhs.indexes()[rhs_pos];
|
||||
let lhs_value = lhs.values()[lhs_pos];
|
||||
let rhs_value = rhs.values()[rhs_pos];
|
||||
let d = F32((lhs_index <= rhs_index) as u32 as f32) * lhs_value
|
||||
- F32((lhs_index >= rhs_index) as u32 as f32) * rhs_value;
|
||||
d2 += d * d;
|
||||
lhs_pos += (lhs_index <= rhs_index) as usize;
|
||||
rhs_pos += (lhs_index >= rhs_index) as usize;
|
||||
}
|
||||
for i in lhs_pos..size1 {
|
||||
d2 += lhs.values()[i] * lhs.values()[i];
|
||||
}
|
||||
for i in rhs_pos..size2 {
|
||||
d2 += rhs.values()[i] * rhs.values()[i];
|
||||
}
|
||||
d2
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
#[multiversion::multiversion(targets(
|
||||
"x86_64/x86-64-v4",
|
||||
"x86_64/x86-64-v3",
|
||||
"x86_64/x86-64-v2",
|
||||
"aarch64+neon"
|
||||
))]
|
||||
pub fn sl2_2<'a>(lhs: SVecf32Borrowed<'a>, rhs: &[F32]) -> F32 {
|
||||
let mut d2 = F32::zero();
|
||||
let mut lhs_pos = 0;
|
||||
let mut rhs_pos = 0;
|
||||
while lhs_pos < lhs.len() {
|
||||
let index_eq = lhs.indexes()[lhs_pos as usize] == rhs_pos;
|
||||
let d =
|
||||
F32(index_eq as u32 as f32) * lhs.values()[lhs_pos as usize] - rhs[rhs_pos as usize];
|
||||
d2 += d * d;
|
||||
lhs_pos += index_eq as u32;
|
||||
rhs_pos += 1;
|
||||
}
|
||||
for i in rhs_pos..rhs.len() as u32 {
|
||||
d2 += rhs[i as usize] * rhs[i as usize];
|
||||
}
|
||||
d2
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
#[multiversion::multiversion(targets(
|
||||
"x86_64/x86-64-v4",
|
||||
"x86_64/x86-64-v3",
|
||||
"x86_64/x86-64-v2",
|
||||
"aarch64+neon"
|
||||
))]
|
||||
pub fn length<'a>(vector: SVecf32Borrowed<'a>) -> F32 {
|
||||
let mut dot = F32::zero();
|
||||
for &i in vector.values() {
|
||||
dot += i * i;
|
||||
}
|
||||
dot.sqrt()
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
#[multiversion::multiversion(targets(
|
||||
"x86_64/x86-64-v4",
|
||||
"x86_64/x86-64-v3",
|
||||
"x86_64/x86-64-v2",
|
||||
"aarch64+neon"
|
||||
))]
|
||||
pub fn l2_normalize(vector: &mut SVecf32Owned) {
|
||||
let l = length(vector.for_borrow());
|
||||
let dims = vector.dims();
|
||||
let indexes = vector.indexes().to_vec();
|
||||
let mut values = vector.values().to_vec();
|
||||
for i in values.iter_mut() {
|
||||
*i /= l;
|
||||
}
|
||||
*vector = SVecf32Owned::new(dims, indexes, values);
|
||||
}
|
||||
|
@ -1,5 +1,6 @@
|
||||
use super::{VectorBorrowed, VectorOwned};
|
||||
use crate::scalar::F16;
|
||||
use super::{VectorBorrowed, VectorKind, VectorOwned};
|
||||
use crate::scalar::{ScalarLike, F16, F32};
|
||||
use num_traits::{Float, Zero};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
@ -39,6 +40,8 @@ impl VectorOwned for Vecf16Owned {
|
||||
type Scalar = F16;
|
||||
type Borrowed<'a> = Vecf16Borrowed<'a>;
|
||||
|
||||
const VECTOR_KIND: VectorKind = VectorKind::Vecf16;
|
||||
|
||||
fn dims(&self) -> u32 {
|
||||
self.0.len() as u32
|
||||
}
|
||||
@ -97,3 +100,248 @@ impl<'a> VectorBorrowed for Vecf16Borrowed<'a> {
|
||||
self.0.to_vec()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn cosine(lhs: &[F16], rhs: &[F16]) -> F32 {
|
||||
#[inline(always)]
|
||||
#[multiversion::multiversion(targets(
|
||||
"x86_64/x86-64-v4",
|
||||
"x86_64/x86-64-v3",
|
||||
"x86_64/x86-64-v2",
|
||||
"aarch64+neon"
|
||||
))]
|
||||
pub fn cosine(lhs: &[F16], rhs: &[F16]) -> F32 {
|
||||
assert!(lhs.len() == rhs.len());
|
||||
let n = lhs.len();
|
||||
let mut xy = F32::zero();
|
||||
let mut x2 = F32::zero();
|
||||
let mut y2 = F32::zero();
|
||||
for i in 0..n {
|
||||
xy += lhs[i].to_f() * rhs[i].to_f();
|
||||
x2 += lhs[i].to_f() * lhs[i].to_f();
|
||||
y2 += rhs[i].to_f() * rhs[i].to_f();
|
||||
}
|
||||
xy / (x2 * y2).sqrt()
|
||||
}
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
if detect::x86_64::detect_avx512fp16() {
|
||||
assert!(lhs.len() == rhs.len());
|
||||
let n = lhs.len();
|
||||
unsafe {
|
||||
return c::v_f16_cosine_avx512fp16(lhs.as_ptr().cast(), rhs.as_ptr().cast(), n).into();
|
||||
}
|
||||
}
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
if detect::x86_64::detect_v4() {
|
||||
assert!(lhs.len() == rhs.len());
|
||||
let n = lhs.len();
|
||||
unsafe {
|
||||
return c::v_f16_cosine_v4(lhs.as_ptr().cast(), rhs.as_ptr().cast(), n).into();
|
||||
}
|
||||
}
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
if detect::x86_64::detect_v3() {
|
||||
assert!(lhs.len() == rhs.len());
|
||||
let n = lhs.len();
|
||||
unsafe {
|
||||
return c::v_f16_cosine_v3(lhs.as_ptr().cast(), rhs.as_ptr().cast(), n).into();
|
||||
}
|
||||
}
|
||||
cosine(lhs, rhs)
|
||||
}
|
||||
|
||||
pub fn dot(lhs: &[F16], rhs: &[F16]) -> F32 {
|
||||
#[inline(always)]
|
||||
#[multiversion::multiversion(targets(
|
||||
"x86_64/x86-64-v4",
|
||||
"x86_64/x86-64-v3",
|
||||
"x86_64/x86-64-v2",
|
||||
"aarch64+neon"
|
||||
))]
|
||||
pub fn dot(lhs: &[F16], rhs: &[F16]) -> F32 {
|
||||
assert!(lhs.len() == rhs.len());
|
||||
let n = lhs.len();
|
||||
let mut xy = F32::zero();
|
||||
for i in 0..n {
|
||||
xy += lhs[i].to_f() * rhs[i].to_f();
|
||||
}
|
||||
xy
|
||||
}
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
if detect::x86_64::detect_avx512fp16() {
|
||||
assert!(lhs.len() == rhs.len());
|
||||
let n = lhs.len();
|
||||
unsafe {
|
||||
return c::v_f16_dot_avx512fp16(lhs.as_ptr().cast(), rhs.as_ptr().cast(), n).into();
|
||||
}
|
||||
}
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
if detect::x86_64::detect_v4() {
|
||||
assert!(lhs.len() == rhs.len());
|
||||
let n = lhs.len();
|
||||
unsafe {
|
||||
return c::v_f16_dot_v4(lhs.as_ptr().cast(), rhs.as_ptr().cast(), n).into();
|
||||
}
|
||||
}
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
if detect::x86_64::detect_v3() {
|
||||
assert!(lhs.len() == rhs.len());
|
||||
let n = lhs.len();
|
||||
unsafe {
|
||||
return c::v_f16_dot_v3(lhs.as_ptr().cast(), rhs.as_ptr().cast(), n).into();
|
||||
}
|
||||
}
|
||||
dot(lhs, rhs)
|
||||
}
|
||||
|
||||
pub fn sl2(lhs: &[F16], rhs: &[F16]) -> F32 {
|
||||
#[inline(always)]
|
||||
#[multiversion::multiversion(targets(
|
||||
"x86_64/x86-64-v4",
|
||||
"x86_64/x86-64-v3",
|
||||
"x86_64/x86-64-v2",
|
||||
"aarch64+neon"
|
||||
))]
|
||||
pub fn sl2(lhs: &[F16], rhs: &[F16]) -> F32 {
|
||||
assert!(lhs.len() == rhs.len());
|
||||
let n = lhs.len();
|
||||
let mut d2 = F32::zero();
|
||||
for i in 0..n {
|
||||
let d = lhs[i].to_f() - rhs[i].to_f();
|
||||
d2 += d * d;
|
||||
}
|
||||
d2
|
||||
}
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
if detect::x86_64::detect_avx512fp16() {
|
||||
assert!(lhs.len() == rhs.len());
|
||||
let n = lhs.len();
|
||||
unsafe {
|
||||
return c::v_f16_sl2_avx512fp16(lhs.as_ptr().cast(), rhs.as_ptr().cast(), n).into();
|
||||
}
|
||||
}
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
if detect::x86_64::detect_v4() {
|
||||
assert!(lhs.len() == rhs.len());
|
||||
let n = lhs.len();
|
||||
unsafe {
|
||||
return c::v_f16_sl2_v4(lhs.as_ptr().cast(), rhs.as_ptr().cast(), n).into();
|
||||
}
|
||||
}
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
if detect::x86_64::detect_v3() {
|
||||
assert!(lhs.len() == rhs.len());
|
||||
let n = lhs.len();
|
||||
unsafe {
|
||||
return c::v_f16_sl2_v3(lhs.as_ptr().cast(), rhs.as_ptr().cast(), n).into();
|
||||
}
|
||||
}
|
||||
sl2(lhs, rhs)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
#[multiversion::multiversion(targets(
|
||||
"x86_64/x86-64-v4",
|
||||
"x86_64/x86-64-v3",
|
||||
"x86_64/x86-64-v2",
|
||||
"aarch64+neon"
|
||||
))]
|
||||
fn length(vector: &[F16]) -> F16 {
|
||||
let n = vector.len();
|
||||
let mut dot = F16::zero();
|
||||
for i in 0..n {
|
||||
dot += vector[i] * vector[i];
|
||||
}
|
||||
dot.sqrt()
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
#[multiversion::multiversion(targets(
|
||||
"x86_64/x86-64-v4",
|
||||
"x86_64/x86-64-v3",
|
||||
"x86_64/x86-64-v2",
|
||||
"aarch64+neon"
|
||||
))]
|
||||
pub fn l2_normalize(vector: &mut [F16]) {
|
||||
let n = vector.len();
|
||||
let l = length(vector);
|
||||
for i in 0..n {
|
||||
vector[i] /= l;
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
#[multiversion::multiversion(targets(
|
||||
"x86_64/x86-64-v4",
|
||||
"x86_64/x86-64-v3",
|
||||
"x86_64/x86-64-v2",
|
||||
"aarch64+neon"
|
||||
))]
|
||||
pub fn xy_x2_y2(lhs: &[F16], rhs: &[F16]) -> (F32, F32, F32) {
|
||||
assert!(lhs.len() == rhs.len());
|
||||
let n = lhs.len();
|
||||
let mut xy = F32::zero();
|
||||
let mut x2 = F32::zero();
|
||||
let mut y2 = F32::zero();
|
||||
for i in 0..n {
|
||||
xy += lhs[i].to_f() * rhs[i].to_f();
|
||||
x2 += lhs[i].to_f() * lhs[i].to_f();
|
||||
y2 += rhs[i].to_f() * rhs[i].to_f();
|
||||
}
|
||||
(xy, x2, y2)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
#[multiversion::multiversion(targets(
|
||||
"x86_64/x86-64-v4",
|
||||
"x86_64/x86-64-v3",
|
||||
"x86_64/x86-64-v2",
|
||||
"aarch64+neon"
|
||||
))]
|
||||
pub fn xy_x2_y2_delta(lhs: &[F16], rhs: &[F16], del: &[F16]) -> (F32, F32, F32) {
|
||||
assert!(lhs.len() == rhs.len());
|
||||
let n = lhs.len();
|
||||
let mut xy = F32::zero();
|
||||
let mut x2 = F32::zero();
|
||||
let mut y2 = F32::zero();
|
||||
for i in 0..n {
|
||||
xy += lhs[i].to_f() * (rhs[i].to_f() + del[i].to_f());
|
||||
x2 += lhs[i].to_f() * lhs[i].to_f();
|
||||
y2 += (rhs[i].to_f() + del[i].to_f()) * (rhs[i].to_f() + del[i].to_f());
|
||||
}
|
||||
(xy, x2, y2)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
#[multiversion::multiversion(targets(
|
||||
"x86_64/x86-64-v4",
|
||||
"x86_64/x86-64-v3",
|
||||
"x86_64/x86-64-v2",
|
||||
"aarch64+neon"
|
||||
))]
|
||||
pub fn dot_delta(lhs: &[F16], rhs: &[F16], del: &[F16]) -> F32 {
|
||||
assert!(lhs.len() == rhs.len());
|
||||
let n: usize = lhs.len();
|
||||
let mut xy = F32::zero();
|
||||
for i in 0..n {
|
||||
xy += lhs[i].to_f() * (rhs[i].to_f() + del[i].to_f());
|
||||
}
|
||||
xy
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
#[multiversion::multiversion(targets(
|
||||
"x86_64/x86-64-v4",
|
||||
"x86_64/x86-64-v3",
|
||||
"x86_64/x86-64-v2",
|
||||
"aarch64+neon"
|
||||
))]
|
||||
pub fn distance_squared_l2_delta(lhs: &[F16], rhs: &[F16], del: &[F16]) -> F32 {
|
||||
assert!(lhs.len() == rhs.len());
|
||||
let n = lhs.len();
|
||||
let mut d2 = F32::zero();
|
||||
for i in 0..n {
|
||||
let d = lhs[i].to_f() - (rhs[i].to_f() + del[i].to_f());
|
||||
d2 += d * d;
|
||||
}
|
||||
d2
|
||||
}
|
||||
|
@ -1,5 +1,6 @@
|
||||
use super::{VectorBorrowed, VectorOwned};
|
||||
use super::{VectorBorrowed, VectorKind, VectorOwned};
|
||||
use crate::scalar::F32;
|
||||
use num_traits::{Float, Zero};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
@ -39,6 +40,8 @@ impl VectorOwned for Vecf32Owned {
|
||||
type Scalar = F32;
|
||||
type Borrowed<'a> = Vecf32Borrowed<'a>;
|
||||
|
||||
const VECTOR_KIND: VectorKind = VectorKind::Vecf32;
|
||||
|
||||
fn dims(&self) -> u32 {
|
||||
self.0.len() as u32
|
||||
}
|
||||
@ -97,3 +100,167 @@ impl<'a> VectorBorrowed for Vecf32Borrowed<'a> {
|
||||
self.0.to_vec()
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
#[multiversion::multiversion(targets(
|
||||
"x86_64/x86-64-v4",
|
||||
"x86_64/x86-64-v3",
|
||||
"x86_64/x86-64-v2",
|
||||
"aarch64+neon"
|
||||
))]
|
||||
pub fn cosine(lhs: &[F32], rhs: &[F32]) -> F32 {
|
||||
assert!(lhs.len() == rhs.len());
|
||||
let n = lhs.len();
|
||||
let mut xy = F32::zero();
|
||||
let mut x2 = F32::zero();
|
||||
let mut y2 = F32::zero();
|
||||
for i in 0..n {
|
||||
xy += lhs[i] * rhs[i];
|
||||
x2 += lhs[i] * lhs[i];
|
||||
y2 += rhs[i] * rhs[i];
|
||||
}
|
||||
xy / (x2 * y2).sqrt()
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
#[multiversion::multiversion(targets(
|
||||
"x86_64/x86-64-v4",
|
||||
"x86_64/x86-64-v3",
|
||||
"x86_64/x86-64-v2",
|
||||
"aarch64+neon"
|
||||
))]
|
||||
pub fn dot(lhs: &[F32], rhs: &[F32]) -> F32 {
|
||||
assert!(lhs.len() == rhs.len());
|
||||
let n = lhs.len();
|
||||
let mut xy = F32::zero();
|
||||
for i in 0..n {
|
||||
xy += lhs[i] * rhs[i];
|
||||
}
|
||||
xy
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
#[multiversion::multiversion(targets(
|
||||
"x86_64/x86-64-v4",
|
||||
"x86_64/x86-64-v3",
|
||||
"x86_64/x86-64-v2",
|
||||
"aarch64+neon"
|
||||
))]
|
||||
pub fn sl2(lhs: &[F32], rhs: &[F32]) -> F32 {
|
||||
assert!(lhs.len() == rhs.len());
|
||||
let n = lhs.len();
|
||||
let mut d2 = F32::zero();
|
||||
for i in 0..n {
|
||||
let d = lhs[i] - rhs[i];
|
||||
d2 += d * d;
|
||||
}
|
||||
d2
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
#[multiversion::multiversion(targets(
|
||||
"x86_64/x86-64-v4",
|
||||
"x86_64/x86-64-v3",
|
||||
"x86_64/x86-64-v2",
|
||||
"aarch64+neon"
|
||||
))]
|
||||
pub fn length(vector: &[F32]) -> F32 {
|
||||
let n = vector.len();
|
||||
let mut dot = F32::zero();
|
||||
for i in 0..n {
|
||||
dot += vector[i] * vector[i];
|
||||
}
|
||||
dot.sqrt()
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
#[multiversion::multiversion(targets(
|
||||
"x86_64/x86-64-v4",
|
||||
"x86_64/x86-64-v3",
|
||||
"x86_64/x86-64-v2",
|
||||
"aarch64+neon"
|
||||
))]
|
||||
pub fn l2_normalize(vector: &mut [F32]) {
|
||||
let n = vector.len();
|
||||
let l = length(vector);
|
||||
for i in 0..n {
|
||||
vector[i] /= l;
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
#[multiversion::multiversion(targets(
|
||||
"x86_64/x86-64-v4",
|
||||
"x86_64/x86-64-v3",
|
||||
"x86_64/x86-64-v2",
|
||||
"aarch64+neon"
|
||||
))]
|
||||
pub fn xy_x2_y2(lhs: &[F32], rhs: &[F32]) -> (F32, F32, F32) {
|
||||
assert!(lhs.len() == rhs.len());
|
||||
let n = lhs.len();
|
||||
let mut xy = F32::zero();
|
||||
let mut x2 = F32::zero();
|
||||
let mut y2 = F32::zero();
|
||||
for i in 0..n {
|
||||
xy += lhs[i] * rhs[i];
|
||||
x2 += lhs[i] * lhs[i];
|
||||
y2 += rhs[i] * rhs[i];
|
||||
}
|
||||
(xy, x2, y2)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
#[multiversion::multiversion(targets(
|
||||
"x86_64/x86-64-v4",
|
||||
"x86_64/x86-64-v3",
|
||||
"x86_64/x86-64-v2",
|
||||
"aarch64+neon"
|
||||
))]
|
||||
pub fn xy_x2_y2_delta(lhs: &[F32], rhs: &[F32], del: &[F32]) -> (F32, F32, F32) {
|
||||
assert!(lhs.len() == rhs.len());
|
||||
let n = lhs.len();
|
||||
let mut xy = F32::zero();
|
||||
let mut x2 = F32::zero();
|
||||
let mut y2 = F32::zero();
|
||||
for i in 0..n {
|
||||
xy += lhs[i] * (rhs[i] + del[i]);
|
||||
x2 += lhs[i] * lhs[i];
|
||||
y2 += (rhs[i] + del[i]) * (rhs[i] + del[i]);
|
||||
}
|
||||
(xy, x2, y2)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
#[multiversion::multiversion(targets(
|
||||
"x86_64/x86-64-v4",
|
||||
"x86_64/x86-64-v3",
|
||||
"x86_64/x86-64-v2",
|
||||
"aarch64+neon"
|
||||
))]
|
||||
pub fn dot_delta(lhs: &[F32], rhs: &[F32], del: &[F32]) -> F32 {
|
||||
assert!(lhs.len() == rhs.len());
|
||||
let n: usize = lhs.len();
|
||||
let mut xy = F32::zero();
|
||||
for i in 0..n {
|
||||
xy += lhs[i] * (rhs[i] + del[i]);
|
||||
}
|
||||
xy
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
#[multiversion::multiversion(targets(
|
||||
"x86_64/x86-64-v4",
|
||||
"x86_64/x86-64-v3",
|
||||
"x86_64/x86-64-v2",
|
||||
"aarch64+neon"
|
||||
))]
|
||||
pub fn distance_squared_l2_delta(lhs: &[F32], rhs: &[F32], del: &[F32]) -> F32 {
|
||||
assert!(lhs.len() == rhs.len());
|
||||
let n = lhs.len();
|
||||
let mut d2 = F32::zero();
|
||||
for i in 0..n {
|
||||
let d = lhs[i] - (rhs[i] + del[i]);
|
||||
d2 += d * d;
|
||||
}
|
||||
d2
|
||||
}
|
||||
|
@ -1,56 +1,8 @@
|
||||
use super::{VectorBorrowed, VectorOwned};
|
||||
use super::{VectorBorrowed, VectorKind, VectorOwned};
|
||||
use crate::scalar::{F32, I8};
|
||||
use num_traits::Float;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[multiversion::multiversion(targets(
|
||||
"x86_64/x86-64-v4",
|
||||
"x86_64/x86-64-v3",
|
||||
"x86_64/x86-64-v2",
|
||||
"aarch64+neon"
|
||||
))]
|
||||
pub fn i8_quantization(vector: &[F32]) -> (Vec<I8>, F32, F32) {
|
||||
let min = vector.iter().copied().fold(F32::infinity(), Float::min);
|
||||
let max = vector.iter().copied().fold(F32::neg_infinity(), Float::max);
|
||||
let alpha = (max - min) / 254.0;
|
||||
let offset = (max + min) / 2.0;
|
||||
let result = vector
|
||||
.iter()
|
||||
.map(|&x| ((x - offset) / alpha).into())
|
||||
.collect();
|
||||
(result, alpha, offset)
|
||||
}
|
||||
|
||||
#[multiversion::multiversion(targets(
|
||||
"x86_64/x86-64-v4",
|
||||
"x86_64/x86-64-v3",
|
||||
"x86_64/x86-64-v2",
|
||||
"aarch64+neon"
|
||||
))]
|
||||
pub fn i8_dequantization(vector: &[I8], alpha: F32, offset: F32) -> Vec<F32> {
|
||||
vector
|
||||
.iter()
|
||||
.map(|&x| (x.to_f32() * alpha + offset))
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
#[multiversion::multiversion(targets(
|
||||
"x86_64/x86-64-v4",
|
||||
"x86_64/x86-64-v3",
|
||||
"x86_64/x86-64-v2",
|
||||
"aarch64+neon"
|
||||
))]
|
||||
pub fn i8_precompute(data: &[I8], alpha: F32, offset: F32) -> (F32, F32) {
|
||||
let sum = data.iter().map(|&x| x.to_f32() * alpha).sum();
|
||||
let l2_norm = data
|
||||
.iter()
|
||||
.map(|&x| (x.to_f32() * alpha + offset) * (x.to_f32() * alpha + offset))
|
||||
.sum::<F32>()
|
||||
.sqrt();
|
||||
(sum, l2_norm)
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Veci8Owned {
|
||||
dims: u32,
|
||||
@ -139,6 +91,8 @@ impl VectorOwned for Veci8Owned {
|
||||
type Scalar = F32;
|
||||
type Borrowed<'a> = Veci8Borrowed<'a>;
|
||||
|
||||
const VECTOR_KIND: VectorKind = VectorKind::Veci8;
|
||||
|
||||
#[inline(always)]
|
||||
fn dims(&self) -> u32 {
|
||||
self.dims
|
||||
@ -335,6 +289,54 @@ impl<'a> From<&'a Veci8Owned> for Veci8Borrowed<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
#[multiversion::multiversion(targets(
|
||||
"x86_64/x86-64-v4",
|
||||
"x86_64/x86-64-v3",
|
||||
"x86_64/x86-64-v2",
|
||||
"aarch64+neon"
|
||||
))]
|
||||
pub fn i8_quantization(vector: &[F32]) -> (Vec<I8>, F32, F32) {
|
||||
let min = vector.iter().copied().fold(F32::infinity(), Float::min);
|
||||
let max = vector.iter().copied().fold(F32::neg_infinity(), Float::max);
|
||||
let alpha = (max - min) / 254.0;
|
||||
let offset = (max + min) / 2.0;
|
||||
let result = vector
|
||||
.iter()
|
||||
.map(|&x| ((x - offset) / alpha).into())
|
||||
.collect();
|
||||
(result, alpha, offset)
|
||||
}
|
||||
|
||||
#[multiversion::multiversion(targets(
|
||||
"x86_64/x86-64-v4",
|
||||
"x86_64/x86-64-v3",
|
||||
"x86_64/x86-64-v2",
|
||||
"aarch64+neon"
|
||||
))]
|
||||
pub fn i8_dequantization(vector: &[I8], alpha: F32, offset: F32) -> Vec<F32> {
|
||||
vector
|
||||
.iter()
|
||||
.map(|&x| (x.to_f32() * alpha + offset))
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
#[multiversion::multiversion(targets(
|
||||
"x86_64/x86-64-v4",
|
||||
"x86_64/x86-64-v3",
|
||||
"x86_64/x86-64-v2",
|
||||
"aarch64+neon"
|
||||
))]
|
||||
pub fn i8_precompute(data: &[I8], alpha: F32, offset: F32) -> (F32, F32) {
|
||||
let sum = data.iter().map(|&x| x.to_f32() * alpha).sum();
|
||||
let l2_norm = data
|
||||
.iter()
|
||||
.map(|&x| (x.to_f32() * alpha + offset) * (x.to_f32() * alpha + offset))
|
||||
.sum::<F32>()
|
||||
.sqrt();
|
||||
(sum, l2_norm)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
@ -352,3 +354,228 @@ mod tests {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn dot(x: &[I8], y: &[I8]) -> F32 {
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
{
|
||||
if detect::x86_64::test_avx512vnni() {
|
||||
return unsafe { dot_i8_avx512vnni(x, y) };
|
||||
}
|
||||
}
|
||||
dot_i8_fallback(x, y)
|
||||
}
|
||||
|
||||
#[multiversion::multiversion(targets(
|
||||
"x86_64/x86-64-v4",
|
||||
"x86_64/x86-64-v3",
|
||||
"x86_64/x86-64-v2",
|
||||
"aarch64+neon"
|
||||
))]
|
||||
fn dot_i8_fallback(x: &[I8], y: &[I8]) -> F32 {
|
||||
// i8 * i8 fall in range of i16. Since our length is less than (2^16 - 1), the result won't overflow.
|
||||
let mut sum = 0;
|
||||
assert_eq!(x.len(), y.len());
|
||||
let length = x.len();
|
||||
// according to https://godbolt.org/z/ff48vW4es, this loop will be autovectorized
|
||||
for i in 0..length {
|
||||
sum += (x[i].0 as i16 * y[i].0 as i16) as i32;
|
||||
}
|
||||
F32(sum as f32)
|
||||
}
|
||||
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
#[target_feature(enable = "avx512vnni,avx512bw,avx512f,bmi2")]
|
||||
unsafe fn dot_i8_avx512vnni(x: &[I8], y: &[I8]) -> F32 {
|
||||
use std::arch::x86_64::*;
|
||||
#[inline]
|
||||
#[target_feature(enable = "avx512vnni,avx512bw,avx512f,bmi2")]
|
||||
pub unsafe fn _mm512_maskz_loadu_epi8(k: __mmask64, mem_addr: *const i8) -> __m512i {
|
||||
let mut dst: __m512i;
|
||||
unsafe {
|
||||
std::arch::asm!(
|
||||
"vmovdqu8 {dst}{{{k}}} {{z}}, [{p}]",
|
||||
p = in(reg) mem_addr,
|
||||
k = in(kreg) k,
|
||||
dst = out(zmm_reg) dst,
|
||||
options(pure, readonly, nostack)
|
||||
);
|
||||
}
|
||||
dst
|
||||
}
|
||||
assert_eq!(x.len(), y.len());
|
||||
let mut sum = 0;
|
||||
let mut i = x.len();
|
||||
let mut p_x = x.as_ptr() as *const i8;
|
||||
let mut p_y = y.as_ptr() as *const i8;
|
||||
let mut vec_x;
|
||||
let mut vec_y;
|
||||
unsafe {
|
||||
let mut result = _mm512_setzero_si512();
|
||||
let zero = _mm512_setzero_si512();
|
||||
while i > 0 {
|
||||
if i < 64 {
|
||||
let mask = _bzhi_u64(0xFFFF_FFFF_FFFF_FFFF, i as u32);
|
||||
vec_x = _mm512_maskz_loadu_epi8(mask, p_x);
|
||||
vec_y = _mm512_maskz_loadu_epi8(mask, p_y);
|
||||
i = 0;
|
||||
} else {
|
||||
vec_x = _mm512_loadu_epi8(p_x);
|
||||
vec_y = _mm512_loadu_epi8(p_y);
|
||||
i -= 64;
|
||||
p_x = p_x.add(64);
|
||||
p_y = p_y.add(64);
|
||||
}
|
||||
// There are only _mm512_dpbusd_epi32 support, dpbusd will zeroextend a[i] and signextend b[i] first, so we need to convert a[i] positive and change corresponding b[i] to get right result.
|
||||
// And because we use -b[i] here, the range of quantization should be [-127, 127] instead of [-128, 127] to avoid overflow.
|
||||
let neg_mask = _mm512_movepi8_mask(vec_x);
|
||||
vec_x = _mm512_mask_abs_epi8(vec_x, neg_mask, vec_x);
|
||||
// Get -b[i] here, use saturating sub to avoid overflow. There are some precision loss here.
|
||||
vec_y = _mm512_mask_subs_epi8(vec_y, neg_mask, zero, vec_y);
|
||||
result = _mm512_dpbusd_epi32(result, vec_x, vec_y);
|
||||
}
|
||||
sum += _mm512_reduce_add_epi32(result);
|
||||
}
|
||||
F32(sum as f32)
|
||||
}
|
||||
|
||||
pub fn dot_distance(x: &Veci8Borrowed<'_>, y: &Veci8Borrowed<'_>) -> F32 {
|
||||
// (alpha_x * x[i] + offset_x) * (alpha_y * y[i] + offset_y)
|
||||
// = alpha_x * alpha_y * x[i] * y[i] + alpha_x * offset_y * x[i] + alpha_y * offset_x * y[i] + offset_x * offset_y
|
||||
// Sum(dot(origin_x[i] , origin_y[i])) = alpha_x * alpha_y * Sum(dot(x[i], y[i])) + offset_y * Sum(alpha_x * x[i]) + offset_x * Sum(alpha_y * y[i]) + offset_x * offset_y * dims
|
||||
let dot_xy = dot(x.data(), y.data());
|
||||
x.alpha() * y.alpha() * dot_xy
|
||||
+ x.offset() * y.sum()
|
||||
+ y.offset() * x.sum()
|
||||
+ x.offset() * y.offset() * F32(x.dims() as f32)
|
||||
}
|
||||
|
||||
pub fn l2_distance(x: &Veci8Borrowed<'_>, y: &Veci8Borrowed<'_>) -> F32 {
|
||||
// Sum(l2(origin_x[i] - origin_y[i])) = sum(x[i] ^ 2 - 2 * x[i] * y[i] + y[i] ^ 2)
|
||||
// = dot(x, x) - 2 * dot(x, y) + dot(y, y)
|
||||
x.l2_norm() * x.l2_norm() - F32(2.0) * dot_distance(x, y) + y.l2_norm() * y.l2_norm()
|
||||
}
|
||||
|
||||
pub fn cosine_distance(x: &Veci8Borrowed<'_>, y: &Veci8Borrowed<'_>) -> F32 {
|
||||
// dot(x, y) / (l2(x) * l2(y))
|
||||
let dot_xy = dot_distance(x, y);
|
||||
let l2_x = x.l2_norm();
|
||||
let l2_y = y.l2_norm();
|
||||
dot_xy / (l2_x * l2_y)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
#[multiversion::multiversion(targets(
|
||||
"x86_64/x86-64-v4",
|
||||
"x86_64/x86-64-v3",
|
||||
"x86_64/x86-64-v2",
|
||||
"aarch64+neon"
|
||||
))]
|
||||
pub fn l2_2<'a>(lhs: Veci8Borrowed<'a>, rhs: &[F32]) -> F32 {
|
||||
let data = lhs.data();
|
||||
assert_eq!(data.len(), rhs.len());
|
||||
data.iter()
|
||||
.zip(rhs.iter())
|
||||
.map(|(&x, &y)| {
|
||||
(x.to_f32() * lhs.alpha() + lhs.offset() - y)
|
||||
* (x.to_f32() * lhs.alpha() + lhs.offset() - y)
|
||||
})
|
||||
.sum::<F32>()
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
#[multiversion::multiversion(targets(
|
||||
"x86_64/x86-64-v4",
|
||||
"x86_64/x86-64-v3",
|
||||
"x86_64/x86-64-v2",
|
||||
"aarch64+neon"
|
||||
))]
|
||||
pub fn dot_2<'a>(lhs: Veci8Borrowed<'a>, rhs: &[F32]) -> F32 {
|
||||
let data = lhs.data();
|
||||
assert_eq!(data.len(), rhs.len());
|
||||
data.iter()
|
||||
.zip(rhs.iter())
|
||||
.map(|(&x, &y)| (x.to_f32() * lhs.alpha() + lhs.offset()) * y)
|
||||
.sum::<F32>()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests_2 {
|
||||
use super::*;
|
||||
|
||||
fn new_random_vec_f32(size: usize) -> Vec<F32> {
|
||||
use rand::Rng;
|
||||
let mut rng = rand::thread_rng();
|
||||
(0..size)
|
||||
.map(|_| F32(rng.gen_range(-100000.0..100000.0)))
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn vec_to_owned(vec: Vec<F32>) -> Veci8Owned {
|
||||
let (v, alpha, offset) = i8_quantization(&vec);
|
||||
Veci8Owned::new(v.len() as u32, v, alpha, offset)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_dot_i8() {
|
||||
let x = vec![F32(1.0), F32(2.0), F32(3.0)];
|
||||
let y = vec![F32(3.0), F32(2.0), F32(1.0)];
|
||||
let x_owned = vec_to_owned(x);
|
||||
let ref_x = x_owned.for_borrow();
|
||||
let y_owned = vec_to_owned(y);
|
||||
let ref_y = y_owned.for_borrow();
|
||||
let result = dot_distance(&ref_x, &ref_y);
|
||||
assert!((result.0 - 10.0).abs() < 0.1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cos_i8() {
|
||||
let x = vec![F32(1.0), F32(2.0), F32(3.0)];
|
||||
let y = vec![F32(3.0), F32(2.0), F32(1.0)];
|
||||
let x_owned = vec_to_owned(x);
|
||||
let ref_x = x_owned.for_borrow();
|
||||
let y_owned = vec_to_owned(y);
|
||||
let ref_y = y_owned.for_borrow();
|
||||
let result = cosine_distance(&ref_x, &ref_y);
|
||||
assert!((result.0 - (10.0 / 14.0)).abs() < 0.1);
|
||||
// test cos_i8 using random generated data, check the precision
|
||||
let x = new_random_vec_f32(1000);
|
||||
let y = new_random_vec_f32(1000);
|
||||
let xy = x.iter().zip(y.iter()).map(|(&x, &y)| x * y).sum::<F32>().0;
|
||||
let l2_x = x.iter().map(|&x| x * x).sum::<F32>().0.sqrt();
|
||||
let l2_y = y.iter().map(|&y| y * y).sum::<F32>().0.sqrt();
|
||||
let result_expected = xy / (l2_x * l2_y);
|
||||
let x_owned = vec_to_owned(x);
|
||||
let ref_x = x_owned.for_borrow();
|
||||
let y_owned = vec_to_owned(y);
|
||||
let ref_y = y_owned.for_borrow();
|
||||
let result = cosine_distance(&ref_x, &ref_y);
|
||||
assert!((result.0 - result_expected).abs() / result_expected < 0.25);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_l2_i8() {
|
||||
let x = vec![F32(1.0), F32(2.0), F32(3.0)];
|
||||
let y = vec![F32(3.0), F32(2.0), F32(1.0)];
|
||||
let x_owned = vec_to_owned(x);
|
||||
let ref_x = x_owned.for_borrow();
|
||||
let y_owned = vec_to_owned(y);
|
||||
let ref_y = y_owned.for_borrow();
|
||||
let result = l2_distance(&ref_x, &ref_y);
|
||||
assert!((result.0 - 8.0).abs() < 0.1);
|
||||
// test l2_i8 using random generated data, check the precision
|
||||
let x = new_random_vec_f32(1000);
|
||||
let y = new_random_vec_f32(1000);
|
||||
let result_expected = x
|
||||
.iter()
|
||||
.zip(y.iter())
|
||||
.map(|(&x, &y)| (x - y) * (x - y))
|
||||
.sum::<F32>()
|
||||
.0;
|
||||
let x_owned = vec_to_owned(x);
|
||||
let ref_x = x_owned.for_borrow();
|
||||
let y_owned = vec_to_owned(y);
|
||||
let ref_y = y_owned.for_borrow();
|
||||
let result = l2_distance(&ref_x, &ref_y);
|
||||
assert!((result.0 - result_expected).abs() / result_expected < 0.05);
|
||||
}
|
||||
}
|
||||
|
@ -1,4 +1,3 @@
|
||||
use crate::error::*;
|
||||
use crate::index::*;
|
||||
use crate::search::*;
|
||||
use crate::vector::*;
|
||||
|
@ -6,6 +6,7 @@ edition.workspace = true
|
||||
[dev-dependencies]
|
||||
half.workspace = true
|
||||
rand.workspace = true
|
||||
|
||||
detect = { path = "../detect" }
|
||||
|
||||
[build-dependencies]
|
||||
|
15
crates/common/Cargo.toml
Normal file
15
crates/common/Cargo.toml
Normal file
@ -0,0 +1,15 @@
|
||||
[package]
|
||||
name = "common"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
|
||||
[dependencies]
|
||||
bytemuck.workspace = true
|
||||
log.workspace = true
|
||||
memmap2.workspace = true
|
||||
rustix.workspace = true
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
|
||||
[lints]
|
||||
workspace = true
|
7
crates/common/src/dir_ops.rs
Normal file
7
crates/common/src/dir_ops.rs
Normal file
@ -0,0 +1,7 @@
|
||||
use std::fs::File;
|
||||
use std::path::Path;
|
||||
|
||||
pub fn sync_dir(path: impl AsRef<Path>) {
|
||||
let file = File::open(path).expect("Failed to sync dir.");
|
||||
file.sync_all().expect("Failed to sync dir.");
|
||||
}
|
@ -1,8 +1,5 @@
|
||||
pub mod clean;
|
||||
pub mod dir_ops;
|
||||
pub mod element_heap;
|
||||
pub mod file_atomic;
|
||||
pub mod file_wal;
|
||||
pub mod mmap_array;
|
||||
pub mod tournament_tree;
|
||||
pub mod vec2;
|
@ -57,6 +57,9 @@ where
|
||||
pub fn len(&self) -> usize {
|
||||
self.info.len
|
||||
}
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.len() == 0
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Deref for MmapArray<T> {
|
@ -20,6 +20,9 @@ impl<T: Zeroable + Ord> Vec2<T> {
|
||||
pub fn len(&self) -> usize {
|
||||
self.v.len() / self.dims as usize
|
||||
}
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.len() == 0
|
||||
}
|
||||
pub fn argsort(&self) -> Vec<usize> {
|
||||
let mut index: Vec<usize> = (0..self.len()).collect();
|
||||
index.sort_by_key(|i| &self[*i]);
|
16
crates/elkan_k_means/Cargo.toml
Normal file
16
crates/elkan_k_means/Cargo.toml
Normal file
@ -0,0 +1,16 @@
|
||||
[package]
|
||||
name = "elkan_k_means"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
|
||||
[dependencies]
|
||||
bytemuck.workspace = true
|
||||
num-traits.workspace = true
|
||||
rand.workspace = true
|
||||
|
||||
base = { path = "../base" }
|
||||
common = { path = "../common" }
|
||||
rayon = { path = "../rayon" }
|
||||
|
||||
[lints]
|
||||
workspace = true
|
@ -1,26 +1,33 @@
|
||||
use crate::prelude::*;
|
||||
use crate::utils::vec2::Vec2;
|
||||
#![allow(clippy::needless_range_loop)]
|
||||
|
||||
pub mod operator;
|
||||
|
||||
use crate::operator::OperatorElkanKMeans;
|
||||
use base::operator::*;
|
||||
use base::scalar::*;
|
||||
use common::vec2::Vec2;
|
||||
use num_traits::{Float, Zero};
|
||||
use rand::rngs::StdRng;
|
||||
use rand::{Rng, SeedableRng};
|
||||
use rayon::iter::{IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator};
|
||||
use rayon::slice::ParallelSliceMut;
|
||||
use std::ops::{Index, IndexMut};
|
||||
|
||||
pub struct ElkanKMeans<S: Global> {
|
||||
pub struct ElkanKMeans<O: Operator> {
|
||||
dims: u32,
|
||||
c: usize,
|
||||
pub centroids: Vec2<Scalar<S>>,
|
||||
pub centroids: Vec2<Scalar<O>>,
|
||||
lowerbound: Square,
|
||||
upperbound: Vec<F32>,
|
||||
assign: Vec<usize>,
|
||||
rand: StdRng,
|
||||
samples: Vec2<Scalar<S>>,
|
||||
samples: Vec2<Scalar<O>>,
|
||||
}
|
||||
|
||||
const DELTA: f32 = 1.0 / 1024.0;
|
||||
|
||||
impl<S: GlobalElkanKMeans> ElkanKMeans<S> {
|
||||
pub fn new(c: usize, samples: Vec2<Scalar<S>>) -> Self {
|
||||
impl<O: OperatorElkanKMeans> ElkanKMeans<O> {
|
||||
pub fn new(c: usize, samples: Vec2<Scalar<O>>) -> Self {
|
||||
let n = samples.len();
|
||||
let dims = samples.dims();
|
||||
|
||||
@ -37,7 +44,7 @@ impl<S: GlobalElkanKMeans> ElkanKMeans<S> {
|
||||
for i in 0..c {
|
||||
let mut sum = F32::zero();
|
||||
dis.par_iter_mut().enumerate().for_each(|(j, x)| {
|
||||
*x = S::elkan_k_means_distance(&samples[j], ¢roids[i]);
|
||||
*x = O::elkan_k_means_distance(&samples[j], ¢roids[i]);
|
||||
});
|
||||
for j in 0..n {
|
||||
lowerbound[(j, i)] = dis[j];
|
||||
@ -104,14 +111,14 @@ impl<S: GlobalElkanKMeans> ElkanKMeans<S> {
|
||||
centroids[i].copy_from_slice(&samples[*index]);
|
||||
} else {
|
||||
let rand_centroids: Vec<_> = (0..dims)
|
||||
.map(|_| Scalar::<S>::from_f32(rand.gen_range(0.0..1.0f32)))
|
||||
.map(|_| Scalar::<O>::from_f32(rand.gen_range(0.0..1.0f32)))
|
||||
.collect();
|
||||
centroids[i].copy_from_slice(rand_centroids.as_slice());
|
||||
}
|
||||
}
|
||||
for i in n..c {
|
||||
let rand_centroids: Vec<_> = (0..dims)
|
||||
.map(|_| Scalar::<S>::from_f32(rand.gen_range(0.0..1.0f32)))
|
||||
.map(|_| Scalar::<O>::from_f32(rand.gen_range(0.0..1.0f32)))
|
||||
.collect();
|
||||
centroids[i].copy_from_slice(rand_centroids.as_slice());
|
||||
}
|
||||
@ -140,7 +147,7 @@ impl<S: GlobalElkanKMeans> ElkanKMeans<S> {
|
||||
let i = ii / c;
|
||||
let j = ii % c;
|
||||
if i <= j {
|
||||
*v = S::elkan_k_means_distance(¢roids[i], ¢roids[j]) * 0.5;
|
||||
*v = O::elkan_k_means_distance(¢roids[i], ¢roids[j]) * 0.5;
|
||||
}
|
||||
});
|
||||
for i in 1..c {
|
||||
@ -165,7 +172,7 @@ impl<S: GlobalElkanKMeans> ElkanKMeans<S> {
|
||||
let mut dis = vec![F32::zero(); n];
|
||||
dis.par_iter_mut().enumerate().for_each(|(i, x)| {
|
||||
if upperbound[i] > sp[assign[i]] {
|
||||
*x = S::elkan_k_means_distance(&samples[i], ¢roids[assign[i]]);
|
||||
*x = O::elkan_k_means_distance(&samples[i], ¢roids[assign[i]]);
|
||||
}
|
||||
});
|
||||
for i in 0..n {
|
||||
@ -188,7 +195,7 @@ impl<S: GlobalElkanKMeans> ElkanKMeans<S> {
|
||||
continue;
|
||||
}
|
||||
if minimal > lowerbound[(i, j)] || minimal > dist0[(assign[i], j)] {
|
||||
let dis = S::elkan_k_means_distance(&samples[i], ¢roids[j]);
|
||||
let dis = O::elkan_k_means_distance(&samples[i], ¢roids[j]);
|
||||
lowerbound[(i, j)] = dis;
|
||||
if dis < minimal {
|
||||
minimal = dis;
|
||||
@ -203,7 +210,7 @@ impl<S: GlobalElkanKMeans> ElkanKMeans<S> {
|
||||
// Step 4, 7
|
||||
let old = std::mem::replace(centroids, Vec2::new(dims, c));
|
||||
let mut count = vec![F32::zero(); c];
|
||||
centroids.fill(Scalar::<S>::zero());
|
||||
centroids.fill(Scalar::<O>::zero());
|
||||
for i in 0..n {
|
||||
for j in 0..dims as usize {
|
||||
centroids[self.assign[i]][j] += samples[i][j];
|
||||
@ -215,7 +222,7 @@ impl<S: GlobalElkanKMeans> ElkanKMeans<S> {
|
||||
continue;
|
||||
}
|
||||
for dim in 0..dims as usize {
|
||||
centroids[i][dim] /= Scalar::<S>::from_f32(count[i].into());
|
||||
centroids[i][dim] /= Scalar::<O>::from_f32(count[i].into());
|
||||
}
|
||||
}
|
||||
for i in 0..c {
|
||||
@ -234,24 +241,24 @@ impl<S: GlobalElkanKMeans> ElkanKMeans<S> {
|
||||
centroids.copy_within(o, i);
|
||||
for dim in 0..dims as usize {
|
||||
if dim % 2 == 0 {
|
||||
centroids[i][dim] *= Scalar::<S>::from_f32(1.0 + DELTA);
|
||||
centroids[o][dim] *= Scalar::<S>::from_f32(1.0 - DELTA);
|
||||
centroids[i][dim] *= Scalar::<O>::from_f32(1.0 + DELTA);
|
||||
centroids[o][dim] *= Scalar::<O>::from_f32(1.0 - DELTA);
|
||||
} else {
|
||||
centroids[i][dim] *= Scalar::<S>::from_f32(1.0 - DELTA);
|
||||
centroids[o][dim] *= Scalar::<S>::from_f32(1.0 + DELTA);
|
||||
centroids[i][dim] *= Scalar::<O>::from_f32(1.0 - DELTA);
|
||||
centroids[o][dim] *= Scalar::<O>::from_f32(1.0 + DELTA);
|
||||
}
|
||||
}
|
||||
count[i] = count[o] / 2.0;
|
||||
count[o] = count[o] - count[i];
|
||||
}
|
||||
centroids.par_chunks_mut(dims as usize).for_each(|v| {
|
||||
S::elkan_k_means_normalize(v);
|
||||
O::elkan_k_means_normalize(v);
|
||||
});
|
||||
|
||||
// Step 5, 6
|
||||
let mut dist1 = vec![F32::zero(); c];
|
||||
dist1.par_iter_mut().enumerate().for_each(|(i, v)| {
|
||||
*v = S::elkan_k_means_distance(&old[i], ¢roids[i]);
|
||||
*v = O::elkan_k_means_distance(&old[i], ¢roids[i]);
|
||||
});
|
||||
for i in 0..n {
|
||||
for j in 0..c {
|
||||
@ -266,7 +273,7 @@ impl<S: GlobalElkanKMeans> ElkanKMeans<S> {
|
||||
change == 0
|
||||
}
|
||||
|
||||
pub fn finish(self) -> Vec2<Scalar<S>> {
|
||||
pub fn finish(self) -> Vec2<Scalar<O>> {
|
||||
self.centroids
|
||||
}
|
||||
}
|
342
crates/elkan_k_means/src/operator.rs
Normal file
342
crates/elkan_k_means/src/operator.rs
Normal file
@ -0,0 +1,342 @@
|
||||
use base::operator::*;
|
||||
use base::scalar::*;
|
||||
use base::vector::*;
|
||||
use num_traits::Float;
|
||||
|
||||
pub trait OperatorElkanKMeans: Operator {
|
||||
type VectorNormalized: VectorOwned;
|
||||
|
||||
fn elkan_k_means_normalize(vector: &mut [Scalar<Self>]);
|
||||
fn elkan_k_means_normalize2(vector: Borrowed<'_, Self>) -> Self::VectorNormalized;
|
||||
fn elkan_k_means_distance(lhs: &[Scalar<Self>], rhs: &[Scalar<Self>]) -> F32;
|
||||
fn elkan_k_means_distance2(
|
||||
lhs: <Self::VectorNormalized as VectorOwned>::Borrowed<'_>,
|
||||
rhs: &[Scalar<Self>],
|
||||
) -> F32;
|
||||
}
|
||||
|
||||
impl OperatorElkanKMeans for BVecf32Cos {
|
||||
type VectorNormalized = Vecf32Owned;
|
||||
|
||||
fn elkan_k_means_normalize(vector: &mut [Scalar<Self>]) {
|
||||
vecf32::l2_normalize(vector)
|
||||
}
|
||||
|
||||
fn elkan_k_means_normalize2(vector: Borrowed<'_, Self>) -> Vecf32Owned {
|
||||
bvecf32::l2_normalize(vector)
|
||||
}
|
||||
|
||||
fn elkan_k_means_distance(lhs: &[Scalar<Self>], rhs: &[Scalar<Self>]) -> F32 {
|
||||
vecf32::dot(lhs, rhs).acos()
|
||||
}
|
||||
|
||||
fn elkan_k_means_distance2(lhs: Vecf32Borrowed<'_>, rhs: &[Scalar<Self>]) -> F32 {
|
||||
vecf32::dot(lhs.slice(), rhs).acos()
|
||||
}
|
||||
}
|
||||
|
||||
impl OperatorElkanKMeans for BVecf32Dot {
|
||||
type VectorNormalized = Vecf32Owned;
|
||||
|
||||
fn elkan_k_means_normalize(vector: &mut [Scalar<Self>]) {
|
||||
vecf32::l2_normalize(vector)
|
||||
}
|
||||
|
||||
fn elkan_k_means_normalize2(vector: Borrowed<'_, Self>) -> Vecf32Owned {
|
||||
bvecf32::l2_normalize(vector)
|
||||
}
|
||||
|
||||
fn elkan_k_means_distance(lhs: &[Scalar<Self>], rhs: &[Scalar<Self>]) -> F32 {
|
||||
vecf32::dot(lhs, rhs).acos()
|
||||
}
|
||||
|
||||
fn elkan_k_means_distance2(lhs: Vecf32Borrowed<'_>, rhs: &[Scalar<Self>]) -> F32 {
|
||||
vecf32::dot(lhs.slice(), rhs).acos()
|
||||
}
|
||||
}
|
||||
|
||||
impl OperatorElkanKMeans for BVecf32Jaccard {
|
||||
type VectorNormalized = Vecf32Owned;
|
||||
|
||||
fn elkan_k_means_normalize(vector: &mut [Scalar<Self>]) {
|
||||
vecf32::l2_normalize(vector)
|
||||
}
|
||||
|
||||
fn elkan_k_means_normalize2(vector: Borrowed<'_, Self>) -> Vecf32Owned {
|
||||
Vecf32Owned::new(vector.to_vec())
|
||||
}
|
||||
|
||||
fn elkan_k_means_distance(lhs: &[Scalar<Self>], rhs: &[Scalar<Self>]) -> F32 {
|
||||
vecf32::sl2(lhs, rhs).sqrt()
|
||||
}
|
||||
|
||||
fn elkan_k_means_distance2(lhs: Vecf32Borrowed<'_>, rhs: &[Scalar<Self>]) -> F32 {
|
||||
vecf32::sl2(lhs.slice(), rhs).sqrt()
|
||||
}
|
||||
}
|
||||
|
||||
impl OperatorElkanKMeans for BVecf32L2 {
|
||||
type VectorNormalized = Vecf32Owned;
|
||||
|
||||
fn elkan_k_means_normalize(vector: &mut [Scalar<Self>]) {
|
||||
vecf32::l2_normalize(vector)
|
||||
}
|
||||
|
||||
fn elkan_k_means_normalize2(vector: Borrowed<'_, Self>) -> Vecf32Owned {
|
||||
Vecf32Owned::new(vector.to_vec())
|
||||
}
|
||||
|
||||
fn elkan_k_means_distance(lhs: &[Scalar<Self>], rhs: &[Scalar<Self>]) -> F32 {
|
||||
vecf32::sl2(lhs, rhs).sqrt()
|
||||
}
|
||||
|
||||
fn elkan_k_means_distance2(lhs: Vecf32Borrowed<'_>, rhs: &[Scalar<Self>]) -> F32 {
|
||||
vecf32::sl2(lhs.slice(), rhs).sqrt()
|
||||
}
|
||||
}
|
||||
|
||||
impl OperatorElkanKMeans for SVecf32Cos {
|
||||
type VectorNormalized = Self::VectorOwned;
|
||||
|
||||
fn elkan_k_means_normalize(vector: &mut [Scalar<Self>]) {
|
||||
vecf32::l2_normalize(vector)
|
||||
}
|
||||
|
||||
fn elkan_k_means_normalize2(vector: Borrowed<'_, Self>) -> SVecf32Owned {
|
||||
let mut vector = vector.for_own();
|
||||
svecf32::l2_normalize(&mut vector);
|
||||
vector
|
||||
}
|
||||
|
||||
fn elkan_k_means_distance(lhs: &[Scalar<Self>], rhs: &[Scalar<Self>]) -> F32 {
|
||||
vecf32::dot(lhs, rhs).acos()
|
||||
}
|
||||
|
||||
fn elkan_k_means_distance2(lhs: Borrowed<'_, Self>, rhs: &[Scalar<Self>]) -> F32 {
|
||||
svecf32::dot_2(lhs, rhs).acos()
|
||||
}
|
||||
}
|
||||
|
||||
impl OperatorElkanKMeans for SVecf32Dot {
|
||||
type VectorNormalized = Self::VectorOwned;
|
||||
|
||||
fn elkan_k_means_normalize(vector: &mut [Scalar<Self>]) {
|
||||
vecf32::l2_normalize(vector)
|
||||
}
|
||||
|
||||
fn elkan_k_means_normalize2(vector: Borrowed<'_, Self>) -> SVecf32Owned {
|
||||
let mut vector = vector.for_own();
|
||||
svecf32::l2_normalize(&mut vector);
|
||||
vector
|
||||
}
|
||||
|
||||
fn elkan_k_means_distance(lhs: &[Scalar<Self>], rhs: &[Scalar<Self>]) -> F32 {
|
||||
vecf32::dot(lhs, rhs).acos()
|
||||
}
|
||||
|
||||
fn elkan_k_means_distance2(lhs: Borrowed<'_, Self>, rhs: &[Scalar<Self>]) -> F32 {
|
||||
svecf32::dot_2(lhs, rhs).acos()
|
||||
}
|
||||
}
|
||||
|
||||
impl OperatorElkanKMeans for SVecf32L2 {
|
||||
type VectorNormalized = Self::VectorOwned;
|
||||
|
||||
fn elkan_k_means_normalize(_: &mut [Scalar<Self>]) {}
|
||||
|
||||
fn elkan_k_means_normalize2(vector: SVecf32Borrowed<'_>) -> SVecf32Owned {
|
||||
vector.for_own()
|
||||
}
|
||||
|
||||
fn elkan_k_means_distance(lhs: &[Scalar<Self>], rhs: &[Scalar<Self>]) -> F32 {
|
||||
vecf32::sl2(lhs, rhs).sqrt()
|
||||
}
|
||||
|
||||
fn elkan_k_means_distance2(lhs: SVecf32Borrowed<'_>, rhs: &[Scalar<Self>]) -> F32 {
|
||||
svecf32::sl2_2(lhs, rhs).sqrt()
|
||||
}
|
||||
}
|
||||
|
||||
impl OperatorElkanKMeans for Vecf16Cos {
|
||||
type VectorNormalized = Self::VectorOwned;
|
||||
|
||||
fn elkan_k_means_normalize(vector: &mut [F16]) {
|
||||
vecf16::l2_normalize(vector)
|
||||
}
|
||||
|
||||
fn elkan_k_means_normalize2(vector: Vecf16Borrowed<'_>) -> Vecf16Owned {
|
||||
let mut vector = vector.for_own();
|
||||
vecf16::l2_normalize(vector.slice_mut());
|
||||
vector
|
||||
}
|
||||
|
||||
fn elkan_k_means_distance(lhs: &[F16], rhs: &[F16]) -> F32 {
|
||||
vecf16::dot(lhs, rhs).acos()
|
||||
}
|
||||
|
||||
fn elkan_k_means_distance2(lhs: Vecf16Borrowed<'_>, rhs: &[F16]) -> F32 {
|
||||
vecf16::dot(lhs.slice(), rhs).acos()
|
||||
}
|
||||
}
|
||||
|
||||
impl OperatorElkanKMeans for Vecf16Dot {
|
||||
type VectorNormalized = Self::VectorOwned;
|
||||
|
||||
fn elkan_k_means_normalize(vector: &mut [F16]) {
|
||||
vecf16::l2_normalize(vector)
|
||||
}
|
||||
|
||||
fn elkan_k_means_normalize2(vector: Vecf16Borrowed<'_>) -> Vecf16Owned {
|
||||
let mut vector = vector.for_own();
|
||||
vecf16::l2_normalize(vector.slice_mut());
|
||||
vector
|
||||
}
|
||||
|
||||
fn elkan_k_means_distance(lhs: &[F16], rhs: &[F16]) -> F32 {
|
||||
vecf16::dot(lhs, rhs).acos()
|
||||
}
|
||||
|
||||
fn elkan_k_means_distance2(lhs: Vecf16Borrowed<'_>, rhs: &[F16]) -> F32 {
|
||||
vecf16::dot(lhs.slice(), rhs).acos()
|
||||
}
|
||||
}
|
||||
|
||||
impl OperatorElkanKMeans for Vecf16L2 {
|
||||
type VectorNormalized = Self::VectorOwned;
|
||||
|
||||
fn elkan_k_means_normalize(_: &mut [F16]) {}
|
||||
|
||||
fn elkan_k_means_normalize2(vector: Vecf16Borrowed<'_>) -> Vecf16Owned {
|
||||
vector.for_own()
|
||||
}
|
||||
|
||||
fn elkan_k_means_distance(lhs: &[F16], rhs: &[F16]) -> F32 {
|
||||
vecf16::sl2(lhs, rhs).sqrt()
|
||||
}
|
||||
|
||||
fn elkan_k_means_distance2(lhs: Vecf16Borrowed<'_>, rhs: &[F16]) -> F32 {
|
||||
vecf16::sl2(lhs.slice(), rhs).sqrt()
|
||||
}
|
||||
}
|
||||
|
||||
impl OperatorElkanKMeans for Vecf32Cos {
|
||||
type VectorNormalized = Self::VectorOwned;
|
||||
|
||||
fn elkan_k_means_normalize(vector: &mut [F32]) {
|
||||
vecf32::l2_normalize(vector)
|
||||
}
|
||||
|
||||
fn elkan_k_means_normalize2(vector: Vecf32Borrowed<'_>) -> Vecf32Owned {
|
||||
let mut vector = vector.for_own();
|
||||
vecf32::l2_normalize(vector.slice_mut());
|
||||
vector
|
||||
}
|
||||
|
||||
fn elkan_k_means_distance(lhs: &[F32], rhs: &[F32]) -> F32 {
|
||||
vecf32::dot(lhs, rhs).acos()
|
||||
}
|
||||
|
||||
fn elkan_k_means_distance2(lhs: Vecf32Borrowed<'_>, rhs: &[F32]) -> F32 {
|
||||
vecf32::dot(lhs.slice(), rhs).acos()
|
||||
}
|
||||
}
|
||||
|
||||
impl OperatorElkanKMeans for Vecf32Dot {
|
||||
type VectorNormalized = Self::VectorOwned;
|
||||
|
||||
fn elkan_k_means_normalize(vector: &mut [F32]) {
|
||||
vecf32::l2_normalize(vector)
|
||||
}
|
||||
|
||||
fn elkan_k_means_normalize2(vector: Vecf32Borrowed<'_>) -> Vecf32Owned {
|
||||
let mut vector = vector.for_own();
|
||||
vecf32::l2_normalize(vector.slice_mut());
|
||||
vector
|
||||
}
|
||||
|
||||
fn elkan_k_means_distance(lhs: &[F32], rhs: &[F32]) -> F32 {
|
||||
vecf32::dot(lhs, rhs).acos()
|
||||
}
|
||||
|
||||
fn elkan_k_means_distance2(lhs: Vecf32Borrowed<'_>, rhs: &[F32]) -> F32 {
|
||||
vecf32::dot(lhs.slice(), rhs).acos()
|
||||
}
|
||||
}
|
||||
|
||||
impl OperatorElkanKMeans for Vecf32L2 {
|
||||
type VectorNormalized = Self::VectorOwned;
|
||||
|
||||
fn elkan_k_means_normalize(_: &mut [F32]) {}
|
||||
|
||||
fn elkan_k_means_normalize2(vector: Vecf32Borrowed<'_>) -> Vecf32Owned {
|
||||
vector.for_own()
|
||||
}
|
||||
|
||||
fn elkan_k_means_distance(lhs: &[Scalar<Self>], rhs: &[Scalar<Self>]) -> F32 {
|
||||
vecf32::sl2(lhs, rhs).sqrt()
|
||||
}
|
||||
|
||||
fn elkan_k_means_distance2(lhs: Vecf32Borrowed<'_>, rhs: &[Scalar<Self>]) -> F32 {
|
||||
vecf32::sl2(lhs.slice(), rhs).sqrt()
|
||||
}
|
||||
}
|
||||
|
||||
impl OperatorElkanKMeans for Veci8Cos {
|
||||
type VectorNormalized = Self::VectorOwned;
|
||||
|
||||
fn elkan_k_means_normalize(vector: &mut [Scalar<Self>]) {
|
||||
vecf32::l2_normalize(vector)
|
||||
}
|
||||
|
||||
fn elkan_k_means_normalize2(vector: Borrowed<'_, Self>) -> Veci8Owned {
|
||||
vector.normalize()
|
||||
}
|
||||
|
||||
fn elkan_k_means_distance(lhs: &[Scalar<Self>], rhs: &[Scalar<Self>]) -> F32 {
|
||||
vecf32::dot(lhs, rhs).acos()
|
||||
}
|
||||
|
||||
fn elkan_k_means_distance2(lhs: Borrowed<'_, Self>, rhs: &[Scalar<Self>]) -> F32 {
|
||||
veci8::dot_2(lhs, rhs).acos()
|
||||
}
|
||||
}
|
||||
|
||||
impl OperatorElkanKMeans for Veci8Dot {
|
||||
type VectorNormalized = Self::VectorOwned;
|
||||
|
||||
fn elkan_k_means_normalize(vector: &mut [Scalar<Self>]) {
|
||||
vecf32::l2_normalize(vector)
|
||||
}
|
||||
|
||||
fn elkan_k_means_normalize2(vector: Borrowed<'_, Self>) -> Veci8Owned {
|
||||
vector.normalize()
|
||||
}
|
||||
|
||||
fn elkan_k_means_distance(lhs: &[Scalar<Self>], rhs: &[Scalar<Self>]) -> F32 {
|
||||
vecf32::dot(lhs, rhs).acos()
|
||||
}
|
||||
|
||||
fn elkan_k_means_distance2(lhs: Borrowed<'_, Self>, rhs: &[Scalar<Self>]) -> F32 {
|
||||
veci8::dot_2(lhs, rhs).acos()
|
||||
}
|
||||
}
|
||||
|
||||
impl OperatorElkanKMeans for Veci8L2 {
|
||||
type VectorNormalized = Self::VectorOwned;
|
||||
|
||||
fn elkan_k_means_normalize(vector: &mut [Scalar<Self>]) {
|
||||
vecf32::l2_normalize(vector)
|
||||
}
|
||||
|
||||
fn elkan_k_means_normalize2(vector: Borrowed<'_, Self>) -> Veci8Owned {
|
||||
vector.normalize()
|
||||
}
|
||||
|
||||
fn elkan_k_means_distance(lhs: &[Scalar<Self>], rhs: &[Scalar<Self>]) -> F32 {
|
||||
vecf32::sl2(lhs, rhs).sqrt()
|
||||
}
|
||||
|
||||
fn elkan_k_means_distance2(lhs: Borrowed<'_, Self>, rhs: &[Scalar<Self>]) -> F32 {
|
||||
veci8::l2_2(lhs, rhs).sqrt()
|
||||
}
|
||||
}
|
@ -4,13 +4,17 @@ version.workspace = true
|
||||
edition.workspace = true
|
||||
|
||||
[dependencies]
|
||||
reqwest = { version = "0.11", default-features = false, features = ["blocking", "json", "rustls-tls"] }
|
||||
thiserror = "~1.0"
|
||||
serde = "~1.0"
|
||||
reqwest = { version = "0.11.25", default-features = false, features = [
|
||||
"blocking",
|
||||
"json",
|
||||
"rustls-tls",
|
||||
] }
|
||||
serde = "1"
|
||||
thiserror = "1"
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
httpmock = "0.7"
|
||||
serde_json = "~1.0"
|
||||
httpmock = "0.7.0"
|
||||
serde_json = "1"
|
||||
|
14
crates/flat/Cargo.toml
Normal file
14
crates/flat/Cargo.toml
Normal file
@ -0,0 +1,14 @@
|
||||
[package]
|
||||
name = "flat"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
|
||||
[dependencies]
|
||||
base = { path = "../base" }
|
||||
common = { path = "../common" }
|
||||
quantization = { path = "../quantization" }
|
||||
rayon = { path = "../rayon" }
|
||||
storage = { path = "../storage" }
|
||||
|
||||
[lints]
|
||||
workspace = true
|
159
crates/flat/src/lib.rs
Normal file
159
crates/flat/src/lib.rs
Normal file
@ -0,0 +1,159 @@
|
||||
#![feature(trait_alias)]
|
||||
#![allow(clippy::len_without_is_empty)]
|
||||
|
||||
use base::index::*;
|
||||
use base::operator::*;
|
||||
use base::search::*;
|
||||
use common::dir_ops::sync_dir;
|
||||
use quantization::operator::OperatorQuantization;
|
||||
use quantization::Quantization;
|
||||
use std::cmp::Reverse;
|
||||
use std::collections::BinaryHeap;
|
||||
use std::fs::create_dir;
|
||||
use std::path::Path;
|
||||
use std::sync::Arc;
|
||||
use storage::operator::OperatorStorage;
|
||||
use storage::StorageCollection;
|
||||
|
||||
pub trait OperatorFlat = Operator + OperatorQuantization + OperatorStorage;
|
||||
|
||||
pub struct Flat<O: OperatorFlat> {
|
||||
mmap: FlatMmap<O>,
|
||||
}
|
||||
|
||||
impl<O: OperatorFlat> Flat<O> {
|
||||
pub fn create<S: Source<O>>(path: &Path, options: IndexOptions, source: &S) -> Self {
|
||||
create_dir(path).unwrap();
|
||||
let ram = make(path, options, source);
|
||||
let mmap = save(path, ram);
|
||||
sync_dir(path);
|
||||
Self { mmap }
|
||||
}
|
||||
|
||||
pub fn open(path: &Path, options: IndexOptions) -> Self {
|
||||
let mmap = open(path, options);
|
||||
Self { mmap }
|
||||
}
|
||||
|
||||
pub fn basic(
|
||||
&self,
|
||||
vector: Borrowed<'_, O>,
|
||||
_opts: &SearchOptions,
|
||||
filter: impl Filter,
|
||||
) -> BinaryHeap<Reverse<Element>> {
|
||||
basic(&self.mmap, vector, filter)
|
||||
}
|
||||
|
||||
pub fn vbase<'a>(
|
||||
&'a self,
|
||||
vector: Borrowed<'a, O>,
|
||||
_opts: &'a SearchOptions,
|
||||
filter: impl Filter + 'a,
|
||||
) -> (Vec<Element>, Box<(dyn Iterator<Item = Element> + 'a)>) {
|
||||
vbase(&self.mmap, vector, filter)
|
||||
}
|
||||
|
||||
pub fn len(&self) -> u32 {
|
||||
self.mmap.storage.len()
|
||||
}
|
||||
|
||||
pub fn vector(&self, i: u32) -> Borrowed<'_, O> {
|
||||
self.mmap.storage.vector(i)
|
||||
}
|
||||
|
||||
pub fn payload(&self, i: u32) -> Payload {
|
||||
self.mmap.storage.payload(i)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<O: OperatorFlat> Send for Flat<O> {}
|
||||
unsafe impl<O: OperatorFlat> Sync for Flat<O> {}
|
||||
|
||||
pub struct FlatRam<O: OperatorFlat> {
|
||||
storage: Arc<StorageCollection<O>>,
|
||||
quantization: Quantization<O, StorageCollection<O>>,
|
||||
}
|
||||
|
||||
pub struct FlatMmap<O: OperatorFlat> {
|
||||
storage: Arc<StorageCollection<O>>,
|
||||
quantization: Quantization<O, StorageCollection<O>>,
|
||||
}
|
||||
|
||||
unsafe impl<O: OperatorFlat> Send for FlatMmap<O> {}
|
||||
unsafe impl<O: OperatorFlat> Sync for FlatMmap<O> {}
|
||||
|
||||
pub fn make<O: OperatorFlat, S: Source<O>>(
|
||||
path: &Path,
|
||||
options: IndexOptions,
|
||||
source: &S,
|
||||
) -> FlatRam<O> {
|
||||
let idx_opts = options.indexing.clone().unwrap_flat();
|
||||
let storage = Arc::new(StorageCollection::create(&path.join("raw"), source));
|
||||
let quantization = Quantization::create(
|
||||
&path.join("quantization"),
|
||||
options.clone(),
|
||||
idx_opts.quantization,
|
||||
&storage,
|
||||
(0..storage.len()).collect::<Vec<_>>(),
|
||||
);
|
||||
FlatRam {
|
||||
storage,
|
||||
quantization,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn save<O: OperatorFlat>(_: &Path, ram: FlatRam<O>) -> FlatMmap<O> {
|
||||
FlatMmap {
|
||||
storage: ram.storage,
|
||||
quantization: ram.quantization,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn open<O: OperatorFlat>(path: &Path, options: IndexOptions) -> FlatMmap<O> {
|
||||
let idx_opts = options.indexing.clone().unwrap_flat();
|
||||
let storage = Arc::new(StorageCollection::open(&path.join("raw"), options.clone()));
|
||||
rayon::check();
|
||||
let quantization = Quantization::open(
|
||||
&path.join("quantization"),
|
||||
options.clone(),
|
||||
idx_opts.quantization,
|
||||
&storage,
|
||||
);
|
||||
rayon::check();
|
||||
FlatMmap {
|
||||
storage,
|
||||
quantization,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn basic<O: OperatorFlat>(
|
||||
mmap: &FlatMmap<O>,
|
||||
vector: Borrowed<'_, O>,
|
||||
mut filter: impl Filter,
|
||||
) -> BinaryHeap<Reverse<Element>> {
|
||||
let mut result = BinaryHeap::new();
|
||||
for i in 0..mmap.storage.len() {
|
||||
let distance = mmap.quantization.distance(vector, i);
|
||||
let payload = mmap.storage.payload(i);
|
||||
if filter.check(payload) {
|
||||
result.push(Reverse(Element { distance, payload }));
|
||||
}
|
||||
}
|
||||
result
|
||||
}
|
||||
|
||||
pub fn vbase<'a, O: OperatorFlat>(
|
||||
mmap: &'a FlatMmap<O>,
|
||||
vector: Borrowed<'a, O>,
|
||||
mut filter: impl Filter + 'a,
|
||||
) -> (Vec<Element>, Box<dyn Iterator<Item = Element> + 'a>) {
|
||||
let mut result = Vec::new();
|
||||
for i in 0..mmap.storage.len() {
|
||||
let distance = mmap.quantization.distance(vector, i);
|
||||
let payload = mmap.storage.payload(i);
|
||||
if filter.check(payload) {
|
||||
result.push(Element { distance, payload });
|
||||
}
|
||||
}
|
||||
(result, Box::new(std::iter::empty()))
|
||||
}
|
17
crates/hnsw/Cargo.toml
Normal file
17
crates/hnsw/Cargo.toml
Normal file
@ -0,0 +1,17 @@
|
||||
[package]
|
||||
name = "hnsw"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
|
||||
[dependencies]
|
||||
bytemuck.workspace = true
|
||||
parking_lot.workspace = true
|
||||
|
||||
base = { path = "../base" }
|
||||
common = { path = "../common" }
|
||||
quantization = { path = "../quantization" }
|
||||
rayon = { path = "../rayon" }
|
||||
storage = { path = "../storage" }
|
||||
|
||||
[lints]
|
||||
workspace = true
|
@ -1,34 +1,36 @@
|
||||
use super::quantization::Quantization;
|
||||
use super::raw::Raw;
|
||||
use crate::index::segments::growing::GrowingSegment;
|
||||
use crate::index::segments::sealed::SealedSegment;
|
||||
use crate::prelude::*;
|
||||
use crate::utils::dir_ops::sync_dir;
|
||||
use crate::utils::element_heap::ElementHeap;
|
||||
use crate::utils::mmap_array::MmapArray;
|
||||
#![feature(trait_alias)]
|
||||
#![allow(clippy::len_without_is_empty)]
|
||||
|
||||
use base::index::*;
|
||||
use base::operator::*;
|
||||
use base::scalar::F32;
|
||||
use base::search::*;
|
||||
use bytemuck::{Pod, Zeroable};
|
||||
use common::dir_ops::sync_dir;
|
||||
use common::mmap_array::MmapArray;
|
||||
use parking_lot::{Mutex, RwLock, RwLockWriteGuard};
|
||||
use rayon::prelude::{IntoParallelIterator, ParallelIterator};
|
||||
use quantization::operator::OperatorQuantization;
|
||||
use quantization::Quantization;
|
||||
use rayon::iter::{IntoParallelIterator, ParallelIterator};
|
||||
use std::cmp::Reverse;
|
||||
use std::collections::BinaryHeap;
|
||||
use std::fs::create_dir;
|
||||
use std::ops::RangeInclusive;
|
||||
use std::path::Path;
|
||||
use std::sync::Arc;
|
||||
use storage::operator::OperatorStorage;
|
||||
use storage::StorageCollection;
|
||||
|
||||
pub struct Hnsw<S: G> {
|
||||
mmap: HnswMmap<S>,
|
||||
pub trait OperatorHnsw = Operator + OperatorQuantization + OperatorStorage;
|
||||
|
||||
pub struct Hnsw<O: OperatorHnsw> {
|
||||
mmap: HnswMmap<O>,
|
||||
}
|
||||
|
||||
impl<S: G> Hnsw<S> {
|
||||
pub fn create(
|
||||
path: &Path,
|
||||
options: IndexOptions,
|
||||
sealed: Vec<Arc<SealedSegment<S>>>,
|
||||
growing: Vec<Arc<GrowingSegment<S>>>,
|
||||
) -> Self {
|
||||
impl<O: OperatorHnsw> Hnsw<O> {
|
||||
pub fn create<S: Source<O>>(path: &Path, options: IndexOptions, source: &S) -> Self {
|
||||
create_dir(path).unwrap();
|
||||
let ram = make(path, sealed, growing, options);
|
||||
let ram = make(path, options, source);
|
||||
let mmap = save(ram, path);
|
||||
sync_dir(path);
|
||||
Self { mmap }
|
||||
@ -41,7 +43,7 @@ impl<S: G> Hnsw<S> {
|
||||
|
||||
pub fn basic(
|
||||
&self,
|
||||
vector: Borrowed<'_, S>,
|
||||
vector: Borrowed<'_, O>,
|
||||
opts: &SearchOptions,
|
||||
filter: impl Filter,
|
||||
) -> BinaryHeap<Reverse<Element>> {
|
||||
@ -50,7 +52,7 @@ impl<S: G> Hnsw<S> {
|
||||
|
||||
pub fn vbase<'a>(
|
||||
&'a self,
|
||||
vector: Borrowed<'a, S>,
|
||||
vector: Borrowed<'a, O>,
|
||||
opts: &'a SearchOptions,
|
||||
filter: impl Filter + 'a,
|
||||
) -> (Vec<Element>, Box<(dyn Iterator<Item = Element> + 'a)>) {
|
||||
@ -58,24 +60,24 @@ impl<S: G> Hnsw<S> {
|
||||
}
|
||||
|
||||
pub fn len(&self) -> u32 {
|
||||
self.mmap.raw.len()
|
||||
self.mmap.storage.len()
|
||||
}
|
||||
|
||||
pub fn vector(&self, i: u32) -> Borrowed<'_, S> {
|
||||
self.mmap.raw.vector(i)
|
||||
pub fn vector(&self, i: u32) -> Borrowed<'_, O> {
|
||||
self.mmap.storage.vector(i)
|
||||
}
|
||||
|
||||
pub fn payload(&self, i: u32) -> Payload {
|
||||
self.mmap.raw.payload(i)
|
||||
self.mmap.storage.payload(i)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<S: G> Send for Hnsw<S> {}
|
||||
unsafe impl<S: G> Sync for Hnsw<S> {}
|
||||
unsafe impl<O: OperatorHnsw> Send for Hnsw<O> {}
|
||||
unsafe impl<O: OperatorHnsw> Sync for Hnsw<O> {}
|
||||
|
||||
pub struct HnswRam<S: G> {
|
||||
raw: Arc<Raw<S>>,
|
||||
quantization: Quantization<S>,
|
||||
pub struct HnswRam<O: OperatorHnsw> {
|
||||
storage: Arc<StorageCollection<O>>,
|
||||
quantization: Quantization<O, StorageCollection<O>>,
|
||||
// ----------------------
|
||||
m: u32,
|
||||
// ----------------------
|
||||
@ -102,9 +104,9 @@ struct HnswRamLayer {
|
||||
edges: Vec<(F32, u32)>,
|
||||
}
|
||||
|
||||
pub struct HnswMmap<S: G> {
|
||||
raw: Arc<Raw<S>>,
|
||||
quantization: Quantization<S>,
|
||||
pub struct HnswMmap<O: OperatorHnsw> {
|
||||
storage: Arc<StorageCollection<O>>,
|
||||
quantization: Quantization<O, StorageCollection<O>>,
|
||||
// ----------------------
|
||||
m: u32,
|
||||
// ----------------------
|
||||
@ -120,36 +122,32 @@ struct HnswMmapEdge(#[allow(dead_code)] F32, u32);
|
||||
// we may convert a memory-mapped graph to a memory graph
|
||||
// so that it speeds merging sealed segments
|
||||
|
||||
unsafe impl<S: G> Send for HnswMmap<S> {}
|
||||
unsafe impl<S: G> Sync for HnswMmap<S> {}
|
||||
unsafe impl<O: OperatorHnsw> Send for HnswMmap<O> {}
|
||||
unsafe impl<O: OperatorHnsw> Sync for HnswMmap<O> {}
|
||||
unsafe impl Pod for HnswMmapEdge {}
|
||||
unsafe impl Zeroable for HnswMmapEdge {}
|
||||
|
||||
pub fn make<S: G>(
|
||||
pub fn make<O: OperatorHnsw, S: Source<O>>(
|
||||
path: &Path,
|
||||
sealed: Vec<Arc<SealedSegment<S>>>,
|
||||
growing: Vec<Arc<GrowingSegment<S>>>,
|
||||
options: IndexOptions,
|
||||
) -> HnswRam<S> {
|
||||
source: &S,
|
||||
) -> HnswRam<O> {
|
||||
let HnswIndexingOptions {
|
||||
m,
|
||||
ef_construction,
|
||||
quantization: quantization_opts,
|
||||
} = options.indexing.clone().unwrap_hnsw();
|
||||
let raw = Arc::new(Raw::create(
|
||||
&path.join("raw"),
|
||||
options.clone(),
|
||||
sealed,
|
||||
growing,
|
||||
));
|
||||
let storage = Arc::new(StorageCollection::create(&path.join("raw"), source));
|
||||
rayon::check();
|
||||
let quantization = Quantization::create(
|
||||
&path.join("quantization"),
|
||||
options.clone(),
|
||||
quantization_opts,
|
||||
&raw,
|
||||
(0..raw.len()).collect::<Vec<_>>(),
|
||||
&storage,
|
||||
(0..storage.len()).collect::<Vec<_>>(),
|
||||
);
|
||||
let n = raw.len();
|
||||
rayon::check();
|
||||
let n = storage.len();
|
||||
let graph = HnswRamGraph {
|
||||
vertexs: (0..n)
|
||||
.into_par_iter()
|
||||
@ -161,14 +159,14 @@ pub fn make<S: G>(
|
||||
.collect(),
|
||||
};
|
||||
let entry = RwLock::<Option<u32>>::new(None);
|
||||
let visited = VisitedPool::new(raw.len());
|
||||
let visited = VisitedPool::new(storage.len());
|
||||
(0..n).into_par_iter().for_each(|i| {
|
||||
fn fast_search<S: G>(
|
||||
quantization: &Quantization<S>,
|
||||
fn fast_search<O: OperatorHnsw>(
|
||||
quantization: &Quantization<O, StorageCollection<O>>,
|
||||
graph: &HnswRamGraph,
|
||||
levels: RangeInclusive<u8>,
|
||||
u: u32,
|
||||
target: Borrowed<'_, S>,
|
||||
target: Borrowed<'_, O>,
|
||||
) -> u32 {
|
||||
let mut u = u;
|
||||
let mut u_dis = quantization.distance(target, u);
|
||||
@ -189,11 +187,11 @@ pub fn make<S: G>(
|
||||
}
|
||||
u
|
||||
}
|
||||
fn local_search<S: G>(
|
||||
quantization: &Quantization<S>,
|
||||
fn local_search<O: OperatorHnsw>(
|
||||
quantization: &Quantization<O, StorageCollection<O>>,
|
||||
graph: &HnswRamGraph,
|
||||
visited: &mut VisitedGuard,
|
||||
vector: Borrowed<'_, S>,
|
||||
vector: Borrowed<'_, O>,
|
||||
s: u32,
|
||||
k: usize,
|
||||
i: u8,
|
||||
@ -230,7 +228,11 @@ pub fn make<S: G>(
|
||||
}
|
||||
results.into_sorted_vec()
|
||||
}
|
||||
fn select<S: G>(quantization: &Quantization<S>, input: &mut Vec<(F32, u32)>, size: u32) {
|
||||
fn select<O: OperatorHnsw>(
|
||||
quantization: &Quantization<O, StorageCollection<O>>,
|
||||
input: &mut Vec<(F32, u32)>,
|
||||
size: u32,
|
||||
) {
|
||||
if input.len() <= size as usize {
|
||||
return;
|
||||
}
|
||||
@ -249,8 +251,9 @@ pub fn make<S: G>(
|
||||
}
|
||||
*input = res;
|
||||
}
|
||||
rayon::check();
|
||||
let mut visited = visited.fetch();
|
||||
let target = raw.vector(i);
|
||||
let target = storage.vector(i);
|
||||
let levels = graph.vertexs[i as usize].levels();
|
||||
let local_entry;
|
||||
let update_entry;
|
||||
@ -325,7 +328,7 @@ pub fn make<S: G>(
|
||||
}
|
||||
});
|
||||
HnswRam {
|
||||
raw,
|
||||
storage,
|
||||
quantization,
|
||||
m,
|
||||
graph,
|
||||
@ -333,7 +336,7 @@ pub fn make<S: G>(
|
||||
}
|
||||
}
|
||||
|
||||
pub fn save<S: G>(mut ram: HnswRam<S>, path: &Path) -> HnswMmap<S> {
|
||||
pub fn save<O: OperatorHnsw>(mut ram: HnswRam<O>, path: &Path) -> HnswMmap<O> {
|
||||
let edges = MmapArray::create(
|
||||
&path.join("edges"),
|
||||
ram.graph
|
||||
@ -343,19 +346,22 @@ pub fn save<S: G>(mut ram: HnswRam<S>, path: &Path) -> HnswMmap<S> {
|
||||
.flat_map(|v| &v.get_mut().edges)
|
||||
.map(|&(_0, _1)| HnswMmapEdge(_0, _1)),
|
||||
);
|
||||
rayon::check();
|
||||
let by_layer_id = MmapArray::create(&path.join("by_layer_id"), {
|
||||
let iter = ram.graph.vertexs.iter_mut();
|
||||
let iter = iter.flat_map(|v| v.layers.iter_mut());
|
||||
let iter = iter.map(|v| v.get_mut().edges.len());
|
||||
caluate_offsets(iter)
|
||||
});
|
||||
rayon::check();
|
||||
let by_vertex_id = MmapArray::create(&path.join("by_vertex_id"), {
|
||||
let iter = ram.graph.vertexs.iter_mut();
|
||||
let iter = iter.map(|v| v.layers.len());
|
||||
caluate_offsets(iter)
|
||||
});
|
||||
rayon::check();
|
||||
HnswMmap {
|
||||
raw: ram.raw,
|
||||
storage: ram.storage,
|
||||
quantization: ram.quantization,
|
||||
m: ram.m,
|
||||
edges,
|
||||
@ -365,22 +371,22 @@ pub fn save<S: G>(mut ram: HnswRam<S>, path: &Path) -> HnswMmap<S> {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn open<S: G>(path: &Path, options: IndexOptions) -> HnswMmap<S> {
|
||||
pub fn open<O: OperatorHnsw>(path: &Path, options: IndexOptions) -> HnswMmap<O> {
|
||||
let idx_opts = options.indexing.clone().unwrap_hnsw();
|
||||
let raw = Arc::new(Raw::open(&path.join("raw"), options.clone()));
|
||||
let storage = Arc::new(StorageCollection::open(&path.join("raw"), options.clone()));
|
||||
let quantization = Quantization::open(
|
||||
&path.join("quantization"),
|
||||
options.clone(),
|
||||
idx_opts.quantization,
|
||||
&raw,
|
||||
&storage,
|
||||
);
|
||||
let edges = MmapArray::open(&path.join("edges"));
|
||||
let by_layer_id = MmapArray::open(&path.join("by_layer_id"));
|
||||
let by_vertex_id = MmapArray::open(&path.join("by_vertex_id"));
|
||||
let idx_opts = options.indexing.unwrap_hnsw();
|
||||
let n = raw.len();
|
||||
let n = storage.len();
|
||||
HnswMmap {
|
||||
raw,
|
||||
storage,
|
||||
quantization,
|
||||
m: idx_opts.m,
|
||||
edges,
|
||||
@ -390,9 +396,9 @@ pub fn open<S: G>(path: &Path, options: IndexOptions) -> HnswMmap<S> {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn basic<S: G>(
|
||||
mmap: &HnswMmap<S>,
|
||||
vector: Borrowed<'_, S>,
|
||||
pub fn basic<O: OperatorHnsw>(
|
||||
mmap: &HnswMmap<O>,
|
||||
vector: Borrowed<'_, O>,
|
||||
ef_search: usize,
|
||||
filter: impl Filter,
|
||||
) -> BinaryHeap<Reverse<Element>> {
|
||||
@ -404,9 +410,9 @@ pub fn basic<S: G>(
|
||||
local_search_basic(mmap, ef_search, u, vector, filter).into_reversed_heap()
|
||||
}
|
||||
|
||||
pub fn vbase<'a, S: G>(
|
||||
mmap: &'a HnswMmap<S>,
|
||||
vector: Borrowed<'a, S>,
|
||||
pub fn vbase<'a, O: OperatorHnsw>(
|
||||
mmap: &'a HnswMmap<O>,
|
||||
vector: Borrowed<'a, O>,
|
||||
range: usize,
|
||||
filter: impl Filter + 'a,
|
||||
) -> (Vec<Element>, Box<(dyn Iterator<Item = Element> + 'a)>) {
|
||||
@ -432,9 +438,9 @@ pub fn vbase<'a, S: G>(
|
||||
(stage1, Box::new(iter))
|
||||
}
|
||||
|
||||
pub fn entry<S: G>(mmap: &HnswMmap<S>, mut filter: impl Filter) -> Option<u32> {
|
||||
pub fn entry<O: OperatorHnsw>(mmap: &HnswMmap<O>, mut filter: impl Filter) -> Option<u32> {
|
||||
let m = mmap.m;
|
||||
let n = mmap.raw.len();
|
||||
let n = mmap.storage.len();
|
||||
let mut shift = 1u64;
|
||||
let mut count = 0u64;
|
||||
while shift * m as u64 <= n as u64 {
|
||||
@ -445,7 +451,7 @@ pub fn entry<S: G>(mmap: &HnswMmap<S>, mut filter: impl Filter) -> Option<u32> {
|
||||
while i * shift <= n as u64 {
|
||||
let e = (i * shift - 1) as u32;
|
||||
if i % m as u64 != 0 {
|
||||
if filter.check(mmap.raw.payload(e)) {
|
||||
if filter.check(mmap.storage.payload(e)) {
|
||||
return Some(e);
|
||||
}
|
||||
count += 1;
|
||||
@ -460,11 +466,11 @@ pub fn entry<S: G>(mmap: &HnswMmap<S>, mut filter: impl Filter) -> Option<u32> {
|
||||
None
|
||||
}
|
||||
|
||||
pub fn fast_search<S: G>(
|
||||
mmap: &HnswMmap<S>,
|
||||
pub fn fast_search<O: OperatorHnsw>(
|
||||
mmap: &HnswMmap<O>,
|
||||
levels: RangeInclusive<u8>,
|
||||
u: u32,
|
||||
vector: Borrowed<'_, S>,
|
||||
vector: Borrowed<'_, O>,
|
||||
mut filter: impl Filter,
|
||||
) -> u32 {
|
||||
let mut u = u;
|
||||
@ -475,7 +481,7 @@ pub fn fast_search<S: G>(
|
||||
changed = false;
|
||||
let edges = find_edges(mmap, u, i);
|
||||
for &HnswMmapEdge(_, v) in edges.iter() {
|
||||
if !filter.check(mmap.raw.payload(v)) {
|
||||
if !filter.check(mmap.storage.payload(v)) {
|
||||
continue;
|
||||
}
|
||||
let v_dis = mmap.quantization.distance(vector, v);
|
||||
@ -490,11 +496,11 @@ pub fn fast_search<S: G>(
|
||||
u
|
||||
}
|
||||
|
||||
pub fn local_search_basic<S: G>(
|
||||
mmap: &HnswMmap<S>,
|
||||
pub fn local_search_basic<O: OperatorHnsw>(
|
||||
mmap: &HnswMmap<O>,
|
||||
k: usize,
|
||||
s: u32,
|
||||
vector: Borrowed<'_, S>,
|
||||
vector: Borrowed<'_, O>,
|
||||
mut filter: impl Filter,
|
||||
) -> ElementHeap {
|
||||
let mut visited = mmap.visited.fetch();
|
||||
@ -506,7 +512,7 @@ pub fn local_search_basic<S: G>(
|
||||
candidates.push(Reverse((s_dis, s)));
|
||||
results.push(Element {
|
||||
distance: s_dis,
|
||||
payload: mmap.raw.payload(s),
|
||||
payload: mmap.storage.payload(s),
|
||||
});
|
||||
while let Some(Reverse((u_dis, u))) = candidates.pop() {
|
||||
if !results.check(u_dis) {
|
||||
@ -518,7 +524,7 @@ pub fn local_search_basic<S: G>(
|
||||
continue;
|
||||
}
|
||||
visited.mark(v);
|
||||
if !filter.check(mmap.raw.payload(v)) {
|
||||
if !filter.check(mmap.storage.payload(v)) {
|
||||
continue;
|
||||
}
|
||||
let v_dis = mmap.quantization.distance(vector, v);
|
||||
@ -528,17 +534,17 @@ pub fn local_search_basic<S: G>(
|
||||
candidates.push(Reverse((v_dis, v)));
|
||||
results.push(Element {
|
||||
distance: v_dis,
|
||||
payload: mmap.raw.payload(v),
|
||||
payload: mmap.storage.payload(v),
|
||||
});
|
||||
}
|
||||
}
|
||||
results
|
||||
}
|
||||
|
||||
pub fn local_search_vbase<'a, S: G>(
|
||||
mmap: &'a HnswMmap<S>,
|
||||
pub fn local_search_vbase<'a, O: OperatorHnsw>(
|
||||
mmap: &'a HnswMmap<O>,
|
||||
s: u32,
|
||||
vector: Borrowed<'a, S>,
|
||||
vector: Borrowed<'a, O>,
|
||||
mut filter: impl Filter + 'a,
|
||||
) -> impl Iterator<Item = Element> + 'a {
|
||||
let mut visited = mmap.visited.fetch2();
|
||||
@ -555,7 +561,7 @@ pub fn local_search_vbase<'a, S: G>(
|
||||
continue;
|
||||
}
|
||||
visited.mark(v);
|
||||
if filter.check(mmap.raw.payload(v)) {
|
||||
if filter.check(mmap.storage.payload(v)) {
|
||||
let v_dis = mmap.quantization.distance(vector, v);
|
||||
candidates.push(Reverse((v_dis, v)));
|
||||
}
|
||||
@ -563,7 +569,7 @@ pub fn local_search_vbase<'a, S: G>(
|
||||
}
|
||||
Some(Element {
|
||||
distance: u_dis,
|
||||
payload: mmap.raw.payload(u),
|
||||
payload: mmap.storage.payload(u),
|
||||
})
|
||||
})
|
||||
}
|
||||
@ -596,7 +602,7 @@ fn caluate_offsets(iter: impl Iterator<Item = usize>) -> impl Iterator<Item = us
|
||||
})
|
||||
}
|
||||
|
||||
fn find_edges<S: G>(mmap: &HnswMmap<S>, u: u32, level: u8) -> &[HnswMmapEdge] {
|
||||
fn find_edges<O: OperatorHnsw>(mmap: &HnswMmap<O>, u: u32, level: u8) -> &[HnswMmapEdge] {
|
||||
let offset = u as usize;
|
||||
let index = mmap.by_vertex_id[offset]..mmap.by_vertex_id[offset + 1];
|
||||
let offset = index.start + level as usize;
|
||||
@ -720,3 +726,32 @@ impl VisitedBuffer {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ElementHeap {
|
||||
binary_heap: BinaryHeap<Element>,
|
||||
k: usize,
|
||||
}
|
||||
|
||||
impl ElementHeap {
|
||||
pub fn new(k: usize) -> Self {
|
||||
assert!(k != 0);
|
||||
Self {
|
||||
binary_heap: BinaryHeap::new(),
|
||||
k,
|
||||
}
|
||||
}
|
||||
pub fn check(&self, distance: F32) -> bool {
|
||||
self.binary_heap.len() < self.k || distance < self.binary_heap.peek().unwrap().distance
|
||||
}
|
||||
pub fn push(&mut self, element: Element) -> Option<Element> {
|
||||
self.binary_heap.push(element);
|
||||
if self.binary_heap.len() == self.k + 1 {
|
||||
self.binary_heap.pop()
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
pub fn into_reversed_heap(self) -> BinaryHeap<Reverse<Element>> {
|
||||
self.binary_heap.into_iter().map(Reverse).collect()
|
||||
}
|
||||
}
|
35
crates/index/Cargo.toml
Normal file
35
crates/index/Cargo.toml
Normal file
@ -0,0 +1,35 @@
|
||||
[package]
|
||||
name = "index"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
|
||||
[dependencies]
|
||||
arc-swap.workspace = true
|
||||
bincode.workspace = true
|
||||
byteorder.workspace = true
|
||||
crc32fast = "1.4.0"
|
||||
crossbeam = "0.8.4"
|
||||
dashmap = "5.5.3"
|
||||
log.workspace = true
|
||||
parking_lot.workspace = true
|
||||
rand.workspace = true
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
thiserror.workspace = true
|
||||
uuid.workspace = true
|
||||
validator.workspace = true
|
||||
|
||||
base = { path = "../base" }
|
||||
common = { path = "../common" }
|
||||
elkan_k_means = { path = "../elkan_k_means" }
|
||||
quantization = { path = "../quantization" }
|
||||
rayon = { path = "../rayon" }
|
||||
storage = { path = "../storage" }
|
||||
|
||||
# algorithms
|
||||
flat = { path = "../flat" }
|
||||
hnsw = { path = "../hnsw" }
|
||||
ivf = { path = "../ivf" }
|
||||
|
||||
[lints]
|
||||
workspace = true
|
@ -1,5 +1,8 @@
|
||||
use crate::prelude::*;
|
||||
use crate::utils::file_wal::FileWal;
|
||||
pub use base::distance::*;
|
||||
pub use base::index::*;
|
||||
pub use base::search::*;
|
||||
pub use base::vector::*;
|
||||
use dashmap::mapref::entry::Entry;
|
||||
use dashmap::DashMap;
|
||||
use parking_lot::Mutex;
|
||||
@ -8,7 +11,7 @@ use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
|
||||
pub struct Delete {
|
||||
version: DashMap<Pointer, u16>,
|
||||
version: DashMap<Pointer, u64>,
|
||||
wal: Mutex<FileWal>,
|
||||
}
|
||||
|
||||
@ -23,7 +26,7 @@ impl Delete {
|
||||
}
|
||||
pub fn open(path: PathBuf) -> Arc<Self> {
|
||||
let mut wal = FileWal::open(path);
|
||||
let version = DashMap::<Pointer, u16>::new();
|
||||
let version = DashMap::<Pointer, u64>::new();
|
||||
while let Some(log) = wal.read() {
|
||||
let log = bincode::deserialize::<Log>(&log).unwrap();
|
||||
let key = log.key;
|
||||
@ -43,10 +46,10 @@ impl Delete {
|
||||
})
|
||||
}
|
||||
pub fn check(&self, payload: Payload) -> Option<Pointer> {
|
||||
let pointer = Pointer::from_u48(payload >> 16);
|
||||
let pointer = payload.pointer();
|
||||
match self.version.get(&pointer) {
|
||||
Some(e) => {
|
||||
if (payload as u16) == *e {
|
||||
if payload.time() == *e {
|
||||
Some(pointer)
|
||||
} else {
|
||||
None
|
||||
@ -69,7 +72,7 @@ impl Delete {
|
||||
}
|
||||
}
|
||||
}
|
||||
pub fn version(&self, key: Pointer) -> u16 {
|
||||
pub fn version(&self, key: Pointer) -> u64 {
|
||||
match self.version.get(&key) {
|
||||
Some(e) => *e,
|
||||
None => 0,
|
86
crates/index/src/indexing.rs
Normal file
86
crates/index/src/indexing.rs
Normal file
@ -0,0 +1,86 @@
|
||||
use crate::Op;
|
||||
pub use base::distance::*;
|
||||
pub use base::index::*;
|
||||
use base::operator::*;
|
||||
pub use base::search::*;
|
||||
pub use base::vector::*;
|
||||
use flat::Flat;
|
||||
use hnsw::Hnsw;
|
||||
use ivf::Ivf;
|
||||
use std::cmp::Reverse;
|
||||
use std::collections::BinaryHeap;
|
||||
use std::path::Path;
|
||||
|
||||
pub enum Indexing<O: Op> {
|
||||
Flat(Flat<O>),
|
||||
Ivf(Ivf<O>),
|
||||
Hnsw(Hnsw<O>),
|
||||
}
|
||||
|
||||
impl<O: Op> Indexing<O> {
|
||||
pub fn create<S: Source<O>>(path: &Path, options: IndexOptions, source: &S) -> Self {
|
||||
match options.indexing {
|
||||
IndexingOptions::Flat(_) => Self::Flat(Flat::create(path, options, source)),
|
||||
IndexingOptions::Ivf(_) => Self::Ivf(Ivf::create(path, options, source)),
|
||||
IndexingOptions::Hnsw(_) => Self::Hnsw(Hnsw::create(path, options, source)),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn open(path: &Path, options: IndexOptions) -> Self {
|
||||
match options.indexing {
|
||||
IndexingOptions::Flat(_) => Self::Flat(Flat::open(path, options)),
|
||||
IndexingOptions::Ivf(_) => Self::Ivf(Ivf::open(path, options)),
|
||||
IndexingOptions::Hnsw(_) => Self::Hnsw(Hnsw::open(path, options)),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn basic(
|
||||
&self,
|
||||
vector: Borrowed<'_, O>,
|
||||
opts: &SearchOptions,
|
||||
filter: impl Filter,
|
||||
) -> BinaryHeap<Reverse<Element>> {
|
||||
match self {
|
||||
Indexing::Flat(x) => x.basic(vector, opts, filter),
|
||||
Indexing::Ivf(x) => x.basic(vector, opts, filter),
|
||||
Indexing::Hnsw(x) => x.basic(vector, opts, filter),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn vbase<'a>(
|
||||
&'a self,
|
||||
vector: Borrowed<'a, O>,
|
||||
opts: &'a SearchOptions,
|
||||
filter: impl Filter + 'a,
|
||||
) -> (Vec<Element>, Box<(dyn Iterator<Item = Element> + 'a)>) {
|
||||
match self {
|
||||
Indexing::Flat(x) => x.vbase(vector, opts, filter),
|
||||
Indexing::Ivf(x) => x.vbase(vector, opts, filter),
|
||||
Indexing::Hnsw(x) => x.vbase(vector, opts, filter),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn len(&self) -> u32 {
|
||||
match self {
|
||||
Indexing::Flat(x) => x.len(),
|
||||
Indexing::Ivf(x) => x.len(),
|
||||
Indexing::Hnsw(x) => x.len(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn vector(&self, i: u32) -> Borrowed<'_, O> {
|
||||
match self {
|
||||
Indexing::Flat(x) => x.vector(i),
|
||||
Indexing::Ivf(x) => x.vector(i),
|
||||
Indexing::Hnsw(x) => x.vector(i),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn payload(&self, i: u32) -> Payload {
|
||||
match self {
|
||||
Indexing::Flat(x) => x.payload(i),
|
||||
Indexing::Ivf(x) => x.payload(i),
|
||||
Indexing::Hnsw(x) => x.payload(i),
|
||||
}
|
||||
}
|
||||
}
|
@ -1,48 +1,67 @@
|
||||
#![feature(trait_alias)]
|
||||
#![allow(clippy::len_without_is_empty)]
|
||||
|
||||
pub mod delete;
|
||||
pub mod indexing;
|
||||
pub mod optimizing;
|
||||
pub mod segments;
|
||||
|
||||
mod utils;
|
||||
|
||||
use self::delete::Delete;
|
||||
use self::segments::growing::GrowingSegment;
|
||||
use self::segments::sealed::SealedSegment;
|
||||
use crate::index::optimizing::indexing::OptimizerIndexing;
|
||||
use crate::index::optimizing::sealing::OptimizerSealing;
|
||||
use crate::prelude::*;
|
||||
use crate::utils::clean::clean;
|
||||
use crate::utils::dir_ops::sync_dir;
|
||||
use crate::utils::file_atomic::FileAtomic;
|
||||
use crate::optimizing::indexing::OptimizerIndexing;
|
||||
use crate::optimizing::sealing::OptimizerSealing;
|
||||
use crate::utils::tournament_tree::LoserTree;
|
||||
use arc_swap::ArcSwap;
|
||||
pub use base::distance::*;
|
||||
pub use base::index::*;
|
||||
use base::operator::*;
|
||||
pub use base::search::*;
|
||||
pub use base::vector::*;
|
||||
use common::clean::clean;
|
||||
use common::dir_ops::sync_dir;
|
||||
use common::file_atomic::FileAtomic;
|
||||
use crossbeam::atomic::AtomicCell;
|
||||
use crossbeam::channel::Sender;
|
||||
use elkan_k_means::operator::OperatorElkanKMeans;
|
||||
use parking_lot::Mutex;
|
||||
use quantization::operator::OperatorQuantization;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::cmp::Reverse;
|
||||
use std::collections::HashMap;
|
||||
use std::collections::HashSet;
|
||||
use std::convert::Infallible;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use std::thread::JoinHandle;
|
||||
use std::time::Instant;
|
||||
use storage::operator::OperatorStorage;
|
||||
use thiserror::Error;
|
||||
use uuid::Uuid;
|
||||
use validator::Validate;
|
||||
|
||||
pub trait Op = Operator + OperatorElkanKMeans + OperatorQuantization + OperatorStorage;
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
#[error("The index view is outdated.")]
|
||||
pub struct OutdatedError;
|
||||
|
||||
pub struct Index<S: G> {
|
||||
pub struct Index<O: Op> {
|
||||
path: PathBuf,
|
||||
options: IndexOptions,
|
||||
delete: Arc<Delete>,
|
||||
protect: Mutex<IndexProtect<S>>,
|
||||
view: ArcSwap<IndexView<S>>,
|
||||
protect: Mutex<IndexProtect<O>>,
|
||||
view: ArcSwap<IndexView<O>>,
|
||||
instant_index: AtomicCell<Instant>,
|
||||
instant_write: AtomicCell<Instant>,
|
||||
background_indexing: Mutex<Option<(Sender<Infallible>, JoinHandle<()>)>>,
|
||||
background_sealing: Mutex<Option<(Sender<Infallible>, JoinHandle<()>)>>,
|
||||
_tracker: Arc<IndexTracker>,
|
||||
}
|
||||
|
||||
impl<S: G> Index<S> {
|
||||
impl<O: Op> Index<O> {
|
||||
pub fn create(path: PathBuf, options: IndexOptions) -> Result<Arc<Self>, CreateError> {
|
||||
if let Err(err) = options.validate() {
|
||||
return Err(CreateError::InvalidIndexOptions {
|
||||
@ -84,10 +103,10 @@ impl<S: G> Index<S> {
|
||||
})),
|
||||
instant_index: AtomicCell::new(Instant::now()),
|
||||
instant_write: AtomicCell::new(Instant::now()),
|
||||
background_indexing: Mutex::new(None),
|
||||
background_sealing: Mutex::new(None),
|
||||
_tracker: Arc::new(IndexTracker { path }),
|
||||
});
|
||||
OptimizerIndexing::new(index.clone()).spawn();
|
||||
OptimizerSealing::new(index.clone()).spawn();
|
||||
Ok(index)
|
||||
}
|
||||
|
||||
@ -113,7 +132,7 @@ impl<S: G> Index<S> {
|
||||
.map(|&uuid| {
|
||||
(
|
||||
uuid,
|
||||
SealedSegment::open(
|
||||
SealedSegment::<O>::open(
|
||||
tracker.clone(),
|
||||
path.join("segments").join(uuid.to_string()),
|
||||
uuid,
|
||||
@ -139,7 +158,7 @@ impl<S: G> Index<S> {
|
||||
})
|
||||
.collect::<HashMap<_, _>>();
|
||||
let delete = Delete::open(path.join("delete"));
|
||||
let index = Arc::new(Index {
|
||||
Arc::new(Index {
|
||||
path: path.clone(),
|
||||
options: options.clone(),
|
||||
delete: delete.clone(),
|
||||
@ -158,16 +177,15 @@ impl<S: G> Index<S> {
|
||||
})),
|
||||
instant_index: AtomicCell::new(Instant::now()),
|
||||
instant_write: AtomicCell::new(Instant::now()),
|
||||
background_indexing: Mutex::new(None),
|
||||
background_sealing: Mutex::new(None),
|
||||
_tracker: tracker,
|
||||
});
|
||||
OptimizerIndexing::new(index.clone()).spawn();
|
||||
OptimizerSealing::new(index.clone()).spawn();
|
||||
index
|
||||
})
|
||||
}
|
||||
pub fn options(&self) -> &IndexOptions {
|
||||
&self.options
|
||||
}
|
||||
pub fn view(&self) -> Arc<IndexView<S>> {
|
||||
pub fn view(&self) -> Arc<IndexView<O>> {
|
||||
self.view.load_full()
|
||||
}
|
||||
pub fn refresh(&self) {
|
||||
@ -225,9 +243,42 @@ impl<S: G> Index<S> {
|
||||
},
|
||||
}
|
||||
}
|
||||
pub fn start(self: &Arc<Self>) {
|
||||
{
|
||||
let mut background_indexing = self.background_indexing.lock();
|
||||
if background_indexing.is_none() {
|
||||
*background_indexing = Some(OptimizerIndexing::new(self.clone()).spawn());
|
||||
}
|
||||
}
|
||||
{
|
||||
let mut background_sealing = self.background_sealing.lock();
|
||||
if background_sealing.is_none() {
|
||||
*background_sealing = Some(OptimizerSealing::new(self.clone()).spawn());
|
||||
}
|
||||
}
|
||||
}
|
||||
pub fn stop(&self) {
|
||||
{
|
||||
let mut background_indexing = self.background_indexing.lock();
|
||||
if let Some((sender, join_handle)) = background_indexing.take() {
|
||||
drop(sender);
|
||||
let _ = join_handle.join();
|
||||
}
|
||||
}
|
||||
{
|
||||
let mut background_sealing = self.background_sealing.lock();
|
||||
if let Some((sender, join_handle)) = background_sealing.take() {
|
||||
drop(sender);
|
||||
let _ = join_handle.join();
|
||||
}
|
||||
}
|
||||
}
|
||||
pub fn wait(&self) -> Arc<IndexTracker> {
|
||||
Arc::clone(&self._tracker)
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: G> Drop for Index<S> {
|
||||
impl<O: Op> Drop for Index<O> {
|
||||
fn drop(&mut self) {}
|
||||
}
|
||||
|
||||
@ -242,18 +293,18 @@ impl Drop for IndexTracker {
|
||||
}
|
||||
}
|
||||
|
||||
pub struct IndexView<S: G> {
|
||||
pub struct IndexView<O: Op> {
|
||||
pub options: IndexOptions,
|
||||
pub delete: Arc<Delete>,
|
||||
pub sealed: HashMap<Uuid, Arc<SealedSegment<S>>>,
|
||||
pub growing: HashMap<Uuid, Arc<GrowingSegment<S>>>,
|
||||
pub write: Option<(Uuid, Arc<GrowingSegment<S>>)>,
|
||||
pub sealed: HashMap<Uuid, Arc<SealedSegment<O>>>,
|
||||
pub growing: HashMap<Uuid, Arc<GrowingSegment<O>>>,
|
||||
pub write: Option<(Uuid, Arc<GrowingSegment<O>>)>,
|
||||
}
|
||||
|
||||
impl<S: G> IndexView<S> {
|
||||
impl<O: Op> IndexView<O> {
|
||||
pub fn basic<'a, F: Fn(Pointer) -> bool + Clone + 'a>(
|
||||
&'a self,
|
||||
vector: Borrowed<'_, S>,
|
||||
vector: Borrowed<'_, O>,
|
||||
opts: &'a SearchOptions,
|
||||
filter: F,
|
||||
) -> Result<impl Iterator<Item = Pointer> + 'a, BasicError> {
|
||||
@ -295,8 +346,7 @@ impl<S: G> IndexView<S> {
|
||||
impl<'a, F: FnMut(Pointer) -> bool + Clone> Filter for Filtering<'a, F> {
|
||||
fn check(&mut self, payload: Payload) -> bool {
|
||||
!self.enable
|
||||
|| (self.delete.check(payload).is_some()
|
||||
&& (self.external)(Pointer::from_u48(payload >> 16)))
|
||||
|| (self.delete.check(payload).is_some() && (self.external)(payload.pointer()))
|
||||
}
|
||||
}
|
||||
|
||||
@ -323,7 +373,7 @@ impl<S: G> IndexView<S> {
|
||||
let loser = LoserTree::new(heaps);
|
||||
Ok(loser.filter_map(|x| {
|
||||
if opts.prefilter_enable || self.delete.check(x.payload).is_some() {
|
||||
Some(Pointer::from_u48(x.payload >> 16))
|
||||
Some(x.payload.pointer())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
@ -331,7 +381,7 @@ impl<S: G> IndexView<S> {
|
||||
}
|
||||
pub fn vbase<'a, F: FnMut(Pointer) -> bool + Clone + 'a>(
|
||||
&'a self,
|
||||
vector: Borrowed<'a, S>,
|
||||
vector: Borrowed<'a, O>,
|
||||
opts: &'a SearchOptions,
|
||||
filter: F,
|
||||
) -> Result<impl Iterator<Item = Pointer> + 'a, VbaseError> {
|
||||
@ -363,8 +413,7 @@ impl<S: G> IndexView<S> {
|
||||
impl<'a, F: FnMut(Pointer) -> bool + Clone + 'a> Filter for Filtering<'a, F> {
|
||||
fn check(&mut self, payload: Payload) -> bool {
|
||||
!self.enable
|
||||
|| (self.delete.check(payload).is_some()
|
||||
&& (self.external)(Pointer::from_u48(payload >> 16)))
|
||||
|| (self.delete.check(payload).is_some() && (self.external)(payload.pointer()))
|
||||
}
|
||||
}
|
||||
|
||||
@ -397,7 +446,7 @@ impl<S: G> IndexView<S> {
|
||||
let loser = LoserTree::new(beta);
|
||||
Ok(loser.filter_map(|x| {
|
||||
if opts.prefilter_enable || self.delete.check(x.payload).is_some() {
|
||||
Some(Pointer::from_u48(x.payload >> 16))
|
||||
Some(x.payload.pointer())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
@ -425,16 +474,16 @@ impl<S: G> IndexView<S> {
|
||||
}
|
||||
pub fn insert(
|
||||
&self,
|
||||
vector: Owned<S>,
|
||||
vector: Owned<O>,
|
||||
pointer: Pointer,
|
||||
) -> Result<Result<(), OutdatedError>, InsertError> {
|
||||
if self.options.vector.dims != vector.dims() {
|
||||
return Err(InsertError::InvalidVector);
|
||||
}
|
||||
|
||||
let payload = (pointer.as_u48() << 16) | self.delete.version(pointer) as Payload;
|
||||
let payload = Payload::new(pointer, self.delete.version(pointer));
|
||||
if let Some((_, growing)) = self.write.as_ref() {
|
||||
use crate::index::segments::growing::GrowingSegmentInsertError;
|
||||
use crate::segments::growing::GrowingSegmentInsertError;
|
||||
if let Err(GrowingSegmentInsertError) = growing.insert(vector, payload) {
|
||||
return Ok(Err(OutdatedError));
|
||||
}
|
||||
@ -462,19 +511,19 @@ struct IndexStartup {
|
||||
growings: HashSet<Uuid>,
|
||||
}
|
||||
|
||||
struct IndexProtect<S: G> {
|
||||
struct IndexProtect<O: Op> {
|
||||
startup: FileAtomic<IndexStartup>,
|
||||
sealed: HashMap<Uuid, Arc<SealedSegment<S>>>,
|
||||
growing: HashMap<Uuid, Arc<GrowingSegment<S>>>,
|
||||
write: Option<(Uuid, Arc<GrowingSegment<S>>)>,
|
||||
sealed: HashMap<Uuid, Arc<SealedSegment<O>>>,
|
||||
growing: HashMap<Uuid, Arc<GrowingSegment<O>>>,
|
||||
write: Option<(Uuid, Arc<GrowingSegment<O>>)>,
|
||||
}
|
||||
|
||||
impl<S: G> IndexProtect<S> {
|
||||
impl<O: Op> IndexProtect<O> {
|
||||
fn maintain(
|
||||
&mut self,
|
||||
options: IndexOptions,
|
||||
delete: Arc<Delete>,
|
||||
swap: &ArcSwap<IndexView<S>>,
|
||||
swap: &ArcSwap<IndexView<O>>,
|
||||
) {
|
||||
let view = Arc::new(IndexView {
|
||||
options,
|
239
crates/index/src/optimizing/indexing.rs
Normal file
239
crates/index/src/optimizing/indexing.rs
Normal file
@ -0,0 +1,239 @@
|
||||
use crate::GrowingSegment;
|
||||
use crate::Index;
|
||||
use crate::Op;
|
||||
use crate::SealedSegment;
|
||||
pub use base::distance::*;
|
||||
pub use base::index::*;
|
||||
use base::operator::Borrowed;
|
||||
pub use base::search::*;
|
||||
pub use base::vector::*;
|
||||
use crossbeam::channel::RecvError;
|
||||
use crossbeam::channel::TryRecvError;
|
||||
use crossbeam::channel::{bounded, Receiver, RecvTimeoutError, Sender};
|
||||
use std::cmp::Reverse;
|
||||
use std::convert::Infallible;
|
||||
use std::sync::Arc;
|
||||
use std::thread::JoinHandle;
|
||||
use std::time::Instant;
|
||||
use thiserror::Error;
|
||||
use uuid::Uuid;
|
||||
|
||||
pub struct IndexSource<O: Op> {
|
||||
sealed: Vec<Arc<SealedSegment<O>>>,
|
||||
growing: Vec<Arc<GrowingSegment<O>>>,
|
||||
dims: u32,
|
||||
}
|
||||
|
||||
impl<O: Op> IndexSource<O> {
|
||||
pub fn new(
|
||||
options: IndexOptions,
|
||||
sealed: Vec<Arc<SealedSegment<O>>>,
|
||||
growing: Vec<Arc<GrowingSegment<O>>>,
|
||||
) -> Self {
|
||||
IndexSource {
|
||||
sealed,
|
||||
growing,
|
||||
dims: options.vector.dims,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<O: Op> Collection<O> for IndexSource<O> {
|
||||
fn dims(&self) -> u32 {
|
||||
self.dims
|
||||
}
|
||||
|
||||
fn len(&self) -> u32 {
|
||||
self.sealed.iter().map(|x| x.len()).sum::<u32>()
|
||||
+ self.growing.iter().map(|x| x.len()).sum::<u32>()
|
||||
}
|
||||
|
||||
fn vector(&self, mut index: u32) -> Borrowed<'_, O> {
|
||||
for x in self.sealed.iter() {
|
||||
if index < x.len() {
|
||||
return x.vector(index);
|
||||
}
|
||||
index -= x.len();
|
||||
}
|
||||
for x in self.growing.iter() {
|
||||
if index < x.len() {
|
||||
return x.vector(index);
|
||||
}
|
||||
index -= x.len();
|
||||
}
|
||||
panic!("Out of bound.")
|
||||
}
|
||||
|
||||
fn payload(&self, mut index: u32) -> Payload {
|
||||
for x in self.sealed.iter() {
|
||||
if index < x.len() {
|
||||
return x.payload(index);
|
||||
}
|
||||
index -= x.len();
|
||||
}
|
||||
for x in self.growing.iter() {
|
||||
if index < x.len() {
|
||||
return x.payload(index);
|
||||
}
|
||||
index -= x.len();
|
||||
}
|
||||
panic!("Out of bound.")
|
||||
}
|
||||
}
|
||||
|
||||
impl<O: Op> Source<O> for IndexSource<O> {}
|
||||
|
||||
pub struct OptimizerIndexing<O: Op> {
|
||||
index: Arc<Index<O>>,
|
||||
}
|
||||
|
||||
impl<O: Op> OptimizerIndexing<O> {
|
||||
pub fn new(index: Arc<Index<O>>) -> Self {
|
||||
Self { index }
|
||||
}
|
||||
pub fn spawn(self) -> (Sender<Infallible>, JoinHandle<()>) {
|
||||
let (tx, rx) = bounded(1);
|
||||
(
|
||||
tx,
|
||||
std::thread::spawn(move || {
|
||||
self.main(rx);
|
||||
}),
|
||||
)
|
||||
}
|
||||
fn main(self, shutdown: Receiver<Infallible>) {
|
||||
let index = self.index;
|
||||
rayon::ThreadPoolBuilder::new()
|
||||
.num_threads(index.options.optimizing.optimizing_threads)
|
||||
.build_scoped(|pool| {
|
||||
std::thread::scope(|scope| {
|
||||
scope.spawn(|| match shutdown.recv() {
|
||||
Ok(never) => match never {},
|
||||
Err(RecvError) => {
|
||||
pool.stop();
|
||||
}
|
||||
});
|
||||
loop {
|
||||
if let Ok(()) = pool.install(|| optimizing_indexing(index.clone())) {
|
||||
match shutdown.try_recv() {
|
||||
Ok(never) => match never {},
|
||||
Err(TryRecvError::Disconnected) => return,
|
||||
Err(TryRecvError::Empty) => (),
|
||||
}
|
||||
continue;
|
||||
}
|
||||
match shutdown.recv_timeout(std::time::Duration::from_secs(60)) {
|
||||
Ok(never) => match never {},
|
||||
Err(RecvTimeoutError::Disconnected) => return,
|
||||
Err(RecvTimeoutError::Timeout) => (),
|
||||
}
|
||||
}
|
||||
});
|
||||
})
|
||||
.unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
enum Seg<O: Op> {
|
||||
Sealed(Arc<SealedSegment<O>>),
|
||||
Growing(Arc<GrowingSegment<O>>),
|
||||
}
|
||||
|
||||
impl<O: Op> Seg<O> {
|
||||
fn uuid(&self) -> Uuid {
|
||||
use Seg::*;
|
||||
match self {
|
||||
Sealed(x) => x.uuid(),
|
||||
Growing(x) => x.uuid(),
|
||||
}
|
||||
}
|
||||
fn len(&self) -> u32 {
|
||||
use Seg::*;
|
||||
match self {
|
||||
Sealed(x) => x.len(),
|
||||
Growing(x) => x.len(),
|
||||
}
|
||||
}
|
||||
fn get_sealed(&self) -> Option<Arc<SealedSegment<O>>> {
|
||||
match self {
|
||||
Seg::Sealed(x) => Some(x.clone()),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
fn get_growing(&self) -> Option<Arc<GrowingSegment<O>>> {
|
||||
match self {
|
||||
Seg::Growing(x) => Some(x.clone()),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
#[error("Interrupted, retry again.")]
|
||||
pub struct RetryError;
|
||||
|
||||
pub fn optimizing_indexing<O: Op>(index: Arc<Index<O>>) -> Result<(), RetryError> {
|
||||
use Seg::*;
|
||||
let segs = {
|
||||
let protect = index.protect.lock();
|
||||
let mut segs_0 = Vec::new();
|
||||
segs_0.extend(protect.growing.values().map(|x| Growing(x.clone())));
|
||||
segs_0.extend(protect.sealed.values().map(|x| Sealed(x.clone())));
|
||||
segs_0.sort_by_key(|case| Reverse(case.len()));
|
||||
let mut segs_1 = Vec::new();
|
||||
let mut total = 0u64;
|
||||
let mut count = 0;
|
||||
while let Some(seg) = segs_0.pop() {
|
||||
if total + seg.len() as u64 <= index.options.segment.max_sealed_segment_size as u64 {
|
||||
total += seg.len() as u64;
|
||||
if let Growing(_) = seg {
|
||||
count += 1;
|
||||
}
|
||||
segs_1.push(seg);
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
if segs_1.is_empty() || (segs_1.len() == 1 && count == 0) {
|
||||
index.instant_index.store(Instant::now());
|
||||
return Err(RetryError);
|
||||
}
|
||||
segs_1
|
||||
};
|
||||
let sealed_segment = merge(&index, &segs);
|
||||
{
|
||||
let mut protect = index.protect.lock();
|
||||
for seg in segs.iter() {
|
||||
if protect.sealed.contains_key(&seg.uuid()) {
|
||||
continue;
|
||||
}
|
||||
if protect.growing.contains_key(&seg.uuid()) {
|
||||
continue;
|
||||
}
|
||||
return Ok(());
|
||||
}
|
||||
for seg in segs.iter() {
|
||||
protect.sealed.remove(&seg.uuid());
|
||||
protect.growing.remove(&seg.uuid());
|
||||
}
|
||||
protect.sealed.insert(sealed_segment.uuid(), sealed_segment);
|
||||
protect.maintain(index.options.clone(), index.delete.clone(), &index.view);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn merge<O: Op>(index: &Arc<Index<O>>, segs: &[Seg<O>]) -> Arc<SealedSegment<O>> {
|
||||
let sealed = segs.iter().filter_map(|x| x.get_sealed()).collect();
|
||||
let growing = segs.iter().filter_map(|x| x.get_growing()).collect();
|
||||
let sealed_segment_uuid = Uuid::new_v4();
|
||||
let collection = IndexSource::new(index.options().clone(), sealed, growing);
|
||||
SealedSegment::create(
|
||||
index._tracker.clone(),
|
||||
index
|
||||
.path
|
||||
.join("segments")
|
||||
.join(sealed_segment_uuid.to_string()),
|
||||
sealed_segment_uuid,
|
||||
index.options.clone(),
|
||||
&collection,
|
||||
)
|
||||
}
|
57
crates/index/src/optimizing/sealing.rs
Normal file
57
crates/index/src/optimizing/sealing.rs
Normal file
@ -0,0 +1,57 @@
|
||||
use crate::Index;
|
||||
use crate::Op;
|
||||
pub use base::distance::*;
|
||||
pub use base::index::*;
|
||||
pub use base::search::*;
|
||||
pub use base::vector::*;
|
||||
use crossbeam::channel::{bounded, Receiver, RecvTimeoutError, Sender};
|
||||
use std::convert::Infallible;
|
||||
use std::sync::Arc;
|
||||
use std::thread::JoinHandle;
|
||||
use std::time::Duration;
|
||||
|
||||
pub struct OptimizerSealing<O: Op> {
|
||||
index: Arc<Index<O>>,
|
||||
}
|
||||
|
||||
impl<O: Op> OptimizerSealing<O> {
|
||||
pub fn new(index: Arc<Index<O>>) -> Self {
|
||||
Self { index }
|
||||
}
|
||||
pub fn spawn(self) -> (Sender<Infallible>, JoinHandle<()>) {
|
||||
let (tx, rx) = bounded(1);
|
||||
(
|
||||
tx,
|
||||
std::thread::spawn(move || {
|
||||
self.main(rx);
|
||||
}),
|
||||
)
|
||||
}
|
||||
fn main(self, shutdown: Receiver<Infallible>) {
|
||||
let index = self.index;
|
||||
let dur = Duration::from_secs(index.options.optimizing.sealing_secs);
|
||||
let least = index.options.optimizing.sealing_size;
|
||||
let mut check = None;
|
||||
loop {
|
||||
let view = index.view();
|
||||
let stamp = view
|
||||
.write
|
||||
.as_ref()
|
||||
.map(|(uuid, segment)| (*uuid, segment.len()));
|
||||
if stamp == check {
|
||||
if let Some((uuid, len)) = stamp {
|
||||
if len >= least {
|
||||
index.seal(uuid);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
check = stamp;
|
||||
}
|
||||
match shutdown.recv_timeout(dur) {
|
||||
Ok(never) => match never {},
|
||||
Err(RecvTimeoutError::Disconnected) => return,
|
||||
Err(RecvTimeoutError::Timeout) => continue,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
@ -1,8 +1,12 @@
|
||||
use super::SegmentTracker;
|
||||
use crate::index::IndexTracker;
|
||||
use crate::prelude::*;
|
||||
use crate::utils::dir_ops::sync_dir;
|
||||
use crate::utils::file_wal::FileWal;
|
||||
use crate::IndexTracker;
|
||||
use crate::Op;
|
||||
use base::index::*;
|
||||
use base::operator::*;
|
||||
use base::search::*;
|
||||
use base::vector::*;
|
||||
use common::dir_ops::sync_dir;
|
||||
use parking_lot::Mutex;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::cell::UnsafeCell;
|
||||
@ -19,16 +23,16 @@ use uuid::Uuid;
|
||||
#[error("`GrowingSegment` stopped growing.")]
|
||||
pub struct GrowingSegmentInsertError;
|
||||
|
||||
pub struct GrowingSegment<S: G> {
|
||||
pub struct GrowingSegment<O: Op> {
|
||||
uuid: Uuid,
|
||||
vec: Vec<UnsafeCell<MaybeUninit<Log<S>>>>,
|
||||
vec: Vec<MaybeUninit<UnsafeCell<Log<O>>>>,
|
||||
wal: Mutex<FileWal>,
|
||||
len: AtomicUsize,
|
||||
pro: Mutex<Protect>,
|
||||
_tracker: Arc<SegmentTracker>,
|
||||
}
|
||||
|
||||
impl<S: G> GrowingSegment<S> {
|
||||
impl<O: Op> GrowingSegment<O> {
|
||||
pub fn create(
|
||||
_tracker: Arc<IndexTracker>,
|
||||
path: PathBuf,
|
||||
@ -41,7 +45,6 @@ impl<S: G> GrowingSegment<S> {
|
||||
sync_dir(&path);
|
||||
Arc::new(Self {
|
||||
uuid,
|
||||
#[allow(clippy::uninit_vec)]
|
||||
vec: unsafe {
|
||||
let mut vec = Vec::with_capacity(capacity as usize);
|
||||
vec.set_len(capacity as usize);
|
||||
@ -66,8 +69,8 @@ impl<S: G> GrowingSegment<S> {
|
||||
let mut wal = FileWal::open(path.join("wal"));
|
||||
let mut vec = Vec::new();
|
||||
while let Some(log) = wal.read() {
|
||||
let log = bincode::deserialize::<Log<S>>(&log).unwrap();
|
||||
vec.push(UnsafeCell::new(MaybeUninit::new(log)));
|
||||
let log = bincode::deserialize::<Log<O>>(&log).unwrap();
|
||||
vec.push(MaybeUninit::new(UnsafeCell::new(log)));
|
||||
}
|
||||
wal.truncate();
|
||||
let n = vec.len();
|
||||
@ -122,7 +125,7 @@ impl<S: G> GrowingSegment<S> {
|
||||
|
||||
pub fn insert(
|
||||
&self,
|
||||
vector: S::VectorOwned,
|
||||
vector: O::VectorOwned,
|
||||
payload: Payload,
|
||||
) -> Result<(), GrowingSegmentInsertError> {
|
||||
let log = Log { vector, payload };
|
||||
@ -136,7 +139,7 @@ impl<S: G> GrowingSegment<S> {
|
||||
pro.inflight += 1;
|
||||
}
|
||||
unsafe {
|
||||
(*self.vec[i].get()).write(log.clone());
|
||||
UnsafeCell::raw_get(self.vec[i].as_ptr()).write(log.clone());
|
||||
}
|
||||
while self.len.load(Ordering::Acquire) != i {
|
||||
std::hint::spin_loop();
|
||||
@ -144,7 +147,7 @@ impl<S: G> GrowingSegment<S> {
|
||||
self.len.store(1 + i, Ordering::Release);
|
||||
self.wal
|
||||
.lock()
|
||||
.write(&bincode::serialize::<Log<S>>(&log).unwrap());
|
||||
.write(&bincode::serialize::<Log<O>>(&log).unwrap());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -157,7 +160,7 @@ impl<S: G> GrowingSegment<S> {
|
||||
id: self.uuid,
|
||||
typ: "growing".to_string(),
|
||||
length: self.len() as usize,
|
||||
size: (self.len() as u64) * (std::mem::size_of::<Log<S>>() as u64),
|
||||
size: (self.len() as u64) * (std::mem::size_of::<Log<O>>() as u64),
|
||||
}
|
||||
}
|
||||
|
||||
@ -166,16 +169,16 @@ impl<S: G> GrowingSegment<S> {
|
||||
id: self.uuid,
|
||||
typ: "write".to_string(),
|
||||
length: self.len() as usize,
|
||||
size: (self.len() as u64) * (std::mem::size_of::<Log<S>>() as u64),
|
||||
size: (self.len() as u64) * (std::mem::size_of::<Log<O>>() as u64),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn vector(&self, i: u32) -> Borrowed<'_, S> {
|
||||
pub fn vector(&self, i: u32) -> Borrowed<'_, O> {
|
||||
let i = i as usize;
|
||||
if i >= self.len.load(Ordering::Acquire) {
|
||||
panic!("Out of bound.");
|
||||
}
|
||||
let log = unsafe { (*self.vec[i].get()).assume_init_ref() };
|
||||
let log = unsafe { &*self.vec[i].assume_init_ref().get().cast_const() };
|
||||
log.vector.for_borrow()
|
||||
}
|
||||
|
||||
@ -184,22 +187,22 @@ impl<S: G> GrowingSegment<S> {
|
||||
if i >= self.len.load(Ordering::Acquire) {
|
||||
panic!("Out of bound.");
|
||||
}
|
||||
let log = unsafe { (*self.vec[i].get()).assume_init_ref() };
|
||||
let log = unsafe { &*self.vec[i].assume_init_ref().get().cast_const() };
|
||||
log.payload
|
||||
}
|
||||
|
||||
pub fn basic(
|
||||
&self,
|
||||
vector: Borrowed<'_, S>,
|
||||
vector: Borrowed<'_, O>,
|
||||
_opts: &SearchOptions,
|
||||
mut filter: impl Filter,
|
||||
) -> BinaryHeap<Reverse<Element>> {
|
||||
let n = self.len.load(Ordering::Acquire);
|
||||
let mut result = BinaryHeap::new();
|
||||
for i in 0..n {
|
||||
let log = unsafe { (*self.vec[i].get()).assume_init_ref() };
|
||||
let log = unsafe { &*self.vec[i].assume_init_ref().get().cast_const() };
|
||||
if filter.check(log.payload) {
|
||||
let distance = S::distance(vector, log.vector.for_borrow());
|
||||
let distance = O::distance(vector, log.vector.for_borrow());
|
||||
result.push(Reverse(Element {
|
||||
distance,
|
||||
payload: log.payload,
|
||||
@ -211,16 +214,16 @@ impl<S: G> GrowingSegment<S> {
|
||||
|
||||
pub fn vbase<'a>(
|
||||
&'a self,
|
||||
vector: Borrowed<'a, S>,
|
||||
vector: Borrowed<'a, O>,
|
||||
_opts: &SearchOptions,
|
||||
mut filter: impl Filter + 'a,
|
||||
) -> (Vec<Element>, Box<dyn Iterator<Item = Element> + 'a>) {
|
||||
let n = self.len.load(Ordering::Acquire);
|
||||
let mut result = Vec::new();
|
||||
for i in 0..n {
|
||||
let log = unsafe { (*self.vec[i].get()).assume_init_ref() };
|
||||
let log = unsafe { &*self.vec[i].assume_init_ref().get().cast_const() };
|
||||
if filter.check(log.payload) {
|
||||
let distance = S::distance(vector, log.vector.for_borrow());
|
||||
let distance = O::distance(vector, log.vector.for_borrow());
|
||||
result.push(Element {
|
||||
distance,
|
||||
payload: log.payload,
|
||||
@ -231,23 +234,23 @@ impl<S: G> GrowingSegment<S> {
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<S: G> Send for GrowingSegment<S> {}
|
||||
unsafe impl<S: G> Sync for GrowingSegment<S> {}
|
||||
unsafe impl<O: Op> Send for GrowingSegment<O> {}
|
||||
unsafe impl<O: Op> Sync for GrowingSegment<O> {}
|
||||
|
||||
impl<S: G> Drop for GrowingSegment<S> {
|
||||
impl<O: Op> Drop for GrowingSegment<O> {
|
||||
fn drop(&mut self) {
|
||||
let n = *self.len.get_mut();
|
||||
for i in 0..n {
|
||||
unsafe {
|
||||
self.vec[i].get_mut().assume_init_drop();
|
||||
self.vec[i].assume_init_drop();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
struct Log<S: G> {
|
||||
vector: S::VectorOwned,
|
||||
struct Log<O: Op> {
|
||||
vector: O::VectorOwned,
|
||||
payload: Payload,
|
||||
}
|
||||
|
@ -1,32 +1,34 @@
|
||||
use super::growing::GrowingSegment;
|
||||
use super::SegmentTracker;
|
||||
use crate::index::indexing::DynamicIndexing;
|
||||
use crate::index::IndexTracker;
|
||||
use crate::prelude::*;
|
||||
use crate::utils::dir_ops::{dir_size, sync_dir};
|
||||
use crate::indexing::Indexing;
|
||||
use crate::utils::dir_ops::dir_size;
|
||||
use crate::IndexTracker;
|
||||
use crate::Op;
|
||||
use base::index::*;
|
||||
use base::operator::*;
|
||||
use base::search::*;
|
||||
use common::dir_ops::sync_dir;
|
||||
use std::cmp::Reverse;
|
||||
use std::collections::BinaryHeap;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use uuid::Uuid;
|
||||
|
||||
pub struct SealedSegment<S: G> {
|
||||
pub struct SealedSegment<O: Op> {
|
||||
uuid: Uuid,
|
||||
indexing: DynamicIndexing<S>,
|
||||
indexing: Indexing<O>,
|
||||
_tracker: Arc<SegmentTracker>,
|
||||
}
|
||||
|
||||
impl<S: G> SealedSegment<S> {
|
||||
pub fn create(
|
||||
impl<O: Op> SealedSegment<O> {
|
||||
pub fn create<S: Source<O>>(
|
||||
_tracker: Arc<IndexTracker>,
|
||||
path: PathBuf,
|
||||
uuid: Uuid,
|
||||
options: IndexOptions,
|
||||
sealed: Vec<Arc<SealedSegment<S>>>,
|
||||
growing: Vec<Arc<GrowingSegment<S>>>,
|
||||
source: &S,
|
||||
) -> Arc<Self> {
|
||||
std::fs::create_dir(&path).unwrap();
|
||||
let indexing = DynamicIndexing::create(&path.join("indexing"), options, sealed, growing);
|
||||
let indexing = Indexing::create(&path.join("indexing"), options, source);
|
||||
sync_dir(&path);
|
||||
Arc::new(Self {
|
||||
uuid,
|
||||
@ -41,7 +43,7 @@ impl<S: G> SealedSegment<S> {
|
||||
uuid: Uuid,
|
||||
options: IndexOptions,
|
||||
) -> Arc<Self> {
|
||||
let indexing = DynamicIndexing::open(&path.join("indexing"), options);
|
||||
let indexing = Indexing::open(&path.join("indexing"), options);
|
||||
Arc::new(Self {
|
||||
uuid,
|
||||
indexing,
|
||||
@ -65,7 +67,7 @@ impl<S: G> SealedSegment<S> {
|
||||
|
||||
pub fn basic(
|
||||
&self,
|
||||
vector: Borrowed<'_, S>,
|
||||
vector: Borrowed<'_, O>,
|
||||
opts: &SearchOptions,
|
||||
filter: impl Filter,
|
||||
) -> BinaryHeap<Reverse<Element>> {
|
||||
@ -74,7 +76,7 @@ impl<S: G> SealedSegment<S> {
|
||||
|
||||
pub fn vbase<'a>(
|
||||
&'a self,
|
||||
vector: Borrowed<'a, S>,
|
||||
vector: Borrowed<'a, O>,
|
||||
opts: &'a SearchOptions,
|
||||
filter: impl Filter + 'a,
|
||||
) -> (Vec<Element>, Box<dyn Iterator<Item = Element> + 'a>) {
|
||||
@ -85,7 +87,7 @@ impl<S: G> SealedSegment<S> {
|
||||
self.indexing.len()
|
||||
}
|
||||
|
||||
pub fn vector(&self, i: u32) -> Borrowed<'_, S> {
|
||||
pub fn vector(&self, i: u32) -> Borrowed<'_, O> {
|
||||
self.indexing.vector(i)
|
||||
}
|
||||
|
@ -1,12 +1,7 @@
|
||||
use std::fs::{read_dir, File};
|
||||
use std::fs::read_dir;
|
||||
use std::io;
|
||||
use std::path::Path;
|
||||
|
||||
pub fn sync_dir(path: impl AsRef<Path>) {
|
||||
let file = File::open(path).expect("Failed to sync dir.");
|
||||
file.sync_all().expect("Failed to sync dir.");
|
||||
}
|
||||
|
||||
pub fn dir_size(dir: &Path) -> io::Result<u64> {
|
||||
let mut size = 0;
|
||||
if dir.is_dir() {
|
3
crates/index/src/utils/mod.rs
Normal file
3
crates/index/src/utils/mod.rs
Normal file
@ -0,0 +1,3 @@
|
||||
pub mod dir_ops;
|
||||
pub mod file_wal;
|
||||
pub mod tournament_tree;
|
19
crates/ivf/Cargo.toml
Normal file
19
crates/ivf/Cargo.toml
Normal file
@ -0,0 +1,19 @@
|
||||
[package]
|
||||
name = "ivf"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
|
||||
[dependencies]
|
||||
num-traits.workspace = true
|
||||
rand.workspace = true
|
||||
serde_json.workspace = true
|
||||
|
||||
base = { path = "../base" }
|
||||
common = { path = "../common" }
|
||||
elkan_k_means = { path = "../elkan_k_means" }
|
||||
quantization = { path = "../quantization" }
|
||||
rayon = { path = "../rayon" }
|
||||
storage = { path = "../storage" }
|
||||
|
||||
[lints]
|
||||
workspace = true
|
@ -1,36 +1,33 @@
|
||||
use crate::algorithms::clustering::elkan_k_means::ElkanKMeans;
|
||||
use crate::algorithms::quantization::Quantization;
|
||||
use crate::algorithms::raw::Raw;
|
||||
use crate::index::segments::growing::GrowingSegment;
|
||||
use crate::index::segments::sealed::SealedSegment;
|
||||
use crate::prelude::*;
|
||||
use crate::utils::dir_ops::sync_dir;
|
||||
use crate::utils::element_heap::ElementHeap;
|
||||
use crate::utils::mmap_array::MmapArray;
|
||||
use crate::utils::vec2::Vec2;
|
||||
use super::OperatorIvf as Op;
|
||||
use base::index::*;
|
||||
use base::operator::*;
|
||||
use base::scalar::F32;
|
||||
use base::search::*;
|
||||
use base::vector::*;
|
||||
use common::dir_ops::sync_dir;
|
||||
use common::mmap_array::MmapArray;
|
||||
use common::vec2::Vec2;
|
||||
use elkan_k_means::ElkanKMeans;
|
||||
use num_traits::Float;
|
||||
use quantization::Quantization;
|
||||
use rand::seq::index::sample;
|
||||
use rand::thread_rng;
|
||||
use rayon::iter::{IndexedParallelIterator, IntoParallelRefMutIterator};
|
||||
use rayon::prelude::ParallelIterator;
|
||||
use rayon::iter::{IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator};
|
||||
use std::cmp::Reverse;
|
||||
use std::collections::BinaryHeap;
|
||||
use std::fs::create_dir;
|
||||
use std::path::Path;
|
||||
use std::sync::Arc;
|
||||
use storage::StorageCollection;
|
||||
|
||||
pub struct IvfNaive<S: G> {
|
||||
mmap: IvfMmap<S>,
|
||||
pub struct IvfNaive<O: Op> {
|
||||
mmap: IvfMmap<O>,
|
||||
}
|
||||
|
||||
impl<S: G> IvfNaive<S> {
|
||||
pub fn create(
|
||||
path: &Path,
|
||||
options: IndexOptions,
|
||||
sealed: Vec<Arc<SealedSegment<S>>>,
|
||||
growing: Vec<Arc<GrowingSegment<S>>>,
|
||||
) -> Self {
|
||||
impl<O: Op> IvfNaive<O> {
|
||||
pub fn create<S: Source<O>>(path: &Path, options: IndexOptions, source: &S) -> Self {
|
||||
create_dir(path).unwrap();
|
||||
let ram = make(path, sealed, growing, options);
|
||||
let ram = make(path, options, source);
|
||||
let mmap = save(ram, path);
|
||||
sync_dir(path);
|
||||
Self { mmap }
|
||||
@ -42,20 +39,20 @@ impl<S: G> IvfNaive<S> {
|
||||
}
|
||||
|
||||
pub fn len(&self) -> u32 {
|
||||
self.mmap.raw.len()
|
||||
self.mmap.storage.len()
|
||||
}
|
||||
|
||||
pub fn vector(&self, i: u32) -> Borrowed<'_, S> {
|
||||
self.mmap.raw.vector(i)
|
||||
pub fn vector(&self, i: u32) -> Borrowed<'_, O> {
|
||||
self.mmap.storage.vector(i)
|
||||
}
|
||||
|
||||
pub fn payload(&self, i: u32) -> Payload {
|
||||
self.mmap.raw.payload(i)
|
||||
self.mmap.storage.payload(i)
|
||||
}
|
||||
|
||||
pub fn basic(
|
||||
&self,
|
||||
vector: Borrowed<'_, S>,
|
||||
vector: Borrowed<'_, O>,
|
||||
opts: &SearchOptions,
|
||||
filter: impl Filter,
|
||||
) -> BinaryHeap<Reverse<Element>> {
|
||||
@ -64,7 +61,7 @@ impl<S: G> IvfNaive<S> {
|
||||
|
||||
pub fn vbase<'a>(
|
||||
&'a self,
|
||||
vector: Borrowed<'a, S>,
|
||||
vector: Borrowed<'a, O>,
|
||||
opts: &'a SearchOptions,
|
||||
filter: impl Filter + 'a,
|
||||
) -> (Vec<Element>, Box<(dyn Iterator<Item = Element> + 'a)>) {
|
||||
@ -72,55 +69,50 @@ impl<S: G> IvfNaive<S> {
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<S: G> Send for IvfNaive<S> {}
|
||||
unsafe impl<S: G> Sync for IvfNaive<S> {}
|
||||
unsafe impl<O: Op> Send for IvfNaive<O> {}
|
||||
unsafe impl<O: Op> Sync for IvfNaive<O> {}
|
||||
|
||||
pub struct IvfRam<S: G> {
|
||||
raw: Arc<Raw<S>>,
|
||||
quantization: Quantization<S>,
|
||||
pub struct IvfRam<O: Op> {
|
||||
storage: Arc<StorageCollection<O>>,
|
||||
quantization: Quantization<O, StorageCollection<O>>,
|
||||
// ----------------------
|
||||
dims: u32,
|
||||
// ----------------------
|
||||
nlist: u32,
|
||||
// ----------------------
|
||||
centroids: Vec2<Scalar<S>>,
|
||||
centroids: Vec2<Scalar<O>>,
|
||||
ptr: Vec<usize>,
|
||||
payloads: Vec<Payload>,
|
||||
}
|
||||
|
||||
unsafe impl<S: G> Send for IvfRam<S> {}
|
||||
unsafe impl<S: G> Sync for IvfRam<S> {}
|
||||
unsafe impl<O: Op> Send for IvfRam<O> {}
|
||||
unsafe impl<O: Op> Sync for IvfRam<O> {}
|
||||
|
||||
pub struct IvfMmap<S: G> {
|
||||
raw: Arc<Raw<S>>,
|
||||
quantization: Quantization<S>,
|
||||
pub struct IvfMmap<O: Op> {
|
||||
storage: Arc<StorageCollection<O>>,
|
||||
quantization: Quantization<O, StorageCollection<O>>,
|
||||
// ----------------------
|
||||
dims: u32,
|
||||
// ----------------------
|
||||
nlist: u32,
|
||||
// ----------------------
|
||||
centroids: MmapArray<Scalar<S>>,
|
||||
centroids: MmapArray<Scalar<O>>,
|
||||
ptr: MmapArray<usize>,
|
||||
payloads: MmapArray<Payload>,
|
||||
}
|
||||
|
||||
unsafe impl<S: G> Send for IvfMmap<S> {}
|
||||
unsafe impl<S: G> Sync for IvfMmap<S> {}
|
||||
unsafe impl<O: Op> Send for IvfMmap<O> {}
|
||||
unsafe impl<O: Op> Sync for IvfMmap<O> {}
|
||||
|
||||
impl<S: G> IvfMmap<S> {
|
||||
fn centroids(&self, i: u32) -> &[Scalar<S>] {
|
||||
impl<O: Op> IvfMmap<O> {
|
||||
fn centroids(&self, i: u32) -> &[Scalar<O>] {
|
||||
let s = i as usize * self.dims as usize;
|
||||
let e = (i + 1) as usize * self.dims as usize;
|
||||
&self.centroids[s..e]
|
||||
}
|
||||
}
|
||||
|
||||
pub fn make<S: G>(
|
||||
path: &Path,
|
||||
sealed: Vec<Arc<SealedSegment<S>>>,
|
||||
growing: Vec<Arc<GrowingSegment<S>>>,
|
||||
options: IndexOptions,
|
||||
) -> IvfRam<S> {
|
||||
pub fn make<O: Op, S: Source<O>>(path: &Path, options: IndexOptions, source: &S) -> IvfRam<O> {
|
||||
let VectorOptions { dims, .. } = options.vector;
|
||||
let IvfIndexingOptions {
|
||||
least_iterations,
|
||||
@ -129,25 +121,23 @@ pub fn make<S: G>(
|
||||
nsample,
|
||||
quantization: quantization_opts,
|
||||
} = options.indexing.clone().unwrap_ivf();
|
||||
let raw = Arc::new(Raw::<S>::create(
|
||||
&path.join("raw"),
|
||||
options.clone(),
|
||||
sealed,
|
||||
growing,
|
||||
));
|
||||
let n = raw.len();
|
||||
let storage = Arc::new(StorageCollection::<O>::create(&path.join("raw"), source));
|
||||
let n = storage.len();
|
||||
let m = std::cmp::min(nsample, n);
|
||||
let f = sample(&mut thread_rng(), n as usize, m as usize).into_vec();
|
||||
let mut samples = Vec2::new(dims, m as usize);
|
||||
for i in 0..m {
|
||||
samples[i as usize].copy_from_slice(raw.vector(f[i as usize] as u32).to_vec().as_ref());
|
||||
S::elkan_k_means_normalize(&mut samples[i as usize]);
|
||||
samples[i as usize].copy_from_slice(storage.vector(f[i as usize] as u32).to_vec().as_ref());
|
||||
O::elkan_k_means_normalize(&mut samples[i as usize]);
|
||||
}
|
||||
let mut k_means = ElkanKMeans::<S>::new(nlist as usize, samples);
|
||||
rayon::check();
|
||||
let mut k_means = ElkanKMeans::<O>::new(nlist as usize, samples);
|
||||
for _ in 0..least_iterations {
|
||||
rayon::check();
|
||||
k_means.iterate();
|
||||
}
|
||||
for _ in least_iterations..iterations {
|
||||
rayon::check();
|
||||
if k_means.iterate() {
|
||||
break;
|
||||
}
|
||||
@ -155,11 +145,12 @@ pub fn make<S: G>(
|
||||
let centroids = k_means.finish();
|
||||
let mut idx = vec![0usize; n as usize];
|
||||
idx.par_iter_mut().enumerate().for_each(|(i, x)| {
|
||||
let vector = raw.vector(i as u32);
|
||||
let vector = S::elkan_k_means_normalize2(vector);
|
||||
rayon::check();
|
||||
let vector = storage.vector(i as u32);
|
||||
let vector = O::elkan_k_means_normalize2(vector);
|
||||
let mut result = (F32::infinity(), 0);
|
||||
for i in 0..nlist as usize {
|
||||
let dis = S::elkan_k_means_distance2(vector.for_borrow(), ¢roids[i]);
|
||||
let dis = O::elkan_k_means_distance2(vector.for_borrow(), ¢roids[i]);
|
||||
result = std::cmp::min(result, (dis, i));
|
||||
}
|
||||
*x = result.1;
|
||||
@ -168,27 +159,31 @@ pub fn make<S: G>(
|
||||
let mut invlists_payloads = vec![Vec::new(); nlist as usize];
|
||||
for i in 0..n {
|
||||
invlists_ids[idx[i as usize]].push(i);
|
||||
invlists_payloads[idx[i as usize]].push(raw.payload(i));
|
||||
invlists_payloads[idx[i as usize]].push(storage.payload(i));
|
||||
}
|
||||
rayon::check();
|
||||
let permutation = Vec::from_iter((0..nlist).flat_map(|i| &invlists_ids[i as usize]).copied());
|
||||
rayon::check();
|
||||
let payloads = Vec::from_iter(
|
||||
(0..nlist)
|
||||
.flat_map(|i| &invlists_payloads[i as usize])
|
||||
.copied(),
|
||||
);
|
||||
rayon::check();
|
||||
let quantization = Quantization::create(
|
||||
&path.join("quantization"),
|
||||
options.clone(),
|
||||
quantization_opts,
|
||||
&raw,
|
||||
&storage,
|
||||
permutation,
|
||||
);
|
||||
rayon::check();
|
||||
let mut ptr = vec![0usize; nlist as usize + 1];
|
||||
for i in 0..nlist {
|
||||
ptr[i as usize + 1] = ptr[i as usize] + invlists_ids[i as usize].len();
|
||||
}
|
||||
IvfRam {
|
||||
raw,
|
||||
storage,
|
||||
quantization,
|
||||
centroids,
|
||||
nlist,
|
||||
@ -198,7 +193,7 @@ pub fn make<S: G>(
|
||||
}
|
||||
}
|
||||
|
||||
pub fn save<S: G>(ram: IvfRam<S>, path: &Path) -> IvfMmap<S> {
|
||||
pub fn save<O: Op>(ram: IvfRam<O>, path: &Path) -> IvfMmap<O> {
|
||||
let centroids = MmapArray::create(
|
||||
&path.join("centroids"),
|
||||
(0..ram.nlist)
|
||||
@ -208,7 +203,7 @@ pub fn save<S: G>(ram: IvfRam<S>, path: &Path) -> IvfMmap<S> {
|
||||
let ptr = MmapArray::create(&path.join("ptr"), ram.ptr.iter().copied());
|
||||
let payloads = MmapArray::create(&path.join("payload"), ram.payloads.iter().copied());
|
||||
IvfMmap {
|
||||
raw: ram.raw,
|
||||
storage: ram.storage,
|
||||
quantization: ram.quantization,
|
||||
dims: ram.dims,
|
||||
nlist: ram.nlist,
|
||||
@ -218,20 +213,20 @@ pub fn save<S: G>(ram: IvfRam<S>, path: &Path) -> IvfMmap<S> {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn open<S: G>(path: &Path, options: IndexOptions) -> IvfMmap<S> {
|
||||
let raw = Arc::new(Raw::open(&path.join("raw"), options.clone()));
|
||||
pub fn open<O: Op>(path: &Path, options: IndexOptions) -> IvfMmap<O> {
|
||||
let storage = Arc::new(StorageCollection::open(&path.join("raw"), options.clone()));
|
||||
let quantization = Quantization::open(
|
||||
&path.join("quantization"),
|
||||
options.clone(),
|
||||
options.indexing.clone().unwrap_ivf().quantization,
|
||||
&raw,
|
||||
&storage,
|
||||
);
|
||||
let centroids = MmapArray::open(&path.join("centroids"));
|
||||
let ptr = MmapArray::open(&path.join("ptr"));
|
||||
let payloads = MmapArray::open(&path.join("payload"));
|
||||
let IvfIndexingOptions { nlist, .. } = options.indexing.unwrap_ivf();
|
||||
IvfMmap {
|
||||
raw,
|
||||
storage,
|
||||
quantization,
|
||||
dims: options.vector.dims,
|
||||
nlist,
|
||||
@ -241,27 +236,25 @@ pub fn open<S: G>(path: &Path, options: IndexOptions) -> IvfMmap<S> {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn basic<S: G>(
|
||||
mmap: &IvfMmap<S>,
|
||||
vector: Borrowed<'_, S>,
|
||||
pub fn basic<O: Op>(
|
||||
mmap: &IvfMmap<O>,
|
||||
vector: Borrowed<'_, O>,
|
||||
nprobe: u32,
|
||||
mut filter: impl Filter,
|
||||
) -> BinaryHeap<Reverse<Element>> {
|
||||
let target = S::elkan_k_means_normalize2(vector);
|
||||
let mut lists = ElementHeap::new(nprobe as usize);
|
||||
let target = O::elkan_k_means_normalize2(vector);
|
||||
let mut lists = Vec::with_capacity(mmap.nlist as usize);
|
||||
for i in 0..mmap.nlist {
|
||||
let centroid = mmap.centroids(i);
|
||||
let distance = S::elkan_k_means_distance2(target.for_borrow(), centroid);
|
||||
if lists.check(distance) {
|
||||
lists.push(Element {
|
||||
distance,
|
||||
payload: i as Payload,
|
||||
});
|
||||
}
|
||||
let distance = O::elkan_k_means_distance2(target.for_borrow(), centroid);
|
||||
lists.push((distance, i));
|
||||
}
|
||||
if nprobe < mmap.nlist {
|
||||
lists.select_nth_unstable(nprobe as usize);
|
||||
lists.truncate(nprobe as usize);
|
||||
}
|
||||
let lists = lists.into_sorted_vec();
|
||||
let mut result = BinaryHeap::new();
|
||||
for i in lists.iter().map(|e| e.payload as usize) {
|
||||
for i in lists.iter().map(|(_, i)| *i as usize) {
|
||||
let start = mmap.ptr[i];
|
||||
let end = mmap.ptr[i + 1];
|
||||
for j in start..end {
|
||||
@ -275,27 +268,25 @@ pub fn basic<S: G>(
|
||||
result
|
||||
}
|
||||
|
||||
pub fn vbase<'a, S: G>(
|
||||
mmap: &'a IvfMmap<S>,
|
||||
vector: Borrowed<'a, S>,
|
||||
pub fn vbase<'a, O: Op>(
|
||||
mmap: &'a IvfMmap<O>,
|
||||
vector: Borrowed<'a, O>,
|
||||
nprobe: u32,
|
||||
mut filter: impl Filter + 'a,
|
||||
) -> (Vec<Element>, Box<(dyn Iterator<Item = Element> + 'a)>) {
|
||||
let target = S::elkan_k_means_normalize2(vector);
|
||||
let mut lists = ElementHeap::new(nprobe as usize);
|
||||
let target = O::elkan_k_means_normalize2(vector);
|
||||
let mut lists = Vec::with_capacity(mmap.nlist as usize);
|
||||
for i in 0..mmap.nlist {
|
||||
let centroid = mmap.centroids(i);
|
||||
let distance = S::elkan_k_means_distance2(target.for_borrow(), centroid);
|
||||
if lists.check(distance) {
|
||||
lists.push(Element {
|
||||
distance,
|
||||
payload: i as Payload,
|
||||
});
|
||||
}
|
||||
let distance = O::elkan_k_means_distance2(target.for_borrow(), centroid);
|
||||
lists.push((distance, i));
|
||||
}
|
||||
if nprobe < mmap.nlist {
|
||||
lists.select_nth_unstable(nprobe as usize);
|
||||
lists.truncate(nprobe as usize);
|
||||
}
|
||||
let lists = lists.into_sorted_vec();
|
||||
let mut result = Vec::new();
|
||||
for i in lists.iter().map(|e| e.payload as usize) {
|
||||
for i in lists.iter().map(|(_, i)| *i as usize) {
|
||||
let start = mmap.ptr[i];
|
||||
let end = mmap.ptr[i + 1];
|
||||
for j in start..end {
|
585
crates/ivf/src/ivf_pq.rs
Normal file
585
crates/ivf/src/ivf_pq.rs
Normal file
@ -0,0 +1,585 @@
|
||||
use super::OperatorIvf as Op;
|
||||
use base::distance::*;
|
||||
use base::index::*;
|
||||
use base::operator::*;
|
||||
use base::scalar::*;
|
||||
use base::search::*;
|
||||
use base::vector::*;
|
||||
use common::dir_ops::sync_dir;
|
||||
use common::mmap_array::MmapArray;
|
||||
use common::vec2::Vec2;
|
||||
use elkan_k_means::ElkanKMeans;
|
||||
use num_traits::{Float, Zero};
|
||||
use quantization::product::operator::OperatorProductQuantization;
|
||||
use rand::seq::index::sample;
|
||||
use rand::thread_rng;
|
||||
use rayon::iter::IndexedParallelIterator;
|
||||
use rayon::iter::IntoParallelRefMutIterator;
|
||||
use rayon::iter::ParallelIterator;
|
||||
use rayon::slice::ParallelSliceMut;
|
||||
use std::cmp::Reverse;
|
||||
use std::collections::BinaryHeap;
|
||||
use std::fs::create_dir;
|
||||
use std::path::Path;
|
||||
use std::sync::Arc;
|
||||
use storage::StorageCollection;
|
||||
|
||||
pub struct IvfPq<O: Op> {
|
||||
mmap: IvfMmap<O>,
|
||||
}
|
||||
|
||||
impl<O: Op> IvfPq<O> {
|
||||
pub fn create<S: Source<O>>(path: &Path, options: IndexOptions, source: &S) -> Self {
|
||||
create_dir(path).unwrap();
|
||||
let ram = make(path, options, source);
|
||||
let mmap = save(ram, path);
|
||||
sync_dir(path);
|
||||
Self { mmap }
|
||||
}
|
||||
|
||||
pub fn open(path: &Path, options: IndexOptions) -> Self {
|
||||
let mmap = open(path, options);
|
||||
Self { mmap }
|
||||
}
|
||||
|
||||
pub fn len(&self) -> u32 {
|
||||
self.mmap.storage.len()
|
||||
}
|
||||
|
||||
pub fn vector(&self, i: u32) -> Borrowed<'_, O> {
|
||||
self.mmap.storage.vector(i)
|
||||
}
|
||||
|
||||
pub fn payload(&self, i: u32) -> Payload {
|
||||
self.mmap.storage.payload(i)
|
||||
}
|
||||
|
||||
pub fn basic(
|
||||
&self,
|
||||
vector: Borrowed<'_, O>,
|
||||
opts: &SearchOptions,
|
||||
filter: impl Filter,
|
||||
) -> BinaryHeap<Reverse<Element>> {
|
||||
basic(&self.mmap, vector, opts.ivf_nprobe, filter)
|
||||
}
|
||||
|
||||
pub fn vbase<'a>(
|
||||
&'a self,
|
||||
vector: Borrowed<'a, O>,
|
||||
opts: &'a SearchOptions,
|
||||
filter: impl Filter + 'a,
|
||||
) -> (Vec<Element>, Box<(dyn Iterator<Item = Element> + 'a)>) {
|
||||
vbase(&self.mmap, vector, opts.ivf_nprobe, filter)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<O: Op> Send for IvfPq<O> {}
|
||||
unsafe impl<O: Op> Sync for IvfPq<O> {}
|
||||
|
||||
pub struct IvfRam<O: Op> {
|
||||
storage: Arc<StorageCollection<O>>,
|
||||
quantization: ProductQuantization<O>,
|
||||
// ----------------------
|
||||
dims: u32,
|
||||
// ----------------------
|
||||
nlist: u32,
|
||||
// ----------------------
|
||||
centroids: Vec2<Scalar<O>>,
|
||||
ptr: Vec<usize>,
|
||||
payloads: Vec<Payload>,
|
||||
}
|
||||
|
||||
unsafe impl<O: Op> Send for IvfRam<O> {}
|
||||
unsafe impl<O: Op> Sync for IvfRam<O> {}
|
||||
|
||||
pub struct IvfMmap<O: Op> {
|
||||
storage: Arc<StorageCollection<O>>,
|
||||
quantization: ProductQuantization<O>,
|
||||
// ----------------------
|
||||
dims: u32,
|
||||
// ----------------------
|
||||
nlist: u32,
|
||||
// ----------------------
|
||||
centroids: MmapArray<Scalar<O>>,
|
||||
ptr: MmapArray<usize>,
|
||||
payloads: MmapArray<Payload>,
|
||||
}
|
||||
|
||||
unsafe impl<O: Op> Send for IvfMmap<O> {}
|
||||
unsafe impl<O: Op> Sync for IvfMmap<O> {}
|
||||
|
||||
impl<O: Op> IvfMmap<O> {
|
||||
fn centroids(&self, i: u32) -> &[Scalar<O>] {
|
||||
let s = i as usize * self.dims as usize;
|
||||
let e = (i + 1) as usize * self.dims as usize;
|
||||
&self.centroids[s..e]
|
||||
}
|
||||
}
|
||||
|
||||
pub fn make<O: Op, S: Source<O>>(path: &Path, options: IndexOptions, source: &S) -> IvfRam<O> {
|
||||
let VectorOptions { dims, .. } = options.vector;
|
||||
let IvfIndexingOptions {
|
||||
least_iterations,
|
||||
iterations,
|
||||
nlist,
|
||||
nsample,
|
||||
quantization: quantization_opts,
|
||||
} = options.indexing.clone().unwrap_ivf();
|
||||
let storage = Arc::new(StorageCollection::<O>::create(&path.join("raw"), source));
|
||||
let n = storage.len();
|
||||
let m = std::cmp::min(nsample, n);
|
||||
let f = sample(&mut thread_rng(), n as usize, m as usize).into_vec();
|
||||
let mut samples = Vec2::new(dims, m as usize);
|
||||
for i in 0..m {
|
||||
samples[i as usize].copy_from_slice(storage.vector(f[i as usize] as u32).to_vec().as_ref());
|
||||
O::elkan_k_means_normalize(&mut samples[i as usize]);
|
||||
}
|
||||
rayon::check();
|
||||
let mut k_means = ElkanKMeans::<O>::new(nlist as usize, samples);
|
||||
for _ in 0..least_iterations {
|
||||
rayon::check();
|
||||
k_means.iterate();
|
||||
}
|
||||
for _ in least_iterations..iterations {
|
||||
rayon::check();
|
||||
if k_means.iterate() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
let centroids = k_means.finish();
|
||||
let mut idx = vec![0usize; n as usize];
|
||||
idx.par_iter_mut().enumerate().for_each(|(i, x)| {
|
||||
rayon::check();
|
||||
let vector = storage.vector(i as u32);
|
||||
let vector = O::elkan_k_means_normalize2(vector);
|
||||
let mut result = (F32::infinity(), 0);
|
||||
for i in 0..nlist as usize {
|
||||
let dis = O::elkan_k_means_distance2(vector.for_borrow(), ¢roids[i]);
|
||||
result = std::cmp::min(result, (dis, i));
|
||||
}
|
||||
*x = result.1;
|
||||
});
|
||||
let mut invlists_ids = vec![Vec::new(); nlist as usize];
|
||||
let mut invlists_payloads = vec![Vec::new(); nlist as usize];
|
||||
for i in 0..n {
|
||||
invlists_ids[idx[i as usize]].push(i);
|
||||
invlists_payloads[idx[i as usize]].push(storage.payload(i));
|
||||
}
|
||||
let mut ptr = vec![0usize; nlist as usize + 1];
|
||||
for i in 0..nlist {
|
||||
ptr[i as usize + 1] = ptr[i as usize] + invlists_ids[i as usize].len();
|
||||
}
|
||||
let ids = Vec::from_iter((0..nlist).flat_map(|i| &invlists_ids[i as usize]).copied());
|
||||
let payloads = Vec::from_iter(
|
||||
(0..nlist)
|
||||
.flat_map(|i| &invlists_payloads[i as usize])
|
||||
.copied(),
|
||||
);
|
||||
rayon::check();
|
||||
let residuals = {
|
||||
let mut residuals = Vec2::new(options.vector.dims, n as usize);
|
||||
residuals
|
||||
.par_chunks_mut(dims as usize)
|
||||
.enumerate()
|
||||
.for_each(|(i, v)| {
|
||||
for j in 0..dims {
|
||||
v[j as usize] = storage.vector(ids[i]).to_vec()[j as usize]
|
||||
- centroids[idx[ids[i] as usize]][j as usize];
|
||||
}
|
||||
});
|
||||
residuals
|
||||
};
|
||||
let quantization = ProductQuantization::create(
|
||||
&path.join("quantization"),
|
||||
options.clone(),
|
||||
quantization_opts,
|
||||
&residuals,
|
||||
¢roids,
|
||||
);
|
||||
IvfRam {
|
||||
storage,
|
||||
quantization,
|
||||
centroids,
|
||||
nlist,
|
||||
dims,
|
||||
ptr,
|
||||
payloads,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn save<O: Op>(ram: IvfRam<O>, path: &Path) -> IvfMmap<O> {
|
||||
let centroids = MmapArray::create(
|
||||
&path.join("centroids"),
|
||||
(0..ram.nlist)
|
||||
.flat_map(|i| &ram.centroids[i as usize])
|
||||
.copied(),
|
||||
);
|
||||
let ptr = MmapArray::create(&path.join("ptr"), ram.ptr.iter().copied());
|
||||
let payloads = MmapArray::create(&path.join("payload"), ram.payloads.iter().copied());
|
||||
IvfMmap {
|
||||
storage: ram.storage,
|
||||
quantization: ram.quantization,
|
||||
dims: ram.dims,
|
||||
nlist: ram.nlist,
|
||||
centroids,
|
||||
ptr,
|
||||
payloads,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn open<O: Op>(path: &Path, options: IndexOptions) -> IvfMmap<O> {
|
||||
let storage = Arc::new(StorageCollection::open(&path.join("raw"), options.clone()));
|
||||
let quantization = ProductQuantization::open(
|
||||
&path.join("quantization"),
|
||||
options.clone(),
|
||||
options.indexing.clone().unwrap_ivf().quantization,
|
||||
&storage,
|
||||
);
|
||||
let centroids = MmapArray::open(&path.join("centroids"));
|
||||
let ptr = MmapArray::open(&path.join("ptr"));
|
||||
let payloads = MmapArray::open(&path.join("payload"));
|
||||
let IvfIndexingOptions { nlist, .. } = options.indexing.unwrap_ivf();
|
||||
IvfMmap {
|
||||
storage,
|
||||
quantization,
|
||||
dims: options.vector.dims,
|
||||
nlist,
|
||||
centroids,
|
||||
ptr,
|
||||
payloads,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn basic<O: Op>(
|
||||
mmap: &IvfMmap<O>,
|
||||
vector: Borrowed<'_, O>,
|
||||
nprobe: u32,
|
||||
mut filter: impl Filter,
|
||||
) -> BinaryHeap<Reverse<Element>> {
|
||||
let dense = vector.to_vec();
|
||||
let mut lists = Vec::with_capacity(mmap.nlist as usize);
|
||||
for i in 0..mmap.nlist {
|
||||
let centroid = mmap.centroids(i);
|
||||
let distance = O::product_quantization_dense_distance(&dense, centroid);
|
||||
lists.push((distance, i));
|
||||
}
|
||||
if nprobe < mmap.nlist {
|
||||
lists.select_nth_unstable(nprobe as usize);
|
||||
lists.truncate(nprobe as usize);
|
||||
}
|
||||
let runtime_table = mmap.quantization.init_query(vector.to_vec().as_ref());
|
||||
let mut result = BinaryHeap::new();
|
||||
for &(coarse_dis, key) in lists.iter() {
|
||||
let start = mmap.ptr[key as usize];
|
||||
let end = mmap.ptr[key as usize + 1];
|
||||
for j in start..end {
|
||||
let payload = mmap.payloads[j];
|
||||
if filter.check(payload) {
|
||||
let distance = mmap.quantization.distance_with_codes(
|
||||
vector,
|
||||
j as u32,
|
||||
mmap.centroids(key),
|
||||
key as usize,
|
||||
coarse_dis,
|
||||
&runtime_table,
|
||||
);
|
||||
result.push(Reverse(Element { distance, payload }));
|
||||
}
|
||||
}
|
||||
}
|
||||
result
|
||||
}
|
||||
|
||||
pub fn vbase<'a, O: Op>(
|
||||
mmap: &'a IvfMmap<O>,
|
||||
vector: Borrowed<'a, O>,
|
||||
nprobe: u32,
|
||||
mut filter: impl Filter + 'a,
|
||||
) -> (Vec<Element>, Box<(dyn Iterator<Item = Element> + 'a)>) {
|
||||
let dense = vector.to_vec();
|
||||
let mut lists = Vec::with_capacity(mmap.nlist as usize);
|
||||
for i in 0..mmap.nlist {
|
||||
let centroid = mmap.centroids(i);
|
||||
let distance = O::product_quantization_dense_distance(&dense, centroid);
|
||||
lists.push((distance, i));
|
||||
}
|
||||
if nprobe < mmap.nlist {
|
||||
lists.select_nth_unstable(nprobe as usize);
|
||||
lists.truncate(nprobe as usize);
|
||||
}
|
||||
let runtime_table = mmap.quantization.init_query(vector.to_vec().as_ref());
|
||||
let mut result = Vec::new();
|
||||
for &(coarse_dis, key) in lists.iter() {
|
||||
let start = mmap.ptr[key as usize];
|
||||
let end = mmap.ptr[key as usize + 1];
|
||||
for j in start..end {
|
||||
let payload = mmap.payloads[j];
|
||||
if filter.check(payload) {
|
||||
let distance = mmap.quantization.distance_with_codes(
|
||||
vector,
|
||||
j as u32,
|
||||
mmap.centroids(key),
|
||||
key as usize,
|
||||
coarse_dis,
|
||||
&runtime_table,
|
||||
);
|
||||
result.push(Element { distance, payload });
|
||||
}
|
||||
}
|
||||
}
|
||||
(result, Box::new(std::iter::empty()))
|
||||
}
|
||||
|
||||
pub struct ProductQuantization<O: Op> {
|
||||
dims: u32,
|
||||
ratio: u32,
|
||||
centroids: Vec<Scalar<O>>,
|
||||
codes: MmapArray<u8>,
|
||||
precomputed_table: Vec<F32>,
|
||||
}
|
||||
|
||||
unsafe impl<O: Op> Send for ProductQuantization<O> {}
|
||||
unsafe impl<O: Op> Sync for ProductQuantization<O> {}
|
||||
|
||||
impl<O: Op> ProductQuantization<O> {
|
||||
pub fn codes(&self, i: u32) -> &[u8] {
|
||||
let width = self.dims.div_ceil(self.ratio);
|
||||
let s = i as usize * width as usize;
|
||||
let e = (i + 1) as usize * width as usize;
|
||||
&self.codes[s..e]
|
||||
}
|
||||
pub fn open(
|
||||
path: &Path,
|
||||
options: IndexOptions,
|
||||
quantization_options: QuantizationOptions,
|
||||
_: &Arc<StorageCollection<O>>,
|
||||
) -> Self {
|
||||
let QuantizationOptions::Product(quantization_options) = quantization_options else {
|
||||
unreachable!()
|
||||
};
|
||||
let centroids =
|
||||
serde_json::from_slice(&std::fs::read(path.join("centroids")).unwrap()).unwrap();
|
||||
let codes = MmapArray::open(&path.join("codes"));
|
||||
let precomputed_table =
|
||||
serde_json::from_slice(&std::fs::read(path.join("table")).unwrap()).unwrap();
|
||||
Self {
|
||||
dims: options.vector.dims,
|
||||
ratio: quantization_options.ratio as _,
|
||||
centroids,
|
||||
codes,
|
||||
precomputed_table,
|
||||
}
|
||||
}
|
||||
pub fn create(
|
||||
path: &Path,
|
||||
options: IndexOptions,
|
||||
quantization_options: QuantizationOptions,
|
||||
v2: &Vec2<Scalar<O>>,
|
||||
coarse_centroids: &Vec2<Scalar<O>>,
|
||||
) -> Self {
|
||||
create_dir(path).unwrap();
|
||||
let QuantizationOptions::Product(quantization_options) = quantization_options else {
|
||||
unreachable!()
|
||||
};
|
||||
let dims = options.vector.dims;
|
||||
let ratio = quantization_options.ratio as u32;
|
||||
let n = v2.len();
|
||||
let m = std::cmp::min(n, quantization_options.sample as usize);
|
||||
let samples = {
|
||||
let f = sample(&mut thread_rng(), n, m).into_vec();
|
||||
let mut samples = Vec2::new(dims, m);
|
||||
for i in 0..m {
|
||||
samples[i].copy_from_slice(&v2[f[i]]);
|
||||
}
|
||||
samples
|
||||
};
|
||||
let width = dims.div_ceil(ratio);
|
||||
// a temp layout (width * 256 * subdims) for par_chunks_mut
|
||||
let mut tmp_centroids = vec![Scalar::<O>::zero(); 256 * dims as usize];
|
||||
// this par_for parallelizes over sub quantizers
|
||||
tmp_centroids
|
||||
.par_chunks_mut(256 * ratio as usize)
|
||||
.enumerate()
|
||||
.for_each(|(i, v)| {
|
||||
// i is the index of subquantizer
|
||||
let subdims = std::cmp::min(ratio, dims - ratio * i as u32) as usize;
|
||||
let mut subsamples = Vec2::new(subdims as u32, m);
|
||||
for j in 0..m {
|
||||
let src = &samples[j][i * ratio as usize..][..subdims];
|
||||
subsamples[j].copy_from_slice(src);
|
||||
}
|
||||
let mut k_means = ElkanKMeans::<O::ProductQuantizationL2>::new(256, subsamples);
|
||||
for _ in 0..25 {
|
||||
if k_means.iterate() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
let centroid = k_means.finish();
|
||||
for j in 0usize..=255 {
|
||||
v[j * subdims..][..subdims].copy_from_slice(¢roid[j]);
|
||||
}
|
||||
});
|
||||
// transform back to normal layout (256 * width * subdims)
|
||||
let mut centroids = vec![Scalar::<O>::zero(); 256 * dims as usize];
|
||||
centroids
|
||||
.par_chunks_mut(dims as usize)
|
||||
.enumerate()
|
||||
.for_each(|(i, v)| {
|
||||
for j in 0..width {
|
||||
let subdims = std::cmp::min(ratio, dims - ratio * j) as usize;
|
||||
v[(j * ratio) as usize..][..subdims].copy_from_slice(
|
||||
&tmp_centroids[(j * ratio) as usize * 256..][i * subdims..][..subdims],
|
||||
);
|
||||
}
|
||||
});
|
||||
let mut codes = vec![0u8; n * width as usize];
|
||||
codes
|
||||
.par_chunks_mut(width as usize)
|
||||
.enumerate()
|
||||
.for_each(|(id, v)| {
|
||||
let vector = v2[id].to_vec();
|
||||
let width = dims.div_ceil(ratio);
|
||||
for i in 0..width {
|
||||
let subdims = std::cmp::min(ratio, dims - ratio * i);
|
||||
let mut minimal = F32::infinity();
|
||||
let mut target = 0u8;
|
||||
let left = &vector[(i * ratio) as usize..][..subdims as usize];
|
||||
for j in 0u8..=255 {
|
||||
let right = ¢roids[j as usize * dims as usize..]
|
||||
[(i * ratio) as usize..][..subdims as usize];
|
||||
let dis = O::ProductQuantizationL2::product_quantization_dense_distance(
|
||||
left, right,
|
||||
);
|
||||
if dis < minimal {
|
||||
minimal = dis;
|
||||
target = j;
|
||||
}
|
||||
}
|
||||
v[i as usize] = target;
|
||||
}
|
||||
});
|
||||
sync_dir(path);
|
||||
std::fs::write(
|
||||
path.join("centroids"),
|
||||
serde_json::to_string(¢roids).unwrap(),
|
||||
)
|
||||
.unwrap();
|
||||
let codes = MmapArray::create(&path.join("codes"), codes.into_iter());
|
||||
// precompute_table
|
||||
let nlist = coarse_centroids.len();
|
||||
let width = dims.div_ceil(ratio);
|
||||
let mut precomputed_table = Vec::new();
|
||||
precomputed_table.resize(nlist * width as usize * 256, F32::zero());
|
||||
precomputed_table
|
||||
.par_chunks_mut(width as usize * 256)
|
||||
.enumerate()
|
||||
.for_each(|(i, v)| {
|
||||
let x_c = &coarse_centroids[i];
|
||||
for j in 0..width {
|
||||
let subdims = std::cmp::min(ratio, dims - ratio * j);
|
||||
let sub_x_c = &x_c[(j * ratio) as usize..][..subdims as usize];
|
||||
for k in 0usize..256 {
|
||||
let sub_x_r = ¢roids[k * dims as usize..][(j * ratio) as usize..]
|
||||
[..subdims as usize];
|
||||
v[j as usize * 256 + k] = squared_norm::<O>(subdims, sub_x_r)
|
||||
+ F32(2.0) * inner_product::<O>(subdims, sub_x_c, sub_x_r);
|
||||
}
|
||||
}
|
||||
});
|
||||
std::fs::write(
|
||||
path.join("table"),
|
||||
serde_json::to_string(&precomputed_table).unwrap(),
|
||||
)
|
||||
.unwrap();
|
||||
Self {
|
||||
dims,
|
||||
ratio,
|
||||
centroids,
|
||||
codes,
|
||||
precomputed_table,
|
||||
}
|
||||
}
|
||||
|
||||
// compute term2 at query time
|
||||
pub fn init_query(&self, query: &[Scalar<O>]) -> Vec<F32> {
|
||||
match O::DISTANCE_KIND {
|
||||
DistanceKind::Cos => Vec::new(),
|
||||
DistanceKind::L2 | DistanceKind::Dot | DistanceKind::Jaccard => {
|
||||
let dims = self.dims;
|
||||
let ratio = self.ratio;
|
||||
let width = dims.div_ceil(ratio);
|
||||
let mut runtime_table = vec![F32::zero(); width as usize * 256];
|
||||
for i in 0..256 {
|
||||
for j in 0..width {
|
||||
let subdims = std::cmp::min(ratio, dims - ratio * j);
|
||||
let sub_query = &query[(j * ratio) as usize..][..subdims as usize];
|
||||
let centroid = &self.centroids[i * dims as usize..][(j * ratio) as usize..]
|
||||
[..subdims as usize];
|
||||
runtime_table[j as usize * 256 + i] =
|
||||
F32(-1.0) * inner_product::<O>(subdims, sub_query, centroid);
|
||||
}
|
||||
}
|
||||
runtime_table
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// add up all terms given codes
|
||||
pub fn distance_with_codes(
|
||||
&self,
|
||||
lhs: Borrowed<'_, O>,
|
||||
rhs: u32,
|
||||
delta: &[Scalar<O>],
|
||||
key: usize,
|
||||
coarse_dis: F32,
|
||||
runtime_table: &[F32],
|
||||
) -> F32 {
|
||||
let codes = self.codes(rhs);
|
||||
let width = self.dims.div_ceil(self.ratio);
|
||||
let precomputed_table = &self.precomputed_table[key * width as usize * 256..];
|
||||
match O::DISTANCE_KIND {
|
||||
DistanceKind::Cos => self.distance_with_delta(lhs, rhs, delta),
|
||||
DistanceKind::L2 => {
|
||||
let mut result = coarse_dis;
|
||||
for i in 0..width {
|
||||
result += precomputed_table[i as usize * 256 + codes[i as usize] as usize]
|
||||
+ F32(2.0) * runtime_table[i as usize * 256 + codes[i as usize] as usize];
|
||||
}
|
||||
result
|
||||
}
|
||||
DistanceKind::Dot => {
|
||||
let mut result = coarse_dis;
|
||||
for i in 0..width {
|
||||
result += runtime_table[i as usize * 256 + codes[i as usize] as usize];
|
||||
}
|
||||
result
|
||||
}
|
||||
DistanceKind::Jaccard => {
|
||||
unimplemented!()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn distance_with_delta(&self, lhs: Borrowed<'_, O>, rhs: u32, delta: &[Scalar<O>]) -> F32 {
|
||||
let dims = self.dims;
|
||||
let ratio = self.ratio;
|
||||
let rhs = self.codes(rhs);
|
||||
O::product_quantization_distance_with_delta(dims, ratio, &self.centroids, lhs, rhs, delta)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn squared_norm<O: Op>(dims: u32, vec: &[Scalar<O>]) -> F32 {
|
||||
let mut result = F32::zero();
|
||||
for i in 0..dims as usize {
|
||||
result += F32((vec[i] * vec[i]).to_f32());
|
||||
}
|
||||
result
|
||||
}
|
||||
|
||||
pub fn inner_product<O: Op>(dims: u32, lhs: &[Scalar<O>], rhs: &[Scalar<O>]) -> F32 {
|
||||
let mut result = F32::zero();
|
||||
for i in 0..dims as usize {
|
||||
result += F32((lhs[i] * rhs[i]).to_f32());
|
||||
}
|
||||
result
|
||||
}
|
@ -1,35 +1,38 @@
|
||||
#![feature(trait_alias)]
|
||||
#![allow(clippy::len_without_is_empty)]
|
||||
#![allow(clippy::needless_range_loop)]
|
||||
|
||||
pub mod ivf_naive;
|
||||
pub mod ivf_pq;
|
||||
|
||||
use self::ivf_naive::IvfNaive;
|
||||
use self::ivf_pq::IvfPq;
|
||||
use crate::index::segments::growing::GrowingSegment;
|
||||
use crate::index::segments::sealed::SealedSegment;
|
||||
use crate::prelude::*;
|
||||
use base::index::*;
|
||||
use base::operator::*;
|
||||
use base::search::*;
|
||||
use elkan_k_means::operator::OperatorElkanKMeans;
|
||||
use quantization::operator::OperatorQuantization;
|
||||
use std::cmp::Reverse;
|
||||
use std::collections::BinaryHeap;
|
||||
use std::path::Path;
|
||||
use std::sync::Arc;
|
||||
use storage::operator::OperatorStorage;
|
||||
|
||||
pub enum Ivf<S: G> {
|
||||
Naive(IvfNaive<S>),
|
||||
Pq(IvfPq<S>),
|
||||
pub trait OperatorIvf = Operator + OperatorElkanKMeans + OperatorQuantization + OperatorStorage;
|
||||
|
||||
pub enum Ivf<O: OperatorIvf> {
|
||||
Naive(IvfNaive<O>),
|
||||
Pq(IvfPq<O>),
|
||||
}
|
||||
|
||||
impl<S: G> Ivf<S> {
|
||||
pub fn create(
|
||||
path: &Path,
|
||||
options: IndexOptions,
|
||||
sealed: Vec<Arc<SealedSegment<S>>>,
|
||||
growing: Vec<Arc<GrowingSegment<S>>>,
|
||||
) -> Self {
|
||||
impl<O: OperatorIvf> Ivf<O> {
|
||||
pub fn create<S: Source<O>>(path: &Path, options: IndexOptions, source: &S) -> Self {
|
||||
if matches!(
|
||||
options.indexing.clone().unwrap_ivf().quantization,
|
||||
QuantizationOptions::Product(_)
|
||||
) {
|
||||
Self::Pq(IvfPq::create(path, options, sealed, growing))
|
||||
Self::Pq(IvfPq::create(path, options, source))
|
||||
} else {
|
||||
Self::Naive(IvfNaive::create(path, options, sealed, growing))
|
||||
Self::Naive(IvfNaive::create(path, options, source))
|
||||
}
|
||||
}
|
||||
|
||||
@ -51,7 +54,7 @@ impl<S: G> Ivf<S> {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn vector(&self, i: u32) -> Borrowed<'_, S> {
|
||||
pub fn vector(&self, i: u32) -> Borrowed<'_, O> {
|
||||
match self {
|
||||
Ivf::Naive(x) => x.vector(i),
|
||||
Ivf::Pq(x) => x.vector(i),
|
||||
@ -67,7 +70,7 @@ impl<S: G> Ivf<S> {
|
||||
|
||||
pub fn basic(
|
||||
&self,
|
||||
vector: Borrowed<'_, S>,
|
||||
vector: Borrowed<'_, O>,
|
||||
opts: &SearchOptions,
|
||||
filter: impl Filter,
|
||||
) -> BinaryHeap<Reverse<Element>> {
|
||||
@ -79,7 +82,7 @@ impl<S: G> Ivf<S> {
|
||||
|
||||
pub fn vbase<'a>(
|
||||
&'a self,
|
||||
vector: Borrowed<'a, S>,
|
||||
vector: Borrowed<'a, O>,
|
||||
opts: &'a SearchOptions,
|
||||
filter: impl Filter + 'a,
|
||||
) -> (Vec<Element>, Box<(dyn Iterator<Item = Element> + 'a)>) {
|
@ -6,6 +6,7 @@ edition.workspace = true
|
||||
[dependencies]
|
||||
rand.workspace = true
|
||||
rustix.workspace = true
|
||||
|
||||
detect = { path = "../detect" }
|
||||
|
||||
[lints]
|
||||
|
17
crates/quantization/Cargo.toml
Normal file
17
crates/quantization/Cargo.toml
Normal file
@ -0,0 +1,17 @@
|
||||
[package]
|
||||
name = "quantization"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
|
||||
[dependencies]
|
||||
multiversion.workspace = true
|
||||
num-traits.workspace = true
|
||||
rand.workspace = true
|
||||
serde_json.workspace = true
|
||||
|
||||
base = { path = "../base" }
|
||||
common = { path = "../common" }
|
||||
elkan_k_means = { path = "../elkan_k_means" }
|
||||
|
||||
[lints]
|
||||
workspace = true
|
@ -1,3 +1,6 @@
|
||||
#![feature(avx512_target_feature)]
|
||||
|
||||
pub mod operator;
|
||||
pub mod product;
|
||||
pub mod scalar;
|
||||
pub mod trivial;
|
||||
@ -5,41 +8,26 @@ pub mod trivial;
|
||||
use self::product::ProductQuantization;
|
||||
use self::scalar::ScalarQuantization;
|
||||
use self::trivial::TrivialQuantization;
|
||||
use super::raw::Raw;
|
||||
use crate::prelude::*;
|
||||
use crate::operator::OperatorQuantization;
|
||||
use base::index::*;
|
||||
use base::operator::*;
|
||||
use base::scalar::*;
|
||||
use base::search::*;
|
||||
use std::path::Path;
|
||||
use std::sync::Arc;
|
||||
|
||||
pub trait Quan<S: G> {
|
||||
fn create(
|
||||
path: &Path,
|
||||
options: IndexOptions,
|
||||
quantization_options: QuantizationOptions,
|
||||
raw: &Arc<Raw<S>>,
|
||||
permutation: Vec<u32>,
|
||||
) -> Self;
|
||||
fn open2(
|
||||
path: &Path,
|
||||
options: IndexOptions,
|
||||
quantization_options: QuantizationOptions,
|
||||
raw: &Arc<Raw<S>>,
|
||||
) -> Self;
|
||||
fn distance(&self, lhs: Borrowed<'_, S>, rhs: u32) -> F32;
|
||||
fn distance2(&self, lhs: u32, rhs: u32) -> F32;
|
||||
pub enum Quantization<O: OperatorQuantization, C: Collection<O>> {
|
||||
Trivial(TrivialQuantization<O, C>),
|
||||
Scalar(ScalarQuantization<O, C>),
|
||||
Product(ProductQuantization<O, C>),
|
||||
}
|
||||
|
||||
pub enum Quantization<S: G> {
|
||||
Trivial(TrivialQuantization<S>),
|
||||
Scalar(ScalarQuantization<S>),
|
||||
Product(ProductQuantization<S>),
|
||||
}
|
||||
|
||||
impl<S: G> Quantization<S> {
|
||||
impl<O: OperatorQuantization, C: Collection<O>> Quantization<O, C> {
|
||||
pub fn create(
|
||||
path: &Path,
|
||||
options: IndexOptions,
|
||||
quantization_options: QuantizationOptions,
|
||||
raw: &Arc<Raw<S>>,
|
||||
collection: &Arc<C>,
|
||||
permutation: Vec<u32>, // permutation is the mapping from placements to original ids
|
||||
) -> Self {
|
||||
match quantization_options {
|
||||
@ -47,21 +35,21 @@ impl<S: G> Quantization<S> {
|
||||
path,
|
||||
options,
|
||||
quantization_options,
|
||||
raw,
|
||||
collection,
|
||||
permutation,
|
||||
)),
|
||||
QuantizationOptions::Scalar(_) => Self::Scalar(ScalarQuantization::create(
|
||||
path,
|
||||
options,
|
||||
quantization_options,
|
||||
raw,
|
||||
collection,
|
||||
permutation,
|
||||
)),
|
||||
QuantizationOptions::Product(_) => Self::Product(ProductQuantization::create(
|
||||
path,
|
||||
options,
|
||||
quantization_options,
|
||||
raw,
|
||||
collection,
|
||||
permutation,
|
||||
)),
|
||||
}
|
||||
@ -71,31 +59,31 @@ impl<S: G> Quantization<S> {
|
||||
path: &Path,
|
||||
options: IndexOptions,
|
||||
quantization_options: QuantizationOptions,
|
||||
raw: &Arc<Raw<S>>,
|
||||
collection: &Arc<C>,
|
||||
) -> Self {
|
||||
match quantization_options {
|
||||
QuantizationOptions::Trivial(_) => Self::Trivial(TrivialQuantization::open2(
|
||||
QuantizationOptions::Trivial(_) => Self::Trivial(TrivialQuantization::open(
|
||||
path,
|
||||
options,
|
||||
quantization_options,
|
||||
raw,
|
||||
collection,
|
||||
)),
|
||||
QuantizationOptions::Scalar(_) => Self::Scalar(ScalarQuantization::open2(
|
||||
QuantizationOptions::Scalar(_) => Self::Scalar(ScalarQuantization::open(
|
||||
path,
|
||||
options,
|
||||
quantization_options,
|
||||
raw,
|
||||
collection,
|
||||
)),
|
||||
QuantizationOptions::Product(_) => Self::Product(ProductQuantization::open2(
|
||||
QuantizationOptions::Product(_) => Self::Product(ProductQuantization::open(
|
||||
path,
|
||||
options,
|
||||
quantization_options,
|
||||
raw,
|
||||
collection,
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn distance(&self, lhs: Borrowed<'_, S>, rhs: u32) -> F32 {
|
||||
pub fn distance(&self, lhs: Borrowed<'_, O>, rhs: u32) -> F32 {
|
||||
use Quantization::*;
|
||||
match self {
|
||||
Trivial(x) => x.distance(lhs, rhs),
|
22
crates/quantization/src/operator.rs
Normal file
22
crates/quantization/src/operator.rs
Normal file
@ -0,0 +1,22 @@
|
||||
use crate::product::operator::OperatorProductQuantization;
|
||||
use crate::scalar::operator::OperatorScalarQuantization;
|
||||
use base::operator::*;
|
||||
|
||||
pub trait OperatorQuantization: OperatorScalarQuantization + OperatorProductQuantization {}
|
||||
|
||||
impl OperatorQuantization for BVecf32Cos {}
|
||||
impl OperatorQuantization for BVecf32Dot {}
|
||||
impl OperatorQuantization for BVecf32Jaccard {}
|
||||
impl OperatorQuantization for BVecf32L2 {}
|
||||
impl OperatorQuantization for SVecf32Cos {}
|
||||
impl OperatorQuantization for SVecf32Dot {}
|
||||
impl OperatorQuantization for SVecf32L2 {}
|
||||
impl OperatorQuantization for Vecf16Cos {}
|
||||
impl OperatorQuantization for Vecf16Dot {}
|
||||
impl OperatorQuantization for Vecf16L2 {}
|
||||
impl OperatorQuantization for Vecf32Cos {}
|
||||
impl OperatorQuantization for Vecf32Dot {}
|
||||
impl OperatorQuantization for Vecf32L2 {}
|
||||
impl OperatorQuantization for Veci8Cos {}
|
||||
impl OperatorQuantization for Veci8Dot {}
|
||||
impl OperatorQuantization for Veci8L2 {}
|
159
crates/quantization/src/product/mod.rs
Normal file
159
crates/quantization/src/product/mod.rs
Normal file
@ -0,0 +1,159 @@
|
||||
pub mod operator;
|
||||
|
||||
use self::operator::OperatorProductQuantization;
|
||||
use base::index::*;
|
||||
use base::operator::*;
|
||||
use base::scalar::*;
|
||||
use base::search::*;
|
||||
use base::vector::*;
|
||||
use common::dir_ops::sync_dir;
|
||||
use common::mmap_array::MmapArray;
|
||||
use common::vec2::Vec2;
|
||||
use elkan_k_means::ElkanKMeans;
|
||||
use num_traits::{Float, Zero};
|
||||
use rand::seq::index::sample;
|
||||
use rand::thread_rng;
|
||||
use std::marker::PhantomData;
|
||||
use std::path::Path;
|
||||
use std::sync::Arc;
|
||||
|
||||
pub struct ProductQuantization<O: OperatorProductQuantization, C: Collection<O>> {
|
||||
dims: u32,
|
||||
ratio: u32,
|
||||
centroids: Vec<Scalar<O>>,
|
||||
codes: MmapArray<u8>,
|
||||
_maker: PhantomData<fn(C) -> C>,
|
||||
}
|
||||
|
||||
unsafe impl<O: OperatorProductQuantization, C: Collection<O>> Send for ProductQuantization<O, C> {}
|
||||
unsafe impl<O: OperatorProductQuantization, C: Collection<O>> Sync for ProductQuantization<O, C> {}
|
||||
|
||||
impl<O: OperatorProductQuantization, C: Collection<O>> ProductQuantization<O, C> {
|
||||
fn codes(&self, i: u32) -> &[u8] {
|
||||
let width = self.dims.div_ceil(self.ratio);
|
||||
let s = i as usize * width as usize;
|
||||
let e = (i + 1) as usize * width as usize;
|
||||
&self.codes[s..e]
|
||||
}
|
||||
}
|
||||
|
||||
impl<O: OperatorProductQuantization, C: Collection<O>> ProductQuantization<O, C> {
|
||||
pub fn create(
|
||||
path: &Path,
|
||||
options: IndexOptions,
|
||||
quantization_options: QuantizationOptions,
|
||||
collection: &Arc<C>,
|
||||
permutation: Vec<u32>, // permutation is the mapping from placements to original ids
|
||||
) -> Self {
|
||||
std::fs::create_dir(path).unwrap();
|
||||
let QuantizationOptions::Product(quantization_options) = quantization_options else {
|
||||
unreachable!()
|
||||
};
|
||||
let dims = options.vector.dims;
|
||||
let ratio = quantization_options.ratio as u32;
|
||||
let n = collection.len();
|
||||
let m = std::cmp::min(n, quantization_options.sample);
|
||||
let samples = {
|
||||
let f = sample(&mut thread_rng(), n as usize, m as usize).into_vec();
|
||||
let mut samples = Vec2::<Scalar<O>>::new(dims, m as usize);
|
||||
for i in 0..m {
|
||||
samples[i as usize]
|
||||
.copy_from_slice(collection.vector(f[i as usize] as u32).to_vec().as_ref());
|
||||
}
|
||||
samples
|
||||
};
|
||||
let width = dims.div_ceil(ratio);
|
||||
let mut centroids = vec![Scalar::<O>::zero(); 256 * dims as usize];
|
||||
for i in 0..width {
|
||||
let subdims = std::cmp::min(ratio, dims - ratio * i);
|
||||
let mut subsamples = Vec2::<Scalar<O>>::new(subdims, m as usize);
|
||||
for j in 0..m {
|
||||
let src = &samples[j as usize][(i * ratio) as usize..][..subdims as usize];
|
||||
subsamples[j as usize].copy_from_slice(src);
|
||||
}
|
||||
let mut k_means = ElkanKMeans::<O::ProductQuantizationL2>::new(256, subsamples);
|
||||
for _ in 0..25 {
|
||||
if k_means.iterate() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
let centroid = k_means.finish();
|
||||
for j in 0u8..=255 {
|
||||
centroids[j as usize * dims as usize..][(i * ratio) as usize..][..subdims as usize]
|
||||
.copy_from_slice(¢roid[j as usize]);
|
||||
}
|
||||
}
|
||||
let codes_iter = (0..n).flat_map(|i| {
|
||||
let vector = collection.vector(permutation[i as usize]).to_vec();
|
||||
let width = dims.div_ceil(ratio);
|
||||
let mut result = Vec::with_capacity(width as usize);
|
||||
for i in 0..width {
|
||||
let subdims = std::cmp::min(ratio, dims - ratio * i);
|
||||
let mut minimal = F32::infinity();
|
||||
let mut target = 0u8;
|
||||
let left = &vector[(i * ratio) as usize..][..subdims as usize];
|
||||
for j in 0u8..=255 {
|
||||
let right = ¢roids[j as usize * dims as usize..][(i * ratio) as usize..]
|
||||
[..subdims as usize];
|
||||
let dis = O::product_quantization_l2_distance(left, right);
|
||||
if dis < minimal {
|
||||
minimal = dis;
|
||||
target = j;
|
||||
}
|
||||
}
|
||||
result.push(target);
|
||||
}
|
||||
result.into_iter()
|
||||
});
|
||||
sync_dir(path);
|
||||
std::fs::write(
|
||||
path.join("centroids"),
|
||||
serde_json::to_string(¢roids).unwrap(),
|
||||
)
|
||||
.unwrap();
|
||||
let codes = MmapArray::create(&path.join("codes"), codes_iter);
|
||||
Self {
|
||||
dims,
|
||||
ratio,
|
||||
centroids,
|
||||
codes,
|
||||
_maker: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn open(
|
||||
path: &Path,
|
||||
options: IndexOptions,
|
||||
quantization_options: QuantizationOptions,
|
||||
_: &Arc<C>,
|
||||
) -> Self {
|
||||
let QuantizationOptions::Product(quantization_options) = quantization_options else {
|
||||
unreachable!()
|
||||
};
|
||||
let centroids =
|
||||
serde_json::from_slice(&std::fs::read(path.join("centroids")).unwrap()).unwrap();
|
||||
let codes = MmapArray::open(&path.join("codes"));
|
||||
Self {
|
||||
dims: options.vector.dims,
|
||||
ratio: quantization_options.ratio as _,
|
||||
centroids,
|
||||
codes,
|
||||
_maker: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn distance(&self, lhs: Borrowed<'_, O>, rhs: u32) -> F32 {
|
||||
let dims = self.dims;
|
||||
let ratio = self.ratio;
|
||||
let rhs = self.codes(rhs);
|
||||
O::product_quantization_distance(dims, ratio, &self.centroids, lhs, rhs)
|
||||
}
|
||||
|
||||
pub fn distance2(&self, lhs: u32, rhs: u32) -> F32 {
|
||||
let dims = self.dims;
|
||||
let ratio = self.ratio;
|
||||
let lhs = self.codes(lhs);
|
||||
let rhs = self.codes(rhs);
|
||||
O::product_quantization_distance2(dims, ratio, &self.centroids, lhs, rhs)
|
||||
}
|
||||
}
|
1041
crates/quantization/src/product/operator.rs
Normal file
1041
crates/quantization/src/product/operator.rs
Normal file
File diff suppressed because it is too large
Load Diff
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user