diff --git a/java/core/lance-jni/src/fragment.rs b/java/core/lance-jni/src/fragment.rs index dacdd08798..459afab022 100644 --- a/java/core/lance-jni/src/fragment.rs +++ b/java/core/lance-jni/src/fragment.rs @@ -62,7 +62,7 @@ fn inner_count_rows_native( "Fragment not found: {fragment_id}" ))); }; - let res = RT.block_on(fragment.count_rows())?; + let res = RT.block_on(fragment.count_rows(None))?; Ok(res) } diff --git a/python/python/lance/fragment.py b/python/python/lance/fragment.py index ce9334c682..e3abc3e1de 100644 --- a/python/python/lance/fragment.py +++ b/python/python/lance/fragment.py @@ -217,7 +217,7 @@ def __init__( if fragment_id is None: raise ValueError("Either fragment or fragment_id must be specified") fragment = dataset.get_fragment(fragment_id)._fragment - self._fragment = fragment + self._fragment: _Fragment = fragment if self._fragment is None: raise ValueError(f"Fragment id does not exist: {fragment_id}") @@ -367,8 +367,8 @@ def count_rows( self, filter: Optional[Union[pa.compute.Expression, str]] = None ) -> int: if filter is not None: - raise ValueError("Does not support filter at the moment") - return self._fragment.count_rows() + return self.scanner(filter=filter).count_rows() + return self._fragment.count_rows(filter) @property def num_deletions(self) -> int: diff --git a/python/python/tests/test_fragment.py b/python/python/tests/test_fragment.py index 7bae75759b..7a55e02788 100644 --- a/python/python/tests/test_fragment.py +++ b/python/python/tests/test_fragment.py @@ -9,6 +9,7 @@ import lance import pandas as pd import pyarrow as pa +import pyarrow.compute as pc import pytest from helper import ProgressForTest from lance import ( @@ -422,3 +423,15 @@ def test_fragment_merge(tmp_path): tmp_path, merge, read_version=dataset.latest_version ) assert [f.name for f in dataset.schema] == ["a", "b", "c", "d"] + + +def test_fragment_count_rows(tmp_path: Path): + data = pa.table({"a": range(800), "b": range(800)}) + ds = write_dataset(data, tmp_path) + + fragments = ds.get_fragments() + assert len(fragments) == 1 + + assert fragments[0].count_rows() == 800 + assert fragments[0].count_rows("a < 200") == 200 + assert fragments[0].count_rows(pc.field("a") < 200) == 200 diff --git a/python/src/fragment.rs b/python/src/fragment.rs index b5cb75fc3a..1ddf89a21b 100644 --- a/python/src/fragment.rs +++ b/python/src/fragment.rs @@ -127,11 +127,11 @@ impl FileFragment { PyLance(self.fragment.metadata().clone()) } - #[pyo3(signature=(_filter=None))] - fn count_rows(&self, _filter: Option) -> PyResult { + #[pyo3(signature=(filter=None))] + fn count_rows(&self, filter: Option) -> PyResult { RT.runtime.block_on(async { self.fragment - .count_rows() + .count_rows(filter) .await .map_err(|e| PyIOError::new_err(e.to_string())) }) diff --git a/rust/lance/src/dataset.rs b/rust/lance/src/dataset.rs index cbcf878d78..bd27c1fc31 100644 --- a/rust/lance/src/dataset.rs +++ b/rust/lance/src/dataset.rs @@ -798,7 +798,7 @@ impl Dataset { pub(crate) async fn count_all_rows(&self) -> Result { let cnts = stream::iter(self.get_fragments()) - .map(|f| async move { f.count_rows().await }) + .map(|f| async move { f.count_rows(None).await }) .buffer_unordered(16) .try_collect::>() .await?; @@ -2037,7 +2037,7 @@ mod tests { assert_eq!(fragments.len(), 10); assert_eq!(dataset.count_fragments(), 10); for fragment in &fragments { - assert_eq!(fragment.count_rows().await.unwrap(), 100); + assert_eq!(fragment.count_rows(None).await.unwrap(), 100); let reader = fragment .open(dataset.schema(), FragReadConfig::default(), None) .await diff --git a/rust/lance/src/dataset/fragment.rs b/rust/lance/src/dataset/fragment.rs index 7788f7cbe0..161c97627f 100644 --- a/rust/lance/src/dataset/fragment.rs +++ b/rust/lance/src/dataset/fragment.rs @@ -710,7 +710,7 @@ impl FileFragment { row_id_sequence, opened_files, ArrowSchema::from(projection), - self.count_rows().await?, + self.count_rows(None).await?, num_physical_rows, )?; @@ -829,7 +829,7 @@ impl FileFragment { } // This should return immediately on modern datasets. - let num_rows = self.count_rows().await?; + let num_rows = self.count_rows(None).await?; // Check if there are any fields that are not in any data files let field_ids_in_files = opened_files @@ -849,15 +849,24 @@ impl FileFragment { } /// Count the rows in this fragment. - pub async fn count_rows(&self) -> Result { - let total_rows = self.physical_rows(); - - let deletion_count = self.count_deletions(); + pub async fn count_rows(&self, filter: Option) -> Result { + match filter { + Some(expr) => self + .scan() + .filter(&expr)? + .count_rows() + .await + .map(|v| v as usize), + None => { + let total_rows = self.physical_rows(); + let deletion_count = self.count_deletions(); - let (total_rows, deletion_count) = - futures::future::try_join(total_rows, deletion_count).await?; + let (total_rows, deletion_count) = + futures::future::try_join(total_rows, deletion_count).await?; - Ok(total_rows - deletion_count) + Ok(total_rows - deletion_count) + } + } } /// Get the number of rows that have been deleted in this fragment. @@ -2644,7 +2653,7 @@ mod tests { assert_eq!(fragments.len(), 5); for f in fragments { assert_eq!(f.metadata.num_rows(), Some(40)); - assert_eq!(f.count_rows().await.unwrap(), 40); + assert_eq!(f.count_rows(None).await.unwrap(), 40); assert_eq!(f.metadata().deletion_file, None); } } @@ -2660,10 +2669,18 @@ mod tests { let dataset = create_dataset(test_uri, data_storage_version).await; let fragment = dataset.get_fragments().pop().unwrap(); - assert_eq!(fragment.count_rows().await.unwrap(), 40); + assert_eq!(fragment.count_rows(None).await.unwrap(), 40); assert_eq!(fragment.physical_rows().await.unwrap(), 40); assert!(fragment.metadata.deletion_file.is_none()); + assert_eq!( + fragment + .count_rows(Some("i < 170".to_string())) + .await + .unwrap(), + 10 + ); + let fragment = fragment .delete("i >= 160 and i <= 172") .await @@ -2672,7 +2689,7 @@ mod tests { fragment.validate().await.unwrap(); - assert_eq!(fragment.count_rows().await.unwrap(), 27); + assert_eq!(fragment.count_rows(None).await.unwrap(), 27); assert_eq!(fragment.physical_rows().await.unwrap(), 40); assert!(fragment.metadata.deletion_file.is_some()); assert_eq!( diff --git a/rust/lance/src/dataset/scanner.rs b/rust/lance/src/dataset/scanner.rs index 4537b75961..22ee289c97 100644 --- a/rust/lance/src/dataset/scanner.rs +++ b/rust/lance/src/dataset/scanner.rs @@ -36,8 +36,9 @@ use datafusion::physical_plan::{ use datafusion::scalar::ScalarValue; use datafusion_physical_expr::aggregate::AggregateExprBuilder; use datafusion_physical_expr::{Partitioning, PhysicalExpr}; +use futures::future::BoxFuture; use futures::stream::{Stream, StreamExt}; -use futures::TryStreamExt; +use futures::{FutureExt, TryStreamExt}; use lance_arrow::floats::{coerce_float_vector, FloatType}; use lance_arrow::DataTypeExt; use lance_core::datatypes::{Field, OnMissing, Projection}; @@ -944,13 +945,17 @@ impl Scanner { /// Create a stream from the Scanner. #[instrument(skip_all)] - pub async fn try_into_stream(&self) -> Result { - let plan = self.create_plan().await?; - - Ok(DatasetRecordBatchStream::new(execute_plan( - plan, - LanceExecutionOptions::default(), - )?)) + pub fn try_into_stream(&self) -> BoxFuture> { + // Future intentionally boxed here to avoid large futures on the stack + async move { + let plan = self.create_plan().await?; + + Ok(DatasetRecordBatchStream::new(execute_plan( + plan, + LanceExecutionOptions::default(), + )?)) + } + .boxed() } pub(crate) async fn try_into_dfstream( @@ -970,46 +975,50 @@ impl Scanner { /// Scan and return the number of matching rows #[instrument(skip_all)] - pub async fn count_rows(&self) -> Result { - let plan = self.create_plan().await?; - // Datafusion interprets COUNT(*) as COUNT(1) - let one = Arc::new(Literal::new(ScalarValue::UInt8(Some(1)))); - - let input_phy_exprs: &[Arc] = &[one]; - let schema = plan.schema(); - - let mut builder = AggregateExprBuilder::new(count_udaf(), input_phy_exprs.to_vec()); - builder = builder.schema(schema); - builder = builder.alias("count_rows".to_string()); - - let count_expr = builder.build()?; - - let plan_schema = plan.schema(); - let count_plan = Arc::new(AggregateExec::try_new( - AggregateMode::Single, - PhysicalGroupBy::new_single(Vec::new()), - vec![count_expr], - vec![None], - plan, - plan_schema, - )?); - let mut stream = execute_plan(count_plan, LanceExecutionOptions::default())?; - - // A count plan will always return a single batch with a single row. - if let Some(first_batch) = stream.next().await { - let batch = first_batch?; - let array = batch - .column(0) - .as_any() - .downcast_ref::() - .ok_or(Error::io( - "Count plan did not return a UInt64Array".to_string(), - location!(), - ))?; - Ok(array.value(0) as u64) - } else { - Ok(0) + pub fn count_rows(&self) -> BoxFuture> { + // Future intentionally boxed here to avoid large futures on the stack + async move { + let plan = self.create_plan().await?; + // Datafusion interprets COUNT(*) as COUNT(1) + let one = Arc::new(Literal::new(ScalarValue::UInt8(Some(1)))); + + let input_phy_exprs: &[Arc] = &[one]; + let schema = plan.schema(); + + let mut builder = AggregateExprBuilder::new(count_udaf(), input_phy_exprs.to_vec()); + builder = builder.schema(schema); + builder = builder.alias("count_rows".to_string()); + + let count_expr = builder.build()?; + + let plan_schema = plan.schema(); + let count_plan = Arc::new(AggregateExec::try_new( + AggregateMode::Single, + PhysicalGroupBy::new_single(Vec::new()), + vec![count_expr], + vec![None], + plan, + plan_schema, + )?); + let mut stream = execute_plan(count_plan, LanceExecutionOptions::default())?; + + // A count plan will always return a single batch with a single row. + if let Some(first_batch) = stream.next().await { + let batch = first_batch?; + let array = batch + .column(0) + .as_any() + .downcast_ref::() + .ok_or(Error::io( + "Count plan did not return a UInt64Array".to_string(), + location!(), + ))?; + Ok(array.value(0) as u64) + } else { + Ok(0) + } } + .boxed() } /// Given a base schema and a list of desired fields figure out which fields, if any, still need loaded diff --git a/rust/lance/src/dataset/take.rs b/rust/lance/src/dataset/take.rs index c390bbd45c..8cbf44cd1f 100644 --- a/rust/lance/src/dataset/take.rs +++ b/rust/lance/src/dataset/take.rs @@ -45,7 +45,7 @@ pub async fn take( let mut frag_iter = fragments.iter(); let mut cur_frag = frag_iter.next(); let mut cur_frag_rows = if let Some(cur_frag) = cur_frag { - cur_frag.count_rows().await? as u64 + cur_frag.count_rows(None).await? as u64 } else { 0 }; @@ -57,7 +57,7 @@ pub async fn take( frag_offset += cur_frag_rows; cur_frag = frag_iter.next(); cur_frag_rows = if let Some(cur_frag) = cur_frag { - cur_frag.count_rows().await? as u64 + cur_frag.count_rows(None).await? as u64 } else { 0 }; diff --git a/rust/lance/src/io/exec/scan.rs b/rust/lance/src/io/exec/scan.rs index 5ec680c647..9cd6ac825f 100644 --- a/rust/lance/src/io/exec/scan.rs +++ b/rust/lance/src/io/exec/scan.rs @@ -159,7 +159,7 @@ impl LanceStream { if let Some(next_frag) = frags_iter.next() { let num_rows_in_frag = next_frag .fragment - .count_rows() + .count_rows(None) // count_rows should be a fast operation in v2 files .now_or_never() .ok_or(Error::Internal {