1
0
mirror of https://github.com/tensorchord/pgvecto.rs.git synced 2025-07-30 19:23:05 +03:00

C code tests & avx512f f16 implement (#183)

* test: add tests for c code

Signed-off-by: usamoi <usamoi@outlook.com>

* fix: relax EPSILON for tests

Signed-off-by: usamoi <usamoi@outlook.com>

---------

Signed-off-by: usamoi <usamoi@outlook.com>
This commit is contained in:
Usamoi
2023-12-15 16:00:08 +08:00
committed by GitHub
parent 2869fbd44c
commit c50912e87d
13 changed files with 276 additions and 31 deletions

View File

@ -3,8 +3,10 @@ name = "c"
version.workspace = true
edition.workspace = true
[dependencies]
half = { version = "~2.3", features = ["use-intrinsics"] }
[dev-dependencies]
half = { version = "~2.3", features = ["use-intrinsics", "rand_distr"] }
detect = { path = "../detect" }
rand = "0.8.5"
[build-dependencies]
cc = "1.0"

View File

@ -29,8 +29,12 @@ v_f16_cosine_avx512fp16(_Float16 *a, _Float16 *b, size_t n) {
xx = _mm512_fmadd_ph(x, x, xx);
yy = _mm512_fmadd_ph(y, y, yy);
}
return (float)(_mm512_reduce_add_ph(xy) /
sqrt(_mm512_reduce_add_ph(xx) * _mm512_reduce_add_ph(yy)));
{
float rxy = _mm512_reduce_add_ph(xy);
float rxx = _mm512_reduce_add_ph(xx);
float ryy = _mm512_reduce_add_ph(yy);
return rxy / sqrt(rxx * ryy);
}
}
__attribute__((target("arch=x86-64-v4,avx512fp16"))) extern float
@ -74,6 +78,76 @@ v_f16_sl2_avx512fp16(_Float16 *a, _Float16 *b, size_t n) {
return (float)_mm512_reduce_add_ph(dd);
}
__attribute__((target("arch=x86-64-v4"))) extern float
v_f16_cosine_v4(_Float16 *a, _Float16 *b, size_t n) {
__m512 xy = _mm512_set1_ps(0);
__m512 xx = _mm512_set1_ps(0);
__m512 yy = _mm512_set1_ps(0);
while (n >= 16) {
__m512 x = _mm512_cvtph_ps(_mm256_loadu_epi16(a));
__m512 y = _mm512_cvtph_ps(_mm256_loadu_epi16(b));
a += 16, b += 16, n -= 16;
xy = _mm512_fmadd_ps(x, y, xy);
xx = _mm512_fmadd_ps(x, x, xx);
yy = _mm512_fmadd_ps(y, y, yy);
}
if (n > 0) {
__mmask16 mask = _bzhi_u32(0xFFFF, n);
__m512 x = _mm512_cvtph_ps(_mm256_maskz_loadu_epi16(mask, a));
__m512 y = _mm512_cvtph_ps(_mm256_maskz_loadu_epi16(mask, b));
xy = _mm512_fmadd_ps(x, y, xy);
xx = _mm512_fmadd_ps(x, x, xx);
yy = _mm512_fmadd_ps(y, y, yy);
}
{
float rxy = _mm512_reduce_add_ps(xy);
float rxx = _mm512_reduce_add_ps(xx);
float ryy = _mm512_reduce_add_ps(yy);
return rxy / sqrt(rxx * ryy);
}
}
__attribute__((target("arch=x86-64-v4"))) extern float
v_f16_dot_v4(_Float16 *a, _Float16 *b, size_t n) {
__m512 xy = _mm512_set1_ps(0);
while (n >= 16) {
__m512 x = _mm512_cvtph_ps(_mm256_loadu_epi16(a));
__m512 y = _mm512_cvtph_ps(_mm256_loadu_epi16(b));
a += 16, b += 16, n -= 16;
xy = _mm512_fmadd_ps(x, y, xy);
}
if (n > 0) {
__mmask16 mask = _bzhi_u32(0xFFFF, n);
__m512 x = _mm512_cvtph_ps(_mm256_maskz_loadu_epi16(mask, a));
__m512 y = _mm512_cvtph_ps(_mm256_maskz_loadu_epi16(mask, b));
xy = _mm512_fmadd_ps(x, y, xy);
}
return _mm512_reduce_add_ps(xy);
}
__attribute__((target("arch=x86-64-v4"))) extern float
v_f16_sl2_v4(_Float16 *a, _Float16 *b, size_t n) {
__m512 dd = _mm512_set1_ps(0);
while (n >= 16) {
__m512 x = _mm512_cvtph_ps(_mm256_loadu_epi16(a));
__m512 y = _mm512_cvtph_ps(_mm256_loadu_epi16(b));
a += 16, b += 16, n -= 16;
__m512 d = _mm512_sub_ps(x, y);
dd = _mm512_fmadd_ps(d, d, dd);
}
if (n > 0) {
__mmask16 mask = _bzhi_u32(0xFFFF, n);
__m512 x = _mm512_cvtph_ps(_mm256_maskz_loadu_epi16(mask, a));
__m512 y = _mm512_cvtph_ps(_mm256_maskz_loadu_epi16(mask, b));
__m512 d = _mm512_sub_ps(x, y);
dd = _mm512_fmadd_ps(d, d, dd);
}
return _mm512_reduce_add_ps(dd);
}
__attribute__((target("arch=x86-64-v3"))) extern float
v_f16_cosine_v3(_Float16 *a, _Float16 *b, size_t n) {
float xy = 0;

View File

@ -6,6 +6,9 @@
extern float v_f16_cosine_avx512fp16(_Float16 *, _Float16 *, size_t n);
extern float v_f16_dot_avx512fp16(_Float16 *, _Float16 *, size_t n);
extern float v_f16_sl2_avx512fp16(_Float16 *, _Float16 *, size_t n);
extern float v_f16_cosine_v4(_Float16 *, _Float16 *, size_t n);
extern float v_f16_dot_v4(_Float16 *, _Float16 *, size_t n);
extern float v_f16_sl2_v4(_Float16 *, _Float16 *, size_t n);
extern float v_f16_cosine_v3(_Float16 *, _Float16 *, size_t n);
extern float v_f16_dot_v3(_Float16 *, _Float16 *, size_t n);
extern float v_f16_sl2_v3(_Float16 *, _Float16 *, size_t n);

View File

@ -4,21 +4,10 @@ extern "C" {
pub fn v_f16_cosine_avx512fp16(a: *const u16, b: *const u16, n: usize) -> f32;
pub fn v_f16_dot_avx512fp16(a: *const u16, b: *const u16, n: usize) -> f32;
pub fn v_f16_sl2_avx512fp16(a: *const u16, b: *const u16, n: usize) -> f32;
pub fn v_f16_cosine_v4(a: *const u16, b: *const u16, n: usize) -> f32;
pub fn v_f16_dot_v4(a: *const u16, b: *const u16, n: usize) -> f32;
pub fn v_f16_sl2_v4(a: *const u16, b: *const u16, n: usize) -> f32;
pub fn v_f16_cosine_v3(a: *const u16, b: *const u16, n: usize) -> f32;
pub fn v_f16_dot_v3(a: *const u16, b: *const u16, n: usize) -> f32;
pub fn v_f16_sl2_v3(a: *const u16, b: *const u16, n: usize) -> f32;
}
// `compiler_builtin` defines `__extendhfsf2` with integer calling convention.
// However C compilers links `__extendhfsf2` with floating calling convention.
// The code should be removed once Rust offically supports `f16`.
#[cfg(target_arch = "x86_64")]
#[no_mangle]
#[linkage = "external"]
extern "C" fn __extendhfsf2(f: f64) -> f32 {
unsafe {
let f: half::f16 = std::mem::transmute_copy(&f);
f.to_f32()
}
}

126
crates/c/tests/x86_64.rs Normal file
View File

@ -0,0 +1,126 @@
#![cfg(target_arch = "x86_64")]
#[test]
fn test_v_f16_cosine() {
const EPSILON: f32 = f16::EPSILON.to_f32_const();
use half::f16;
unsafe fn v_f16_cosine(a: *const u16, b: *const u16, n: usize) -> f32 {
let mut xy = 0.0f32;
let mut xx = 0.0f32;
let mut yy = 0.0f32;
for i in 0..n {
let x = a.add(i).cast::<f16>().read().to_f32();
let y = b.add(i).cast::<f16>().read().to_f32();
xy += x * y;
xx += x * x;
yy += y * y;
}
xy / (xx * yy).sqrt()
}
let n = 4000;
let a = (0..n).map(|_| rand::random::<f16>()).collect::<Vec<_>>();
let b = (0..n).map(|_| rand::random::<f16>()).collect::<Vec<_>>();
let r = unsafe { v_f16_cosine(a.as_ptr().cast(), b.as_ptr().cast(), n) };
if detect::x86_64::detect_avx512fp16() {
println!("detected avx512fp16");
let c = unsafe { c::v_f16_cosine_avx512fp16(a.as_ptr().cast(), b.as_ptr().cast(), n) };
assert!((c - r).abs() < EPSILON, "c = {c}, r = {r}.");
} else {
println!("detected no avx512fp16, skipped");
}
if detect::x86_64::detect_v4() {
println!("detected v4");
let c = unsafe { c::v_f16_cosine_v4(a.as_ptr().cast(), b.as_ptr().cast(), n) };
assert!((c - r).abs() < EPSILON, "c = {c}, r = {r}.");
} else {
println!("detected no v4, skipped");
}
if detect::x86_64::detect_v3() {
println!("detected v3");
let c = unsafe { c::v_f16_cosine_v3(a.as_ptr().cast(), b.as_ptr().cast(), n) };
assert!((c - r).abs() < EPSILON, "c = {c}, r = {r}.");
} else {
println!("detected no v3, skipped");
}
}
#[test]
fn test_v_f16_dot() {
const EPSILON: f32 = 1.0f32;
use half::f16;
unsafe fn v_f16_dot(a: *const u16, b: *const u16, n: usize) -> f32 {
let mut xy = 0.0f32;
for i in 0..n {
let x = a.add(i).cast::<f16>().read().to_f32();
let y = b.add(i).cast::<f16>().read().to_f32();
xy += x * y;
}
xy
}
let n = 4000;
let a = (0..n).map(|_| rand::random::<f16>()).collect::<Vec<_>>();
let b = (0..n).map(|_| rand::random::<f16>()).collect::<Vec<_>>();
let r = unsafe { v_f16_dot(a.as_ptr().cast(), b.as_ptr().cast(), n) };
if detect::x86_64::detect_avx512fp16() {
println!("detected avx512fp16");
let c = unsafe { c::v_f16_dot_avx512fp16(a.as_ptr().cast(), b.as_ptr().cast(), n) };
assert!((c - r).abs() < EPSILON, "c = {c}, r = {r}.");
} else {
println!("detected no avx512fp16, skipped");
}
if detect::x86_64::detect_v4() {
println!("detected v4");
let c = unsafe { c::v_f16_dot_v4(a.as_ptr().cast(), b.as_ptr().cast(), n) };
assert!((c - r).abs() < EPSILON, "c = {c}, r = {r}.");
} else {
println!("detected no v4, skipped");
}
if detect::x86_64::detect_v3() {
println!("detected v3");
let c = unsafe { c::v_f16_dot_v3(a.as_ptr().cast(), b.as_ptr().cast(), n) };
assert!((c - r).abs() < EPSILON, "c = {c}, r = {r}.");
} else {
println!("detected no v3, skipped");
}
}
#[test]
fn test_v_f16_sl2() {
const EPSILON: f32 = 1.0f32;
use half::f16;
unsafe fn v_f16_sl2(a: *const u16, b: *const u16, n: usize) -> f32 {
let mut dd = 0.0f32;
for i in 0..n {
let x = a.add(i).cast::<f16>().read().to_f32();
let y = b.add(i).cast::<f16>().read().to_f32();
let d = x - y;
dd += d * d;
}
dd
}
let n = 4000;
let a = (0..n).map(|_| rand::random::<f16>()).collect::<Vec<_>>();
let b = (0..n).map(|_| rand::random::<f16>()).collect::<Vec<_>>();
let r = unsafe { v_f16_sl2(a.as_ptr().cast(), b.as_ptr().cast(), n) };
if detect::x86_64::detect_avx512fp16() {
println!("detected avx512fp16");
let c = unsafe { c::v_f16_sl2_avx512fp16(a.as_ptr().cast(), b.as_ptr().cast(), n) };
assert!((c - r).abs() < EPSILON, "c = {c}, r = {r}.");
} else {
println!("detected no avx512fp16, skipped");
}
if detect::x86_64::detect_v4() {
println!("detected v4");
let c = unsafe { c::v_f16_sl2_v4(a.as_ptr().cast(), b.as_ptr().cast(), n) };
assert!((c - r).abs() < EPSILON, "c = {c}, r = {r}.");
} else {
println!("detected no v4, skipped");
}
if detect::x86_64::detect_v3() {
println!("detected v3");
let c = unsafe { c::v_f16_sl2_v3(a.as_ptr().cast(), b.as_ptr().cast(), n) };
assert!((c - r).abs() < EPSILON, "c = {c}, r = {r}.");
} else {
println!("detected no v3, skipped");
}
}