You've already forked pgvecto.rs
mirror of
https://github.com/tensorchord/pgvecto.rs.git
synced 2025-07-29 08:21:12 +03:00
fix: prevent overflow at kmeans (#287)
* fix: prevent overflow at kmeans Signed-off-by: cutecutecat <junyuchen@tensorchord.ai> * fix Signed-off-by: cutecutecat <junyuchen@tensorchord.ai> --------- Signed-off-by: cutecutecat <junyuchen@tensorchord.ai>
This commit is contained in:
@ -83,6 +83,36 @@ impl<S: G> ElkanKMeans<S> {
|
||||
}
|
||||
}
|
||||
|
||||
/// Quick approach if we have little data
|
||||
fn quick_centroids(&mut self) -> bool {
|
||||
let c = self.c;
|
||||
let samples = &self.samples;
|
||||
let rand = &mut self.rand;
|
||||
let centroids = &mut self.centroids;
|
||||
let n = samples.len();
|
||||
let dims = samples.dims();
|
||||
let sorted_index = samples.argsort();
|
||||
for i in 0..n {
|
||||
let index = sorted_index.get(i).unwrap();
|
||||
let last = sorted_index.get(std::cmp::max(i, 1) - 1).unwrap();
|
||||
if *index == 0 || samples[*last] != samples[*index] {
|
||||
centroids[i].copy_from_slice(&samples[*index]);
|
||||
} else {
|
||||
let rand_centroids: Vec<_> = (0..dims)
|
||||
.map(|_| S::Scalar::from_f32(rand.gen_range(0.0..1.0f32)))
|
||||
.collect();
|
||||
centroids[i].copy_from_slice(rand_centroids.as_slice());
|
||||
}
|
||||
}
|
||||
for i in n..c {
|
||||
let rand_centroids: Vec<_> = (0..dims)
|
||||
.map(|_| S::Scalar::from_f32(rand.gen_range(0.0..1.0f32)))
|
||||
.collect();
|
||||
centroids[i].copy_from_slice(rand_centroids.as_slice());
|
||||
}
|
||||
true
|
||||
}
|
||||
|
||||
pub fn iterate(&mut self) -> bool {
|
||||
let c = self.c;
|
||||
let dims = self.dims;
|
||||
@ -94,6 +124,9 @@ impl<S: G> ElkanKMeans<S> {
|
||||
let upperbound = &mut self.upperbound;
|
||||
let mut change = 0;
|
||||
let n = samples.len();
|
||||
if n <= c {
|
||||
return self.quick_centroids();
|
||||
}
|
||||
|
||||
// Step 1
|
||||
let mut dist0 = Square::new(c, c);
|
||||
|
@ -20,6 +20,11 @@ impl<S: G> Vec2<S> {
|
||||
pub fn len(&self) -> usize {
|
||||
self.v.len() / self.dims as usize
|
||||
}
|
||||
pub fn argsort(&self) -> Vec<usize> {
|
||||
let mut index: Vec<usize> = (0..self.len()).collect();
|
||||
index.sort_by_key(|i| &self[*i]);
|
||||
index
|
||||
}
|
||||
pub fn copy_within(&mut self, i: usize, j: usize) {
|
||||
assert!(i < self.len() && j < self.len());
|
||||
unsafe {
|
||||
|
Reference in New Issue
Block a user