1
0
mirror of https://github.com/tensorchord/pgvecto.rs.git synced 2025-04-18 21:44:00 +03:00

fix: add v4, v3 for vecf32 dot, sl2 (#559)

Signed-off-by: usamoi <usamoi@outlook.com>
This commit is contained in:
usamoi 2024-08-14 19:12:40 +08:00 committed by GitHub
parent 94e4e2f970
commit 813c04d797
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 256 additions and 10 deletions

1
.gitignore vendored
View File

@ -10,3 +10,4 @@ __pycache__
rustc-ice-*.txt
build
.intentionally-empty-file.o
.venv

View File

@ -213,7 +213,128 @@ impl<'a> PartialOrd for Vecf32Borrowed<'a> {
}
}
#[detect::multiversion(v4, v3, v2, neon, fallback)]
#[cfg(target_arch = "x86_64")]
#[detect::target_cpu(enable = "v4")]
unsafe fn dot_v4(lhs: &[F32], rhs: &[F32]) -> F32 {
assert!(lhs.len() == rhs.len());
use std::arch::x86_64::*;
unsafe {
let mut n = lhs.len() as u32;
let mut a = lhs.as_ptr();
let mut b = rhs.as_ptr();
let mut xy = _mm512_set1_ps(0.0);
while n >= 16 {
let x = _mm512_loadu_ps(a.cast());
let y = _mm512_loadu_ps(b.cast());
a = a.add(16);
b = b.add(16);
n -= 16;
xy = _mm512_fmadd_ps(x, y, xy);
}
if n > 0 {
let mask = _bzhi_u32(0xFFFF, n) as u16;
let x = _mm512_maskz_loadu_ps(mask, a.cast());
let y = _mm512_maskz_loadu_ps(mask, b.cast());
xy = _mm512_fmadd_ps(x, y, xy);
}
F32(_mm512_reduce_add_ps(xy))
}
}
#[cfg(all(target_arch = "x86_64", test))]
#[test]
fn dot_v4_test() {
const EPSILON: F32 = F32(2.0);
detect::init();
if !detect::v4::detect() {
println!("test {} ... skipped (v4)", module_path!());
return;
}
for _ in 0..300 {
let n = 4010;
let lhs = (0..n).map(|_| F32(rand::random::<_>())).collect::<Vec<_>>();
let rhs = (0..n).map(|_| F32(rand::random::<_>())).collect::<Vec<_>>();
for z in 3990..4010 {
let lhs = &lhs[..z];
let rhs = &rhs[..z];
let specialized = unsafe { dot_v4(&lhs, &rhs) };
let fallback = unsafe { dot_fallback(&lhs, &rhs) };
assert!(
(specialized - fallback).abs() < EPSILON,
"specialized = {specialized}, fallback = {fallback}."
);
}
}
}
#[cfg(target_arch = "x86_64")]
#[detect::target_cpu(enable = "v3")]
unsafe fn dot_v3(lhs: &[F32], rhs: &[F32]) -> F32 {
assert!(lhs.len() == rhs.len());
use std::arch::x86_64::*;
unsafe {
let mut n = lhs.len() as u32;
let mut a = lhs.as_ptr();
let mut b = rhs.as_ptr();
let mut xy = _mm256_set1_ps(0.0);
while n >= 8 {
let x = _mm256_loadu_ps(a.cast());
let y = _mm256_loadu_ps(b.cast());
a = a.add(8);
b = b.add(8);
n -= 8;
xy = _mm256_fmadd_ps(x, y, xy);
}
#[inline]
#[detect::target_cpu(enable = "v3")]
unsafe fn _mm256_reduce_add_ps(mut x: __m256) -> f32 {
unsafe {
x = _mm256_add_ps(x, _mm256_permute2f128_ps(x, x, 1));
x = _mm256_hadd_ps(x, x);
x = _mm256_hadd_ps(x, x);
_mm256_cvtss_f32(x)
}
}
let mut xy = F32(_mm256_reduce_add_ps(xy));
while n > 0 {
let x = a.read();
let y = b.read();
a = a.add(1);
b = b.add(1);
n -= 1;
xy += x * y;
}
xy
}
}
#[cfg(all(target_arch = "x86_64", test))]
#[test]
fn dot_v3_test() {
const EPSILON: F32 = F32(2.0);
detect::init();
if !detect::v3::detect() {
println!("test {} ... skipped (v3)", module_path!());
return;
}
for _ in 0..300 {
let n = 4010;
let lhs = (0..n).map(|_| F32(rand::random::<_>())).collect::<Vec<_>>();
let rhs = (0..n).map(|_| F32(rand::random::<_>())).collect::<Vec<_>>();
for z in 3990..4010 {
let lhs = &lhs[..z];
let rhs = &rhs[..z];
let specialized = unsafe { dot_v3(&lhs, &rhs) };
let fallback = unsafe { dot_fallback(&lhs, &rhs) };
assert!(
(specialized - fallback).abs() < EPSILON,
"specialized = {specialized}, fallback = {fallback}."
);
}
}
}
#[detect::multiversion(v4 = import, v3 = import, v2, neon, fallback = export)]
pub fn dot(lhs: &[F32], rhs: &[F32]) -> F32 {
assert!(lhs.len() == rhs.len());
let n = lhs.len();
@ -224,7 +345,131 @@ pub fn dot(lhs: &[F32], rhs: &[F32]) -> F32 {
xy
}
#[detect::multiversion(v4, v3, v2, neon, fallback)]
#[cfg(target_arch = "x86_64")]
#[detect::target_cpu(enable = "v4")]
unsafe fn sl2_v4(lhs: &[F32], rhs: &[F32]) -> F32 {
assert!(lhs.len() == rhs.len());
use std::arch::x86_64::*;
unsafe {
let mut n = lhs.len() as u32;
let mut a = lhs.as_ptr();
let mut b = rhs.as_ptr();
let mut dd = _mm512_set1_ps(0.0);
while n >= 16 {
let x = _mm512_loadu_ps(a.cast());
let y = _mm512_loadu_ps(b.cast());
a = a.add(16);
b = b.add(16);
n -= 16;
let d = _mm512_sub_ps(x, y);
dd = _mm512_fmadd_ps(d, d, dd);
}
if n > 0 {
let mask = _bzhi_u32(0xFFFF, n) as u16;
let x = _mm512_maskz_loadu_ps(mask, a.cast());
let y = _mm512_maskz_loadu_ps(mask, b.cast());
let d = _mm512_sub_ps(x, y);
dd = _mm512_fmadd_ps(d, d, dd);
}
F32(_mm512_reduce_add_ps(dd))
}
}
#[cfg(all(target_arch = "x86_64", test))]
#[test]
fn sl2_v4_test() {
const EPSILON: F32 = F32(2.0);
detect::init();
if !detect::v4::detect() {
println!("test {} ... skipped (v4)", module_path!());
return;
}
for _ in 0..300 {
let n = 4010;
let lhs = (0..n).map(|_| F32(rand::random::<_>())).collect::<Vec<_>>();
let rhs = (0..n).map(|_| F32(rand::random::<_>())).collect::<Vec<_>>();
for z in 3990..4010 {
let lhs = &lhs[..z];
let rhs = &rhs[..z];
let specialized = unsafe { sl2_v4(&lhs, &rhs) };
let fallback = unsafe { sl2_fallback(&lhs, &rhs) };
assert!(
(specialized - fallback).abs() < EPSILON,
"specialized = {specialized}, fallback = {fallback}."
);
}
}
}
#[cfg(target_arch = "x86_64")]
#[detect::target_cpu(enable = "v3")]
unsafe fn sl2_v3(lhs: &[F32], rhs: &[F32]) -> F32 {
assert!(lhs.len() == rhs.len());
use std::arch::x86_64::*;
unsafe {
let mut n = lhs.len() as u32;
let mut a = lhs.as_ptr();
let mut b = rhs.as_ptr();
let mut dd = _mm256_set1_ps(0.0);
while n >= 8 {
let x = _mm256_loadu_ps(a.cast());
let y = _mm256_loadu_ps(b.cast());
a = a.add(8);
b = b.add(8);
n -= 8;
let d = _mm256_sub_ps(x, y);
dd = _mm256_fmadd_ps(d, d, dd);
}
#[inline]
#[detect::target_cpu(enable = "v3")]
unsafe fn _mm256_reduce_add_ps(mut x: __m256) -> f32 {
unsafe {
x = _mm256_add_ps(x, _mm256_permute2f128_ps(x, x, 1));
x = _mm256_hadd_ps(x, x);
x = _mm256_hadd_ps(x, x);
_mm256_cvtss_f32(x)
}
}
let mut rdd = F32(_mm256_reduce_add_ps(dd));
while n > 0 {
let x = a.read();
let y = b.read();
a = a.add(1);
b = b.add(1);
n -= 1;
rdd += (x - y) * (x - y);
}
rdd
}
}
#[cfg(all(target_arch = "x86_64", test))]
#[test]
fn sl2_v3_test() {
const EPSILON: F32 = F32(2.0);
detect::init();
if !detect::v3::detect() {
println!("test {} ... skipped (v3)", module_path!());
return;
}
for _ in 0..300 {
let n = 4010;
let lhs = (0..n).map(|_| F32(rand::random::<_>())).collect::<Vec<_>>();
let rhs = (0..n).map(|_| F32(rand::random::<_>())).collect::<Vec<_>>();
for z in 3990..4010 {
let lhs = &lhs[..z];
let rhs = &rhs[..z];
let specialized = unsafe { sl2_v3(&lhs, &rhs) };
let fallback = unsafe { sl2_fallback(&lhs, &rhs) };
assert!(
(specialized - fallback).abs() < EPSILON,
"specialized = {specialized}, fallback = {fallback}."
);
}
}
}
#[detect::multiversion(v4 = import, v3 = import, v2, neon, fallback = export)]
pub fn sl2(lhs: &[F32], rhs: &[F32]) -> F32 {
assert!(lhs.len() == rhs.len());
let n = lhs.len();

View File

@ -41,12 +41,12 @@ impl ThreadPoolBuilder {
builder: self.builder.num_threads(num_threads),
}
}
pub fn build_scoped(
pub fn build_scoped<R>(
self,
f: impl FnOnce(&ThreadPool),
) -> Result<(), rayon::ThreadPoolBuildError> {
f: impl FnOnce(&ThreadPool) -> R,
) -> Result<Option<R>, rayon::ThreadPoolBuildError> {
let stop = Arc::new(AtomicBool::new(false));
match std::panic::catch_unwind(AssertUnwindSafe(|| {
let x = match std::panic::catch_unwind(AssertUnwindSafe(|| {
self.builder
.start_handler({
let stop = stop.clone();
@ -71,15 +71,15 @@ impl ThreadPoolBuilder {
},
)
})) {
Ok(Ok(())) => (),
Ok(Ok(r)) => Some(r),
Ok(Err(e)) => return Err(e),
Err(e) if e.downcast_ref::<CheckPanic>().is_some() => (),
Err(e) if e.downcast_ref::<CheckPanic>().is_some() => None,
Err(e) => std::panic::resume_unwind(e),
}
};
if Arc::strong_count(&stop) > 1 {
panic!("Thread leak detected.");
}
Ok(())
Ok(x)
}
}