From eb87bfa117c95419b26fe7c929290d439a8206ad Mon Sep 17 00:00:00 2001 From: Will Jones Date: Thu, 22 Aug 2024 09:00:49 -0700 Subject: [PATCH] perf: stable row id prefilter (#2706) Cache the row id mask prefilter and make it faster to construct. --- rust/lance-core/src/utils/deletion.rs | 16 +++- rust/lance-core/src/utils/mask.rs | 14 +++- rust/lance-table/src/rowids.rs | 111 +++++++++++++++++++++---- rust/lance-table/src/rowids/segment.rs | 53 ++++++++++++ rust/lance/src/dataset/optimize.rs | 14 +++- rust/lance/src/index/prefilter.rs | 90 +++++++++++--------- 6 files changed, 239 insertions(+), 59 deletions(-) diff --git a/rust/lance-core/src/utils/deletion.rs b/rust/lance-core/src/utils/deletion.rs index 80ed958663..b847df0079 100644 --- a/rust/lance-core/src/utils/deletion.rs +++ b/rust/lance-core/src/utils/deletion.rs @@ -60,6 +60,14 @@ impl DeletionVector { } } + pub fn iter(&self) -> Box + Send + '_> { + match self { + Self::NoDeletions => Box::new(std::iter::empty()), + Self::Set(set) => Box::new(set.iter().copied()), + Self::Bitmap(bitmap) => Box::new(bitmap.iter()), + } + } + pub fn into_sorted_iter(self) -> Box + Send + 'static> { match self { Self::NoDeletions => Box::new(std::iter::empty()), @@ -183,7 +191,13 @@ impl IntoIterator for DeletionVector { fn into_iter(self) -> Self::IntoIter { match self { Self::NoDeletions => Box::new(std::iter::empty()), - Self::Set(set) => Box::new(set.into_iter()), + Self::Set(set) => { + // In many cases, it's much better if this is sorted. It's + // guaranteed to be small, so the cost is low. + let mut sorted = set.into_iter().collect::>(); + sorted.sort(); + Box::new(sorted.into_iter()) + } Self::Bitmap(bitmap) => Box::new(bitmap.into_iter()), } } diff --git a/rust/lance-core/src/utils/mask.rs b/rust/lance-core/src/utils/mask.rs index fb92b492cc..1f2e7213c7 100644 --- a/rust/lance-core/src/utils/mask.rs +++ b/rust/lance-core/src/utils/mask.rs @@ -9,6 +9,7 @@ use std::{collections::BTreeMap, io::Read}; use arrow_array::{Array, BinaryArray, GenericBinaryArray}; use arrow_buffer::{Buffer, NullBuffer, OffsetBuffer}; use byteorder::{ReadBytesExt, WriteBytesExt}; +use deepsize::DeepSizeOf; use roaring::RoaringBitmap; use crate::Result; @@ -23,7 +24,7 @@ use super::address::RowAddress; /// /// If both the allow_list and the block_list are None (the default) then /// all row ids are selected -#[derive(Clone, Debug, Default)] +#[derive(Clone, Debug, Default, DeepSizeOf)] pub struct RowIdMask { /// If Some then only these row ids are selected pub allow_list: Option, @@ -273,7 +274,7 @@ impl std::ops::BitOr for RowIdMask { /// /// This is similar to a [RoaringTreemap] but it is optimized for the case where /// entire fragments are selected or deselected. -#[derive(Clone, Debug, Default, PartialEq)] +#[derive(Clone, Debug, Default, PartialEq, DeepSizeOf)] pub struct RowIdTreeMap { /// The contents of the set. If there is a pair (k, Full) then the entire /// fragment k is selected. If there is a pair (k, Partial(v)) then the @@ -287,6 +288,15 @@ enum RowIdSelection { Partial(RoaringBitmap), } +impl DeepSizeOf for RowIdSelection { + fn deep_size_of_children(&self, _context: &mut deepsize::Context) -> usize { + match self { + Self::Full => 0, + Self::Partial(bitmap) => bitmap.serialized_size(), + } + } +} + impl RowIdTreeMap { /// Create an empty set pub fn new() -> Self { diff --git a/rust/lance-table/src/rowids.rs b/rust/lance-table/src/rowids.rs index 33d8f172b9..aede5f1d3f 100644 --- a/rust/lance-table/src/rowids.rs +++ b/rust/lance-table/src/rowids.rs @@ -153,23 +153,33 @@ impl RowIdSequence { } /// Delete row ids by position. - pub fn mask(&mut self, positions: impl IntoIterator) -> Result<()> { - let row_ids = positions - .into_iter() - .map(|pos| { - self.get(pos).ok_or_else(|| { - Error::invalid_input( - format!( - "position out of bounds: {} on sequence of length {}", - pos, - self.len() - ), - location!(), - ) - }) - }) - .collect::>>()?; - self.delete(row_ids); + pub fn mask(&mut self, positions: impl IntoIterator) -> Result<()> { + let mut local_positions = Vec::new(); + let mut positions_iter = positions.into_iter(); + let mut curr_position = positions_iter.next(); + let mut offset = 0; + let mut cutoff = 0; + + for segment in &mut self.0 { + // Make vector of local positions + cutoff += segment.len() as u32; + while let Some(position) = curr_position { + if position >= cutoff { + break; + } + local_positions.push(position - offset); + curr_position = positions_iter.next(); + } + + if !local_positions.is_empty() { + segment.mask(&local_positions); + local_positions.clear(); + } + offset = cutoff; + } + + self.0.retain(|segment| segment.len() != 0); + Ok(()) } @@ -753,4 +763,71 @@ mod test { .collect::(); assert_eq!(tree_map, expected); } + + #[test] + fn test_row_id_mask() { + // 0, 1, 2, 3, 4 + // 50, 51, 52, 55, 56, 57, 58, 59 + // 7, 9 + // 10, 12, 14 + // 35, 39 + let sequence = RowIdSequence(vec![ + U64Segment::Range(0..5), + U64Segment::RangeWithHoles { + range: 50..60, + holes: vec![53, 54].into(), + }, + U64Segment::SortedArray(vec![7, 9].into()), + U64Segment::RangeWithBitmap { + range: 10..15, + bitmap: [true, false, true, false, true].as_slice().into(), + }, + U64Segment::Array(vec![35, 39].into()), + ]); + + // Masking one in each segment + let values_to_remove = [4, 55, 7, 12, 39]; + let positions_to_remove = sequence + .iter() + .enumerate() + .filter_map(|(i, val)| { + if values_to_remove.contains(&val) { + Some(i as u32) + } else { + None + } + }) + .collect::>(); + let mut sequence = sequence; + sequence.mask(positions_to_remove).unwrap(); + let expected = RowIdSequence(vec![ + U64Segment::Range(0..4), + U64Segment::RangeWithBitmap { + range: 50..60, + bitmap: [ + true, true, true, false, false, false, true, true, true, true, + ] + .as_slice() + .into(), + }, + U64Segment::Range(9..10), + U64Segment::RangeWithBitmap { + range: 10..15, + bitmap: [true, false, false, false, true].as_slice().into(), + }, + U64Segment::Array(vec![35].into()), + ]); + assert_eq!(sequence, expected); + } + + #[test] + fn test_row_id_mask_everything() { + let mut sequence = RowIdSequence(vec![ + U64Segment::Range(0..5), + U64Segment::SortedArray(vec![7, 9].into()), + ]); + sequence.mask(0..sequence.len() as u32).unwrap(); + let expected = RowIdSequence(vec![]); + assert_eq!(sequence, expected); + } } diff --git a/rust/lance-table/src/rowids/segment.rs b/rust/lance-table/src/rowids/segment.rs index 6a78033807..7e4fa51b0a 100644 --- a/rust/lance-table/src/rowids/segment.rs +++ b/rust/lance-table/src/rowids/segment.rs @@ -392,6 +392,59 @@ impl U64Segment { let stats = Self::compute_stats(make_new_iter()); Self::from_stats_and_sequence(stats, make_new_iter()) } + + pub fn mask(&mut self, positions: &[u32]) { + if positions.is_empty() { + return; + } + if positions.len() == self.len() { + *self = Self::Range(0..0); + return; + } + let count = (self.len() - positions.len()) as u64; + let sorted = match self { + Self::Range(_) => true, + Self::RangeWithHoles { .. } => true, + Self::RangeWithBitmap { .. } => true, + Self::SortedArray(_) => true, + Self::Array(_) => false, + }; + // To get minimum, need to find the first value that is not masked. + let first_unmasked = (0..self.len()) + .zip(positions.iter().cycle()) + .find(|(sequential_i, i)| **i != *sequential_i as u32) + .map(|(sequential_i, _)| sequential_i) + .unwrap(); + let min = self.get(first_unmasked).unwrap(); + + let last_unmasked = (0..self.len()) + .rev() + .zip(positions.iter().rev().cycle()) + .filter(|(sequential_i, i)| **i != *sequential_i as u32) + .map(|(sequential_i, _)| sequential_i) + .next() + .unwrap(); + let max = self.get(last_unmasked).unwrap(); + + let stats = SegmentStats { + min, + max, + count, + sorted, + }; + + let mut positions = positions.iter().copied().peekable(); + let sequence = self.iter().enumerate().filter_map(move |(i, val)| { + if let Some(next_pos) = positions.peek() { + if *next_pos == i as u32 { + positions.next(); + return None; + } + } + Some(val) + }); + *self = Self::from_stats_and_sequence(stats, sequence) + } } #[cfg(test)] diff --git a/rust/lance/src/dataset/optimize.rs b/rust/lance/src/dataset/optimize.rs index 3e2294dcfe..e6739b4c49 100644 --- a/rust/lance/src/dataset/optimize.rs +++ b/rust/lance/src/dataset/optimize.rs @@ -738,13 +738,25 @@ async fn rechunk_stable_row_ids( let deletions = read_deletion_file(&dataset.base, frag, dataset.object_store()).await?; if let Some(deletions) = deletions { let mut new_seq = seq.as_ref().clone(); - new_seq.mask(deletions.into_iter().map(|x| x as usize))?; + new_seq.mask(deletions.into_iter())?; *seq = Arc::new(new_seq); } Ok::<(), crate::Error>(()) }) .await?; + debug_assert_eq!( + { old_sequences.iter().map(|(_, seq)| seq.len()).sum::() }, + { + new_fragments + .iter() + .map(|frag| frag.physical_rows.unwrap() as u64) + .sum::() + }, + "{:?}", + old_sequences + ); + let new_sequences = lance_table::rowids::rechunk_sequences( old_sequences .into_iter() diff --git a/rust/lance/src/index/prefilter.rs b/rust/lance/src/index/prefilter.rs index fe2a5bc6eb..e60ad6d65f 100644 --- a/rust/lance/src/index/prefilter.rs +++ b/rust/lance/src/index/prefilter.rs @@ -6,6 +6,7 @@ //! Based on the query, we might have information about which fragment ids and //! row ids can be excluded from the search. +use std::borrow::Cow; use std::cell::OnceCell; use std::collections::HashMap; use std::sync::Arc; @@ -17,10 +18,13 @@ use futures::stream; use futures::FutureExt; use futures::StreamExt; use futures::TryStreamExt; +use lance_core::utils::deletion::DeletionVector; use lance_core::utils::mask::RowIdMask; use lance_core::utils::mask::RowIdTreeMap; +use lance_core::utils::tokio::spawn_cpu; use lance_table::format::Fragment; use lance_table::format::Index; +use lance_table::rowids::RowIdSequence; use roaring::RoaringBitmap; use tokio::join; use tracing::instrument; @@ -119,49 +123,59 @@ impl DatasetPreFilter { // This can only be computed as an allow list, since we have no idea // what the row ids were in the missing fragments. - // For each fragment, compute which row ids are still in use. - let dataset_ref = dataset.as_ref(); + let path = dataset + .base + .child(format!("row_id_mask{}", dataset.manifest().version)); + + let session = dataset.session(); + + async fn load_row_ids_and_deletions( + dataset: &Dataset, + ) -> Result, Option>)>> { + stream::iter(dataset.get_fragments()) + .map(|frag| async move { + let row_ids = load_row_id_sequence(dataset, frag.metadata()); + let deletion_vector = frag.get_deletion_vector(); + let (row_ids, deletion_vector) = join!(row_ids, deletion_vector); + Ok::<_, crate::Error>((row_ids?, deletion_vector?)) + }) + .buffer_unordered(dataset.object_store().io_parallelism()? as usize) + .try_collect::>() + .await + } - let row_ids_and_deletions = stream::iter(dataset.get_fragments()) - .map(|frag| async move { - let row_ids = load_row_id_sequence(dataset_ref, frag.metadata()); - let deletion_vector = frag.get_deletion_vector(); - let (row_ids, deletion_vector) = join!(row_ids, deletion_vector); - Ok::<_, crate::Error>((row_ids?, deletion_vector?)) - }) - .buffer_unordered(10) - .try_collect::>() - .await?; - - // The process of computing the final mask is CPU-bound, so we spawn it - // on a blocking thread. - let allow_list = tokio::task::spawn_blocking(move || { - row_ids_and_deletions.into_iter().fold( - RowIdTreeMap::new(), - |mut allow_list, (row_ids, deletion_vector)| { - let row_ids = if let Some(deletion_vector) = deletion_vector { - // We have to mask the row ids - row_ids.as_ref().iter().enumerate().fold( + session + .file_metadata_cache + .get_or_insert(&path, move |_| { + let dataset = dataset.clone(); + async move { + let row_ids_and_deletions = load_row_ids_and_deletions(&dataset).await?; + + // The process of computing the final mask is CPU-bound, so we spawn it + // on a blocking thread. + let allow_list = spawn_cpu(move || { + Ok(row_ids_and_deletions.into_iter().fold( RowIdTreeMap::new(), - |mut allow_list, (idx, row_id)| { - if !deletion_vector.contains(idx as u32) { - allow_list.insert(row_id); - } + |mut allow_list, (row_ids, deletion_vector)| { + let seq = if let Some(deletion_vector) = deletion_vector { + let mut row_ids = row_ids.as_ref().clone(); + row_ids.mask(deletion_vector.iter()).unwrap(); + Cow::Owned(row_ids) + } else { + Cow::Borrowed(row_ids.as_ref()) + }; + let treemap = RowIdTreeMap::from(seq.as_ref()); + allow_list |= treemap; allow_list }, - ) - } else { - // Can do a direct translation - RowIdTreeMap::from(row_ids.as_ref()) - }; - allow_list |= row_ids; - allow_list - }, - ) - }) - .await?; + )) + }) + .await?; - Ok(Arc::new(RowIdMask::from_allowed(allow_list))) + Ok(RowIdMask::from_allowed(allow_list)) + } + }) + .await } /// Creates a task to load mask to filter out deleted rows.