You've already forked pgvecto.rs
mirror of
https://github.com/tensorchord/pgvecto.rs.git
synced 2025-07-30 19:23:05 +03:00
C code tests & avx512f f16 implement (#183)
* test: add tests for c code Signed-off-by: usamoi <usamoi@outlook.com> * fix: relax EPSILON for tests Signed-off-by: usamoi <usamoi@outlook.com> --------- Signed-off-by: usamoi <usamoi@outlook.com>
This commit is contained in:
@ -3,8 +3,10 @@ name = "c"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
|
||||
[dependencies]
|
||||
half = { version = "~2.3", features = ["use-intrinsics"] }
|
||||
[dev-dependencies]
|
||||
half = { version = "~2.3", features = ["use-intrinsics", "rand_distr"] }
|
||||
detect = { path = "../detect" }
|
||||
rand = "0.8.5"
|
||||
|
||||
[build-dependencies]
|
||||
cc = "1.0"
|
||||
|
@ -29,8 +29,12 @@ v_f16_cosine_avx512fp16(_Float16 *a, _Float16 *b, size_t n) {
|
||||
xx = _mm512_fmadd_ph(x, x, xx);
|
||||
yy = _mm512_fmadd_ph(y, y, yy);
|
||||
}
|
||||
return (float)(_mm512_reduce_add_ph(xy) /
|
||||
sqrt(_mm512_reduce_add_ph(xx) * _mm512_reduce_add_ph(yy)));
|
||||
{
|
||||
float rxy = _mm512_reduce_add_ph(xy);
|
||||
float rxx = _mm512_reduce_add_ph(xx);
|
||||
float ryy = _mm512_reduce_add_ph(yy);
|
||||
return rxy / sqrt(rxx * ryy);
|
||||
}
|
||||
}
|
||||
|
||||
__attribute__((target("arch=x86-64-v4,avx512fp16"))) extern float
|
||||
@ -74,6 +78,76 @@ v_f16_sl2_avx512fp16(_Float16 *a, _Float16 *b, size_t n) {
|
||||
return (float)_mm512_reduce_add_ph(dd);
|
||||
}
|
||||
|
||||
__attribute__((target("arch=x86-64-v4"))) extern float
|
||||
v_f16_cosine_v4(_Float16 *a, _Float16 *b, size_t n) {
|
||||
__m512 xy = _mm512_set1_ps(0);
|
||||
__m512 xx = _mm512_set1_ps(0);
|
||||
__m512 yy = _mm512_set1_ps(0);
|
||||
|
||||
while (n >= 16) {
|
||||
__m512 x = _mm512_cvtph_ps(_mm256_loadu_epi16(a));
|
||||
__m512 y = _mm512_cvtph_ps(_mm256_loadu_epi16(b));
|
||||
a += 16, b += 16, n -= 16;
|
||||
xy = _mm512_fmadd_ps(x, y, xy);
|
||||
xx = _mm512_fmadd_ps(x, x, xx);
|
||||
yy = _mm512_fmadd_ps(y, y, yy);
|
||||
}
|
||||
if (n > 0) {
|
||||
__mmask16 mask = _bzhi_u32(0xFFFF, n);
|
||||
__m512 x = _mm512_cvtph_ps(_mm256_maskz_loadu_epi16(mask, a));
|
||||
__m512 y = _mm512_cvtph_ps(_mm256_maskz_loadu_epi16(mask, b));
|
||||
xy = _mm512_fmadd_ps(x, y, xy);
|
||||
xx = _mm512_fmadd_ps(x, x, xx);
|
||||
yy = _mm512_fmadd_ps(y, y, yy);
|
||||
}
|
||||
{
|
||||
float rxy = _mm512_reduce_add_ps(xy);
|
||||
float rxx = _mm512_reduce_add_ps(xx);
|
||||
float ryy = _mm512_reduce_add_ps(yy);
|
||||
return rxy / sqrt(rxx * ryy);
|
||||
}
|
||||
}
|
||||
|
||||
__attribute__((target("arch=x86-64-v4"))) extern float
|
||||
v_f16_dot_v4(_Float16 *a, _Float16 *b, size_t n) {
|
||||
__m512 xy = _mm512_set1_ps(0);
|
||||
|
||||
while (n >= 16) {
|
||||
__m512 x = _mm512_cvtph_ps(_mm256_loadu_epi16(a));
|
||||
__m512 y = _mm512_cvtph_ps(_mm256_loadu_epi16(b));
|
||||
a += 16, b += 16, n -= 16;
|
||||
xy = _mm512_fmadd_ps(x, y, xy);
|
||||
}
|
||||
if (n > 0) {
|
||||
__mmask16 mask = _bzhi_u32(0xFFFF, n);
|
||||
__m512 x = _mm512_cvtph_ps(_mm256_maskz_loadu_epi16(mask, a));
|
||||
__m512 y = _mm512_cvtph_ps(_mm256_maskz_loadu_epi16(mask, b));
|
||||
xy = _mm512_fmadd_ps(x, y, xy);
|
||||
}
|
||||
return _mm512_reduce_add_ps(xy);
|
||||
}
|
||||
|
||||
__attribute__((target("arch=x86-64-v4"))) extern float
|
||||
v_f16_sl2_v4(_Float16 *a, _Float16 *b, size_t n) {
|
||||
__m512 dd = _mm512_set1_ps(0);
|
||||
|
||||
while (n >= 16) {
|
||||
__m512 x = _mm512_cvtph_ps(_mm256_loadu_epi16(a));
|
||||
__m512 y = _mm512_cvtph_ps(_mm256_loadu_epi16(b));
|
||||
a += 16, b += 16, n -= 16;
|
||||
__m512 d = _mm512_sub_ps(x, y);
|
||||
dd = _mm512_fmadd_ps(d, d, dd);
|
||||
}
|
||||
if (n > 0) {
|
||||
__mmask16 mask = _bzhi_u32(0xFFFF, n);
|
||||
__m512 x = _mm512_cvtph_ps(_mm256_maskz_loadu_epi16(mask, a));
|
||||
__m512 y = _mm512_cvtph_ps(_mm256_maskz_loadu_epi16(mask, b));
|
||||
__m512 d = _mm512_sub_ps(x, y);
|
||||
dd = _mm512_fmadd_ps(d, d, dd);
|
||||
}
|
||||
return _mm512_reduce_add_ps(dd);
|
||||
}
|
||||
|
||||
__attribute__((target("arch=x86-64-v3"))) extern float
|
||||
v_f16_cosine_v3(_Float16 *a, _Float16 *b, size_t n) {
|
||||
float xy = 0;
|
||||
|
@ -6,6 +6,9 @@
|
||||
extern float v_f16_cosine_avx512fp16(_Float16 *, _Float16 *, size_t n);
|
||||
extern float v_f16_dot_avx512fp16(_Float16 *, _Float16 *, size_t n);
|
||||
extern float v_f16_sl2_avx512fp16(_Float16 *, _Float16 *, size_t n);
|
||||
extern float v_f16_cosine_v4(_Float16 *, _Float16 *, size_t n);
|
||||
extern float v_f16_dot_v4(_Float16 *, _Float16 *, size_t n);
|
||||
extern float v_f16_sl2_v4(_Float16 *, _Float16 *, size_t n);
|
||||
extern float v_f16_cosine_v3(_Float16 *, _Float16 *, size_t n);
|
||||
extern float v_f16_dot_v3(_Float16 *, _Float16 *, size_t n);
|
||||
extern float v_f16_sl2_v3(_Float16 *, _Float16 *, size_t n);
|
||||
|
@ -4,21 +4,10 @@ extern "C" {
|
||||
pub fn v_f16_cosine_avx512fp16(a: *const u16, b: *const u16, n: usize) -> f32;
|
||||
pub fn v_f16_dot_avx512fp16(a: *const u16, b: *const u16, n: usize) -> f32;
|
||||
pub fn v_f16_sl2_avx512fp16(a: *const u16, b: *const u16, n: usize) -> f32;
|
||||
pub fn v_f16_cosine_v4(a: *const u16, b: *const u16, n: usize) -> f32;
|
||||
pub fn v_f16_dot_v4(a: *const u16, b: *const u16, n: usize) -> f32;
|
||||
pub fn v_f16_sl2_v4(a: *const u16, b: *const u16, n: usize) -> f32;
|
||||
pub fn v_f16_cosine_v3(a: *const u16, b: *const u16, n: usize) -> f32;
|
||||
pub fn v_f16_dot_v3(a: *const u16, b: *const u16, n: usize) -> f32;
|
||||
pub fn v_f16_sl2_v3(a: *const u16, b: *const u16, n: usize) -> f32;
|
||||
}
|
||||
|
||||
// `compiler_builtin` defines `__extendhfsf2` with integer calling convention.
|
||||
// However C compilers links `__extendhfsf2` with floating calling convention.
|
||||
// The code should be removed once Rust offically supports `f16`.
|
||||
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
#[no_mangle]
|
||||
#[linkage = "external"]
|
||||
extern "C" fn __extendhfsf2(f: f64) -> f32 {
|
||||
unsafe {
|
||||
let f: half::f16 = std::mem::transmute_copy(&f);
|
||||
f.to_f32()
|
||||
}
|
||||
}
|
||||
|
126
crates/c/tests/x86_64.rs
Normal file
126
crates/c/tests/x86_64.rs
Normal file
@ -0,0 +1,126 @@
|
||||
#![cfg(target_arch = "x86_64")]
|
||||
|
||||
#[test]
|
||||
fn test_v_f16_cosine() {
|
||||
const EPSILON: f32 = f16::EPSILON.to_f32_const();
|
||||
use half::f16;
|
||||
unsafe fn v_f16_cosine(a: *const u16, b: *const u16, n: usize) -> f32 {
|
||||
let mut xy = 0.0f32;
|
||||
let mut xx = 0.0f32;
|
||||
let mut yy = 0.0f32;
|
||||
for i in 0..n {
|
||||
let x = a.add(i).cast::<f16>().read().to_f32();
|
||||
let y = b.add(i).cast::<f16>().read().to_f32();
|
||||
xy += x * y;
|
||||
xx += x * x;
|
||||
yy += y * y;
|
||||
}
|
||||
xy / (xx * yy).sqrt()
|
||||
}
|
||||
let n = 4000;
|
||||
let a = (0..n).map(|_| rand::random::<f16>()).collect::<Vec<_>>();
|
||||
let b = (0..n).map(|_| rand::random::<f16>()).collect::<Vec<_>>();
|
||||
let r = unsafe { v_f16_cosine(a.as_ptr().cast(), b.as_ptr().cast(), n) };
|
||||
if detect::x86_64::detect_avx512fp16() {
|
||||
println!("detected avx512fp16");
|
||||
let c = unsafe { c::v_f16_cosine_avx512fp16(a.as_ptr().cast(), b.as_ptr().cast(), n) };
|
||||
assert!((c - r).abs() < EPSILON, "c = {c}, r = {r}.");
|
||||
} else {
|
||||
println!("detected no avx512fp16, skipped");
|
||||
}
|
||||
if detect::x86_64::detect_v4() {
|
||||
println!("detected v4");
|
||||
let c = unsafe { c::v_f16_cosine_v4(a.as_ptr().cast(), b.as_ptr().cast(), n) };
|
||||
assert!((c - r).abs() < EPSILON, "c = {c}, r = {r}.");
|
||||
} else {
|
||||
println!("detected no v4, skipped");
|
||||
}
|
||||
if detect::x86_64::detect_v3() {
|
||||
println!("detected v3");
|
||||
let c = unsafe { c::v_f16_cosine_v3(a.as_ptr().cast(), b.as_ptr().cast(), n) };
|
||||
assert!((c - r).abs() < EPSILON, "c = {c}, r = {r}.");
|
||||
} else {
|
||||
println!("detected no v3, skipped");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_v_f16_dot() {
|
||||
const EPSILON: f32 = 1.0f32;
|
||||
use half::f16;
|
||||
unsafe fn v_f16_dot(a: *const u16, b: *const u16, n: usize) -> f32 {
|
||||
let mut xy = 0.0f32;
|
||||
for i in 0..n {
|
||||
let x = a.add(i).cast::<f16>().read().to_f32();
|
||||
let y = b.add(i).cast::<f16>().read().to_f32();
|
||||
xy += x * y;
|
||||
}
|
||||
xy
|
||||
}
|
||||
let n = 4000;
|
||||
let a = (0..n).map(|_| rand::random::<f16>()).collect::<Vec<_>>();
|
||||
let b = (0..n).map(|_| rand::random::<f16>()).collect::<Vec<_>>();
|
||||
let r = unsafe { v_f16_dot(a.as_ptr().cast(), b.as_ptr().cast(), n) };
|
||||
if detect::x86_64::detect_avx512fp16() {
|
||||
println!("detected avx512fp16");
|
||||
let c = unsafe { c::v_f16_dot_avx512fp16(a.as_ptr().cast(), b.as_ptr().cast(), n) };
|
||||
assert!((c - r).abs() < EPSILON, "c = {c}, r = {r}.");
|
||||
} else {
|
||||
println!("detected no avx512fp16, skipped");
|
||||
}
|
||||
if detect::x86_64::detect_v4() {
|
||||
println!("detected v4");
|
||||
let c = unsafe { c::v_f16_dot_v4(a.as_ptr().cast(), b.as_ptr().cast(), n) };
|
||||
assert!((c - r).abs() < EPSILON, "c = {c}, r = {r}.");
|
||||
} else {
|
||||
println!("detected no v4, skipped");
|
||||
}
|
||||
if detect::x86_64::detect_v3() {
|
||||
println!("detected v3");
|
||||
let c = unsafe { c::v_f16_dot_v3(a.as_ptr().cast(), b.as_ptr().cast(), n) };
|
||||
assert!((c - r).abs() < EPSILON, "c = {c}, r = {r}.");
|
||||
} else {
|
||||
println!("detected no v3, skipped");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_v_f16_sl2() {
|
||||
const EPSILON: f32 = 1.0f32;
|
||||
use half::f16;
|
||||
unsafe fn v_f16_sl2(a: *const u16, b: *const u16, n: usize) -> f32 {
|
||||
let mut dd = 0.0f32;
|
||||
for i in 0..n {
|
||||
let x = a.add(i).cast::<f16>().read().to_f32();
|
||||
let y = b.add(i).cast::<f16>().read().to_f32();
|
||||
let d = x - y;
|
||||
dd += d * d;
|
||||
}
|
||||
dd
|
||||
}
|
||||
let n = 4000;
|
||||
let a = (0..n).map(|_| rand::random::<f16>()).collect::<Vec<_>>();
|
||||
let b = (0..n).map(|_| rand::random::<f16>()).collect::<Vec<_>>();
|
||||
let r = unsafe { v_f16_sl2(a.as_ptr().cast(), b.as_ptr().cast(), n) };
|
||||
if detect::x86_64::detect_avx512fp16() {
|
||||
println!("detected avx512fp16");
|
||||
let c = unsafe { c::v_f16_sl2_avx512fp16(a.as_ptr().cast(), b.as_ptr().cast(), n) };
|
||||
assert!((c - r).abs() < EPSILON, "c = {c}, r = {r}.");
|
||||
} else {
|
||||
println!("detected no avx512fp16, skipped");
|
||||
}
|
||||
if detect::x86_64::detect_v4() {
|
||||
println!("detected v4");
|
||||
let c = unsafe { c::v_f16_sl2_v4(a.as_ptr().cast(), b.as_ptr().cast(), n) };
|
||||
assert!((c - r).abs() < EPSILON, "c = {c}, r = {r}.");
|
||||
} else {
|
||||
println!("detected no v4, skipped");
|
||||
}
|
||||
if detect::x86_64::detect_v3() {
|
||||
println!("detected v3");
|
||||
let c = unsafe { c::v_f16_sl2_v3(a.as_ptr().cast(), b.as_ptr().cast(), n) };
|
||||
assert!((c - r).abs() < EPSILON, "c = {c}, r = {r}.");
|
||||
} else {
|
||||
println!("detected no v3, skipped");
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user