1
0
mirror of https://github.com/tensorchord/pgvecto.rs.git synced 2025-08-01 06:46:52 +03:00

C code tests & avx512f f16 implement (#183)

* test: add tests for c code

Signed-off-by: usamoi <usamoi@outlook.com>

* fix: relax EPSILON for tests

Signed-off-by: usamoi <usamoi@outlook.com>

---------

Signed-off-by: usamoi <usamoi@outlook.com>
This commit is contained in:
Usamoi
2023-12-15 16:00:08 +08:00
committed by GitHub
parent 2869fbd44c
commit c50912e87d
13 changed files with 276 additions and 31 deletions

View File

@ -101,7 +101,7 @@ jobs:
cargo build --no-default-features --features "pg${{ matrix.version }} pg_test" --target aarch64-unknown-linux-gnu cargo build --no-default-features --features "pg${{ matrix.version }} pg_test" --target aarch64-unknown-linux-gnu
- name: Test - name: Test
run: | run: |
cargo test --all --no-fail-fast --no-default-features --features "pg${{ matrix.version }} pg_test" --target x86_64-unknown-linux-gnu cargo test --all --no-fail-fast --no-default-features --features "pg${{ matrix.version }} pg_test" --target x86_64-unknown-linux-gnu -- --nocapture
- name: Install release - name: Install release
run: ./scripts/ci_install.sh run: ./scripts/ci_install.sh
- name: Sqllogictest - name: Sqllogictest

25
Cargo.lock generated
View File

@ -486,7 +486,9 @@ name = "c"
version = "0.0.0" version = "0.0.0"
dependencies = [ dependencies = [
"cc", "cc",
"detect",
"half 2.3.1", "half 2.3.1",
"rand",
] ]
[[package]] [[package]]
@ -837,6 +839,14 @@ dependencies = [
"serde", "serde",
] ]
[[package]]
name = "detect"
version = "0.0.0"
dependencies = [
"ctor",
"std_detect",
]
[[package]] [[package]]
name = "diff" name = "diff"
version = "0.1.13" version = "0.1.13"
@ -1263,6 +1273,8 @@ dependencies = [
"cfg-if", "cfg-if",
"crunchy", "crunchy",
"num-traits", "num-traits",
"rand",
"rand_distr",
"serde", "serde",
] ]
@ -2436,6 +2448,16 @@ dependencies = [
"getrandom", "getrandom",
] ]
[[package]]
name = "rand_distr"
version = "0.4.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "32cb0b9bc82b0a0876c2dd994a7e7a2683d3e7390ca40e6886785ef0c7e3ee31"
dependencies = [
"num-traits",
"rand",
]
[[package]] [[package]]
name = "rand_xorshift" name = "rand_xorshift"
version = "0.3.0" version = "0.3.0"
@ -2808,8 +2830,8 @@ dependencies = [
"c", "c",
"crc32fast", "crc32fast",
"crossbeam", "crossbeam",
"ctor",
"dashmap", "dashmap",
"detect",
"half 2.3.1", "half 2.3.1",
"libc", "libc",
"log", "log",
@ -2824,7 +2846,6 @@ dependencies = [
"serde", "serde",
"serde_json", "serde_json",
"serde_with", "serde_with",
"std_detect",
"tempfile", "tempfile",
"thiserror", "thiserror",
"ulock-sys", "ulock-sys",

View File

@ -3,8 +3,10 @@ name = "c"
version.workspace = true version.workspace = true
edition.workspace = true edition.workspace = true
[dependencies] [dev-dependencies]
half = { version = "~2.3", features = ["use-intrinsics"] } half = { version = "~2.3", features = ["use-intrinsics", "rand_distr"] }
detect = { path = "../detect" }
rand = "0.8.5"
[build-dependencies] [build-dependencies]
cc = "1.0" cc = "1.0"

View File

@ -29,8 +29,12 @@ v_f16_cosine_avx512fp16(_Float16 *a, _Float16 *b, size_t n) {
xx = _mm512_fmadd_ph(x, x, xx); xx = _mm512_fmadd_ph(x, x, xx);
yy = _mm512_fmadd_ph(y, y, yy); yy = _mm512_fmadd_ph(y, y, yy);
} }
return (float)(_mm512_reduce_add_ph(xy) / {
sqrt(_mm512_reduce_add_ph(xx) * _mm512_reduce_add_ph(yy))); float rxy = _mm512_reduce_add_ph(xy);
float rxx = _mm512_reduce_add_ph(xx);
float ryy = _mm512_reduce_add_ph(yy);
return rxy / sqrt(rxx * ryy);
}
} }
__attribute__((target("arch=x86-64-v4,avx512fp16"))) extern float __attribute__((target("arch=x86-64-v4,avx512fp16"))) extern float
@ -74,6 +78,76 @@ v_f16_sl2_avx512fp16(_Float16 *a, _Float16 *b, size_t n) {
return (float)_mm512_reduce_add_ph(dd); return (float)_mm512_reduce_add_ph(dd);
} }
__attribute__((target("arch=x86-64-v4"))) extern float
v_f16_cosine_v4(_Float16 *a, _Float16 *b, size_t n) {
__m512 xy = _mm512_set1_ps(0);
__m512 xx = _mm512_set1_ps(0);
__m512 yy = _mm512_set1_ps(0);
while (n >= 16) {
__m512 x = _mm512_cvtph_ps(_mm256_loadu_epi16(a));
__m512 y = _mm512_cvtph_ps(_mm256_loadu_epi16(b));
a += 16, b += 16, n -= 16;
xy = _mm512_fmadd_ps(x, y, xy);
xx = _mm512_fmadd_ps(x, x, xx);
yy = _mm512_fmadd_ps(y, y, yy);
}
if (n > 0) {
__mmask16 mask = _bzhi_u32(0xFFFF, n);
__m512 x = _mm512_cvtph_ps(_mm256_maskz_loadu_epi16(mask, a));
__m512 y = _mm512_cvtph_ps(_mm256_maskz_loadu_epi16(mask, b));
xy = _mm512_fmadd_ps(x, y, xy);
xx = _mm512_fmadd_ps(x, x, xx);
yy = _mm512_fmadd_ps(y, y, yy);
}
{
float rxy = _mm512_reduce_add_ps(xy);
float rxx = _mm512_reduce_add_ps(xx);
float ryy = _mm512_reduce_add_ps(yy);
return rxy / sqrt(rxx * ryy);
}
}
__attribute__((target("arch=x86-64-v4"))) extern float
v_f16_dot_v4(_Float16 *a, _Float16 *b, size_t n) {
__m512 xy = _mm512_set1_ps(0);
while (n >= 16) {
__m512 x = _mm512_cvtph_ps(_mm256_loadu_epi16(a));
__m512 y = _mm512_cvtph_ps(_mm256_loadu_epi16(b));
a += 16, b += 16, n -= 16;
xy = _mm512_fmadd_ps(x, y, xy);
}
if (n > 0) {
__mmask16 mask = _bzhi_u32(0xFFFF, n);
__m512 x = _mm512_cvtph_ps(_mm256_maskz_loadu_epi16(mask, a));
__m512 y = _mm512_cvtph_ps(_mm256_maskz_loadu_epi16(mask, b));
xy = _mm512_fmadd_ps(x, y, xy);
}
return _mm512_reduce_add_ps(xy);
}
__attribute__((target("arch=x86-64-v4"))) extern float
v_f16_sl2_v4(_Float16 *a, _Float16 *b, size_t n) {
__m512 dd = _mm512_set1_ps(0);
while (n >= 16) {
__m512 x = _mm512_cvtph_ps(_mm256_loadu_epi16(a));
__m512 y = _mm512_cvtph_ps(_mm256_loadu_epi16(b));
a += 16, b += 16, n -= 16;
__m512 d = _mm512_sub_ps(x, y);
dd = _mm512_fmadd_ps(d, d, dd);
}
if (n > 0) {
__mmask16 mask = _bzhi_u32(0xFFFF, n);
__m512 x = _mm512_cvtph_ps(_mm256_maskz_loadu_epi16(mask, a));
__m512 y = _mm512_cvtph_ps(_mm256_maskz_loadu_epi16(mask, b));
__m512 d = _mm512_sub_ps(x, y);
dd = _mm512_fmadd_ps(d, d, dd);
}
return _mm512_reduce_add_ps(dd);
}
__attribute__((target("arch=x86-64-v3"))) extern float __attribute__((target("arch=x86-64-v3"))) extern float
v_f16_cosine_v3(_Float16 *a, _Float16 *b, size_t n) { v_f16_cosine_v3(_Float16 *a, _Float16 *b, size_t n) {
float xy = 0; float xy = 0;

View File

@ -6,6 +6,9 @@
extern float v_f16_cosine_avx512fp16(_Float16 *, _Float16 *, size_t n); extern float v_f16_cosine_avx512fp16(_Float16 *, _Float16 *, size_t n);
extern float v_f16_dot_avx512fp16(_Float16 *, _Float16 *, size_t n); extern float v_f16_dot_avx512fp16(_Float16 *, _Float16 *, size_t n);
extern float v_f16_sl2_avx512fp16(_Float16 *, _Float16 *, size_t n); extern float v_f16_sl2_avx512fp16(_Float16 *, _Float16 *, size_t n);
extern float v_f16_cosine_v4(_Float16 *, _Float16 *, size_t n);
extern float v_f16_dot_v4(_Float16 *, _Float16 *, size_t n);
extern float v_f16_sl2_v4(_Float16 *, _Float16 *, size_t n);
extern float v_f16_cosine_v3(_Float16 *, _Float16 *, size_t n); extern float v_f16_cosine_v3(_Float16 *, _Float16 *, size_t n);
extern float v_f16_dot_v3(_Float16 *, _Float16 *, size_t n); extern float v_f16_dot_v3(_Float16 *, _Float16 *, size_t n);
extern float v_f16_sl2_v3(_Float16 *, _Float16 *, size_t n); extern float v_f16_sl2_v3(_Float16 *, _Float16 *, size_t n);

View File

@ -4,21 +4,10 @@ extern "C" {
pub fn v_f16_cosine_avx512fp16(a: *const u16, b: *const u16, n: usize) -> f32; pub fn v_f16_cosine_avx512fp16(a: *const u16, b: *const u16, n: usize) -> f32;
pub fn v_f16_dot_avx512fp16(a: *const u16, b: *const u16, n: usize) -> f32; pub fn v_f16_dot_avx512fp16(a: *const u16, b: *const u16, n: usize) -> f32;
pub fn v_f16_sl2_avx512fp16(a: *const u16, b: *const u16, n: usize) -> f32; pub fn v_f16_sl2_avx512fp16(a: *const u16, b: *const u16, n: usize) -> f32;
pub fn v_f16_cosine_v4(a: *const u16, b: *const u16, n: usize) -> f32;
pub fn v_f16_dot_v4(a: *const u16, b: *const u16, n: usize) -> f32;
pub fn v_f16_sl2_v4(a: *const u16, b: *const u16, n: usize) -> f32;
pub fn v_f16_cosine_v3(a: *const u16, b: *const u16, n: usize) -> f32; pub fn v_f16_cosine_v3(a: *const u16, b: *const u16, n: usize) -> f32;
pub fn v_f16_dot_v3(a: *const u16, b: *const u16, n: usize) -> f32; pub fn v_f16_dot_v3(a: *const u16, b: *const u16, n: usize) -> f32;
pub fn v_f16_sl2_v3(a: *const u16, b: *const u16, n: usize) -> f32; pub fn v_f16_sl2_v3(a: *const u16, b: *const u16, n: usize) -> f32;
} }
// `compiler_builtin` defines `__extendhfsf2` with integer calling convention.
// However C compilers links `__extendhfsf2` with floating calling convention.
// The code should be removed once Rust offically supports `f16`.
#[cfg(target_arch = "x86_64")]
#[no_mangle]
#[linkage = "external"]
extern "C" fn __extendhfsf2(f: f64) -> f32 {
unsafe {
let f: half::f16 = std::mem::transmute_copy(&f);
f.to_f32()
}
}

126
crates/c/tests/x86_64.rs Normal file
View File

@ -0,0 +1,126 @@
#![cfg(target_arch = "x86_64")]
#[test]
fn test_v_f16_cosine() {
const EPSILON: f32 = f16::EPSILON.to_f32_const();
use half::f16;
unsafe fn v_f16_cosine(a: *const u16, b: *const u16, n: usize) -> f32 {
let mut xy = 0.0f32;
let mut xx = 0.0f32;
let mut yy = 0.0f32;
for i in 0..n {
let x = a.add(i).cast::<f16>().read().to_f32();
let y = b.add(i).cast::<f16>().read().to_f32();
xy += x * y;
xx += x * x;
yy += y * y;
}
xy / (xx * yy).sqrt()
}
let n = 4000;
let a = (0..n).map(|_| rand::random::<f16>()).collect::<Vec<_>>();
let b = (0..n).map(|_| rand::random::<f16>()).collect::<Vec<_>>();
let r = unsafe { v_f16_cosine(a.as_ptr().cast(), b.as_ptr().cast(), n) };
if detect::x86_64::detect_avx512fp16() {
println!("detected avx512fp16");
let c = unsafe { c::v_f16_cosine_avx512fp16(a.as_ptr().cast(), b.as_ptr().cast(), n) };
assert!((c - r).abs() < EPSILON, "c = {c}, r = {r}.");
} else {
println!("detected no avx512fp16, skipped");
}
if detect::x86_64::detect_v4() {
println!("detected v4");
let c = unsafe { c::v_f16_cosine_v4(a.as_ptr().cast(), b.as_ptr().cast(), n) };
assert!((c - r).abs() < EPSILON, "c = {c}, r = {r}.");
} else {
println!("detected no v4, skipped");
}
if detect::x86_64::detect_v3() {
println!("detected v3");
let c = unsafe { c::v_f16_cosine_v3(a.as_ptr().cast(), b.as_ptr().cast(), n) };
assert!((c - r).abs() < EPSILON, "c = {c}, r = {r}.");
} else {
println!("detected no v3, skipped");
}
}
#[test]
fn test_v_f16_dot() {
const EPSILON: f32 = 1.0f32;
use half::f16;
unsafe fn v_f16_dot(a: *const u16, b: *const u16, n: usize) -> f32 {
let mut xy = 0.0f32;
for i in 0..n {
let x = a.add(i).cast::<f16>().read().to_f32();
let y = b.add(i).cast::<f16>().read().to_f32();
xy += x * y;
}
xy
}
let n = 4000;
let a = (0..n).map(|_| rand::random::<f16>()).collect::<Vec<_>>();
let b = (0..n).map(|_| rand::random::<f16>()).collect::<Vec<_>>();
let r = unsafe { v_f16_dot(a.as_ptr().cast(), b.as_ptr().cast(), n) };
if detect::x86_64::detect_avx512fp16() {
println!("detected avx512fp16");
let c = unsafe { c::v_f16_dot_avx512fp16(a.as_ptr().cast(), b.as_ptr().cast(), n) };
assert!((c - r).abs() < EPSILON, "c = {c}, r = {r}.");
} else {
println!("detected no avx512fp16, skipped");
}
if detect::x86_64::detect_v4() {
println!("detected v4");
let c = unsafe { c::v_f16_dot_v4(a.as_ptr().cast(), b.as_ptr().cast(), n) };
assert!((c - r).abs() < EPSILON, "c = {c}, r = {r}.");
} else {
println!("detected no v4, skipped");
}
if detect::x86_64::detect_v3() {
println!("detected v3");
let c = unsafe { c::v_f16_dot_v3(a.as_ptr().cast(), b.as_ptr().cast(), n) };
assert!((c - r).abs() < EPSILON, "c = {c}, r = {r}.");
} else {
println!("detected no v3, skipped");
}
}
#[test]
fn test_v_f16_sl2() {
const EPSILON: f32 = 1.0f32;
use half::f16;
unsafe fn v_f16_sl2(a: *const u16, b: *const u16, n: usize) -> f32 {
let mut dd = 0.0f32;
for i in 0..n {
let x = a.add(i).cast::<f16>().read().to_f32();
let y = b.add(i).cast::<f16>().read().to_f32();
let d = x - y;
dd += d * d;
}
dd
}
let n = 4000;
let a = (0..n).map(|_| rand::random::<f16>()).collect::<Vec<_>>();
let b = (0..n).map(|_| rand::random::<f16>()).collect::<Vec<_>>();
let r = unsafe { v_f16_sl2(a.as_ptr().cast(), b.as_ptr().cast(), n) };
if detect::x86_64::detect_avx512fp16() {
println!("detected avx512fp16");
let c = unsafe { c::v_f16_sl2_avx512fp16(a.as_ptr().cast(), b.as_ptr().cast(), n) };
assert!((c - r).abs() < EPSILON, "c = {c}, r = {r}.");
} else {
println!("detected no avx512fp16, skipped");
}
if detect::x86_64::detect_v4() {
println!("detected v4");
let c = unsafe { c::v_f16_sl2_v4(a.as_ptr().cast(), b.as_ptr().cast(), n) };
assert!((c - r).abs() < EPSILON, "c = {c}, r = {r}.");
} else {
println!("detected no v4, skipped");
}
if detect::x86_64::detect_v3() {
println!("detected v3");
let c = unsafe { c::v_f16_sl2_v3(a.as_ptr().cast(), b.as_ptr().cast(), n) };
assert!((c - r).abs() < EPSILON, "c = {c}, r = {r}.");
} else {
println!("detected no v3, skipped");
}
}

8
crates/detect/Cargo.toml Normal file
View File

@ -0,0 +1,8 @@
[package]
name = "detect"
version.workspace = true
edition.workspace = true
[dependencies]
std_detect = { git = "https://github.com/tensorchord/stdarch.git", branch = "avx512fp16" }
ctor = "0.2.6"

View File

@ -33,7 +33,7 @@ fn ctor_v4() {
ATOMIC_V4.store(test_v4(), Ordering::Relaxed); ATOMIC_V4.store(test_v4(), Ordering::Relaxed);
} }
pub fn _detect_v4() -> bool { pub fn detect_v4() -> bool {
ATOMIC_V4.load(Ordering::Relaxed) ATOMIC_V4.load(Ordering::Relaxed)
} }

View File

@ -16,7 +16,7 @@ bincode.workspace = true
half.workspace = true half.workspace = true
num-traits.workspace = true num-traits.workspace = true
c = { path = "../c" } c = { path = "../c" }
std_detect = { git = "https://github.com/tensorchord/stdarch.git", branch = "avx512fp16" } detect = { path = "../detect" }
rand = "0.8.5" rand = "0.8.5"
crc32fast = "1.3.2" crc32fast = "1.3.2"
crossbeam = "0.8.2" crossbeam = "0.8.2"
@ -32,7 +32,6 @@ arc-swap = "1.6.0"
bytemuck = { version = "1.14.0", features = ["extern_crate_alloc"] } bytemuck = { version = "1.14.0", features = ["extern_crate_alloc"] }
serde_with = "3.4.0" serde_with = "3.4.0"
multiversion = "0.7.3" multiversion = "0.7.3"
ctor = "0.2.6"
[target.'cfg(target_os = "macos")'.dependencies] [target.'cfg(target_os = "macos")'.dependencies]
ulock-sys = "0.1.0" ulock-sys = "0.1.0"

View File

@ -22,7 +22,7 @@ pub fn cosine(lhs: &[F16], rhs: &[F16]) -> F32 {
xy / (x2 * y2).sqrt() xy / (x2 * y2).sqrt()
} }
#[cfg(target_arch = "x86_64")] #[cfg(target_arch = "x86_64")]
if crate::utils::detect::x86_64::detect_avx512fp16() { if detect::x86_64::detect_avx512fp16() {
assert!(lhs.len() == rhs.len()); assert!(lhs.len() == rhs.len());
let n = lhs.len(); let n = lhs.len();
unsafe { unsafe {
@ -30,7 +30,15 @@ pub fn cosine(lhs: &[F16], rhs: &[F16]) -> F32 {
} }
} }
#[cfg(target_arch = "x86_64")] #[cfg(target_arch = "x86_64")]
if crate::utils::detect::x86_64::detect_v3() { if detect::x86_64::detect_v4() {
assert!(lhs.len() == rhs.len());
let n = lhs.len();
unsafe {
return c::v_f16_cosine_v4(lhs.as_ptr().cast(), rhs.as_ptr().cast(), n).into();
}
}
#[cfg(target_arch = "x86_64")]
if detect::x86_64::detect_v3() {
assert!(lhs.len() == rhs.len()); assert!(lhs.len() == rhs.len());
let n = lhs.len(); let n = lhs.len();
unsafe { unsafe {
@ -58,7 +66,7 @@ pub fn dot(lhs: &[F16], rhs: &[F16]) -> F32 {
xy xy
} }
#[cfg(target_arch = "x86_64")] #[cfg(target_arch = "x86_64")]
if crate::utils::detect::x86_64::detect_avx512fp16() { if detect::x86_64::detect_avx512fp16() {
assert!(lhs.len() == rhs.len()); assert!(lhs.len() == rhs.len());
let n = lhs.len(); let n = lhs.len();
unsafe { unsafe {
@ -66,7 +74,15 @@ pub fn dot(lhs: &[F16], rhs: &[F16]) -> F32 {
} }
} }
#[cfg(target_arch = "x86_64")] #[cfg(target_arch = "x86_64")]
if crate::utils::detect::x86_64::detect_v3() { if detect::x86_64::detect_v4() {
assert!(lhs.len() == rhs.len());
let n = lhs.len();
unsafe {
return c::v_f16_dot_v4(lhs.as_ptr().cast(), rhs.as_ptr().cast(), n).into();
}
}
#[cfg(target_arch = "x86_64")]
if detect::x86_64::detect_v3() {
assert!(lhs.len() == rhs.len()); assert!(lhs.len() == rhs.len());
let n = lhs.len(); let n = lhs.len();
unsafe { unsafe {
@ -95,7 +111,7 @@ pub fn sl2(lhs: &[F16], rhs: &[F16]) -> F32 {
d2 d2
} }
#[cfg(target_arch = "x86_64")] #[cfg(target_arch = "x86_64")]
if crate::utils::detect::x86_64::detect_avx512fp16() { if detect::x86_64::detect_avx512fp16() {
assert!(lhs.len() == rhs.len()); assert!(lhs.len() == rhs.len());
let n = lhs.len(); let n = lhs.len();
unsafe { unsafe {
@ -103,7 +119,15 @@ pub fn sl2(lhs: &[F16], rhs: &[F16]) -> F32 {
} }
} }
#[cfg(target_arch = "x86_64")] #[cfg(target_arch = "x86_64")]
if crate::utils::detect::x86_64::detect_v3() { if detect::x86_64::detect_v4() {
assert!(lhs.len() == rhs.len());
let n = lhs.len();
unsafe {
return c::v_f16_sl2_v4(lhs.as_ptr().cast(), rhs.as_ptr().cast(), n).into();
}
}
#[cfg(target_arch = "x86_64")]
if detect::x86_64::detect_v3() {
assert!(lhs.len() == rhs.len()); assert!(lhs.len() == rhs.len());
let n = lhs.len(); let n = lhs.len();
unsafe { unsafe {

View File

@ -1,6 +1,5 @@
pub mod cells; pub mod cells;
pub mod clean; pub mod clean;
pub mod detect;
pub mod dir_ops; pub mod dir_ops;
pub mod file_atomic; pub mod file_atomic;
pub mod file_wal; pub mod file_wal;