Skip to content

Commit

Permalink
feat(py): support count rows with filter in a fragment (#3318)
Browse files Browse the repository at this point in the history
Co-authored-by: Weston Pace <[email protected]>
  • Loading branch information
eddyxu and westonpace authored Dec 31, 2024
1 parent 2092808 commit 898396d
Show file tree
Hide file tree
Showing 9 changed files with 110 additions and 71 deletions.
2 changes: 1 addition & 1 deletion java/core/lance-jni/src/fragment.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ fn inner_count_rows_native(
"Fragment not found: {fragment_id}"
)));
};
let res = RT.block_on(fragment.count_rows())?;
let res = RT.block_on(fragment.count_rows(None))?;
Ok(res)
}

Expand Down
6 changes: 3 additions & 3 deletions python/python/lance/fragment.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ def __init__(
if fragment_id is None:
raise ValueError("Either fragment or fragment_id must be specified")
fragment = dataset.get_fragment(fragment_id)._fragment
self._fragment = fragment
self._fragment: _Fragment = fragment
if self._fragment is None:
raise ValueError(f"Fragment id does not exist: {fragment_id}")

Expand Down Expand Up @@ -367,8 +367,8 @@ def count_rows(
self, filter: Optional[Union[pa.compute.Expression, str]] = None
) -> int:
if filter is not None:
raise ValueError("Does not support filter at the moment")
return self._fragment.count_rows()
return self.scanner(filter=filter).count_rows()
return self._fragment.count_rows(filter)

@property
def num_deletions(self) -> int:
Expand Down
13 changes: 13 additions & 0 deletions python/python/tests/test_fragment.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import lance
import pandas as pd
import pyarrow as pa
import pyarrow.compute as pc
import pytest
from helper import ProgressForTest
from lance import (
Expand Down Expand Up @@ -422,3 +423,15 @@ def test_fragment_merge(tmp_path):
tmp_path, merge, read_version=dataset.latest_version
)
assert [f.name for f in dataset.schema] == ["a", "b", "c", "d"]


def test_fragment_count_rows(tmp_path: Path):
data = pa.table({"a": range(800), "b": range(800)})
ds = write_dataset(data, tmp_path)

fragments = ds.get_fragments()
assert len(fragments) == 1

assert fragments[0].count_rows() == 800
assert fragments[0].count_rows("a < 200") == 200
assert fragments[0].count_rows(pc.field("a") < 200) == 200
6 changes: 3 additions & 3 deletions python/src/fragment.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,11 +127,11 @@ impl FileFragment {
PyLance(self.fragment.metadata().clone())
}

#[pyo3(signature=(_filter=None))]
fn count_rows(&self, _filter: Option<String>) -> PyResult<usize> {
#[pyo3(signature=(filter=None))]
fn count_rows(&self, filter: Option<String>) -> PyResult<usize> {
RT.runtime.block_on(async {
self.fragment
.count_rows()
.count_rows(filter)
.await
.map_err(|e| PyIOError::new_err(e.to_string()))
})
Expand Down
4 changes: 2 additions & 2 deletions rust/lance/src/dataset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -798,7 +798,7 @@ impl Dataset {

pub(crate) async fn count_all_rows(&self) -> Result<usize> {
let cnts = stream::iter(self.get_fragments())
.map(|f| async move { f.count_rows().await })
.map(|f| async move { f.count_rows(None).await })
.buffer_unordered(16)
.try_collect::<Vec<_>>()
.await?;
Expand Down Expand Up @@ -2037,7 +2037,7 @@ mod tests {
assert_eq!(fragments.len(), 10);
assert_eq!(dataset.count_fragments(), 10);
for fragment in &fragments {
assert_eq!(fragment.count_rows().await.unwrap(), 100);
assert_eq!(fragment.count_rows(None).await.unwrap(), 100);
let reader = fragment
.open(dataset.schema(), FragReadConfig::default(), None)
.await
Expand Down
41 changes: 29 additions & 12 deletions rust/lance/src/dataset/fragment.rs
Original file line number Diff line number Diff line change
Expand Up @@ -710,7 +710,7 @@ impl FileFragment {
row_id_sequence,
opened_files,
ArrowSchema::from(projection),
self.count_rows().await?,
self.count_rows(None).await?,
num_physical_rows,
)?;

Expand Down Expand Up @@ -829,7 +829,7 @@ impl FileFragment {
}

// This should return immediately on modern datasets.
let num_rows = self.count_rows().await?;
let num_rows = self.count_rows(None).await?;

// Check if there are any fields that are not in any data files
let field_ids_in_files = opened_files
Expand All @@ -849,15 +849,24 @@ impl FileFragment {
}

/// Count the rows in this fragment.
pub async fn count_rows(&self) -> Result<usize> {
let total_rows = self.physical_rows();

let deletion_count = self.count_deletions();
pub async fn count_rows(&self, filter: Option<String>) -> Result<usize> {
match filter {
Some(expr) => self
.scan()
.filter(&expr)?
.count_rows()
.await
.map(|v| v as usize),
None => {
let total_rows = self.physical_rows();
let deletion_count = self.count_deletions();

let (total_rows, deletion_count) =
futures::future::try_join(total_rows, deletion_count).await?;
let (total_rows, deletion_count) =
futures::future::try_join(total_rows, deletion_count).await?;

Ok(total_rows - deletion_count)
Ok(total_rows - deletion_count)
}
}
}

/// Get the number of rows that have been deleted in this fragment.
Expand Down Expand Up @@ -2644,7 +2653,7 @@ mod tests {
assert_eq!(fragments.len(), 5);
for f in fragments {
assert_eq!(f.metadata.num_rows(), Some(40));
assert_eq!(f.count_rows().await.unwrap(), 40);
assert_eq!(f.count_rows(None).await.unwrap(), 40);
assert_eq!(f.metadata().deletion_file, None);
}
}
Expand All @@ -2660,10 +2669,18 @@ mod tests {
let dataset = create_dataset(test_uri, data_storage_version).await;
let fragment = dataset.get_fragments().pop().unwrap();

assert_eq!(fragment.count_rows().await.unwrap(), 40);
assert_eq!(fragment.count_rows(None).await.unwrap(), 40);
assert_eq!(fragment.physical_rows().await.unwrap(), 40);
assert!(fragment.metadata.deletion_file.is_none());

assert_eq!(
fragment
.count_rows(Some("i < 170".to_string()))
.await
.unwrap(),
10
);

let fragment = fragment
.delete("i >= 160 and i <= 172")
.await
Expand All @@ -2672,7 +2689,7 @@ mod tests {

fragment.validate().await.unwrap();

assert_eq!(fragment.count_rows().await.unwrap(), 27);
assert_eq!(fragment.count_rows(None).await.unwrap(), 27);
assert_eq!(fragment.physical_rows().await.unwrap(), 40);
assert!(fragment.metadata.deletion_file.is_some());
assert_eq!(
Expand Down
103 changes: 56 additions & 47 deletions rust/lance/src/dataset/scanner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,9 @@ use datafusion::physical_plan::{
use datafusion::scalar::ScalarValue;
use datafusion_physical_expr::aggregate::AggregateExprBuilder;
use datafusion_physical_expr::{Partitioning, PhysicalExpr};
use futures::future::BoxFuture;
use futures::stream::{Stream, StreamExt};
use futures::TryStreamExt;
use futures::{FutureExt, TryStreamExt};
use lance_arrow::floats::{coerce_float_vector, FloatType};
use lance_arrow::DataTypeExt;
use lance_core::datatypes::{Field, OnMissing, Projection};
Expand Down Expand Up @@ -944,13 +945,17 @@ impl Scanner {

/// Create a stream from the Scanner.
#[instrument(skip_all)]
pub async fn try_into_stream(&self) -> Result<DatasetRecordBatchStream> {
let plan = self.create_plan().await?;

Ok(DatasetRecordBatchStream::new(execute_plan(
plan,
LanceExecutionOptions::default(),
)?))
pub fn try_into_stream(&self) -> BoxFuture<Result<DatasetRecordBatchStream>> {
// Future intentionally boxed here to avoid large futures on the stack
async move {
let plan = self.create_plan().await?;

Ok(DatasetRecordBatchStream::new(execute_plan(
plan,
LanceExecutionOptions::default(),
)?))
}
.boxed()
}

pub(crate) async fn try_into_dfstream(
Expand All @@ -970,46 +975,50 @@ impl Scanner {

/// Scan and return the number of matching rows
#[instrument(skip_all)]
pub async fn count_rows(&self) -> Result<u64> {
let plan = self.create_plan().await?;
// Datafusion interprets COUNT(*) as COUNT(1)
let one = Arc::new(Literal::new(ScalarValue::UInt8(Some(1))));

let input_phy_exprs: &[Arc<dyn PhysicalExpr>] = &[one];
let schema = plan.schema();

let mut builder = AggregateExprBuilder::new(count_udaf(), input_phy_exprs.to_vec());
builder = builder.schema(schema);
builder = builder.alias("count_rows".to_string());

let count_expr = builder.build()?;

let plan_schema = plan.schema();
let count_plan = Arc::new(AggregateExec::try_new(
AggregateMode::Single,
PhysicalGroupBy::new_single(Vec::new()),
vec![count_expr],
vec![None],
plan,
plan_schema,
)?);
let mut stream = execute_plan(count_plan, LanceExecutionOptions::default())?;

// A count plan will always return a single batch with a single row.
if let Some(first_batch) = stream.next().await {
let batch = first_batch?;
let array = batch
.column(0)
.as_any()
.downcast_ref::<Int64Array>()
.ok_or(Error::io(
"Count plan did not return a UInt64Array".to_string(),
location!(),
))?;
Ok(array.value(0) as u64)
} else {
Ok(0)
pub fn count_rows(&self) -> BoxFuture<Result<u64>> {
// Future intentionally boxed here to avoid large futures on the stack
async move {
let plan = self.create_plan().await?;
// Datafusion interprets COUNT(*) as COUNT(1)
let one = Arc::new(Literal::new(ScalarValue::UInt8(Some(1))));

let input_phy_exprs: &[Arc<dyn PhysicalExpr>] = &[one];
let schema = plan.schema();

let mut builder = AggregateExprBuilder::new(count_udaf(), input_phy_exprs.to_vec());
builder = builder.schema(schema);
builder = builder.alias("count_rows".to_string());

let count_expr = builder.build()?;

let plan_schema = plan.schema();
let count_plan = Arc::new(AggregateExec::try_new(
AggregateMode::Single,
PhysicalGroupBy::new_single(Vec::new()),
vec![count_expr],
vec![None],
plan,
plan_schema,
)?);
let mut stream = execute_plan(count_plan, LanceExecutionOptions::default())?;

// A count plan will always return a single batch with a single row.
if let Some(first_batch) = stream.next().await {
let batch = first_batch?;
let array = batch
.column(0)
.as_any()
.downcast_ref::<Int64Array>()
.ok_or(Error::io(
"Count plan did not return a UInt64Array".to_string(),
location!(),
))?;
Ok(array.value(0) as u64)
} else {
Ok(0)
}
}
.boxed()
}

/// Given a base schema and a list of desired fields figure out which fields, if any, still need loaded
Expand Down
4 changes: 2 additions & 2 deletions rust/lance/src/dataset/take.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ pub async fn take(
let mut frag_iter = fragments.iter();
let mut cur_frag = frag_iter.next();
let mut cur_frag_rows = if let Some(cur_frag) = cur_frag {
cur_frag.count_rows().await? as u64
cur_frag.count_rows(None).await? as u64
} else {
0
};
Expand All @@ -57,7 +57,7 @@ pub async fn take(
frag_offset += cur_frag_rows;
cur_frag = frag_iter.next();
cur_frag_rows = if let Some(cur_frag) = cur_frag {
cur_frag.count_rows().await? as u64
cur_frag.count_rows(None).await? as u64
} else {
0
};
Expand Down
2 changes: 1 addition & 1 deletion rust/lance/src/io/exec/scan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ impl LanceStream {
if let Some(next_frag) = frags_iter.next() {
let num_rows_in_frag = next_frag
.fragment
.count_rows()
.count_rows(None)
// count_rows should be a fast operation in v2 files
.now_or_never()
.ok_or(Error::Internal {
Expand Down

0 comments on commit 898396d

Please sign in to comment.