From 20928088df35c40032e4b49291da4b8aaf44f0d1 Mon Sep 17 00:00:00 2001 From: Chongchen Chen Date: Wed, 1 Jan 2025 01:58:44 +0800 Subject: [PATCH 1/3] fix: default value is overwritten (#3319) --- python/python/lance/dataset.py | 33 +++++++++++++++-------------- python/python/tests/test_dataset.py | 7 ++++++ 2 files changed, 24 insertions(+), 16 deletions(-) diff --git a/python/python/lance/dataset.py b/python/python/lance/dataset.py index 7e7229b6a9..2274ca2a4b 100644 --- a/python/python/lance/dataset.py +++ b/python/python/lance/dataset.py @@ -507,13 +507,13 @@ def to_table( batch_size: Optional[int] = None, batch_readahead: Optional[int] = None, fragment_readahead: Optional[int] = None, - scan_in_order: bool = True, + scan_in_order: Optional[bool] = None, *, - prefilter: bool = False, - with_row_id: bool = False, - with_row_address: bool = False, - use_stats: bool = True, - fast_search: bool = False, + prefilter: Optional[bool] = None, + with_row_id: Optional[bool] = None, + with_row_address: Optional[bool] = None, + use_stats: Optional[bool] = None, + fast_search: Optional[bool] = None, full_text_query: Optional[Union[str, dict]] = None, io_buffer_size: Optional[int] = None, late_materialization: Optional[bool | List[str]] = None, @@ -558,11 +558,11 @@ def to_table( The number of batches to read ahead. fragment_readahead: int, optional The number of fragments to read ahead. - scan_in_order: bool, default True + scan_in_order: bool, optional, default True Whether to read the fragments and batches in order. If false, throughput may be higher, but batches will be returned out of order and memory use might increase. - prefilter: bool, default False + prefilter: bool, optional, default False Run filter before the vector search. late_materialization: bool or List[str], default None Allows custom control over late materialization. See @@ -570,12 +570,13 @@ def to_table( use_scalar_index: bool, default True Allows custom control over scalar index usage. See ``ScannerBuilder.use_scalar_index`` for more information. - with_row_id: bool, default False + with_row_id: bool, optional, default False Return row ID. - with_row_address: bool, default False + with_row_address: bool, optional, default False Return row address - use_stats: bool, default True + use_stats: bool, optional, default True Use stats pushdown during filters. + fast_search: bool, optional, default False full_text_query: str or dict, optional query string to search for, the results will be ranked by BM25. e.g. "hello world", would match documents contains "hello" or "world". @@ -687,12 +688,12 @@ def to_batches( batch_size: Optional[int] = None, batch_readahead: Optional[int] = None, fragment_readahead: Optional[int] = None, - scan_in_order: bool = True, + scan_in_order: Optional[bool] = None, *, - prefilter: bool = False, - with_row_id: bool = False, - with_row_address: bool = False, - use_stats: bool = True, + prefilter: Optional[bool] = None, + with_row_id: Optional[bool] = None, + with_row_address: Optional[bool] = None, + use_stats: Optional[bool] = None, full_text_query: Optional[Union[str, dict]] = None, io_buffer_size: Optional[int] = None, late_materialization: Optional[bool | List[str]] = None, diff --git a/python/python/tests/test_dataset.py b/python/python/tests/test_dataset.py index 587f6a8165..955702aa14 100644 --- a/python/python/tests/test_dataset.py +++ b/python/python/tests/test_dataset.py @@ -2806,3 +2806,10 @@ def test_dataset_drop(tmp_path: Path): assert Path(tmp_path).exists() lance.LanceDataset.drop(tmp_path) assert not Path(tmp_path).exists() + + +def test_dataset_schema(tmp_path: Path): + table = pa.table({"x": [0]}) + ds = lance.write_dataset(table, str(tmp_path)) # noqa: F841 + ds._default_scan_options = {"with_row_id": True} + assert ds.schema == ds.to_table().schema From 898396de974b6ecada0a84c0c78918684fbd9271 Mon Sep 17 00:00:00 2001 From: Lei Xu Date: Tue, 31 Dec 2024 10:18:37 -0800 Subject: [PATCH 2/3] feat(py): support count rows with filter in a fragment (#3318) Co-authored-by: Weston Pace --- java/core/lance-jni/src/fragment.rs | 2 +- python/python/lance/fragment.py | 6 +- python/python/tests/test_fragment.py | 13 ++++ python/src/fragment.rs | 6 +- rust/lance/src/dataset.rs | 4 +- rust/lance/src/dataset/fragment.rs | 41 +++++++---- rust/lance/src/dataset/scanner.rs | 103 +++++++++++++++------------ rust/lance/src/dataset/take.rs | 4 +- rust/lance/src/io/exec/scan.rs | 2 +- 9 files changed, 110 insertions(+), 71 deletions(-) diff --git a/java/core/lance-jni/src/fragment.rs b/java/core/lance-jni/src/fragment.rs index dacdd08798..459afab022 100644 --- a/java/core/lance-jni/src/fragment.rs +++ b/java/core/lance-jni/src/fragment.rs @@ -62,7 +62,7 @@ fn inner_count_rows_native( "Fragment not found: {fragment_id}" ))); }; - let res = RT.block_on(fragment.count_rows())?; + let res = RT.block_on(fragment.count_rows(None))?; Ok(res) } diff --git a/python/python/lance/fragment.py b/python/python/lance/fragment.py index ce9334c682..e3abc3e1de 100644 --- a/python/python/lance/fragment.py +++ b/python/python/lance/fragment.py @@ -217,7 +217,7 @@ def __init__( if fragment_id is None: raise ValueError("Either fragment or fragment_id must be specified") fragment = dataset.get_fragment(fragment_id)._fragment - self._fragment = fragment + self._fragment: _Fragment = fragment if self._fragment is None: raise ValueError(f"Fragment id does not exist: {fragment_id}") @@ -367,8 +367,8 @@ def count_rows( self, filter: Optional[Union[pa.compute.Expression, str]] = None ) -> int: if filter is not None: - raise ValueError("Does not support filter at the moment") - return self._fragment.count_rows() + return self.scanner(filter=filter).count_rows() + return self._fragment.count_rows(filter) @property def num_deletions(self) -> int: diff --git a/python/python/tests/test_fragment.py b/python/python/tests/test_fragment.py index 7bae75759b..7a55e02788 100644 --- a/python/python/tests/test_fragment.py +++ b/python/python/tests/test_fragment.py @@ -9,6 +9,7 @@ import lance import pandas as pd import pyarrow as pa +import pyarrow.compute as pc import pytest from helper import ProgressForTest from lance import ( @@ -422,3 +423,15 @@ def test_fragment_merge(tmp_path): tmp_path, merge, read_version=dataset.latest_version ) assert [f.name for f in dataset.schema] == ["a", "b", "c", "d"] + + +def test_fragment_count_rows(tmp_path: Path): + data = pa.table({"a": range(800), "b": range(800)}) + ds = write_dataset(data, tmp_path) + + fragments = ds.get_fragments() + assert len(fragments) == 1 + + assert fragments[0].count_rows() == 800 + assert fragments[0].count_rows("a < 200") == 200 + assert fragments[0].count_rows(pc.field("a") < 200) == 200 diff --git a/python/src/fragment.rs b/python/src/fragment.rs index b5cb75fc3a..1ddf89a21b 100644 --- a/python/src/fragment.rs +++ b/python/src/fragment.rs @@ -127,11 +127,11 @@ impl FileFragment { PyLance(self.fragment.metadata().clone()) } - #[pyo3(signature=(_filter=None))] - fn count_rows(&self, _filter: Option) -> PyResult { + #[pyo3(signature=(filter=None))] + fn count_rows(&self, filter: Option) -> PyResult { RT.runtime.block_on(async { self.fragment - .count_rows() + .count_rows(filter) .await .map_err(|e| PyIOError::new_err(e.to_string())) }) diff --git a/rust/lance/src/dataset.rs b/rust/lance/src/dataset.rs index cbcf878d78..bd27c1fc31 100644 --- a/rust/lance/src/dataset.rs +++ b/rust/lance/src/dataset.rs @@ -798,7 +798,7 @@ impl Dataset { pub(crate) async fn count_all_rows(&self) -> Result { let cnts = stream::iter(self.get_fragments()) - .map(|f| async move { f.count_rows().await }) + .map(|f| async move { f.count_rows(None).await }) .buffer_unordered(16) .try_collect::>() .await?; @@ -2037,7 +2037,7 @@ mod tests { assert_eq!(fragments.len(), 10); assert_eq!(dataset.count_fragments(), 10); for fragment in &fragments { - assert_eq!(fragment.count_rows().await.unwrap(), 100); + assert_eq!(fragment.count_rows(None).await.unwrap(), 100); let reader = fragment .open(dataset.schema(), FragReadConfig::default(), None) .await diff --git a/rust/lance/src/dataset/fragment.rs b/rust/lance/src/dataset/fragment.rs index 7788f7cbe0..161c97627f 100644 --- a/rust/lance/src/dataset/fragment.rs +++ b/rust/lance/src/dataset/fragment.rs @@ -710,7 +710,7 @@ impl FileFragment { row_id_sequence, opened_files, ArrowSchema::from(projection), - self.count_rows().await?, + self.count_rows(None).await?, num_physical_rows, )?; @@ -829,7 +829,7 @@ impl FileFragment { } // This should return immediately on modern datasets. - let num_rows = self.count_rows().await?; + let num_rows = self.count_rows(None).await?; // Check if there are any fields that are not in any data files let field_ids_in_files = opened_files @@ -849,15 +849,24 @@ impl FileFragment { } /// Count the rows in this fragment. - pub async fn count_rows(&self) -> Result { - let total_rows = self.physical_rows(); - - let deletion_count = self.count_deletions(); + pub async fn count_rows(&self, filter: Option) -> Result { + match filter { + Some(expr) => self + .scan() + .filter(&expr)? + .count_rows() + .await + .map(|v| v as usize), + None => { + let total_rows = self.physical_rows(); + let deletion_count = self.count_deletions(); - let (total_rows, deletion_count) = - futures::future::try_join(total_rows, deletion_count).await?; + let (total_rows, deletion_count) = + futures::future::try_join(total_rows, deletion_count).await?; - Ok(total_rows - deletion_count) + Ok(total_rows - deletion_count) + } + } } /// Get the number of rows that have been deleted in this fragment. @@ -2644,7 +2653,7 @@ mod tests { assert_eq!(fragments.len(), 5); for f in fragments { assert_eq!(f.metadata.num_rows(), Some(40)); - assert_eq!(f.count_rows().await.unwrap(), 40); + assert_eq!(f.count_rows(None).await.unwrap(), 40); assert_eq!(f.metadata().deletion_file, None); } } @@ -2660,10 +2669,18 @@ mod tests { let dataset = create_dataset(test_uri, data_storage_version).await; let fragment = dataset.get_fragments().pop().unwrap(); - assert_eq!(fragment.count_rows().await.unwrap(), 40); + assert_eq!(fragment.count_rows(None).await.unwrap(), 40); assert_eq!(fragment.physical_rows().await.unwrap(), 40); assert!(fragment.metadata.deletion_file.is_none()); + assert_eq!( + fragment + .count_rows(Some("i < 170".to_string())) + .await + .unwrap(), + 10 + ); + let fragment = fragment .delete("i >= 160 and i <= 172") .await @@ -2672,7 +2689,7 @@ mod tests { fragment.validate().await.unwrap(); - assert_eq!(fragment.count_rows().await.unwrap(), 27); + assert_eq!(fragment.count_rows(None).await.unwrap(), 27); assert_eq!(fragment.physical_rows().await.unwrap(), 40); assert!(fragment.metadata.deletion_file.is_some()); assert_eq!( diff --git a/rust/lance/src/dataset/scanner.rs b/rust/lance/src/dataset/scanner.rs index 4537b75961..22ee289c97 100644 --- a/rust/lance/src/dataset/scanner.rs +++ b/rust/lance/src/dataset/scanner.rs @@ -36,8 +36,9 @@ use datafusion::physical_plan::{ use datafusion::scalar::ScalarValue; use datafusion_physical_expr::aggregate::AggregateExprBuilder; use datafusion_physical_expr::{Partitioning, PhysicalExpr}; +use futures::future::BoxFuture; use futures::stream::{Stream, StreamExt}; -use futures::TryStreamExt; +use futures::{FutureExt, TryStreamExt}; use lance_arrow::floats::{coerce_float_vector, FloatType}; use lance_arrow::DataTypeExt; use lance_core::datatypes::{Field, OnMissing, Projection}; @@ -944,13 +945,17 @@ impl Scanner { /// Create a stream from the Scanner. #[instrument(skip_all)] - pub async fn try_into_stream(&self) -> Result { - let plan = self.create_plan().await?; - - Ok(DatasetRecordBatchStream::new(execute_plan( - plan, - LanceExecutionOptions::default(), - )?)) + pub fn try_into_stream(&self) -> BoxFuture> { + // Future intentionally boxed here to avoid large futures on the stack + async move { + let plan = self.create_plan().await?; + + Ok(DatasetRecordBatchStream::new(execute_plan( + plan, + LanceExecutionOptions::default(), + )?)) + } + .boxed() } pub(crate) async fn try_into_dfstream( @@ -970,46 +975,50 @@ impl Scanner { /// Scan and return the number of matching rows #[instrument(skip_all)] - pub async fn count_rows(&self) -> Result { - let plan = self.create_plan().await?; - // Datafusion interprets COUNT(*) as COUNT(1) - let one = Arc::new(Literal::new(ScalarValue::UInt8(Some(1)))); - - let input_phy_exprs: &[Arc] = &[one]; - let schema = plan.schema(); - - let mut builder = AggregateExprBuilder::new(count_udaf(), input_phy_exprs.to_vec()); - builder = builder.schema(schema); - builder = builder.alias("count_rows".to_string()); - - let count_expr = builder.build()?; - - let plan_schema = plan.schema(); - let count_plan = Arc::new(AggregateExec::try_new( - AggregateMode::Single, - PhysicalGroupBy::new_single(Vec::new()), - vec![count_expr], - vec![None], - plan, - plan_schema, - )?); - let mut stream = execute_plan(count_plan, LanceExecutionOptions::default())?; - - // A count plan will always return a single batch with a single row. - if let Some(first_batch) = stream.next().await { - let batch = first_batch?; - let array = batch - .column(0) - .as_any() - .downcast_ref::() - .ok_or(Error::io( - "Count plan did not return a UInt64Array".to_string(), - location!(), - ))?; - Ok(array.value(0) as u64) - } else { - Ok(0) + pub fn count_rows(&self) -> BoxFuture> { + // Future intentionally boxed here to avoid large futures on the stack + async move { + let plan = self.create_plan().await?; + // Datafusion interprets COUNT(*) as COUNT(1) + let one = Arc::new(Literal::new(ScalarValue::UInt8(Some(1)))); + + let input_phy_exprs: &[Arc] = &[one]; + let schema = plan.schema(); + + let mut builder = AggregateExprBuilder::new(count_udaf(), input_phy_exprs.to_vec()); + builder = builder.schema(schema); + builder = builder.alias("count_rows".to_string()); + + let count_expr = builder.build()?; + + let plan_schema = plan.schema(); + let count_plan = Arc::new(AggregateExec::try_new( + AggregateMode::Single, + PhysicalGroupBy::new_single(Vec::new()), + vec![count_expr], + vec![None], + plan, + plan_schema, + )?); + let mut stream = execute_plan(count_plan, LanceExecutionOptions::default())?; + + // A count plan will always return a single batch with a single row. + if let Some(first_batch) = stream.next().await { + let batch = first_batch?; + let array = batch + .column(0) + .as_any() + .downcast_ref::() + .ok_or(Error::io( + "Count plan did not return a UInt64Array".to_string(), + location!(), + ))?; + Ok(array.value(0) as u64) + } else { + Ok(0) + } } + .boxed() } /// Given a base schema and a list of desired fields figure out which fields, if any, still need loaded diff --git a/rust/lance/src/dataset/take.rs b/rust/lance/src/dataset/take.rs index c390bbd45c..8cbf44cd1f 100644 --- a/rust/lance/src/dataset/take.rs +++ b/rust/lance/src/dataset/take.rs @@ -45,7 +45,7 @@ pub async fn take( let mut frag_iter = fragments.iter(); let mut cur_frag = frag_iter.next(); let mut cur_frag_rows = if let Some(cur_frag) = cur_frag { - cur_frag.count_rows().await? as u64 + cur_frag.count_rows(None).await? as u64 } else { 0 }; @@ -57,7 +57,7 @@ pub async fn take( frag_offset += cur_frag_rows; cur_frag = frag_iter.next(); cur_frag_rows = if let Some(cur_frag) = cur_frag { - cur_frag.count_rows().await? as u64 + cur_frag.count_rows(None).await? as u64 } else { 0 }; diff --git a/rust/lance/src/io/exec/scan.rs b/rust/lance/src/io/exec/scan.rs index 5ec680c647..9cd6ac825f 100644 --- a/rust/lance/src/io/exec/scan.rs +++ b/rust/lance/src/io/exec/scan.rs @@ -159,7 +159,7 @@ impl LanceStream { if let Some(next_frag) = frags_iter.next() { let num_rows_in_frag = next_frag .fragment - .count_rows() + .count_rows(None) // count_rows should be a fast operation in v2 files .now_or_never() .ok_or(Error::Internal { From 8767c102f08de80e6648268695f7b2815465d48c Mon Sep 17 00:00:00 2001 From: vinoyang Date: Wed, 1 Jan 2025 02:19:03 +0800 Subject: [PATCH 3/3] feat(java): support take api for java module (#3316) --- java/core/lance-jni/src/blocking_dataset.rs | 58 ++++++++++++++++++- java/core/lance-jni/src/ffi.rs | 24 ++++++++ .../main/java/com/lancedb/lance/Dataset.java | 32 ++++++++++ .../com/lancedb/lance/test/JniTestHelper.java | 7 +++ .../java/com/lancedb/lance/DatasetTest.java | 35 ++++++++++- .../test/java/com/lancedb/lance/JNITest.java | 5 ++ 6 files changed, 156 insertions(+), 5 deletions(-) diff --git a/java/core/lance-jni/src/blocking_dataset.rs b/java/core/lance-jni/src/blocking_dataset.rs index 2e763afca5..322156d5a2 100644 --- a/java/core/lance-jni/src/blocking_dataset.rs +++ b/java/core/lance-jni/src/blocking_dataset.rs @@ -22,15 +22,16 @@ use arrow::datatypes::Schema; use arrow::ffi::FFI_ArrowSchema; use arrow::ffi_stream::ArrowArrayStreamReader; use arrow::ffi_stream::FFI_ArrowArrayStream; +use arrow::ipc::writer::StreamWriter; use arrow::record_batch::RecordBatchIterator; use arrow_schema::DataType; use jni::objects::{JMap, JString, JValue}; -use jni::sys::jlong; use jni::sys::{jboolean, jint}; +use jni::sys::{jbyteArray, jlong}; use jni::{objects::JObject, JNIEnv}; use lance::dataset::builder::DatasetBuilder; use lance::dataset::transaction::Operation; -use lance::dataset::{ColumnAlteration, Dataset, ReadParams, WriteParams}; +use lance::dataset::{ColumnAlteration, Dataset, ProjectionRequest, ReadParams, WriteParams}; use lance::io::{ObjectStore, ObjectStoreParams}; use lance::table::format::Fragment; use lance::table::format::Index; @@ -683,6 +684,59 @@ fn inner_list_indexes<'local>( Ok(array_list) } +#[no_mangle] +pub extern "system" fn Java_com_lancedb_lance_Dataset_nativeTake( + mut env: JNIEnv, + java_dataset: JObject, + indices_obj: JObject, // List + columns_obj: JObject, // List +) -> jbyteArray { + match inner_take(&mut env, java_dataset, indices_obj, columns_obj) { + Ok(byte_array) => byte_array, + Err(e) => { + let _ = env.throw_new("java/lang/RuntimeException", format!("{:?}", e)); + std::ptr::null_mut() + } + } +} + +fn inner_take( + env: &mut JNIEnv, + java_dataset: JObject, + indices_obj: JObject, // List + columns_obj: JObject, // List +) -> Result { + let indices: Vec = env.get_longs(&indices_obj)?; + let indices_u64: Vec = indices.iter().map(|&x| x as u64).collect(); + let indices_slice: &[u64] = &indices_u64; + let columns: Vec = env.get_strings(&columns_obj)?; + + let result = { + let dataset_guard = + unsafe { env.get_rust_field::<_, _, BlockingDataset>(java_dataset, NATIVE_DATASET) }?; + let dataset = &dataset_guard.inner; + + let projection = ProjectionRequest::from_columns(columns, dataset.schema()); + + match RT.block_on(dataset.take(indices_slice, projection)) { + Ok(res) => res, + Err(e) => { + return Err(e.into()); + } + } + }; + + let mut buffer = Vec::new(); + { + let mut writer = StreamWriter::try_new(&mut buffer, &result.schema())?; + writer.write(&result)?; + writer.finish()?; + } + + let byte_array = env.byte_array_from_slice(&buffer)?; + Ok(**byte_array) +} + ////////////////////////////// // Schema evolution Methods // ////////////////////////////// diff --git a/java/core/lance-jni/src/ffi.rs b/java/core/lance-jni/src/ffi.rs index dd11a1ee38..f92d3ec873 100644 --- a/java/core/lance-jni/src/ffi.rs +++ b/java/core/lance-jni/src/ffi.rs @@ -26,6 +26,9 @@ pub trait JNIEnvExt { /// Get integers from Java List object. fn get_integers(&mut self, obj: &JObject) -> Result>; + /// Get longs from Java List object. + fn get_longs(&mut self, obj: &JObject) -> Result>; + /// Get strings from Java List object. fn get_strings(&mut self, obj: &JObject) -> Result>; @@ -127,6 +130,18 @@ impl JNIEnvExt for JNIEnv<'_> { Ok(results) } + fn get_longs(&mut self, obj: &JObject) -> Result> { + let list = self.get_list(obj)?; + let mut iter = list.iter(self)?; + let mut results = Vec::with_capacity(list.size(self)? as usize); + while let Some(elem) = iter.next(self)? { + let long_obj = self.call_method(elem, "longValue", "()J", &[])?; + let long_value = long_obj.j()?; + results.push(long_value); + } + Ok(results) + } + fn get_strings(&mut self, obj: &JObject) -> Result> { let list = self.get_list(obj)?; let mut iter = list.iter(self)?; @@ -348,6 +363,15 @@ pub extern "system" fn Java_com_lancedb_lance_test_JniTestHelper_parseInts( ok_or_throw_without_return!(env, env.get_integers(&list_obj)); } +#[no_mangle] +pub extern "system" fn Java_com_lancedb_lance_test_JniTestHelper_parseLongs( + mut env: JNIEnv, + _obj: JObject, + list_obj: JObject, // List +) { + ok_or_throw_without_return!(env, env.get_longs(&list_obj)); +} + #[no_mangle] pub extern "system" fn Java_com_lancedb_lance_test_JniTestHelper_parseIntsOpt( mut env: JNIEnv, diff --git a/java/core/src/main/java/com/lancedb/lance/Dataset.java b/java/core/src/main/java/com/lancedb/lance/Dataset.java index 9a12d0c36a..8f1e5de507 100644 --- a/java/core/src/main/java/com/lancedb/lance/Dataset.java +++ b/java/core/src/main/java/com/lancedb/lance/Dataset.java @@ -24,9 +24,15 @@ import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.RootAllocator; import org.apache.arrow.util.Preconditions; +import org.apache.arrow.vector.ipc.ArrowReader; +import org.apache.arrow.vector.ipc.ArrowStreamReader; import org.apache.arrow.vector.types.pojo.Schema; +import java.io.ByteArrayInputStream; import java.io.Closeable; +import java.io.IOException; +import java.nio.channels.Channels; +import java.nio.channels.ReadableByteChannel; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -315,6 +321,32 @@ public LanceScanner newScan(ScanOptions options) { } } + /** + * Select rows of data by index. + * + * @param indices the indices to take + * @param columns the columns to take + * @return an ArrowReader + */ + public ArrowReader take(List indices, List columns) throws IOException { + Preconditions.checkArgument(nativeDatasetHandle != 0, "Dataset is closed"); + try (LockManager.ReadLock readLock = lockManager.acquireReadLock()) { + byte[] arrowData = nativeTake(indices, columns); + ByteArrayInputStream byteArrayInputStream = new ByteArrayInputStream(arrowData); + ReadableByteChannel readChannel = Channels.newChannel(byteArrayInputStream); + return new ArrowStreamReader(readChannel, allocator) { + @Override + public void close() throws IOException { + super.close(); + readChannel.close(); + byteArrayInputStream.close(); + } + }; + } + } + + private native byte[] nativeTake(List indices, List columns); + /** * Gets the URI of the dataset. * diff --git a/java/core/src/main/java/com/lancedb/lance/test/JniTestHelper.java b/java/core/src/main/java/com/lancedb/lance/test/JniTestHelper.java index 89f1f8a4b6..28ed442c0b 100644 --- a/java/core/src/main/java/com/lancedb/lance/test/JniTestHelper.java +++ b/java/core/src/main/java/com/lancedb/lance/test/JniTestHelper.java @@ -37,6 +37,13 @@ public class JniTestHelper { */ public static native void parseInts(List intsList); + /** + * JNI parse longs test. + * + * @param longsList the given list of longs + */ + public static native void parseLongs(List longsList); + /** * JNI parse ints opts test. * diff --git a/java/core/src/test/java/com/lancedb/lance/DatasetTest.java b/java/core/src/test/java/com/lancedb/lance/DatasetTest.java index 92765d28f2..4275ef9573 100644 --- a/java/core/src/test/java/com/lancedb/lance/DatasetTest.java +++ b/java/core/src/test/java/com/lancedb/lance/DatasetTest.java @@ -15,6 +15,8 @@ import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.ipc.ArrowReader; import org.apache.arrow.vector.types.pojo.ArrowType; import org.apache.arrow.vector.types.pojo.Field; import org.apache.arrow.vector.types.pojo.Schema; @@ -25,10 +27,9 @@ import java.io.IOException; import java.net.URISyntaxException; +import java.nio.channels.ClosedChannelException; import java.nio.file.Path; -import java.util.Arrays; -import java.util.Collections; -import java.util.HashMap; +import java.util.*; import java.util.stream.Collectors; import static org.junit.jupiter.api.Assertions.*; @@ -307,4 +308,32 @@ void testDropPath() { Dataset.drop(datasetPath, new HashMap<>()); } } + + @Test + void testTake() throws IOException, ClosedChannelException { + String testMethodName = new Object() {}.getClass().getEnclosingMethod().getName(); + String datasetPath = tempDir.resolve(testMethodName).toString(); + try (RootAllocator allocator = new RootAllocator(Long.MAX_VALUE)) { + TestUtils.SimpleTestDataset testDataset = + new TestUtils.SimpleTestDataset(allocator, datasetPath); + dataset = testDataset.createEmptyDataset(); + + try (Dataset dataset2 = testDataset.write(1, 5)) { + List indices = Arrays.asList(1L, 4L); + List columns = Arrays.asList("id", "name"); + try (ArrowReader reader = dataset2.take(indices, columns)) { + while (reader.loadNextBatch()) { + VectorSchemaRoot result = reader.getVectorSchemaRoot(); + assertNotNull(result); + assertEquals(indices.size(), result.getRowCount()); + + for (int i = 0; i < indices.size(); i++) { + assertEquals(indices.get(i).intValue(), result.getVector("id").getObject(i)); + assertNotNull(result.getVector("name").getObject(i)); + } + } + } + } + } + } } diff --git a/java/core/src/test/java/com/lancedb/lance/JNITest.java b/java/core/src/test/java/com/lancedb/lance/JNITest.java index 885379d804..ddb4ea3cdf 100644 --- a/java/core/src/test/java/com/lancedb/lance/JNITest.java +++ b/java/core/src/test/java/com/lancedb/lance/JNITest.java @@ -37,6 +37,11 @@ public void testInts() { JniTestHelper.parseInts(Arrays.asList(1, 2, 3)); } + @Test + public void testLongs() { + JniTestHelper.parseLongs(Arrays.asList(1L, 2L, 3L, Long.MAX_VALUE)); + } + @Test public void testIntsOpt() { JniTestHelper.parseIntsOpt(Optional.of(Arrays.asList(1, 2, 3)));