1
0
mirror of https://github.com/tensorchord/pgvecto.rs.git synced 2025-07-29 08:21:12 +03:00

fix: prevent overflow at kmeans (#287)

* fix: prevent overflow at kmeans

Signed-off-by: cutecutecat <junyuchen@tensorchord.ai>

* fix

Signed-off-by: cutecutecat <junyuchen@tensorchord.ai>

---------

Signed-off-by: cutecutecat <junyuchen@tensorchord.ai>
This commit is contained in:
cutecutecat
2024-01-22 11:24:01 +08:00
committed by GitHub
parent 01a5774eaf
commit c95f99f846
2 changed files with 38 additions and 0 deletions

View File

@ -83,6 +83,36 @@ impl<S: G> ElkanKMeans<S> {
}
}
/// Quick approach if we have little data
fn quick_centroids(&mut self) -> bool {
let c = self.c;
let samples = &self.samples;
let rand = &mut self.rand;
let centroids = &mut self.centroids;
let n = samples.len();
let dims = samples.dims();
let sorted_index = samples.argsort();
for i in 0..n {
let index = sorted_index.get(i).unwrap();
let last = sorted_index.get(std::cmp::max(i, 1) - 1).unwrap();
if *index == 0 || samples[*last] != samples[*index] {
centroids[i].copy_from_slice(&samples[*index]);
} else {
let rand_centroids: Vec<_> = (0..dims)
.map(|_| S::Scalar::from_f32(rand.gen_range(0.0..1.0f32)))
.collect();
centroids[i].copy_from_slice(rand_centroids.as_slice());
}
}
for i in n..c {
let rand_centroids: Vec<_> = (0..dims)
.map(|_| S::Scalar::from_f32(rand.gen_range(0.0..1.0f32)))
.collect();
centroids[i].copy_from_slice(rand_centroids.as_slice());
}
true
}
pub fn iterate(&mut self) -> bool {
let c = self.c;
let dims = self.dims;
@ -94,6 +124,9 @@ impl<S: G> ElkanKMeans<S> {
let upperbound = &mut self.upperbound;
let mut change = 0;
let n = samples.len();
if n <= c {
return self.quick_centroids();
}
// Step 1
let mut dist0 = Square::new(c, c);

View File

@ -20,6 +20,11 @@ impl<S: G> Vec2<S> {
pub fn len(&self) -> usize {
self.v.len() / self.dims as usize
}
pub fn argsort(&self) -> Vec<usize> {
let mut index: Vec<usize> = (0..self.len()).collect();
index.sort_by_key(|i| &self[*i]);
index
}
pub fn copy_within(&mut self, i: usize, j: usize) {
assert!(i < self.len() && j < self.len());
unsafe {