1
0
mirror of https://github.com/tensorchord/pgvecto.rs.git synced 2025-07-29 08:21:12 +03:00

refactor: replace heap of heaps with loser tree in merging results (#315)

Signed-off-by: usamoi <usamoi@outlook.com>
This commit is contained in:
Usamoi
2024-02-02 15:21:45 +08:00
committed by GitHub
parent be5a816810
commit f26ffba75d
5 changed files with 164 additions and 94 deletions

View File

@ -27,7 +27,7 @@ memoffset = "0.9.0"
arrayvec = { version = "0.7.3", features = ["serde"] }
memmap2 = "0.9.0"
rayon = "1.6.1"
uuid = { version = "1.6.1", features = ["serde"] }
uuid = { version = "1.6.1", features = ["v4", "serde"] }
arc-swap = "1.6.0"
multiversion = "0.7.3"

View File

@ -15,13 +15,12 @@ use crate::prelude::*;
use crate::utils::clean::clean;
use crate::utils::dir_ops::sync_dir;
use crate::utils::file_atomic::FileAtomic;
use crate::utils::iter::RefPeekable;
use crate::utils::tournament_tree::LoserTree;
use arc_swap::ArcSwap;
use crossbeam::atomic::AtomicCell;
use parking_lot::Mutex;
use serde::{Deserialize, Serialize};
use std::cmp::Reverse;
use std::collections::BinaryHeap;
use std::collections::HashMap;
use std::collections::HashSet;
use std::path::PathBuf;
@ -317,25 +316,13 @@ impl<S: G> IndexView<S> {
return Err(ServiceError::Unmatched);
}
struct Comparer(BinaryHeap<Reverse<Element>>);
struct Comparer(std::collections::BinaryHeap<Reverse<Element>>);
impl PartialEq for Comparer {
fn eq(&self, other: &Self) -> bool {
self.cmp(other).is_eq()
}
}
impl Iterator for Comparer {
type Item = Element;
impl Eq for Comparer {}
impl PartialOrd for Comparer {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl Ord for Comparer {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.0.peek().cmp(&other.0.peek()).reverse()
fn next(&mut self) -> Option<Self::Item> {
self.0.pop().map(|Reverse(x)| x)
}
}
@ -370,7 +357,7 @@ impl<S: G> IndexView<S> {
};
let n = self.sealed.len() + self.growing.len() + 1;
let mut heaps = BinaryHeap::with_capacity(1 + n);
let mut heaps = Vec::with_capacity(1 + n);
for (_, sealed) in self.sealed.iter() {
let p = sealed.basic(vector, opts, filter.clone());
heaps.push(Comparer(p));
@ -383,16 +370,13 @@ impl<S: G> IndexView<S> {
let p = write.basic(vector, opts, filter.clone());
heaps.push(Comparer(p));
}
Ok(std::iter::from_fn(move || {
while let Some(mut iter) = heaps.pop() {
if let Some(Reverse(x)) = iter.0.pop() {
heaps.push(iter);
if opts.prefilter_enable || self.delete.check(x.payload).is_some() {
return Some(Pointer::from_u48(x.payload >> 16));
}
}
let loser = LoserTree::new(heaps);
Ok(loser.filter_map(|x| {
if opts.prefilter_enable || self.delete.check(x.payload).is_some() {
Some(Pointer::from_u48(x.payload >> 16))
} else {
None
}
None
}))
}
pub fn vbase<'a, F: FnMut(Pointer) -> bool + Clone + 'a>(
@ -405,28 +389,6 @@ impl<S: G> IndexView<S> {
return Err(ServiceError::Unmatched);
}
struct Comparer<'a>(RefPeekable<Box<dyn Iterator<Item = Element> + 'a>>);
impl PartialEq for Comparer<'_> {
fn eq(&self, other: &Self) -> bool {
self.cmp(other).is_eq()
}
}
impl Eq for Comparer<'_> {}
impl PartialOrd for Comparer<'_> {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl Ord for Comparer<'_> {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.0.peek().cmp(&other.0.peek()).reverse()
}
}
struct Filtering<'a, F: 'a> {
enable: bool,
delete: &'a Delete,
@ -459,34 +421,31 @@ impl<S: G> IndexView<S> {
let n = self.sealed.len() + self.growing.len() + 1;
let mut alpha = Vec::new();
let mut beta = BinaryHeap::with_capacity(1 + n);
let mut beta = Vec::with_capacity(1 + n);
for (_, sealed) in self.sealed.iter() {
let (stage1, stage2) = sealed.vbase(vector, opts, filter.clone());
alpha.extend(stage1);
beta.push(Comparer(RefPeekable::new(stage2)));
beta.push(stage2);
}
for (_, growing) in self.growing.iter() {
let (stage1, stage2) = growing.vbase(vector, opts, filter.clone());
alpha.extend(stage1);
beta.push(Comparer(RefPeekable::new(stage2)));
beta.push(stage2);
}
if let Some((_, write)) = &self.write {
let (stage1, stage2) = write.vbase(vector, opts, filter.clone());
alpha.extend(stage1);
beta.push(Comparer(RefPeekable::new(stage2)));
beta.push(stage2);
}
alpha.sort_unstable();
beta.push(Comparer(RefPeekable::new(Box::new(alpha.into_iter()))));
Ok(std::iter::from_fn(move || {
while let Some(mut iter) = beta.pop() {
if let Some(x) = iter.0.next() {
beta.push(iter);
if opts.prefilter_enable || self.delete.check(x.payload).is_some() {
return Some(Pointer::from_u48(x.payload >> 16));
}
}
beta.push(Box::new(alpha.into_iter()));
let loser = LoserTree::new(beta);
Ok(loser.filter_map(|x| {
if opts.prefilter_enable || self.delete.check(x.payload).is_some() {
Some(Pointer::from_u48(x.payload >> 16))
} else {
None
}
None
}))
}
pub fn insert(

View File

@ -1,27 +0,0 @@
pub struct RefPeekable<I: Iterator> {
peeked: Option<I::Item>,
iter: I,
}
impl<I: Iterator> RefPeekable<I> {
pub fn new(mut iter: I) -> RefPeekable<I> {
RefPeekable {
peeked: iter.next(),
iter,
}
}
pub fn peek(&self) -> Option<&I::Item> {
self.peeked.as_ref()
}
}
impl<I: Iterator> Iterator for RefPeekable<I> {
type Item = I::Item;
#[inline]
fn next(&mut self) -> Option<I::Item> {
let result = self.peeked.take();
self.peeked = self.iter.next();
result
}
}

View File

@ -4,6 +4,6 @@ pub mod dir_ops;
pub mod element_heap;
pub mod file_atomic;
pub mod file_wal;
pub mod iter;
pub mod mmap_array;
pub mod tournament_tree;
pub mod vec2;

View File

@ -0,0 +1,138 @@
use std::cmp::Reverse;
pub struct LoserTree<I, T> {
// 0..n
iterators: Vec<I>,
// 0..m
x: Vec<Option<Reverse<T>>>,
// 0..m, m = (winner: 1) + (losers: 2 ^ 0 + 2 ^ 1 + 2 ^ 2 + 2 ^ 3 + ... + 2 ^ (k - 1))
losers: Vec<usize>,
}
impl<I> LoserTree<I, I::Item>
where
I: Iterator,
I::Item: Ord,
{
pub fn new(mut iterators: Vec<I>) -> Self {
let n = iterators.len();
let m = n.next_power_of_two();
let mut x = Vec::new();
x.resize_with(m, || None);
let mut losers = vec![usize::MAX; m];
for i in 0..n {
x[i] = iterators[i].next().map(Reverse);
}
let mut winners = vec![usize::MAX; 2 * m];
for i in 0..m {
winners[m + i] = i;
}
for i in (1..m).rev() {
let (l, r) = (winners[i << 1], winners[i << 1 | 1]);
(losers[i], winners[i]) = if x[l] < x[r] { (l, r) } else { (r, l) };
}
losers[0] = winners[1];
Self {
iterators,
x,
losers,
}
}
}
impl<I> Iterator for LoserTree<I, I::Item>
where
I: Iterator,
I::Item: Ord,
{
type Item = I::Item;
fn next(&mut self) -> Option<Self::Item> {
let n = self.iterators.len();
let m = n.next_power_of_two();
let r = self.losers[0];
let Reverse(result) = self.x[r].take()?;
self.x[r] = self.iterators[r].next().map(Reverse);
let mut v = r;
let mut i = (m + r) >> 1;
while i != 0 {
if self.x[v] < self.x[self.losers[i]] {
std::mem::swap(&mut v, &mut self.losers[i]);
}
i >>= 1;
}
self.losers[0] = v;
Some(result)
}
}
#[cfg(test)]
mod test {
use super::*;
use rand::Rng;
fn check(seqs: &[Vec<u32>]) {
let brute_force = {
let mut result = Vec::new();
let mut seqs = seqs
.iter()
.map(|x| x.clone().into_iter().peekable())
.collect::<Vec<_>>();
while !seqs.is_empty() {
let mut index = 0usize;
let mut value = u32::MAX;
for (i, seq) in seqs.iter_mut().enumerate() {
if let Some(&x) = seq.peek() {
if x <= value {
index = i;
value = x;
}
}
}
let Some(_) = seqs[index].next() else { break };
result.push(value);
}
result
};
let loser_tree = {
let iterators = seqs.iter().map(|x| x.iter().copied()).collect();
LoserTree::new(iterators).collect::<Vec<_>>()
};
assert_eq!(brute_force, loser_tree);
}
#[test]
fn test_hardcode() {
check(&[]);
check(&[vec![0, 2, 4], vec![1, 3, 5], vec![], vec![], vec![]]);
check(&[vec![], vec![], vec![], vec![], vec![]]);
check(&[vec![1, 1, 1, 1, 1, 1]]);
check(&[vec![1, 2, 3, 4, 5, 6], vec![1, 2, 3, 4, 5, 6]]);
check(&[vec![2, 2, 3, 3, 4, 4, 5], vec![1, 1, 5, 6, 6]]);
}
#[test]
fn test_random() {
fn vec(n: usize) -> Vec<u32> {
let mut vec = vec![0u32; n];
vec.fill_with(|| rand::thread_rng().gen_range(0..100_000));
vec.sort();
vec
}
fn vecs() -> Vec<Vec<u32>> {
use rand::Rng;
let m = rand::thread_rng().gen_range(0..100);
let mut vecs = Vec::new();
for _ in 0..m {
let n = rand::thread_rng().gen_range(0..10000);
vecs.push(vec(n));
}
vecs
}
for _ in 0..10 {
check(&vecs());
}
}
}