From c50912e87d69b4c8c5835830e18f29d3a349c862 Mon Sep 17 00:00:00 2001 From: Usamoi Date: Fri, 15 Dec 2023 16:00:08 +0800 Subject: [PATCH] C code tests & avx512f f16 implement (#183) * test: add tests for c code Signed-off-by: usamoi * fix: relax EPSILON for tests Signed-off-by: usamoi --------- Signed-off-by: usamoi --- .github/workflows/check.yml | 2 +- Cargo.lock | 25 +++- crates/c/Cargo.toml | 6 +- crates/c/src/c.c | 78 ++++++++++- crates/c/src/c.h | 3 + crates/c/src/c.rs | 17 +-- crates/c/tests/x86_64.rs | 126 ++++++++++++++++++ crates/detect/Cargo.toml | 8 ++ .../src/utils/detect.rs => detect/src/lib.rs} | 0 .../src/utils/detect => detect/src}/x86_64.rs | 2 +- crates/service/Cargo.toml | 3 +- crates/service/src/prelude/global/f16.rs | 36 ++++- crates/service/src/utils/mod.rs | 1 - 13 files changed, 276 insertions(+), 31 deletions(-) create mode 100644 crates/c/tests/x86_64.rs create mode 100644 crates/detect/Cargo.toml rename crates/{service/src/utils/detect.rs => detect/src/lib.rs} (100%) rename crates/{service/src/utils/detect => detect/src}/x86_64.rs (98%) diff --git a/.github/workflows/check.yml b/.github/workflows/check.yml index 2e7ed88..76ccaa2 100644 --- a/.github/workflows/check.yml +++ b/.github/workflows/check.yml @@ -101,7 +101,7 @@ jobs: cargo build --no-default-features --features "pg${{ matrix.version }} pg_test" --target aarch64-unknown-linux-gnu - name: Test run: | - cargo test --all --no-fail-fast --no-default-features --features "pg${{ matrix.version }} pg_test" --target x86_64-unknown-linux-gnu + cargo test --all --no-fail-fast --no-default-features --features "pg${{ matrix.version }} pg_test" --target x86_64-unknown-linux-gnu -- --nocapture - name: Install release run: ./scripts/ci_install.sh - name: Sqllogictest diff --git a/Cargo.lock b/Cargo.lock index 6a57f89..41caf30 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -486,7 +486,9 @@ name = "c" version = "0.0.0" dependencies = [ "cc", + "detect", "half 2.3.1", + "rand", ] [[package]] @@ -837,6 +839,14 @@ dependencies = [ "serde", ] +[[package]] +name = "detect" +version = "0.0.0" +dependencies = [ + "ctor", + "std_detect", +] + [[package]] name = "diff" version = "0.1.13" @@ -1263,6 +1273,8 @@ dependencies = [ "cfg-if", "crunchy", "num-traits", + "rand", + "rand_distr", "serde", ] @@ -2436,6 +2448,16 @@ dependencies = [ "getrandom", ] +[[package]] +name = "rand_distr" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32cb0b9bc82b0a0876c2dd994a7e7a2683d3e7390ca40e6886785ef0c7e3ee31" +dependencies = [ + "num-traits", + "rand", +] + [[package]] name = "rand_xorshift" version = "0.3.0" @@ -2808,8 +2830,8 @@ dependencies = [ "c", "crc32fast", "crossbeam", - "ctor", "dashmap", + "detect", "half 2.3.1", "libc", "log", @@ -2824,7 +2846,6 @@ dependencies = [ "serde", "serde_json", "serde_with", - "std_detect", "tempfile", "thiserror", "ulock-sys", diff --git a/crates/c/Cargo.toml b/crates/c/Cargo.toml index 5dc084e..f0f0274 100644 --- a/crates/c/Cargo.toml +++ b/crates/c/Cargo.toml @@ -3,8 +3,10 @@ name = "c" version.workspace = true edition.workspace = true -[dependencies] -half = { version = "~2.3", features = ["use-intrinsics"] } +[dev-dependencies] +half = { version = "~2.3", features = ["use-intrinsics", "rand_distr"] } +detect = { path = "../detect" } +rand = "0.8.5" [build-dependencies] cc = "1.0" diff --git a/crates/c/src/c.c b/crates/c/src/c.c index e41f282..2a3e434 100644 --- a/crates/c/src/c.c +++ b/crates/c/src/c.c @@ -29,8 +29,12 @@ v_f16_cosine_avx512fp16(_Float16 *a, _Float16 *b, size_t n) { xx = _mm512_fmadd_ph(x, x, xx); yy = _mm512_fmadd_ph(y, y, yy); } - return (float)(_mm512_reduce_add_ph(xy) / - sqrt(_mm512_reduce_add_ph(xx) * _mm512_reduce_add_ph(yy))); + { + float rxy = _mm512_reduce_add_ph(xy); + float rxx = _mm512_reduce_add_ph(xx); + float ryy = _mm512_reduce_add_ph(yy); + return rxy / sqrt(rxx * ryy); + } } __attribute__((target("arch=x86-64-v4,avx512fp16"))) extern float @@ -74,6 +78,76 @@ v_f16_sl2_avx512fp16(_Float16 *a, _Float16 *b, size_t n) { return (float)_mm512_reduce_add_ph(dd); } +__attribute__((target("arch=x86-64-v4"))) extern float +v_f16_cosine_v4(_Float16 *a, _Float16 *b, size_t n) { + __m512 xy = _mm512_set1_ps(0); + __m512 xx = _mm512_set1_ps(0); + __m512 yy = _mm512_set1_ps(0); + + while (n >= 16) { + __m512 x = _mm512_cvtph_ps(_mm256_loadu_epi16(a)); + __m512 y = _mm512_cvtph_ps(_mm256_loadu_epi16(b)); + a += 16, b += 16, n -= 16; + xy = _mm512_fmadd_ps(x, y, xy); + xx = _mm512_fmadd_ps(x, x, xx); + yy = _mm512_fmadd_ps(y, y, yy); + } + if (n > 0) { + __mmask16 mask = _bzhi_u32(0xFFFF, n); + __m512 x = _mm512_cvtph_ps(_mm256_maskz_loadu_epi16(mask, a)); + __m512 y = _mm512_cvtph_ps(_mm256_maskz_loadu_epi16(mask, b)); + xy = _mm512_fmadd_ps(x, y, xy); + xx = _mm512_fmadd_ps(x, x, xx); + yy = _mm512_fmadd_ps(y, y, yy); + } + { + float rxy = _mm512_reduce_add_ps(xy); + float rxx = _mm512_reduce_add_ps(xx); + float ryy = _mm512_reduce_add_ps(yy); + return rxy / sqrt(rxx * ryy); + } +} + +__attribute__((target("arch=x86-64-v4"))) extern float +v_f16_dot_v4(_Float16 *a, _Float16 *b, size_t n) { + __m512 xy = _mm512_set1_ps(0); + + while (n >= 16) { + __m512 x = _mm512_cvtph_ps(_mm256_loadu_epi16(a)); + __m512 y = _mm512_cvtph_ps(_mm256_loadu_epi16(b)); + a += 16, b += 16, n -= 16; + xy = _mm512_fmadd_ps(x, y, xy); + } + if (n > 0) { + __mmask16 mask = _bzhi_u32(0xFFFF, n); + __m512 x = _mm512_cvtph_ps(_mm256_maskz_loadu_epi16(mask, a)); + __m512 y = _mm512_cvtph_ps(_mm256_maskz_loadu_epi16(mask, b)); + xy = _mm512_fmadd_ps(x, y, xy); + } + return _mm512_reduce_add_ps(xy); +} + +__attribute__((target("arch=x86-64-v4"))) extern float +v_f16_sl2_v4(_Float16 *a, _Float16 *b, size_t n) { + __m512 dd = _mm512_set1_ps(0); + + while (n >= 16) { + __m512 x = _mm512_cvtph_ps(_mm256_loadu_epi16(a)); + __m512 y = _mm512_cvtph_ps(_mm256_loadu_epi16(b)); + a += 16, b += 16, n -= 16; + __m512 d = _mm512_sub_ps(x, y); + dd = _mm512_fmadd_ps(d, d, dd); + } + if (n > 0) { + __mmask16 mask = _bzhi_u32(0xFFFF, n); + __m512 x = _mm512_cvtph_ps(_mm256_maskz_loadu_epi16(mask, a)); + __m512 y = _mm512_cvtph_ps(_mm256_maskz_loadu_epi16(mask, b)); + __m512 d = _mm512_sub_ps(x, y); + dd = _mm512_fmadd_ps(d, d, dd); + } + return _mm512_reduce_add_ps(dd); +} + __attribute__((target("arch=x86-64-v3"))) extern float v_f16_cosine_v3(_Float16 *a, _Float16 *b, size_t n) { float xy = 0; diff --git a/crates/c/src/c.h b/crates/c/src/c.h index d50c3d7..26d216f 100644 --- a/crates/c/src/c.h +++ b/crates/c/src/c.h @@ -6,6 +6,9 @@ extern float v_f16_cosine_avx512fp16(_Float16 *, _Float16 *, size_t n); extern float v_f16_dot_avx512fp16(_Float16 *, _Float16 *, size_t n); extern float v_f16_sl2_avx512fp16(_Float16 *, _Float16 *, size_t n); +extern float v_f16_cosine_v4(_Float16 *, _Float16 *, size_t n); +extern float v_f16_dot_v4(_Float16 *, _Float16 *, size_t n); +extern float v_f16_sl2_v4(_Float16 *, _Float16 *, size_t n); extern float v_f16_cosine_v3(_Float16 *, _Float16 *, size_t n); extern float v_f16_dot_v3(_Float16 *, _Float16 *, size_t n); extern float v_f16_sl2_v3(_Float16 *, _Float16 *, size_t n); diff --git a/crates/c/src/c.rs b/crates/c/src/c.rs index a4ac2c2..c2ee4e9 100644 --- a/crates/c/src/c.rs +++ b/crates/c/src/c.rs @@ -4,21 +4,10 @@ extern "C" { pub fn v_f16_cosine_avx512fp16(a: *const u16, b: *const u16, n: usize) -> f32; pub fn v_f16_dot_avx512fp16(a: *const u16, b: *const u16, n: usize) -> f32; pub fn v_f16_sl2_avx512fp16(a: *const u16, b: *const u16, n: usize) -> f32; + pub fn v_f16_cosine_v4(a: *const u16, b: *const u16, n: usize) -> f32; + pub fn v_f16_dot_v4(a: *const u16, b: *const u16, n: usize) -> f32; + pub fn v_f16_sl2_v4(a: *const u16, b: *const u16, n: usize) -> f32; pub fn v_f16_cosine_v3(a: *const u16, b: *const u16, n: usize) -> f32; pub fn v_f16_dot_v3(a: *const u16, b: *const u16, n: usize) -> f32; pub fn v_f16_sl2_v3(a: *const u16, b: *const u16, n: usize) -> f32; } - -// `compiler_builtin` defines `__extendhfsf2` with integer calling convention. -// However C compilers links `__extendhfsf2` with floating calling convention. -// The code should be removed once Rust offically supports `f16`. - -#[cfg(target_arch = "x86_64")] -#[no_mangle] -#[linkage = "external"] -extern "C" fn __extendhfsf2(f: f64) -> f32 { - unsafe { - let f: half::f16 = std::mem::transmute_copy(&f); - f.to_f32() - } -} diff --git a/crates/c/tests/x86_64.rs b/crates/c/tests/x86_64.rs new file mode 100644 index 0000000..0e981a9 --- /dev/null +++ b/crates/c/tests/x86_64.rs @@ -0,0 +1,126 @@ +#![cfg(target_arch = "x86_64")] + +#[test] +fn test_v_f16_cosine() { + const EPSILON: f32 = f16::EPSILON.to_f32_const(); + use half::f16; + unsafe fn v_f16_cosine(a: *const u16, b: *const u16, n: usize) -> f32 { + let mut xy = 0.0f32; + let mut xx = 0.0f32; + let mut yy = 0.0f32; + for i in 0..n { + let x = a.add(i).cast::().read().to_f32(); + let y = b.add(i).cast::().read().to_f32(); + xy += x * y; + xx += x * x; + yy += y * y; + } + xy / (xx * yy).sqrt() + } + let n = 4000; + let a = (0..n).map(|_| rand::random::()).collect::>(); + let b = (0..n).map(|_| rand::random::()).collect::>(); + let r = unsafe { v_f16_cosine(a.as_ptr().cast(), b.as_ptr().cast(), n) }; + if detect::x86_64::detect_avx512fp16() { + println!("detected avx512fp16"); + let c = unsafe { c::v_f16_cosine_avx512fp16(a.as_ptr().cast(), b.as_ptr().cast(), n) }; + assert!((c - r).abs() < EPSILON, "c = {c}, r = {r}."); + } else { + println!("detected no avx512fp16, skipped"); + } + if detect::x86_64::detect_v4() { + println!("detected v4"); + let c = unsafe { c::v_f16_cosine_v4(a.as_ptr().cast(), b.as_ptr().cast(), n) }; + assert!((c - r).abs() < EPSILON, "c = {c}, r = {r}."); + } else { + println!("detected no v4, skipped"); + } + if detect::x86_64::detect_v3() { + println!("detected v3"); + let c = unsafe { c::v_f16_cosine_v3(a.as_ptr().cast(), b.as_ptr().cast(), n) }; + assert!((c - r).abs() < EPSILON, "c = {c}, r = {r}."); + } else { + println!("detected no v3, skipped"); + } +} + +#[test] +fn test_v_f16_dot() { + const EPSILON: f32 = 1.0f32; + use half::f16; + unsafe fn v_f16_dot(a: *const u16, b: *const u16, n: usize) -> f32 { + let mut xy = 0.0f32; + for i in 0..n { + let x = a.add(i).cast::().read().to_f32(); + let y = b.add(i).cast::().read().to_f32(); + xy += x * y; + } + xy + } + let n = 4000; + let a = (0..n).map(|_| rand::random::()).collect::>(); + let b = (0..n).map(|_| rand::random::()).collect::>(); + let r = unsafe { v_f16_dot(a.as_ptr().cast(), b.as_ptr().cast(), n) }; + if detect::x86_64::detect_avx512fp16() { + println!("detected avx512fp16"); + let c = unsafe { c::v_f16_dot_avx512fp16(a.as_ptr().cast(), b.as_ptr().cast(), n) }; + assert!((c - r).abs() < EPSILON, "c = {c}, r = {r}."); + } else { + println!("detected no avx512fp16, skipped"); + } + if detect::x86_64::detect_v4() { + println!("detected v4"); + let c = unsafe { c::v_f16_dot_v4(a.as_ptr().cast(), b.as_ptr().cast(), n) }; + assert!((c - r).abs() < EPSILON, "c = {c}, r = {r}."); + } else { + println!("detected no v4, skipped"); + } + if detect::x86_64::detect_v3() { + println!("detected v3"); + let c = unsafe { c::v_f16_dot_v3(a.as_ptr().cast(), b.as_ptr().cast(), n) }; + assert!((c - r).abs() < EPSILON, "c = {c}, r = {r}."); + } else { + println!("detected no v3, skipped"); + } +} + +#[test] +fn test_v_f16_sl2() { + const EPSILON: f32 = 1.0f32; + use half::f16; + unsafe fn v_f16_sl2(a: *const u16, b: *const u16, n: usize) -> f32 { + let mut dd = 0.0f32; + for i in 0..n { + let x = a.add(i).cast::().read().to_f32(); + let y = b.add(i).cast::().read().to_f32(); + let d = x - y; + dd += d * d; + } + dd + } + let n = 4000; + let a = (0..n).map(|_| rand::random::()).collect::>(); + let b = (0..n).map(|_| rand::random::()).collect::>(); + let r = unsafe { v_f16_sl2(a.as_ptr().cast(), b.as_ptr().cast(), n) }; + if detect::x86_64::detect_avx512fp16() { + println!("detected avx512fp16"); + let c = unsafe { c::v_f16_sl2_avx512fp16(a.as_ptr().cast(), b.as_ptr().cast(), n) }; + assert!((c - r).abs() < EPSILON, "c = {c}, r = {r}."); + } else { + println!("detected no avx512fp16, skipped"); + } + if detect::x86_64::detect_v4() { + println!("detected v4"); + let c = unsafe { c::v_f16_sl2_v4(a.as_ptr().cast(), b.as_ptr().cast(), n) }; + assert!((c - r).abs() < EPSILON, "c = {c}, r = {r}."); + } else { + println!("detected no v4, skipped"); + } + if detect::x86_64::detect_v3() { + println!("detected v3"); + let c = unsafe { c::v_f16_sl2_v3(a.as_ptr().cast(), b.as_ptr().cast(), n) }; + assert!((c - r).abs() < EPSILON, "c = {c}, r = {r}."); + } else { + println!("detected no v3, skipped"); + } +} diff --git a/crates/detect/Cargo.toml b/crates/detect/Cargo.toml new file mode 100644 index 0000000..758effe --- /dev/null +++ b/crates/detect/Cargo.toml @@ -0,0 +1,8 @@ +[package] +name = "detect" +version.workspace = true +edition.workspace = true + +[dependencies] +std_detect = { git = "https://github.com/tensorchord/stdarch.git", branch = "avx512fp16" } +ctor = "0.2.6" diff --git a/crates/service/src/utils/detect.rs b/crates/detect/src/lib.rs similarity index 100% rename from crates/service/src/utils/detect.rs rename to crates/detect/src/lib.rs diff --git a/crates/service/src/utils/detect/x86_64.rs b/crates/detect/src/x86_64.rs similarity index 98% rename from crates/service/src/utils/detect/x86_64.rs rename to crates/detect/src/x86_64.rs index 5bd3c87..b9790a8 100644 --- a/crates/service/src/utils/detect/x86_64.rs +++ b/crates/detect/src/x86_64.rs @@ -33,7 +33,7 @@ fn ctor_v4() { ATOMIC_V4.store(test_v4(), Ordering::Relaxed); } -pub fn _detect_v4() -> bool { +pub fn detect_v4() -> bool { ATOMIC_V4.load(Ordering::Relaxed) } diff --git a/crates/service/Cargo.toml b/crates/service/Cargo.toml index 3671570..150619d 100644 --- a/crates/service/Cargo.toml +++ b/crates/service/Cargo.toml @@ -16,7 +16,7 @@ bincode.workspace = true half.workspace = true num-traits.workspace = true c = { path = "../c" } -std_detect = { git = "https://github.com/tensorchord/stdarch.git", branch = "avx512fp16" } +detect = { path = "../detect" } rand = "0.8.5" crc32fast = "1.3.2" crossbeam = "0.8.2" @@ -32,7 +32,6 @@ arc-swap = "1.6.0" bytemuck = { version = "1.14.0", features = ["extern_crate_alloc"] } serde_with = "3.4.0" multiversion = "0.7.3" -ctor = "0.2.6" [target.'cfg(target_os = "macos")'.dependencies] ulock-sys = "0.1.0" diff --git a/crates/service/src/prelude/global/f16.rs b/crates/service/src/prelude/global/f16.rs index 2f7e13c..1452937 100644 --- a/crates/service/src/prelude/global/f16.rs +++ b/crates/service/src/prelude/global/f16.rs @@ -22,7 +22,7 @@ pub fn cosine(lhs: &[F16], rhs: &[F16]) -> F32 { xy / (x2 * y2).sqrt() } #[cfg(target_arch = "x86_64")] - if crate::utils::detect::x86_64::detect_avx512fp16() { + if detect::x86_64::detect_avx512fp16() { assert!(lhs.len() == rhs.len()); let n = lhs.len(); unsafe { @@ -30,7 +30,15 @@ pub fn cosine(lhs: &[F16], rhs: &[F16]) -> F32 { } } #[cfg(target_arch = "x86_64")] - if crate::utils::detect::x86_64::detect_v3() { + if detect::x86_64::detect_v4() { + assert!(lhs.len() == rhs.len()); + let n = lhs.len(); + unsafe { + return c::v_f16_cosine_v4(lhs.as_ptr().cast(), rhs.as_ptr().cast(), n).into(); + } + } + #[cfg(target_arch = "x86_64")] + if detect::x86_64::detect_v3() { assert!(lhs.len() == rhs.len()); let n = lhs.len(); unsafe { @@ -58,7 +66,7 @@ pub fn dot(lhs: &[F16], rhs: &[F16]) -> F32 { xy } #[cfg(target_arch = "x86_64")] - if crate::utils::detect::x86_64::detect_avx512fp16() { + if detect::x86_64::detect_avx512fp16() { assert!(lhs.len() == rhs.len()); let n = lhs.len(); unsafe { @@ -66,7 +74,15 @@ pub fn dot(lhs: &[F16], rhs: &[F16]) -> F32 { } } #[cfg(target_arch = "x86_64")] - if crate::utils::detect::x86_64::detect_v3() { + if detect::x86_64::detect_v4() { + assert!(lhs.len() == rhs.len()); + let n = lhs.len(); + unsafe { + return c::v_f16_dot_v4(lhs.as_ptr().cast(), rhs.as_ptr().cast(), n).into(); + } + } + #[cfg(target_arch = "x86_64")] + if detect::x86_64::detect_v3() { assert!(lhs.len() == rhs.len()); let n = lhs.len(); unsafe { @@ -95,7 +111,7 @@ pub fn sl2(lhs: &[F16], rhs: &[F16]) -> F32 { d2 } #[cfg(target_arch = "x86_64")] - if crate::utils::detect::x86_64::detect_avx512fp16() { + if detect::x86_64::detect_avx512fp16() { assert!(lhs.len() == rhs.len()); let n = lhs.len(); unsafe { @@ -103,7 +119,15 @@ pub fn sl2(lhs: &[F16], rhs: &[F16]) -> F32 { } } #[cfg(target_arch = "x86_64")] - if crate::utils::detect::x86_64::detect_v3() { + if detect::x86_64::detect_v4() { + assert!(lhs.len() == rhs.len()); + let n = lhs.len(); + unsafe { + return c::v_f16_sl2_v4(lhs.as_ptr().cast(), rhs.as_ptr().cast(), n).into(); + } + } + #[cfg(target_arch = "x86_64")] + if detect::x86_64::detect_v3() { assert!(lhs.len() == rhs.len()); let n = lhs.len(); unsafe { diff --git a/crates/service/src/utils/mod.rs b/crates/service/src/utils/mod.rs index e422424..55f717a 100644 --- a/crates/service/src/utils/mod.rs +++ b/crates/service/src/utils/mod.rs @@ -1,6 +1,5 @@ pub mod cells; pub mod clean; -pub mod detect; pub mod dir_ops; pub mod file_atomic; pub mod file_wal;