diff --git a/java/core/lance-jni/src/blocking_dataset.rs b/java/core/lance-jni/src/blocking_dataset.rs index 2e763afca5..1cc01b88c1 100644 --- a/java/core/lance-jni/src/blocking_dataset.rs +++ b/java/core/lance-jni/src/blocking_dataset.rs @@ -22,15 +22,16 @@ use arrow::datatypes::Schema; use arrow::ffi::FFI_ArrowSchema; use arrow::ffi_stream::ArrowArrayStreamReader; use arrow::ffi_stream::FFI_ArrowArrayStream; +use arrow::ipc::writer::StreamWriter; use arrow::record_batch::RecordBatchIterator; use arrow_schema::DataType; use jni::objects::{JMap, JString, JValue}; -use jni::sys::jlong; use jni::sys::{jboolean, jint}; +use jni::sys::{jbyteArray, jlong}; use jni::{objects::JObject, JNIEnv}; use lance::dataset::builder::DatasetBuilder; use lance::dataset::transaction::Operation; -use lance::dataset::{ColumnAlteration, Dataset, ReadParams, WriteParams}; +use lance::dataset::{ColumnAlteration, Dataset, ProjectionRequest, ReadParams, WriteParams}; use lance::io::{ObjectStore, ObjectStoreParams}; use lance::table::format::Fragment; use lance::table::format::Index; @@ -683,6 +684,59 @@ fn inner_list_indexes<'local>( Ok(array_list) } +#[no_mangle] +pub extern "system" fn Java_com_lancedb_lance_Dataset_nativeTake<'local>( + mut env: JNIEnv<'local>, + java_dataset: JObject, + indices_obj: JObject, // List + columns_obj: JObject, // List +) -> jbyteArray { + match inner_take(&mut env, java_dataset, indices_obj, columns_obj) { + Ok(byte_array) => byte_array, + Err(e) => { + let _ = env.throw_new("java/lang/RuntimeException", format!("{:?}", e)); + std::ptr::null_mut() + } + } +} + +fn inner_take<'local>( + env: &mut JNIEnv<'local>, + java_dataset: JObject, + indices_obj: JObject, // List + columns_obj: JObject, // List +) -> Result { + let indices: Vec = env.get_integers(&indices_obj)?; + let indices_u64: Vec = indices.iter().map(|&x| x as u64).collect(); + let indices_slice: &[u64] = &indices_u64; + let columns: Vec = env.get_strings(&columns_obj)?; + + let result = { + let dataset_guard = + unsafe { env.get_rust_field::<_, _, BlockingDataset>(java_dataset, NATIVE_DATASET) }?; + let dataset = &dataset_guard.inner; + + let projection = ProjectionRequest::from_columns(columns, dataset.schema()); + + match RT.block_on(dataset.take(&indices_slice, projection)) { + Ok(res) => res, + Err(e) => { + return Err(e.into()); + } + } + }; + + let mut buffer = Vec::new(); + { + let mut writer = StreamWriter::try_new(&mut buffer, &result.schema())?; + writer.write(&result)?; + writer.finish()?; + } + + let byte_array = env.byte_array_from_slice(&buffer)?; + Ok(**byte_array) +} + ////////////////////////////// // Schema evolution Methods // ////////////////////////////// diff --git a/java/core/src/main/java/com/lancedb/lance/Dataset.java b/java/core/src/main/java/com/lancedb/lance/Dataset.java index 9a12d0c36a..c8208dfca5 100644 --- a/java/core/src/main/java/com/lancedb/lance/Dataset.java +++ b/java/core/src/main/java/com/lancedb/lance/Dataset.java @@ -24,9 +24,15 @@ import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.RootAllocator; import org.apache.arrow.util.Preconditions; +import org.apache.arrow.vector.ipc.ArrowReader; +import org.apache.arrow.vector.ipc.ArrowStreamReader; import org.apache.arrow.vector.types.pojo.Schema; +import java.io.ByteArrayInputStream; import java.io.Closeable; +import java.io.IOException; +import java.nio.channels.Channels; +import java.nio.channels.ReadableByteChannel; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -315,6 +321,26 @@ public LanceScanner newScan(ScanOptions options) { } } + /** + * Select rows of data by index. + * + * @param indices the indices to take + * @param columns the columns to take + * @return an ArrowReader + */ + public ArrowReader take(List indices, List columns) throws IOException { + try (LockManager.ReadLock readLock = lockManager.acquireReadLock()) { + Preconditions.checkArgument(nativeDatasetHandle != 0, "Scanner is closed"); + byte[] arrowData = nativeTake(indices, columns); + try (ByteArrayInputStream byteArrayInputStream = new ByteArrayInputStream(arrowData); + ReadableByteChannel readChannel = Channels.newChannel(byteArrayInputStream)) { + return new ArrowStreamReader(readChannel, allocator); + } + } + } + + private native byte[] nativeTake(List indices, List columns); + /** * Gets the URI of the dataset. * diff --git a/java/core/src/test/java/com/lancedb/lance/DatasetTest.java b/java/core/src/test/java/com/lancedb/lance/DatasetTest.java index 92765d28f2..7496e77be2 100644 --- a/java/core/src/test/java/com/lancedb/lance/DatasetTest.java +++ b/java/core/src/test/java/com/lancedb/lance/DatasetTest.java @@ -15,6 +15,8 @@ import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.ipc.ArrowReader; import org.apache.arrow.vector.types.pojo.ArrowType; import org.apache.arrow.vector.types.pojo.Field; import org.apache.arrow.vector.types.pojo.Schema; @@ -26,9 +28,7 @@ import java.io.IOException; import java.net.URISyntaxException; import java.nio.file.Path; -import java.util.Arrays; -import java.util.Collections; -import java.util.HashMap; +import java.util.*; import java.util.stream.Collectors; import static org.junit.jupiter.api.Assertions.*; @@ -307,4 +307,30 @@ void testDropPath() { Dataset.drop(datasetPath, new HashMap<>()); } } + + @Test + void testTake() throws IOException { + String testMethodName = new Object() {}.getClass().getEnclosingMethod().getName(); + String datasetPath = tempDir.resolve(testMethodName).toString(); + try (RootAllocator allocator = new RootAllocator(Long.MAX_VALUE)) { + TestUtils.SimpleTestDataset testDataset = + new TestUtils.SimpleTestDataset(allocator, datasetPath); + dataset = testDataset.createEmptyDataset(); + + try (Dataset dataset2 = testDataset.write(1, 5)) { + List indices = Arrays.asList(1, 4); + List columns = Arrays.asList("id", "name"); + try (ArrowReader reader = dataset2.take(indices, columns)) { + reader.loadNextBatch(); + VectorSchemaRoot result = reader.getVectorSchemaRoot(); + assertNotNull(result); + assertEquals(indices.size(), result.getRowCount()); + + for (int i = 0; i < indices.size(); i++) { + assertEquals(indices.get(i), result.getVector("id").getObject(i)); + } + } + } + } + } }