mirror of
https://github.com/tensorchord/pgvecto.rs.git
synced 2025-04-18 21:44:00 +03:00
fix: add v4, v3 for vecf32 dot, sl2 (#559)
Signed-off-by: usamoi <usamoi@outlook.com>
This commit is contained in:
parent
94e4e2f970
commit
813c04d797
1
.gitignore
vendored
1
.gitignore
vendored
@ -10,3 +10,4 @@ __pycache__
|
||||
rustc-ice-*.txt
|
||||
build
|
||||
.intentionally-empty-file.o
|
||||
.venv
|
||||
|
@ -213,7 +213,128 @@ impl<'a> PartialOrd for Vecf32Borrowed<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
#[detect::multiversion(v4, v3, v2, neon, fallback)]
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
#[detect::target_cpu(enable = "v4")]
|
||||
unsafe fn dot_v4(lhs: &[F32], rhs: &[F32]) -> F32 {
|
||||
assert!(lhs.len() == rhs.len());
|
||||
use std::arch::x86_64::*;
|
||||
unsafe {
|
||||
let mut n = lhs.len() as u32;
|
||||
let mut a = lhs.as_ptr();
|
||||
let mut b = rhs.as_ptr();
|
||||
let mut xy = _mm512_set1_ps(0.0);
|
||||
while n >= 16 {
|
||||
let x = _mm512_loadu_ps(a.cast());
|
||||
let y = _mm512_loadu_ps(b.cast());
|
||||
a = a.add(16);
|
||||
b = b.add(16);
|
||||
n -= 16;
|
||||
xy = _mm512_fmadd_ps(x, y, xy);
|
||||
}
|
||||
if n > 0 {
|
||||
let mask = _bzhi_u32(0xFFFF, n) as u16;
|
||||
let x = _mm512_maskz_loadu_ps(mask, a.cast());
|
||||
let y = _mm512_maskz_loadu_ps(mask, b.cast());
|
||||
xy = _mm512_fmadd_ps(x, y, xy);
|
||||
}
|
||||
F32(_mm512_reduce_add_ps(xy))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(all(target_arch = "x86_64", test))]
|
||||
#[test]
|
||||
fn dot_v4_test() {
|
||||
const EPSILON: F32 = F32(2.0);
|
||||
detect::init();
|
||||
if !detect::v4::detect() {
|
||||
println!("test {} ... skipped (v4)", module_path!());
|
||||
return;
|
||||
}
|
||||
for _ in 0..300 {
|
||||
let n = 4010;
|
||||
let lhs = (0..n).map(|_| F32(rand::random::<_>())).collect::<Vec<_>>();
|
||||
let rhs = (0..n).map(|_| F32(rand::random::<_>())).collect::<Vec<_>>();
|
||||
for z in 3990..4010 {
|
||||
let lhs = &lhs[..z];
|
||||
let rhs = &rhs[..z];
|
||||
let specialized = unsafe { dot_v4(&lhs, &rhs) };
|
||||
let fallback = unsafe { dot_fallback(&lhs, &rhs) };
|
||||
assert!(
|
||||
(specialized - fallback).abs() < EPSILON,
|
||||
"specialized = {specialized}, fallback = {fallback}."
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
#[detect::target_cpu(enable = "v3")]
|
||||
unsafe fn dot_v3(lhs: &[F32], rhs: &[F32]) -> F32 {
|
||||
assert!(lhs.len() == rhs.len());
|
||||
use std::arch::x86_64::*;
|
||||
unsafe {
|
||||
let mut n = lhs.len() as u32;
|
||||
let mut a = lhs.as_ptr();
|
||||
let mut b = rhs.as_ptr();
|
||||
let mut xy = _mm256_set1_ps(0.0);
|
||||
while n >= 8 {
|
||||
let x = _mm256_loadu_ps(a.cast());
|
||||
let y = _mm256_loadu_ps(b.cast());
|
||||
a = a.add(8);
|
||||
b = b.add(8);
|
||||
n -= 8;
|
||||
xy = _mm256_fmadd_ps(x, y, xy);
|
||||
}
|
||||
#[inline]
|
||||
#[detect::target_cpu(enable = "v3")]
|
||||
unsafe fn _mm256_reduce_add_ps(mut x: __m256) -> f32 {
|
||||
unsafe {
|
||||
x = _mm256_add_ps(x, _mm256_permute2f128_ps(x, x, 1));
|
||||
x = _mm256_hadd_ps(x, x);
|
||||
x = _mm256_hadd_ps(x, x);
|
||||
_mm256_cvtss_f32(x)
|
||||
}
|
||||
}
|
||||
let mut xy = F32(_mm256_reduce_add_ps(xy));
|
||||
while n > 0 {
|
||||
let x = a.read();
|
||||
let y = b.read();
|
||||
a = a.add(1);
|
||||
b = b.add(1);
|
||||
n -= 1;
|
||||
xy += x * y;
|
||||
}
|
||||
xy
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(all(target_arch = "x86_64", test))]
|
||||
#[test]
|
||||
fn dot_v3_test() {
|
||||
const EPSILON: F32 = F32(2.0);
|
||||
detect::init();
|
||||
if !detect::v3::detect() {
|
||||
println!("test {} ... skipped (v3)", module_path!());
|
||||
return;
|
||||
}
|
||||
for _ in 0..300 {
|
||||
let n = 4010;
|
||||
let lhs = (0..n).map(|_| F32(rand::random::<_>())).collect::<Vec<_>>();
|
||||
let rhs = (0..n).map(|_| F32(rand::random::<_>())).collect::<Vec<_>>();
|
||||
for z in 3990..4010 {
|
||||
let lhs = &lhs[..z];
|
||||
let rhs = &rhs[..z];
|
||||
let specialized = unsafe { dot_v3(&lhs, &rhs) };
|
||||
let fallback = unsafe { dot_fallback(&lhs, &rhs) };
|
||||
assert!(
|
||||
(specialized - fallback).abs() < EPSILON,
|
||||
"specialized = {specialized}, fallback = {fallback}."
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[detect::multiversion(v4 = import, v3 = import, v2, neon, fallback = export)]
|
||||
pub fn dot(lhs: &[F32], rhs: &[F32]) -> F32 {
|
||||
assert!(lhs.len() == rhs.len());
|
||||
let n = lhs.len();
|
||||
@ -224,7 +345,131 @@ pub fn dot(lhs: &[F32], rhs: &[F32]) -> F32 {
|
||||
xy
|
||||
}
|
||||
|
||||
#[detect::multiversion(v4, v3, v2, neon, fallback)]
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
#[detect::target_cpu(enable = "v4")]
|
||||
unsafe fn sl2_v4(lhs: &[F32], rhs: &[F32]) -> F32 {
|
||||
assert!(lhs.len() == rhs.len());
|
||||
use std::arch::x86_64::*;
|
||||
unsafe {
|
||||
let mut n = lhs.len() as u32;
|
||||
let mut a = lhs.as_ptr();
|
||||
let mut b = rhs.as_ptr();
|
||||
let mut dd = _mm512_set1_ps(0.0);
|
||||
while n >= 16 {
|
||||
let x = _mm512_loadu_ps(a.cast());
|
||||
let y = _mm512_loadu_ps(b.cast());
|
||||
a = a.add(16);
|
||||
b = b.add(16);
|
||||
n -= 16;
|
||||
let d = _mm512_sub_ps(x, y);
|
||||
dd = _mm512_fmadd_ps(d, d, dd);
|
||||
}
|
||||
if n > 0 {
|
||||
let mask = _bzhi_u32(0xFFFF, n) as u16;
|
||||
let x = _mm512_maskz_loadu_ps(mask, a.cast());
|
||||
let y = _mm512_maskz_loadu_ps(mask, b.cast());
|
||||
let d = _mm512_sub_ps(x, y);
|
||||
dd = _mm512_fmadd_ps(d, d, dd);
|
||||
}
|
||||
F32(_mm512_reduce_add_ps(dd))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(all(target_arch = "x86_64", test))]
|
||||
#[test]
|
||||
fn sl2_v4_test() {
|
||||
const EPSILON: F32 = F32(2.0);
|
||||
detect::init();
|
||||
if !detect::v4::detect() {
|
||||
println!("test {} ... skipped (v4)", module_path!());
|
||||
return;
|
||||
}
|
||||
for _ in 0..300 {
|
||||
let n = 4010;
|
||||
let lhs = (0..n).map(|_| F32(rand::random::<_>())).collect::<Vec<_>>();
|
||||
let rhs = (0..n).map(|_| F32(rand::random::<_>())).collect::<Vec<_>>();
|
||||
for z in 3990..4010 {
|
||||
let lhs = &lhs[..z];
|
||||
let rhs = &rhs[..z];
|
||||
let specialized = unsafe { sl2_v4(&lhs, &rhs) };
|
||||
let fallback = unsafe { sl2_fallback(&lhs, &rhs) };
|
||||
assert!(
|
||||
(specialized - fallback).abs() < EPSILON,
|
||||
"specialized = {specialized}, fallback = {fallback}."
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
#[detect::target_cpu(enable = "v3")]
|
||||
unsafe fn sl2_v3(lhs: &[F32], rhs: &[F32]) -> F32 {
|
||||
assert!(lhs.len() == rhs.len());
|
||||
use std::arch::x86_64::*;
|
||||
unsafe {
|
||||
let mut n = lhs.len() as u32;
|
||||
let mut a = lhs.as_ptr();
|
||||
let mut b = rhs.as_ptr();
|
||||
let mut dd = _mm256_set1_ps(0.0);
|
||||
while n >= 8 {
|
||||
let x = _mm256_loadu_ps(a.cast());
|
||||
let y = _mm256_loadu_ps(b.cast());
|
||||
a = a.add(8);
|
||||
b = b.add(8);
|
||||
n -= 8;
|
||||
let d = _mm256_sub_ps(x, y);
|
||||
dd = _mm256_fmadd_ps(d, d, dd);
|
||||
}
|
||||
#[inline]
|
||||
#[detect::target_cpu(enable = "v3")]
|
||||
unsafe fn _mm256_reduce_add_ps(mut x: __m256) -> f32 {
|
||||
unsafe {
|
||||
x = _mm256_add_ps(x, _mm256_permute2f128_ps(x, x, 1));
|
||||
x = _mm256_hadd_ps(x, x);
|
||||
x = _mm256_hadd_ps(x, x);
|
||||
_mm256_cvtss_f32(x)
|
||||
}
|
||||
}
|
||||
let mut rdd = F32(_mm256_reduce_add_ps(dd));
|
||||
while n > 0 {
|
||||
let x = a.read();
|
||||
let y = b.read();
|
||||
a = a.add(1);
|
||||
b = b.add(1);
|
||||
n -= 1;
|
||||
rdd += (x - y) * (x - y);
|
||||
}
|
||||
rdd
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(all(target_arch = "x86_64", test))]
|
||||
#[test]
|
||||
fn sl2_v3_test() {
|
||||
const EPSILON: F32 = F32(2.0);
|
||||
detect::init();
|
||||
if !detect::v3::detect() {
|
||||
println!("test {} ... skipped (v3)", module_path!());
|
||||
return;
|
||||
}
|
||||
for _ in 0..300 {
|
||||
let n = 4010;
|
||||
let lhs = (0..n).map(|_| F32(rand::random::<_>())).collect::<Vec<_>>();
|
||||
let rhs = (0..n).map(|_| F32(rand::random::<_>())).collect::<Vec<_>>();
|
||||
for z in 3990..4010 {
|
||||
let lhs = &lhs[..z];
|
||||
let rhs = &rhs[..z];
|
||||
let specialized = unsafe { sl2_v3(&lhs, &rhs) };
|
||||
let fallback = unsafe { sl2_fallback(&lhs, &rhs) };
|
||||
assert!(
|
||||
(specialized - fallback).abs() < EPSILON,
|
||||
"specialized = {specialized}, fallback = {fallback}."
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[detect::multiversion(v4 = import, v3 = import, v2, neon, fallback = export)]
|
||||
pub fn sl2(lhs: &[F32], rhs: &[F32]) -> F32 {
|
||||
assert!(lhs.len() == rhs.len());
|
||||
let n = lhs.len();
|
||||
|
@ -41,12 +41,12 @@ impl ThreadPoolBuilder {
|
||||
builder: self.builder.num_threads(num_threads),
|
||||
}
|
||||
}
|
||||
pub fn build_scoped(
|
||||
pub fn build_scoped<R>(
|
||||
self,
|
||||
f: impl FnOnce(&ThreadPool),
|
||||
) -> Result<(), rayon::ThreadPoolBuildError> {
|
||||
f: impl FnOnce(&ThreadPool) -> R,
|
||||
) -> Result<Option<R>, rayon::ThreadPoolBuildError> {
|
||||
let stop = Arc::new(AtomicBool::new(false));
|
||||
match std::panic::catch_unwind(AssertUnwindSafe(|| {
|
||||
let x = match std::panic::catch_unwind(AssertUnwindSafe(|| {
|
||||
self.builder
|
||||
.start_handler({
|
||||
let stop = stop.clone();
|
||||
@ -71,15 +71,15 @@ impl ThreadPoolBuilder {
|
||||
},
|
||||
)
|
||||
})) {
|
||||
Ok(Ok(())) => (),
|
||||
Ok(Ok(r)) => Some(r),
|
||||
Ok(Err(e)) => return Err(e),
|
||||
Err(e) if e.downcast_ref::<CheckPanic>().is_some() => (),
|
||||
Err(e) if e.downcast_ref::<CheckPanic>().is_some() => None,
|
||||
Err(e) => std::panic::resume_unwind(e),
|
||||
}
|
||||
};
|
||||
if Arc::strong_count(&stop) > 1 {
|
||||
panic!("Thread leak detected.");
|
||||
}
|
||||
Ok(())
|
||||
Ok(x)
|
||||
}
|
||||
}
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user