1
0
mirror of https://github.com/tensorchord/pgvecto.rs.git synced 2025-07-30 19:23:05 +03:00

C code tests & avx512f f16 implement (#183)

* test: add tests for c code

Signed-off-by: usamoi <usamoi@outlook.com>

* fix: relax EPSILON for tests

Signed-off-by: usamoi <usamoi@outlook.com>

---------

Signed-off-by: usamoi <usamoi@outlook.com>
This commit is contained in:
Usamoi
2023-12-15 16:00:08 +08:00
committed by GitHub
parent 2869fbd44c
commit c50912e87d
13 changed files with 276 additions and 31 deletions

View File

@ -101,7 +101,7 @@ jobs:
cargo build --no-default-features --features "pg${{ matrix.version }} pg_test" --target aarch64-unknown-linux-gnu
- name: Test
run: |
cargo test --all --no-fail-fast --no-default-features --features "pg${{ matrix.version }} pg_test" --target x86_64-unknown-linux-gnu
cargo test --all --no-fail-fast --no-default-features --features "pg${{ matrix.version }} pg_test" --target x86_64-unknown-linux-gnu -- --nocapture
- name: Install release
run: ./scripts/ci_install.sh
- name: Sqllogictest

25
Cargo.lock generated
View File

@ -486,7 +486,9 @@ name = "c"
version = "0.0.0"
dependencies = [
"cc",
"detect",
"half 2.3.1",
"rand",
]
[[package]]
@ -837,6 +839,14 @@ dependencies = [
"serde",
]
[[package]]
name = "detect"
version = "0.0.0"
dependencies = [
"ctor",
"std_detect",
]
[[package]]
name = "diff"
version = "0.1.13"
@ -1263,6 +1273,8 @@ dependencies = [
"cfg-if",
"crunchy",
"num-traits",
"rand",
"rand_distr",
"serde",
]
@ -2436,6 +2448,16 @@ dependencies = [
"getrandom",
]
[[package]]
name = "rand_distr"
version = "0.4.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "32cb0b9bc82b0a0876c2dd994a7e7a2683d3e7390ca40e6886785ef0c7e3ee31"
dependencies = [
"num-traits",
"rand",
]
[[package]]
name = "rand_xorshift"
version = "0.3.0"
@ -2808,8 +2830,8 @@ dependencies = [
"c",
"crc32fast",
"crossbeam",
"ctor",
"dashmap",
"detect",
"half 2.3.1",
"libc",
"log",
@ -2824,7 +2846,6 @@ dependencies = [
"serde",
"serde_json",
"serde_with",
"std_detect",
"tempfile",
"thiserror",
"ulock-sys",

View File

@ -3,8 +3,10 @@ name = "c"
version.workspace = true
edition.workspace = true
[dependencies]
half = { version = "~2.3", features = ["use-intrinsics"] }
[dev-dependencies]
half = { version = "~2.3", features = ["use-intrinsics", "rand_distr"] }
detect = { path = "../detect" }
rand = "0.8.5"
[build-dependencies]
cc = "1.0"

View File

@ -29,8 +29,12 @@ v_f16_cosine_avx512fp16(_Float16 *a, _Float16 *b, size_t n) {
xx = _mm512_fmadd_ph(x, x, xx);
yy = _mm512_fmadd_ph(y, y, yy);
}
return (float)(_mm512_reduce_add_ph(xy) /
sqrt(_mm512_reduce_add_ph(xx) * _mm512_reduce_add_ph(yy)));
{
float rxy = _mm512_reduce_add_ph(xy);
float rxx = _mm512_reduce_add_ph(xx);
float ryy = _mm512_reduce_add_ph(yy);
return rxy / sqrt(rxx * ryy);
}
}
__attribute__((target("arch=x86-64-v4,avx512fp16"))) extern float
@ -74,6 +78,76 @@ v_f16_sl2_avx512fp16(_Float16 *a, _Float16 *b, size_t n) {
return (float)_mm512_reduce_add_ph(dd);
}
__attribute__((target("arch=x86-64-v4"))) extern float
v_f16_cosine_v4(_Float16 *a, _Float16 *b, size_t n) {
__m512 xy = _mm512_set1_ps(0);
__m512 xx = _mm512_set1_ps(0);
__m512 yy = _mm512_set1_ps(0);
while (n >= 16) {
__m512 x = _mm512_cvtph_ps(_mm256_loadu_epi16(a));
__m512 y = _mm512_cvtph_ps(_mm256_loadu_epi16(b));
a += 16, b += 16, n -= 16;
xy = _mm512_fmadd_ps(x, y, xy);
xx = _mm512_fmadd_ps(x, x, xx);
yy = _mm512_fmadd_ps(y, y, yy);
}
if (n > 0) {
__mmask16 mask = _bzhi_u32(0xFFFF, n);
__m512 x = _mm512_cvtph_ps(_mm256_maskz_loadu_epi16(mask, a));
__m512 y = _mm512_cvtph_ps(_mm256_maskz_loadu_epi16(mask, b));
xy = _mm512_fmadd_ps(x, y, xy);
xx = _mm512_fmadd_ps(x, x, xx);
yy = _mm512_fmadd_ps(y, y, yy);
}
{
float rxy = _mm512_reduce_add_ps(xy);
float rxx = _mm512_reduce_add_ps(xx);
float ryy = _mm512_reduce_add_ps(yy);
return rxy / sqrt(rxx * ryy);
}
}
__attribute__((target("arch=x86-64-v4"))) extern float
v_f16_dot_v4(_Float16 *a, _Float16 *b, size_t n) {
__m512 xy = _mm512_set1_ps(0);
while (n >= 16) {
__m512 x = _mm512_cvtph_ps(_mm256_loadu_epi16(a));
__m512 y = _mm512_cvtph_ps(_mm256_loadu_epi16(b));
a += 16, b += 16, n -= 16;
xy = _mm512_fmadd_ps(x, y, xy);
}
if (n > 0) {
__mmask16 mask = _bzhi_u32(0xFFFF, n);
__m512 x = _mm512_cvtph_ps(_mm256_maskz_loadu_epi16(mask, a));
__m512 y = _mm512_cvtph_ps(_mm256_maskz_loadu_epi16(mask, b));
xy = _mm512_fmadd_ps(x, y, xy);
}
return _mm512_reduce_add_ps(xy);
}
__attribute__((target("arch=x86-64-v4"))) extern float
v_f16_sl2_v4(_Float16 *a, _Float16 *b, size_t n) {
__m512 dd = _mm512_set1_ps(0);
while (n >= 16) {
__m512 x = _mm512_cvtph_ps(_mm256_loadu_epi16(a));
__m512 y = _mm512_cvtph_ps(_mm256_loadu_epi16(b));
a += 16, b += 16, n -= 16;
__m512 d = _mm512_sub_ps(x, y);
dd = _mm512_fmadd_ps(d, d, dd);
}
if (n > 0) {
__mmask16 mask = _bzhi_u32(0xFFFF, n);
__m512 x = _mm512_cvtph_ps(_mm256_maskz_loadu_epi16(mask, a));
__m512 y = _mm512_cvtph_ps(_mm256_maskz_loadu_epi16(mask, b));
__m512 d = _mm512_sub_ps(x, y);
dd = _mm512_fmadd_ps(d, d, dd);
}
return _mm512_reduce_add_ps(dd);
}
__attribute__((target("arch=x86-64-v3"))) extern float
v_f16_cosine_v3(_Float16 *a, _Float16 *b, size_t n) {
float xy = 0;

View File

@ -6,6 +6,9 @@
extern float v_f16_cosine_avx512fp16(_Float16 *, _Float16 *, size_t n);
extern float v_f16_dot_avx512fp16(_Float16 *, _Float16 *, size_t n);
extern float v_f16_sl2_avx512fp16(_Float16 *, _Float16 *, size_t n);
extern float v_f16_cosine_v4(_Float16 *, _Float16 *, size_t n);
extern float v_f16_dot_v4(_Float16 *, _Float16 *, size_t n);
extern float v_f16_sl2_v4(_Float16 *, _Float16 *, size_t n);
extern float v_f16_cosine_v3(_Float16 *, _Float16 *, size_t n);
extern float v_f16_dot_v3(_Float16 *, _Float16 *, size_t n);
extern float v_f16_sl2_v3(_Float16 *, _Float16 *, size_t n);

View File

@ -4,21 +4,10 @@ extern "C" {
pub fn v_f16_cosine_avx512fp16(a: *const u16, b: *const u16, n: usize) -> f32;
pub fn v_f16_dot_avx512fp16(a: *const u16, b: *const u16, n: usize) -> f32;
pub fn v_f16_sl2_avx512fp16(a: *const u16, b: *const u16, n: usize) -> f32;
pub fn v_f16_cosine_v4(a: *const u16, b: *const u16, n: usize) -> f32;
pub fn v_f16_dot_v4(a: *const u16, b: *const u16, n: usize) -> f32;
pub fn v_f16_sl2_v4(a: *const u16, b: *const u16, n: usize) -> f32;
pub fn v_f16_cosine_v3(a: *const u16, b: *const u16, n: usize) -> f32;
pub fn v_f16_dot_v3(a: *const u16, b: *const u16, n: usize) -> f32;
pub fn v_f16_sl2_v3(a: *const u16, b: *const u16, n: usize) -> f32;
}
// `compiler_builtin` defines `__extendhfsf2` with integer calling convention.
// However C compilers links `__extendhfsf2` with floating calling convention.
// The code should be removed once Rust offically supports `f16`.
#[cfg(target_arch = "x86_64")]
#[no_mangle]
#[linkage = "external"]
extern "C" fn __extendhfsf2(f: f64) -> f32 {
unsafe {
let f: half::f16 = std::mem::transmute_copy(&f);
f.to_f32()
}
}

126
crates/c/tests/x86_64.rs Normal file
View File

@ -0,0 +1,126 @@
#![cfg(target_arch = "x86_64")]
#[test]
fn test_v_f16_cosine() {
const EPSILON: f32 = f16::EPSILON.to_f32_const();
use half::f16;
unsafe fn v_f16_cosine(a: *const u16, b: *const u16, n: usize) -> f32 {
let mut xy = 0.0f32;
let mut xx = 0.0f32;
let mut yy = 0.0f32;
for i in 0..n {
let x = a.add(i).cast::<f16>().read().to_f32();
let y = b.add(i).cast::<f16>().read().to_f32();
xy += x * y;
xx += x * x;
yy += y * y;
}
xy / (xx * yy).sqrt()
}
let n = 4000;
let a = (0..n).map(|_| rand::random::<f16>()).collect::<Vec<_>>();
let b = (0..n).map(|_| rand::random::<f16>()).collect::<Vec<_>>();
let r = unsafe { v_f16_cosine(a.as_ptr().cast(), b.as_ptr().cast(), n) };
if detect::x86_64::detect_avx512fp16() {
println!("detected avx512fp16");
let c = unsafe { c::v_f16_cosine_avx512fp16(a.as_ptr().cast(), b.as_ptr().cast(), n) };
assert!((c - r).abs() < EPSILON, "c = {c}, r = {r}.");
} else {
println!("detected no avx512fp16, skipped");
}
if detect::x86_64::detect_v4() {
println!("detected v4");
let c = unsafe { c::v_f16_cosine_v4(a.as_ptr().cast(), b.as_ptr().cast(), n) };
assert!((c - r).abs() < EPSILON, "c = {c}, r = {r}.");
} else {
println!("detected no v4, skipped");
}
if detect::x86_64::detect_v3() {
println!("detected v3");
let c = unsafe { c::v_f16_cosine_v3(a.as_ptr().cast(), b.as_ptr().cast(), n) };
assert!((c - r).abs() < EPSILON, "c = {c}, r = {r}.");
} else {
println!("detected no v3, skipped");
}
}
#[test]
fn test_v_f16_dot() {
const EPSILON: f32 = 1.0f32;
use half::f16;
unsafe fn v_f16_dot(a: *const u16, b: *const u16, n: usize) -> f32 {
let mut xy = 0.0f32;
for i in 0..n {
let x = a.add(i).cast::<f16>().read().to_f32();
let y = b.add(i).cast::<f16>().read().to_f32();
xy += x * y;
}
xy
}
let n = 4000;
let a = (0..n).map(|_| rand::random::<f16>()).collect::<Vec<_>>();
let b = (0..n).map(|_| rand::random::<f16>()).collect::<Vec<_>>();
let r = unsafe { v_f16_dot(a.as_ptr().cast(), b.as_ptr().cast(), n) };
if detect::x86_64::detect_avx512fp16() {
println!("detected avx512fp16");
let c = unsafe { c::v_f16_dot_avx512fp16(a.as_ptr().cast(), b.as_ptr().cast(), n) };
assert!((c - r).abs() < EPSILON, "c = {c}, r = {r}.");
} else {
println!("detected no avx512fp16, skipped");
}
if detect::x86_64::detect_v4() {
println!("detected v4");
let c = unsafe { c::v_f16_dot_v4(a.as_ptr().cast(), b.as_ptr().cast(), n) };
assert!((c - r).abs() < EPSILON, "c = {c}, r = {r}.");
} else {
println!("detected no v4, skipped");
}
if detect::x86_64::detect_v3() {
println!("detected v3");
let c = unsafe { c::v_f16_dot_v3(a.as_ptr().cast(), b.as_ptr().cast(), n) };
assert!((c - r).abs() < EPSILON, "c = {c}, r = {r}.");
} else {
println!("detected no v3, skipped");
}
}
#[test]
fn test_v_f16_sl2() {
const EPSILON: f32 = 1.0f32;
use half::f16;
unsafe fn v_f16_sl2(a: *const u16, b: *const u16, n: usize) -> f32 {
let mut dd = 0.0f32;
for i in 0..n {
let x = a.add(i).cast::<f16>().read().to_f32();
let y = b.add(i).cast::<f16>().read().to_f32();
let d = x - y;
dd += d * d;
}
dd
}
let n = 4000;
let a = (0..n).map(|_| rand::random::<f16>()).collect::<Vec<_>>();
let b = (0..n).map(|_| rand::random::<f16>()).collect::<Vec<_>>();
let r = unsafe { v_f16_sl2(a.as_ptr().cast(), b.as_ptr().cast(), n) };
if detect::x86_64::detect_avx512fp16() {
println!("detected avx512fp16");
let c = unsafe { c::v_f16_sl2_avx512fp16(a.as_ptr().cast(), b.as_ptr().cast(), n) };
assert!((c - r).abs() < EPSILON, "c = {c}, r = {r}.");
} else {
println!("detected no avx512fp16, skipped");
}
if detect::x86_64::detect_v4() {
println!("detected v4");
let c = unsafe { c::v_f16_sl2_v4(a.as_ptr().cast(), b.as_ptr().cast(), n) };
assert!((c - r).abs() < EPSILON, "c = {c}, r = {r}.");
} else {
println!("detected no v4, skipped");
}
if detect::x86_64::detect_v3() {
println!("detected v3");
let c = unsafe { c::v_f16_sl2_v3(a.as_ptr().cast(), b.as_ptr().cast(), n) };
assert!((c - r).abs() < EPSILON, "c = {c}, r = {r}.");
} else {
println!("detected no v3, skipped");
}
}

8
crates/detect/Cargo.toml Normal file
View File

@ -0,0 +1,8 @@
[package]
name = "detect"
version.workspace = true
edition.workspace = true
[dependencies]
std_detect = { git = "https://github.com/tensorchord/stdarch.git", branch = "avx512fp16" }
ctor = "0.2.6"

View File

@ -33,7 +33,7 @@ fn ctor_v4() {
ATOMIC_V4.store(test_v4(), Ordering::Relaxed);
}
pub fn _detect_v4() -> bool {
pub fn detect_v4() -> bool {
ATOMIC_V4.load(Ordering::Relaxed)
}

View File

@ -16,7 +16,7 @@ bincode.workspace = true
half.workspace = true
num-traits.workspace = true
c = { path = "../c" }
std_detect = { git = "https://github.com/tensorchord/stdarch.git", branch = "avx512fp16" }
detect = { path = "../detect" }
rand = "0.8.5"
crc32fast = "1.3.2"
crossbeam = "0.8.2"
@ -32,7 +32,6 @@ arc-swap = "1.6.0"
bytemuck = { version = "1.14.0", features = ["extern_crate_alloc"] }
serde_with = "3.4.0"
multiversion = "0.7.3"
ctor = "0.2.6"
[target.'cfg(target_os = "macos")'.dependencies]
ulock-sys = "0.1.0"

View File

@ -22,7 +22,7 @@ pub fn cosine(lhs: &[F16], rhs: &[F16]) -> F32 {
xy / (x2 * y2).sqrt()
}
#[cfg(target_arch = "x86_64")]
if crate::utils::detect::x86_64::detect_avx512fp16() {
if detect::x86_64::detect_avx512fp16() {
assert!(lhs.len() == rhs.len());
let n = lhs.len();
unsafe {
@ -30,7 +30,15 @@ pub fn cosine(lhs: &[F16], rhs: &[F16]) -> F32 {
}
}
#[cfg(target_arch = "x86_64")]
if crate::utils::detect::x86_64::detect_v3() {
if detect::x86_64::detect_v4() {
assert!(lhs.len() == rhs.len());
let n = lhs.len();
unsafe {
return c::v_f16_cosine_v4(lhs.as_ptr().cast(), rhs.as_ptr().cast(), n).into();
}
}
#[cfg(target_arch = "x86_64")]
if detect::x86_64::detect_v3() {
assert!(lhs.len() == rhs.len());
let n = lhs.len();
unsafe {
@ -58,7 +66,7 @@ pub fn dot(lhs: &[F16], rhs: &[F16]) -> F32 {
xy
}
#[cfg(target_arch = "x86_64")]
if crate::utils::detect::x86_64::detect_avx512fp16() {
if detect::x86_64::detect_avx512fp16() {
assert!(lhs.len() == rhs.len());
let n = lhs.len();
unsafe {
@ -66,7 +74,15 @@ pub fn dot(lhs: &[F16], rhs: &[F16]) -> F32 {
}
}
#[cfg(target_arch = "x86_64")]
if crate::utils::detect::x86_64::detect_v3() {
if detect::x86_64::detect_v4() {
assert!(lhs.len() == rhs.len());
let n = lhs.len();
unsafe {
return c::v_f16_dot_v4(lhs.as_ptr().cast(), rhs.as_ptr().cast(), n).into();
}
}
#[cfg(target_arch = "x86_64")]
if detect::x86_64::detect_v3() {
assert!(lhs.len() == rhs.len());
let n = lhs.len();
unsafe {
@ -95,7 +111,7 @@ pub fn sl2(lhs: &[F16], rhs: &[F16]) -> F32 {
d2
}
#[cfg(target_arch = "x86_64")]
if crate::utils::detect::x86_64::detect_avx512fp16() {
if detect::x86_64::detect_avx512fp16() {
assert!(lhs.len() == rhs.len());
let n = lhs.len();
unsafe {
@ -103,7 +119,15 @@ pub fn sl2(lhs: &[F16], rhs: &[F16]) -> F32 {
}
}
#[cfg(target_arch = "x86_64")]
if crate::utils::detect::x86_64::detect_v3() {
if detect::x86_64::detect_v4() {
assert!(lhs.len() == rhs.len());
let n = lhs.len();
unsafe {
return c::v_f16_sl2_v4(lhs.as_ptr().cast(), rhs.as_ptr().cast(), n).into();
}
}
#[cfg(target_arch = "x86_64")]
if detect::x86_64::detect_v3() {
assert!(lhs.len() == rhs.len());
let n = lhs.len();
unsafe {

View File

@ -1,6 +1,5 @@
pub mod cells;
pub mod clean;
pub mod detect;
pub mod dir_ops;
pub mod file_atomic;
pub mod file_wal;