You've already forked pgvecto.rs
mirror of
https://github.com/tensorchord/pgvecto.rs.git
synced 2025-07-29 08:21:12 +03:00
refactor: replace heap of heaps with loser tree in merging results (#315)
Signed-off-by: usamoi <usamoi@outlook.com>
This commit is contained in:
@ -27,7 +27,7 @@ memoffset = "0.9.0"
|
||||
arrayvec = { version = "0.7.3", features = ["serde"] }
|
||||
memmap2 = "0.9.0"
|
||||
rayon = "1.6.1"
|
||||
uuid = { version = "1.6.1", features = ["serde"] }
|
||||
uuid = { version = "1.6.1", features = ["v4", "serde"] }
|
||||
arc-swap = "1.6.0"
|
||||
multiversion = "0.7.3"
|
||||
|
||||
|
@ -15,13 +15,12 @@ use crate::prelude::*;
|
||||
use crate::utils::clean::clean;
|
||||
use crate::utils::dir_ops::sync_dir;
|
||||
use crate::utils::file_atomic::FileAtomic;
|
||||
use crate::utils::iter::RefPeekable;
|
||||
use crate::utils::tournament_tree::LoserTree;
|
||||
use arc_swap::ArcSwap;
|
||||
use crossbeam::atomic::AtomicCell;
|
||||
use parking_lot::Mutex;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::cmp::Reverse;
|
||||
use std::collections::BinaryHeap;
|
||||
use std::collections::HashMap;
|
||||
use std::collections::HashSet;
|
||||
use std::path::PathBuf;
|
||||
@ -317,25 +316,13 @@ impl<S: G> IndexView<S> {
|
||||
return Err(ServiceError::Unmatched);
|
||||
}
|
||||
|
||||
struct Comparer(BinaryHeap<Reverse<Element>>);
|
||||
struct Comparer(std::collections::BinaryHeap<Reverse<Element>>);
|
||||
|
||||
impl PartialEq for Comparer {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.cmp(other).is_eq()
|
||||
}
|
||||
}
|
||||
impl Iterator for Comparer {
|
||||
type Item = Element;
|
||||
|
||||
impl Eq for Comparer {}
|
||||
|
||||
impl PartialOrd for Comparer {
|
||||
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
|
||||
Some(self.cmp(other))
|
||||
}
|
||||
}
|
||||
|
||||
impl Ord for Comparer {
|
||||
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
|
||||
self.0.peek().cmp(&other.0.peek()).reverse()
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
self.0.pop().map(|Reverse(x)| x)
|
||||
}
|
||||
}
|
||||
|
||||
@ -370,7 +357,7 @@ impl<S: G> IndexView<S> {
|
||||
};
|
||||
|
||||
let n = self.sealed.len() + self.growing.len() + 1;
|
||||
let mut heaps = BinaryHeap::with_capacity(1 + n);
|
||||
let mut heaps = Vec::with_capacity(1 + n);
|
||||
for (_, sealed) in self.sealed.iter() {
|
||||
let p = sealed.basic(vector, opts, filter.clone());
|
||||
heaps.push(Comparer(p));
|
||||
@ -383,16 +370,13 @@ impl<S: G> IndexView<S> {
|
||||
let p = write.basic(vector, opts, filter.clone());
|
||||
heaps.push(Comparer(p));
|
||||
}
|
||||
Ok(std::iter::from_fn(move || {
|
||||
while let Some(mut iter) = heaps.pop() {
|
||||
if let Some(Reverse(x)) = iter.0.pop() {
|
||||
heaps.push(iter);
|
||||
if opts.prefilter_enable || self.delete.check(x.payload).is_some() {
|
||||
return Some(Pointer::from_u48(x.payload >> 16));
|
||||
}
|
||||
}
|
||||
let loser = LoserTree::new(heaps);
|
||||
Ok(loser.filter_map(|x| {
|
||||
if opts.prefilter_enable || self.delete.check(x.payload).is_some() {
|
||||
Some(Pointer::from_u48(x.payload >> 16))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
None
|
||||
}))
|
||||
}
|
||||
pub fn vbase<'a, F: FnMut(Pointer) -> bool + Clone + 'a>(
|
||||
@ -405,28 +389,6 @@ impl<S: G> IndexView<S> {
|
||||
return Err(ServiceError::Unmatched);
|
||||
}
|
||||
|
||||
struct Comparer<'a>(RefPeekable<Box<dyn Iterator<Item = Element> + 'a>>);
|
||||
|
||||
impl PartialEq for Comparer<'_> {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.cmp(other).is_eq()
|
||||
}
|
||||
}
|
||||
|
||||
impl Eq for Comparer<'_> {}
|
||||
|
||||
impl PartialOrd for Comparer<'_> {
|
||||
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
|
||||
Some(self.cmp(other))
|
||||
}
|
||||
}
|
||||
|
||||
impl Ord for Comparer<'_> {
|
||||
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
|
||||
self.0.peek().cmp(&other.0.peek()).reverse()
|
||||
}
|
||||
}
|
||||
|
||||
struct Filtering<'a, F: 'a> {
|
||||
enable: bool,
|
||||
delete: &'a Delete,
|
||||
@ -459,34 +421,31 @@ impl<S: G> IndexView<S> {
|
||||
|
||||
let n = self.sealed.len() + self.growing.len() + 1;
|
||||
let mut alpha = Vec::new();
|
||||
let mut beta = BinaryHeap::with_capacity(1 + n);
|
||||
let mut beta = Vec::with_capacity(1 + n);
|
||||
for (_, sealed) in self.sealed.iter() {
|
||||
let (stage1, stage2) = sealed.vbase(vector, opts, filter.clone());
|
||||
alpha.extend(stage1);
|
||||
beta.push(Comparer(RefPeekable::new(stage2)));
|
||||
beta.push(stage2);
|
||||
}
|
||||
for (_, growing) in self.growing.iter() {
|
||||
let (stage1, stage2) = growing.vbase(vector, opts, filter.clone());
|
||||
alpha.extend(stage1);
|
||||
beta.push(Comparer(RefPeekable::new(stage2)));
|
||||
beta.push(stage2);
|
||||
}
|
||||
if let Some((_, write)) = &self.write {
|
||||
let (stage1, stage2) = write.vbase(vector, opts, filter.clone());
|
||||
alpha.extend(stage1);
|
||||
beta.push(Comparer(RefPeekable::new(stage2)));
|
||||
beta.push(stage2);
|
||||
}
|
||||
alpha.sort_unstable();
|
||||
beta.push(Comparer(RefPeekable::new(Box::new(alpha.into_iter()))));
|
||||
Ok(std::iter::from_fn(move || {
|
||||
while let Some(mut iter) = beta.pop() {
|
||||
if let Some(x) = iter.0.next() {
|
||||
beta.push(iter);
|
||||
if opts.prefilter_enable || self.delete.check(x.payload).is_some() {
|
||||
return Some(Pointer::from_u48(x.payload >> 16));
|
||||
}
|
||||
}
|
||||
beta.push(Box::new(alpha.into_iter()));
|
||||
let loser = LoserTree::new(beta);
|
||||
Ok(loser.filter_map(|x| {
|
||||
if opts.prefilter_enable || self.delete.check(x.payload).is_some() {
|
||||
Some(Pointer::from_u48(x.payload >> 16))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
None
|
||||
}))
|
||||
}
|
||||
pub fn insert(
|
||||
|
@ -1,27 +0,0 @@
|
||||
pub struct RefPeekable<I: Iterator> {
|
||||
peeked: Option<I::Item>,
|
||||
iter: I,
|
||||
}
|
||||
|
||||
impl<I: Iterator> RefPeekable<I> {
|
||||
pub fn new(mut iter: I) -> RefPeekable<I> {
|
||||
RefPeekable {
|
||||
peeked: iter.next(),
|
||||
iter,
|
||||
}
|
||||
}
|
||||
pub fn peek(&self) -> Option<&I::Item> {
|
||||
self.peeked.as_ref()
|
||||
}
|
||||
}
|
||||
|
||||
impl<I: Iterator> Iterator for RefPeekable<I> {
|
||||
type Item = I::Item;
|
||||
|
||||
#[inline]
|
||||
fn next(&mut self) -> Option<I::Item> {
|
||||
let result = self.peeked.take();
|
||||
self.peeked = self.iter.next();
|
||||
result
|
||||
}
|
||||
}
|
@ -4,6 +4,6 @@ pub mod dir_ops;
|
||||
pub mod element_heap;
|
||||
pub mod file_atomic;
|
||||
pub mod file_wal;
|
||||
pub mod iter;
|
||||
pub mod mmap_array;
|
||||
pub mod tournament_tree;
|
||||
pub mod vec2;
|
||||
|
138
crates/service/src/utils/tournament_tree.rs
Normal file
138
crates/service/src/utils/tournament_tree.rs
Normal file
@ -0,0 +1,138 @@
|
||||
use std::cmp::Reverse;
|
||||
|
||||
pub struct LoserTree<I, T> {
|
||||
// 0..n
|
||||
iterators: Vec<I>,
|
||||
// 0..m
|
||||
x: Vec<Option<Reverse<T>>>,
|
||||
// 0..m, m = (winner: 1) + (losers: 2 ^ 0 + 2 ^ 1 + 2 ^ 2 + 2 ^ 3 + ... + 2 ^ (k - 1))
|
||||
losers: Vec<usize>,
|
||||
}
|
||||
|
||||
impl<I> LoserTree<I, I::Item>
|
||||
where
|
||||
I: Iterator,
|
||||
I::Item: Ord,
|
||||
{
|
||||
pub fn new(mut iterators: Vec<I>) -> Self {
|
||||
let n = iterators.len();
|
||||
let m = n.next_power_of_two();
|
||||
let mut x = Vec::new();
|
||||
x.resize_with(m, || None);
|
||||
let mut losers = vec![usize::MAX; m];
|
||||
for i in 0..n {
|
||||
x[i] = iterators[i].next().map(Reverse);
|
||||
}
|
||||
let mut winners = vec![usize::MAX; 2 * m];
|
||||
for i in 0..m {
|
||||
winners[m + i] = i;
|
||||
}
|
||||
for i in (1..m).rev() {
|
||||
let (l, r) = (winners[i << 1], winners[i << 1 | 1]);
|
||||
(losers[i], winners[i]) = if x[l] < x[r] { (l, r) } else { (r, l) };
|
||||
}
|
||||
losers[0] = winners[1];
|
||||
Self {
|
||||
iterators,
|
||||
x,
|
||||
losers,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<I> Iterator for LoserTree<I, I::Item>
|
||||
where
|
||||
I: Iterator,
|
||||
I::Item: Ord,
|
||||
{
|
||||
type Item = I::Item;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
let n = self.iterators.len();
|
||||
let m = n.next_power_of_two();
|
||||
let r = self.losers[0];
|
||||
let Reverse(result) = self.x[r].take()?;
|
||||
self.x[r] = self.iterators[r].next().map(Reverse);
|
||||
let mut v = r;
|
||||
let mut i = (m + r) >> 1;
|
||||
while i != 0 {
|
||||
if self.x[v] < self.x[self.losers[i]] {
|
||||
std::mem::swap(&mut v, &mut self.losers[i]);
|
||||
}
|
||||
i >>= 1;
|
||||
}
|
||||
self.losers[0] = v;
|
||||
Some(result)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::*;
|
||||
use rand::Rng;
|
||||
|
||||
fn check(seqs: &[Vec<u32>]) {
|
||||
let brute_force = {
|
||||
let mut result = Vec::new();
|
||||
let mut seqs = seqs
|
||||
.iter()
|
||||
.map(|x| x.clone().into_iter().peekable())
|
||||
.collect::<Vec<_>>();
|
||||
while !seqs.is_empty() {
|
||||
let mut index = 0usize;
|
||||
let mut value = u32::MAX;
|
||||
for (i, seq) in seqs.iter_mut().enumerate() {
|
||||
if let Some(&x) = seq.peek() {
|
||||
if x <= value {
|
||||
index = i;
|
||||
value = x;
|
||||
}
|
||||
}
|
||||
}
|
||||
let Some(_) = seqs[index].next() else { break };
|
||||
result.push(value);
|
||||
}
|
||||
result
|
||||
};
|
||||
let loser_tree = {
|
||||
let iterators = seqs.iter().map(|x| x.iter().copied()).collect();
|
||||
LoserTree::new(iterators).collect::<Vec<_>>()
|
||||
};
|
||||
assert_eq!(brute_force, loser_tree);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_hardcode() {
|
||||
check(&[]);
|
||||
check(&[vec![0, 2, 4], vec![1, 3, 5], vec![], vec![], vec![]]);
|
||||
check(&[vec![], vec![], vec![], vec![], vec![]]);
|
||||
check(&[vec![1, 1, 1, 1, 1, 1]]);
|
||||
check(&[vec![1, 2, 3, 4, 5, 6], vec![1, 2, 3, 4, 5, 6]]);
|
||||
check(&[vec![2, 2, 3, 3, 4, 4, 5], vec![1, 1, 5, 6, 6]]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_random() {
|
||||
fn vec(n: usize) -> Vec<u32> {
|
||||
let mut vec = vec![0u32; n];
|
||||
vec.fill_with(|| rand::thread_rng().gen_range(0..100_000));
|
||||
vec.sort();
|
||||
vec
|
||||
}
|
||||
|
||||
fn vecs() -> Vec<Vec<u32>> {
|
||||
use rand::Rng;
|
||||
let m = rand::thread_rng().gen_range(0..100);
|
||||
let mut vecs = Vec::new();
|
||||
for _ in 0..m {
|
||||
let n = rand::thread_rng().gen_range(0..10000);
|
||||
vecs.push(vec(n));
|
||||
}
|
||||
vecs
|
||||
}
|
||||
|
||||
for _ in 0..10 {
|
||||
check(&vecs());
|
||||
}
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user