diff --git a/.github/workflows/ci-benchmarks.yml b/.github/workflows/ci-benchmarks.yml index 90fc72af07..1b87ec69e0 100644 --- a/.github/workflows/ci-benchmarks.yml +++ b/.github/workflows/ci-benchmarks.yml @@ -1,6 +1,7 @@ name: Run Regression Benchmarks on: + workflow_dispatch: push: branches: - main diff --git a/java/core/lance-jni/src/blocking_dataset.rs b/java/core/lance-jni/src/blocking_dataset.rs index 2e763afca5..322156d5a2 100644 --- a/java/core/lance-jni/src/blocking_dataset.rs +++ b/java/core/lance-jni/src/blocking_dataset.rs @@ -22,15 +22,16 @@ use arrow::datatypes::Schema; use arrow::ffi::FFI_ArrowSchema; use arrow::ffi_stream::ArrowArrayStreamReader; use arrow::ffi_stream::FFI_ArrowArrayStream; +use arrow::ipc::writer::StreamWriter; use arrow::record_batch::RecordBatchIterator; use arrow_schema::DataType; use jni::objects::{JMap, JString, JValue}; -use jni::sys::jlong; use jni::sys::{jboolean, jint}; +use jni::sys::{jbyteArray, jlong}; use jni::{objects::JObject, JNIEnv}; use lance::dataset::builder::DatasetBuilder; use lance::dataset::transaction::Operation; -use lance::dataset::{ColumnAlteration, Dataset, ReadParams, WriteParams}; +use lance::dataset::{ColumnAlteration, Dataset, ProjectionRequest, ReadParams, WriteParams}; use lance::io::{ObjectStore, ObjectStoreParams}; use lance::table::format::Fragment; use lance::table::format::Index; @@ -683,6 +684,59 @@ fn inner_list_indexes<'local>( Ok(array_list) } +#[no_mangle] +pub extern "system" fn Java_com_lancedb_lance_Dataset_nativeTake( + mut env: JNIEnv, + java_dataset: JObject, + indices_obj: JObject, // List + columns_obj: JObject, // List +) -> jbyteArray { + match inner_take(&mut env, java_dataset, indices_obj, columns_obj) { + Ok(byte_array) => byte_array, + Err(e) => { + let _ = env.throw_new("java/lang/RuntimeException", format!("{:?}", e)); + std::ptr::null_mut() + } + } +} + +fn inner_take( + env: &mut JNIEnv, + java_dataset: JObject, + indices_obj: JObject, // List + columns_obj: JObject, // List +) -> Result { + let indices: Vec = env.get_longs(&indices_obj)?; + let indices_u64: Vec = indices.iter().map(|&x| x as u64).collect(); + let indices_slice: &[u64] = &indices_u64; + let columns: Vec = env.get_strings(&columns_obj)?; + + let result = { + let dataset_guard = + unsafe { env.get_rust_field::<_, _, BlockingDataset>(java_dataset, NATIVE_DATASET) }?; + let dataset = &dataset_guard.inner; + + let projection = ProjectionRequest::from_columns(columns, dataset.schema()); + + match RT.block_on(dataset.take(indices_slice, projection)) { + Ok(res) => res, + Err(e) => { + return Err(e.into()); + } + } + }; + + let mut buffer = Vec::new(); + { + let mut writer = StreamWriter::try_new(&mut buffer, &result.schema())?; + writer.write(&result)?; + writer.finish()?; + } + + let byte_array = env.byte_array_from_slice(&buffer)?; + Ok(**byte_array) +} + ////////////////////////////// // Schema evolution Methods // ////////////////////////////// diff --git a/java/core/lance-jni/src/ffi.rs b/java/core/lance-jni/src/ffi.rs index dd11a1ee38..f92d3ec873 100644 --- a/java/core/lance-jni/src/ffi.rs +++ b/java/core/lance-jni/src/ffi.rs @@ -26,6 +26,9 @@ pub trait JNIEnvExt { /// Get integers from Java List object. fn get_integers(&mut self, obj: &JObject) -> Result>; + /// Get longs from Java List object. + fn get_longs(&mut self, obj: &JObject) -> Result>; + /// Get strings from Java List object. fn get_strings(&mut self, obj: &JObject) -> Result>; @@ -127,6 +130,18 @@ impl JNIEnvExt for JNIEnv<'_> { Ok(results) } + fn get_longs(&mut self, obj: &JObject) -> Result> { + let list = self.get_list(obj)?; + let mut iter = list.iter(self)?; + let mut results = Vec::with_capacity(list.size(self)? as usize); + while let Some(elem) = iter.next(self)? { + let long_obj = self.call_method(elem, "longValue", "()J", &[])?; + let long_value = long_obj.j()?; + results.push(long_value); + } + Ok(results) + } + fn get_strings(&mut self, obj: &JObject) -> Result> { let list = self.get_list(obj)?; let mut iter = list.iter(self)?; @@ -348,6 +363,15 @@ pub extern "system" fn Java_com_lancedb_lance_test_JniTestHelper_parseInts( ok_or_throw_without_return!(env, env.get_integers(&list_obj)); } +#[no_mangle] +pub extern "system" fn Java_com_lancedb_lance_test_JniTestHelper_parseLongs( + mut env: JNIEnv, + _obj: JObject, + list_obj: JObject, // List +) { + ok_or_throw_without_return!(env, env.get_longs(&list_obj)); +} + #[no_mangle] pub extern "system" fn Java_com_lancedb_lance_test_JniTestHelper_parseIntsOpt( mut env: JNIEnv, diff --git a/java/core/lance-jni/src/fragment.rs b/java/core/lance-jni/src/fragment.rs index dacdd08798..459afab022 100644 --- a/java/core/lance-jni/src/fragment.rs +++ b/java/core/lance-jni/src/fragment.rs @@ -62,7 +62,7 @@ fn inner_count_rows_native( "Fragment not found: {fragment_id}" ))); }; - let res = RT.block_on(fragment.count_rows())?; + let res = RT.block_on(fragment.count_rows(None))?; Ok(res) } diff --git a/java/core/src/main/java/com/lancedb/lance/Dataset.java b/java/core/src/main/java/com/lancedb/lance/Dataset.java index 9a12d0c36a..8f1e5de507 100644 --- a/java/core/src/main/java/com/lancedb/lance/Dataset.java +++ b/java/core/src/main/java/com/lancedb/lance/Dataset.java @@ -24,9 +24,15 @@ import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.RootAllocator; import org.apache.arrow.util.Preconditions; +import org.apache.arrow.vector.ipc.ArrowReader; +import org.apache.arrow.vector.ipc.ArrowStreamReader; import org.apache.arrow.vector.types.pojo.Schema; +import java.io.ByteArrayInputStream; import java.io.Closeable; +import java.io.IOException; +import java.nio.channels.Channels; +import java.nio.channels.ReadableByteChannel; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -315,6 +321,32 @@ public LanceScanner newScan(ScanOptions options) { } } + /** + * Select rows of data by index. + * + * @param indices the indices to take + * @param columns the columns to take + * @return an ArrowReader + */ + public ArrowReader take(List indices, List columns) throws IOException { + Preconditions.checkArgument(nativeDatasetHandle != 0, "Dataset is closed"); + try (LockManager.ReadLock readLock = lockManager.acquireReadLock()) { + byte[] arrowData = nativeTake(indices, columns); + ByteArrayInputStream byteArrayInputStream = new ByteArrayInputStream(arrowData); + ReadableByteChannel readChannel = Channels.newChannel(byteArrayInputStream); + return new ArrowStreamReader(readChannel, allocator) { + @Override + public void close() throws IOException { + super.close(); + readChannel.close(); + byteArrayInputStream.close(); + } + }; + } + } + + private native byte[] nativeTake(List indices, List columns); + /** * Gets the URI of the dataset. * diff --git a/java/core/src/main/java/com/lancedb/lance/test/JniTestHelper.java b/java/core/src/main/java/com/lancedb/lance/test/JniTestHelper.java index 89f1f8a4b6..28ed442c0b 100644 --- a/java/core/src/main/java/com/lancedb/lance/test/JniTestHelper.java +++ b/java/core/src/main/java/com/lancedb/lance/test/JniTestHelper.java @@ -37,6 +37,13 @@ public class JniTestHelper { */ public static native void parseInts(List intsList); + /** + * JNI parse longs test. + * + * @param longsList the given list of longs + */ + public static native void parseLongs(List longsList); + /** * JNI parse ints opts test. * diff --git a/java/core/src/test/java/com/lancedb/lance/DatasetTest.java b/java/core/src/test/java/com/lancedb/lance/DatasetTest.java index 92765d28f2..4275ef9573 100644 --- a/java/core/src/test/java/com/lancedb/lance/DatasetTest.java +++ b/java/core/src/test/java/com/lancedb/lance/DatasetTest.java @@ -15,6 +15,8 @@ import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.ipc.ArrowReader; import org.apache.arrow.vector.types.pojo.ArrowType; import org.apache.arrow.vector.types.pojo.Field; import org.apache.arrow.vector.types.pojo.Schema; @@ -25,10 +27,9 @@ import java.io.IOException; import java.net.URISyntaxException; +import java.nio.channels.ClosedChannelException; import java.nio.file.Path; -import java.util.Arrays; -import java.util.Collections; -import java.util.HashMap; +import java.util.*; import java.util.stream.Collectors; import static org.junit.jupiter.api.Assertions.*; @@ -307,4 +308,32 @@ void testDropPath() { Dataset.drop(datasetPath, new HashMap<>()); } } + + @Test + void testTake() throws IOException, ClosedChannelException { + String testMethodName = new Object() {}.getClass().getEnclosingMethod().getName(); + String datasetPath = tempDir.resolve(testMethodName).toString(); + try (RootAllocator allocator = new RootAllocator(Long.MAX_VALUE)) { + TestUtils.SimpleTestDataset testDataset = + new TestUtils.SimpleTestDataset(allocator, datasetPath); + dataset = testDataset.createEmptyDataset(); + + try (Dataset dataset2 = testDataset.write(1, 5)) { + List indices = Arrays.asList(1L, 4L); + List columns = Arrays.asList("id", "name"); + try (ArrowReader reader = dataset2.take(indices, columns)) { + while (reader.loadNextBatch()) { + VectorSchemaRoot result = reader.getVectorSchemaRoot(); + assertNotNull(result); + assertEquals(indices.size(), result.getRowCount()); + + for (int i = 0; i < indices.size(); i++) { + assertEquals(indices.get(i).intValue(), result.getVector("id").getObject(i)); + assertNotNull(result.getVector("name").getObject(i)); + } + } + } + } + } + } } diff --git a/java/core/src/test/java/com/lancedb/lance/JNITest.java b/java/core/src/test/java/com/lancedb/lance/JNITest.java index 885379d804..ddb4ea3cdf 100644 --- a/java/core/src/test/java/com/lancedb/lance/JNITest.java +++ b/java/core/src/test/java/com/lancedb/lance/JNITest.java @@ -37,6 +37,11 @@ public void testInts() { JniTestHelper.parseInts(Arrays.asList(1, 2, 3)); } + @Test + public void testLongs() { + JniTestHelper.parseLongs(Arrays.asList(1L, 2L, 3L, Long.MAX_VALUE)); + } + @Test public void testIntsOpt() { JniTestHelper.parseIntsOpt(Optional.of(Arrays.asList(1, 2, 3))); diff --git a/java/spark/src/main/java/com/lancedb/lance/spark/read/FilterPushDown.java b/java/spark/src/main/java/com/lancedb/lance/spark/read/FilterPushDown.java index 9d7824d033..9202008fcb 100644 --- a/java/spark/src/main/java/com/lancedb/lance/spark/read/FilterPushDown.java +++ b/java/spark/src/main/java/com/lancedb/lance/spark/read/FilterPushDown.java @@ -36,6 +36,7 @@ import java.sql.Date; import java.sql.Timestamp; import java.util.ArrayList; +import java.util.Arrays; import java.util.List; import java.util.stream.Collectors; @@ -89,7 +90,7 @@ public static boolean isFilterSupported(Filter filter) { } else if (filter instanceof EqualNullSafe) { return false; } else if (filter instanceof In) { - return false; + return true; } else if (filter instanceof LessThan) { return true; } else if (filter instanceof LessThanOrEqual) { @@ -163,6 +164,13 @@ private static Optional compileFilter(Filter filter) { Optional child = compileFilter(f.child()); if (child.isEmpty()) return child; return Optional.of(String.format("NOT (%s)", child.get())); + } else if (filter instanceof In) { + In in = (In) filter; + String values = + Arrays.stream(in.values()) + .map(FilterPushDown::compileValue) + .collect(Collectors.joining(",")); + return Optional.of(String.format("%s IN (%s)", in.attribute(), values)); } return Optional.empty(); diff --git a/java/spark/src/test/java/com/lancedb/lance/spark/read/FilterPushDownTest.java b/java/spark/src/test/java/com/lancedb/lance/spark/read/FilterPushDownTest.java index a427fbd3ef..ba15151ae7 100644 --- a/java/spark/src/test/java/com/lancedb/lance/spark/read/FilterPushDownTest.java +++ b/java/spark/src/test/java/com/lancedb/lance/spark/read/FilterPushDownTest.java @@ -82,4 +82,26 @@ public void testCompileFiltersToSqlWhereClauseWithEmptyFilters() { Optional whereClause = FilterPushDown.compileFiltersToSqlWhereClause(filters); assertFalse(whereClause.isPresent()); } + + @Test + public void testIntegerInFilterPushDown() { + Object[] values = new Object[2]; + values[0] = 500; + values[1] = 600; + Filter[] filters = new Filter[] {new GreaterThan("age", 30), new In("salary", values)}; + Optional whereClause = FilterPushDown.compileFiltersToSqlWhereClause(filters); + assertTrue(whereClause.isPresent()); + assertEquals("(age > 30) AND (salary IN (500,600))", whereClause.get()); + } + + @Test + public void testStringInFilterPushDown() { + Object[] values = new Object[2]; + values[0] = "500"; + values[1] = "600"; + Filter[] filters = new Filter[] {new GreaterThan("age", 30), new In("salary", values)}; + Optional whereClause = FilterPushDown.compileFiltersToSqlWhereClause(filters); + assertTrue(whereClause.isPresent()); + assertEquals("(age > 30) AND (salary IN ('500','600'))", whereClause.get()); + } } diff --git a/python/python/ci_benchmarks/benchmarks/test_search.py b/python/python/ci_benchmarks/benchmarks/test_search.py index b2229d89b0..2cf31dc32a 100644 --- a/python/python/ci_benchmarks/benchmarks/test_search.py +++ b/python/python/ci_benchmarks/benchmarks/test_search.py @@ -34,3 +34,22 @@ def bench(): ) benchmark.pedantic(bench, rounds=1, iterations=1) + + +BTREE_FILTERS = ["image_widths = 3997", "image_widths >= 3990 AND image_widths <= 3997"] + + +@pytest.mark.parametrize("filt", BTREE_FILTERS) +def test_eda_btree_search(benchmark, filt): + dataset_uri = get_dataset_uri("image_eda") + ds = lance.dataset(dataset_uri) + + def bench(): + ds.to_table( + columns=[], + filter=filt, + with_row_id=True, + ) + + # We warmup so we can test hot index performance + benchmark.pedantic(bench, warmup_rounds=1, rounds=1, iterations=100) diff --git a/python/python/lance/dataset.py b/python/python/lance/dataset.py index 7e7229b6a9..2274ca2a4b 100644 --- a/python/python/lance/dataset.py +++ b/python/python/lance/dataset.py @@ -507,13 +507,13 @@ def to_table( batch_size: Optional[int] = None, batch_readahead: Optional[int] = None, fragment_readahead: Optional[int] = None, - scan_in_order: bool = True, + scan_in_order: Optional[bool] = None, *, - prefilter: bool = False, - with_row_id: bool = False, - with_row_address: bool = False, - use_stats: bool = True, - fast_search: bool = False, + prefilter: Optional[bool] = None, + with_row_id: Optional[bool] = None, + with_row_address: Optional[bool] = None, + use_stats: Optional[bool] = None, + fast_search: Optional[bool] = None, full_text_query: Optional[Union[str, dict]] = None, io_buffer_size: Optional[int] = None, late_materialization: Optional[bool | List[str]] = None, @@ -558,11 +558,11 @@ def to_table( The number of batches to read ahead. fragment_readahead: int, optional The number of fragments to read ahead. - scan_in_order: bool, default True + scan_in_order: bool, optional, default True Whether to read the fragments and batches in order. If false, throughput may be higher, but batches will be returned out of order and memory use might increase. - prefilter: bool, default False + prefilter: bool, optional, default False Run filter before the vector search. late_materialization: bool or List[str], default None Allows custom control over late materialization. See @@ -570,12 +570,13 @@ def to_table( use_scalar_index: bool, default True Allows custom control over scalar index usage. See ``ScannerBuilder.use_scalar_index`` for more information. - with_row_id: bool, default False + with_row_id: bool, optional, default False Return row ID. - with_row_address: bool, default False + with_row_address: bool, optional, default False Return row address - use_stats: bool, default True + use_stats: bool, optional, default True Use stats pushdown during filters. + fast_search: bool, optional, default False full_text_query: str or dict, optional query string to search for, the results will be ranked by BM25. e.g. "hello world", would match documents contains "hello" or "world". @@ -687,12 +688,12 @@ def to_batches( batch_size: Optional[int] = None, batch_readahead: Optional[int] = None, fragment_readahead: Optional[int] = None, - scan_in_order: bool = True, + scan_in_order: Optional[bool] = None, *, - prefilter: bool = False, - with_row_id: bool = False, - with_row_address: bool = False, - use_stats: bool = True, + prefilter: Optional[bool] = None, + with_row_id: Optional[bool] = None, + with_row_address: Optional[bool] = None, + use_stats: Optional[bool] = None, full_text_query: Optional[Union[str, dict]] = None, io_buffer_size: Optional[int] = None, late_materialization: Optional[bool | List[str]] = None, diff --git a/python/python/lance/fragment.py b/python/python/lance/fragment.py index ce9334c682..e3abc3e1de 100644 --- a/python/python/lance/fragment.py +++ b/python/python/lance/fragment.py @@ -217,7 +217,7 @@ def __init__( if fragment_id is None: raise ValueError("Either fragment or fragment_id must be specified") fragment = dataset.get_fragment(fragment_id)._fragment - self._fragment = fragment + self._fragment: _Fragment = fragment if self._fragment is None: raise ValueError(f"Fragment id does not exist: {fragment_id}") @@ -367,8 +367,8 @@ def count_rows( self, filter: Optional[Union[pa.compute.Expression, str]] = None ) -> int: if filter is not None: - raise ValueError("Does not support filter at the moment") - return self._fragment.count_rows() + return self.scanner(filter=filter).count_rows() + return self._fragment.count_rows(filter) @property def num_deletions(self) -> int: diff --git a/python/python/tests/test_dataset.py b/python/python/tests/test_dataset.py index 587f6a8165..955702aa14 100644 --- a/python/python/tests/test_dataset.py +++ b/python/python/tests/test_dataset.py @@ -2806,3 +2806,10 @@ def test_dataset_drop(tmp_path: Path): assert Path(tmp_path).exists() lance.LanceDataset.drop(tmp_path) assert not Path(tmp_path).exists() + + +def test_dataset_schema(tmp_path: Path): + table = pa.table({"x": [0]}) + ds = lance.write_dataset(table, str(tmp_path)) # noqa: F841 + ds._default_scan_options = {"with_row_id": True} + assert ds.schema == ds.to_table().schema diff --git a/python/python/tests/test_filter.py b/python/python/tests/test_filter.py index 2fad73a7b8..cc864ea245 100644 --- a/python/python/tests/test_filter.py +++ b/python/python/tests/test_filter.py @@ -86,6 +86,10 @@ def test_sql_predicates(dataset): ("str = 'aa'", 16), ("str in ('aa', 'bb')", 26), ("rec.bool", 50), + ("rec.bool is true", 50), + ("rec.bool is not true", 50), + ("rec.bool is false", 50), + ("rec.bool is not false", 50), ("rec.date = cast('2021-01-01' as date)", 1), ("rec.dt = cast('2021-01-01 00:00:00' as timestamp(6))", 1), ("rec.dt = cast('2021-01-01 00:00:00' as timestamp)", 1), diff --git a/python/python/tests/test_fragment.py b/python/python/tests/test_fragment.py index 7bae75759b..7a55e02788 100644 --- a/python/python/tests/test_fragment.py +++ b/python/python/tests/test_fragment.py @@ -9,6 +9,7 @@ import lance import pandas as pd import pyarrow as pa +import pyarrow.compute as pc import pytest from helper import ProgressForTest from lance import ( @@ -422,3 +423,15 @@ def test_fragment_merge(tmp_path): tmp_path, merge, read_version=dataset.latest_version ) assert [f.name for f in dataset.schema] == ["a", "b", "c", "d"] + + +def test_fragment_count_rows(tmp_path: Path): + data = pa.table({"a": range(800), "b": range(800)}) + ds = write_dataset(data, tmp_path) + + fragments = ds.get_fragments() + assert len(fragments) == 1 + + assert fragments[0].count_rows() == 800 + assert fragments[0].count_rows("a < 200") == 200 + assert fragments[0].count_rows(pc.field("a") < 200) == 200 diff --git a/python/src/fragment.rs b/python/src/fragment.rs index b5cb75fc3a..1ddf89a21b 100644 --- a/python/src/fragment.rs +++ b/python/src/fragment.rs @@ -127,11 +127,11 @@ impl FileFragment { PyLance(self.fragment.metadata().clone()) } - #[pyo3(signature=(_filter=None))] - fn count_rows(&self, _filter: Option) -> PyResult { + #[pyo3(signature=(filter=None))] + fn count_rows(&self, filter: Option) -> PyResult { RT.runtime.block_on(async { self.fragment - .count_rows() + .count_rows(filter) .await .map_err(|e| PyIOError::new_err(e.to_string())) }) diff --git a/rust/lance-datafusion/src/planner.rs b/rust/lance-datafusion/src/planner.rs index e9237f1aa2..aa596d05c7 100644 --- a/rust/lance-datafusion/src/planner.rs +++ b/rust/lance-datafusion/src/planner.rs @@ -636,7 +636,7 @@ impl Planner { })) } SQLExpr::IsFalse(expr) => Ok(Expr::IsFalse(Box::new(self.parse_sql_expr(expr)?))), - SQLExpr::IsNotFalse(_) => Ok(Expr::IsNotFalse(Box::new(self.parse_sql_expr(expr)?))), + SQLExpr::IsNotFalse(expr) => Ok(Expr::IsNotFalse(Box::new(self.parse_sql_expr(expr)?))), SQLExpr::IsTrue(expr) => Ok(Expr::IsTrue(Box::new(self.parse_sql_expr(expr)?))), SQLExpr::IsNotTrue(expr) => Ok(Expr::IsNotTrue(Box::new(self.parse_sql_expr(expr)?))), SQLExpr::IsNull(expr) => Ok(Expr::IsNull(Box::new(self.parse_sql_expr(expr)?))), diff --git a/rust/lance-index/src/scalar/btree.rs b/rust/lance-index/src/scalar/btree.rs index abbe65490f..2590e2863b 100644 --- a/rust/lance-index/src/scalar/btree.rs +++ b/rust/lance-index/src/scalar/btree.rs @@ -23,7 +23,7 @@ use datafusion::{ use datafusion_common::{DataFusionError, ScalarValue}; use datafusion_expr::Accumulator; use datafusion_physical_expr::{expressions::Column, PhysicalSortExpr}; -use deepsize::DeepSizeOf; +use deepsize::{Context, DeepSizeOf}; use futures::{ future::BoxFuture, stream::{self}, @@ -37,6 +37,8 @@ use lance_datafusion::{ chunker::chunk_concat_stream, exec::{execute_plan, LanceExecutionOptions, OneShotExec}, }; +use log::debug; +use moka::sync::Cache; use roaring::RoaringBitmap; use serde::{Serialize, Serializer}; use snafu::{location, Location}; @@ -53,6 +55,13 @@ const BTREE_PAGES_NAME: &str = "page_data.lance"; pub const DEFAULT_BTREE_BATCH_SIZE: u64 = 4096; const BATCH_SIZE_META_KEY: &str = "batch_size"; +lazy_static::lazy_static! { + static ref CACHE_SIZE: u64 = std::env::var("LANCE_BTREE_CACHE_SIZE") + .ok() + .and_then(|s| s.parse().ok()) + .unwrap_or(512 * 1024 * 1024); +} + /// Wraps a ScalarValue and implements Ord (ScalarValue only implements PartialOrd) #[derive(Clone, Debug)] pub struct OrderableScalarValue(pub ScalarValue); @@ -659,6 +668,42 @@ impl BTreeLookup { } } +// Caches btree pages in memory +#[derive(Debug)] +struct BTreeCache(Cache>); + +impl DeepSizeOf for BTreeCache { + fn deep_size_of_children(&self, _: &mut Context) -> usize { + self.0.iter().map(|(_, v)| v.deep_size_of()).sum() + } +} + +// We only need to open a file reader for pages if we need to load a page. If all +// pages are cached we don't open it. If we do open it we should only open it once. +#[derive(Clone)] +struct LazyIndexReader { + index_reader: Arc>>>, + store: Arc, +} + +impl LazyIndexReader { + fn new(store: Arc) -> Self { + Self { + index_reader: Arc::new(tokio::sync::Mutex::new(None)), + store, + } + } + + async fn get(&self) -> Result> { + let mut reader = self.index_reader.lock().await; + if reader.is_none() { + let index_reader = self.store.open_index_file(BTREE_PAGES_NAME).await?; + *reader = Some(index_reader); + } + Ok(reader.as_ref().unwrap().clone()) + } +} + /// A btree index satisfies scalar queries using a b tree /// /// The upper layers of the btree are expected to be cached and, when unloaded, @@ -677,6 +722,7 @@ impl BTreeLookup { #[derive(Clone, Debug, DeepSizeOf)] pub struct BTreeIndex { page_lookup: Arc, + page_cache: Arc, store: Arc, sub_index: Arc, batch_size: u64, @@ -691,24 +737,45 @@ impl BTreeIndex { batch_size: u64, ) -> Self { let page_lookup = Arc::new(BTreeLookup::new(tree, null_pages)); + let page_cache = Arc::new(BTreeCache( + Cache::builder() + .max_capacity(*CACHE_SIZE) + .weigher(|_, v: &Arc| v.deep_size_of() as u32) + .build(), + )); Self { page_lookup, + page_cache, store, sub_index, batch_size, } } - async fn search_page( + async fn lookup_page( &self, - query: &SargableQuery, page_number: u32, - index_reader: Arc, - ) -> Result { + index_reader: LazyIndexReader, + ) -> Result> { + if let Some(cached) = self.page_cache.0.get(&page_number) { + return Ok(cached); + } + let index_reader = index_reader.get().await?; let serialized_page = index_reader .read_record_batch(page_number as u64, self.batch_size) .await?; let subindex = self.sub_index.load_subindex(serialized_page).await?; + self.page_cache.0.insert(page_number, subindex.clone()); + Ok(subindex) + } + + async fn search_page( + &self, + query: &SargableQuery, + page_number: u32, + index_reader: LazyIndexReader, + ) -> Result { + let subindex = self.lookup_page(page_number, index_reader).await?; // TODO: If this is an IN query we can perhaps simplify the subindex query by restricting it to the // values that might be in the page. E.g. if we are searching for X IN [5, 3, 7] and five is in pages // 1 and 2 and three is in page 2 and seven is in pages 8 and 9 then when we search page 2 we only need @@ -894,14 +961,15 @@ impl ScalarIndex for BTreeIndex { )), SargableQuery::IsNull() => self.page_lookup.pages_null(), }; - let sub_index_reader = self.store.open_index_file(BTREE_PAGES_NAME).await?; + let lazy_index_reader = LazyIndexReader::new(self.store.clone()); let page_tasks = pages .into_iter() .map(|page_index| { - self.search_page(query, page_index, sub_index_reader.clone()) + self.search_page(query, page_index, lazy_index_reader.clone()) .boxed() }) .collect::>(); + debug!("Searching {} btree pages", page_tasks.len()); stream::iter(page_tasks) // I/O and compute mixed here but important case is index in cache so // use compute intensive thread count diff --git a/rust/lance/src/dataset.rs b/rust/lance/src/dataset.rs index cbcf878d78..bd27c1fc31 100644 --- a/rust/lance/src/dataset.rs +++ b/rust/lance/src/dataset.rs @@ -798,7 +798,7 @@ impl Dataset { pub(crate) async fn count_all_rows(&self) -> Result { let cnts = stream::iter(self.get_fragments()) - .map(|f| async move { f.count_rows().await }) + .map(|f| async move { f.count_rows(None).await }) .buffer_unordered(16) .try_collect::>() .await?; @@ -2037,7 +2037,7 @@ mod tests { assert_eq!(fragments.len(), 10); assert_eq!(dataset.count_fragments(), 10); for fragment in &fragments { - assert_eq!(fragment.count_rows().await.unwrap(), 100); + assert_eq!(fragment.count_rows(None).await.unwrap(), 100); let reader = fragment .open(dataset.schema(), FragReadConfig::default(), None) .await diff --git a/rust/lance/src/dataset/fragment.rs b/rust/lance/src/dataset/fragment.rs index 7788f7cbe0..161c97627f 100644 --- a/rust/lance/src/dataset/fragment.rs +++ b/rust/lance/src/dataset/fragment.rs @@ -710,7 +710,7 @@ impl FileFragment { row_id_sequence, opened_files, ArrowSchema::from(projection), - self.count_rows().await?, + self.count_rows(None).await?, num_physical_rows, )?; @@ -829,7 +829,7 @@ impl FileFragment { } // This should return immediately on modern datasets. - let num_rows = self.count_rows().await?; + let num_rows = self.count_rows(None).await?; // Check if there are any fields that are not in any data files let field_ids_in_files = opened_files @@ -849,15 +849,24 @@ impl FileFragment { } /// Count the rows in this fragment. - pub async fn count_rows(&self) -> Result { - let total_rows = self.physical_rows(); - - let deletion_count = self.count_deletions(); + pub async fn count_rows(&self, filter: Option) -> Result { + match filter { + Some(expr) => self + .scan() + .filter(&expr)? + .count_rows() + .await + .map(|v| v as usize), + None => { + let total_rows = self.physical_rows(); + let deletion_count = self.count_deletions(); - let (total_rows, deletion_count) = - futures::future::try_join(total_rows, deletion_count).await?; + let (total_rows, deletion_count) = + futures::future::try_join(total_rows, deletion_count).await?; - Ok(total_rows - deletion_count) + Ok(total_rows - deletion_count) + } + } } /// Get the number of rows that have been deleted in this fragment. @@ -2644,7 +2653,7 @@ mod tests { assert_eq!(fragments.len(), 5); for f in fragments { assert_eq!(f.metadata.num_rows(), Some(40)); - assert_eq!(f.count_rows().await.unwrap(), 40); + assert_eq!(f.count_rows(None).await.unwrap(), 40); assert_eq!(f.metadata().deletion_file, None); } } @@ -2660,10 +2669,18 @@ mod tests { let dataset = create_dataset(test_uri, data_storage_version).await; let fragment = dataset.get_fragments().pop().unwrap(); - assert_eq!(fragment.count_rows().await.unwrap(), 40); + assert_eq!(fragment.count_rows(None).await.unwrap(), 40); assert_eq!(fragment.physical_rows().await.unwrap(), 40); assert!(fragment.metadata.deletion_file.is_none()); + assert_eq!( + fragment + .count_rows(Some("i < 170".to_string())) + .await + .unwrap(), + 10 + ); + let fragment = fragment .delete("i >= 160 and i <= 172") .await @@ -2672,7 +2689,7 @@ mod tests { fragment.validate().await.unwrap(); - assert_eq!(fragment.count_rows().await.unwrap(), 27); + assert_eq!(fragment.count_rows(None).await.unwrap(), 27); assert_eq!(fragment.physical_rows().await.unwrap(), 40); assert!(fragment.metadata.deletion_file.is_some()); assert_eq!( diff --git a/rust/lance/src/dataset/scanner.rs b/rust/lance/src/dataset/scanner.rs index 4537b75961..22ee289c97 100644 --- a/rust/lance/src/dataset/scanner.rs +++ b/rust/lance/src/dataset/scanner.rs @@ -36,8 +36,9 @@ use datafusion::physical_plan::{ use datafusion::scalar::ScalarValue; use datafusion_physical_expr::aggregate::AggregateExprBuilder; use datafusion_physical_expr::{Partitioning, PhysicalExpr}; +use futures::future::BoxFuture; use futures::stream::{Stream, StreamExt}; -use futures::TryStreamExt; +use futures::{FutureExt, TryStreamExt}; use lance_arrow::floats::{coerce_float_vector, FloatType}; use lance_arrow::DataTypeExt; use lance_core::datatypes::{Field, OnMissing, Projection}; @@ -944,13 +945,17 @@ impl Scanner { /// Create a stream from the Scanner. #[instrument(skip_all)] - pub async fn try_into_stream(&self) -> Result { - let plan = self.create_plan().await?; - - Ok(DatasetRecordBatchStream::new(execute_plan( - plan, - LanceExecutionOptions::default(), - )?)) + pub fn try_into_stream(&self) -> BoxFuture> { + // Future intentionally boxed here to avoid large futures on the stack + async move { + let plan = self.create_plan().await?; + + Ok(DatasetRecordBatchStream::new(execute_plan( + plan, + LanceExecutionOptions::default(), + )?)) + } + .boxed() } pub(crate) async fn try_into_dfstream( @@ -970,46 +975,50 @@ impl Scanner { /// Scan and return the number of matching rows #[instrument(skip_all)] - pub async fn count_rows(&self) -> Result { - let plan = self.create_plan().await?; - // Datafusion interprets COUNT(*) as COUNT(1) - let one = Arc::new(Literal::new(ScalarValue::UInt8(Some(1)))); - - let input_phy_exprs: &[Arc] = &[one]; - let schema = plan.schema(); - - let mut builder = AggregateExprBuilder::new(count_udaf(), input_phy_exprs.to_vec()); - builder = builder.schema(schema); - builder = builder.alias("count_rows".to_string()); - - let count_expr = builder.build()?; - - let plan_schema = plan.schema(); - let count_plan = Arc::new(AggregateExec::try_new( - AggregateMode::Single, - PhysicalGroupBy::new_single(Vec::new()), - vec![count_expr], - vec![None], - plan, - plan_schema, - )?); - let mut stream = execute_plan(count_plan, LanceExecutionOptions::default())?; - - // A count plan will always return a single batch with a single row. - if let Some(first_batch) = stream.next().await { - let batch = first_batch?; - let array = batch - .column(0) - .as_any() - .downcast_ref::() - .ok_or(Error::io( - "Count plan did not return a UInt64Array".to_string(), - location!(), - ))?; - Ok(array.value(0) as u64) - } else { - Ok(0) + pub fn count_rows(&self) -> BoxFuture> { + // Future intentionally boxed here to avoid large futures on the stack + async move { + let plan = self.create_plan().await?; + // Datafusion interprets COUNT(*) as COUNT(1) + let one = Arc::new(Literal::new(ScalarValue::UInt8(Some(1)))); + + let input_phy_exprs: &[Arc] = &[one]; + let schema = plan.schema(); + + let mut builder = AggregateExprBuilder::new(count_udaf(), input_phy_exprs.to_vec()); + builder = builder.schema(schema); + builder = builder.alias("count_rows".to_string()); + + let count_expr = builder.build()?; + + let plan_schema = plan.schema(); + let count_plan = Arc::new(AggregateExec::try_new( + AggregateMode::Single, + PhysicalGroupBy::new_single(Vec::new()), + vec![count_expr], + vec![None], + plan, + plan_schema, + )?); + let mut stream = execute_plan(count_plan, LanceExecutionOptions::default())?; + + // A count plan will always return a single batch with a single row. + if let Some(first_batch) = stream.next().await { + let batch = first_batch?; + let array = batch + .column(0) + .as_any() + .downcast_ref::() + .ok_or(Error::io( + "Count plan did not return a UInt64Array".to_string(), + location!(), + ))?; + Ok(array.value(0) as u64) + } else { + Ok(0) + } } + .boxed() } /// Given a base schema and a list of desired fields figure out which fields, if any, still need loaded diff --git a/rust/lance/src/dataset/take.rs b/rust/lance/src/dataset/take.rs index c390bbd45c..8cbf44cd1f 100644 --- a/rust/lance/src/dataset/take.rs +++ b/rust/lance/src/dataset/take.rs @@ -45,7 +45,7 @@ pub async fn take( let mut frag_iter = fragments.iter(); let mut cur_frag = frag_iter.next(); let mut cur_frag_rows = if let Some(cur_frag) = cur_frag { - cur_frag.count_rows().await? as u64 + cur_frag.count_rows(None).await? as u64 } else { 0 }; @@ -57,7 +57,7 @@ pub async fn take( frag_offset += cur_frag_rows; cur_frag = frag_iter.next(); cur_frag_rows = if let Some(cur_frag) = cur_frag { - cur_frag.count_rows().await? as u64 + cur_frag.count_rows(None).await? as u64 } else { 0 }; diff --git a/rust/lance/src/io/exec/scan.rs b/rust/lance/src/io/exec/scan.rs index 5ec680c647..9cd6ac825f 100644 --- a/rust/lance/src/io/exec/scan.rs +++ b/rust/lance/src/io/exec/scan.rs @@ -159,7 +159,7 @@ impl LanceStream { if let Some(next_frag) = frags_iter.next() { let num_rows_in_frag = next_frag .fragment - .count_rows() + .count_rows(None) // count_rows should be a fast operation in v2 files .now_or_never() .ok_or(Error::Internal {