Skip to content

Commit

Permalink
feat(java): support take api for java module
Browse files Browse the repository at this point in the history
  • Loading branch information
yanghua committed Dec 30, 2024
1 parent bcb040e commit 8c8c64a
Show file tree
Hide file tree
Showing 3 changed files with 111 additions and 5 deletions.
58 changes: 56 additions & 2 deletions java/core/lance-jni/src/blocking_dataset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,16 @@ use arrow::datatypes::Schema;
use arrow::ffi::FFI_ArrowSchema;
use arrow::ffi_stream::ArrowArrayStreamReader;
use arrow::ffi_stream::FFI_ArrowArrayStream;
use arrow::ipc::writer::StreamWriter;
use arrow::record_batch::RecordBatchIterator;
use arrow_schema::DataType;
use jni::objects::{JMap, JString, JValue};
use jni::sys::jlong;
use jni::sys::{jboolean, jint};
use jni::sys::{jbyteArray, jlong};
use jni::{objects::JObject, JNIEnv};
use lance::dataset::builder::DatasetBuilder;
use lance::dataset::transaction::Operation;
use lance::dataset::{ColumnAlteration, Dataset, ReadParams, WriteParams};
use lance::dataset::{ColumnAlteration, Dataset, ProjectionRequest, ReadParams, WriteParams};
use lance::io::{ObjectStore, ObjectStoreParams};
use lance::table::format::Fragment;
use lance::table::format::Index;
Expand Down Expand Up @@ -683,6 +684,59 @@ fn inner_list_indexes<'local>(
Ok(array_list)
}

#[no_mangle]
pub extern "system" fn Java_com_lancedb_lance_Dataset_nativeTake<'local>(
mut env: JNIEnv<'local>,
java_dataset: JObject,
indices_obj: JObject, // List<Integer>
columns_obj: JObject, // List<String>
) -> jbyteArray {
match inner_take(&mut env, java_dataset, indices_obj, columns_obj) {
Ok(byte_array) => byte_array,
Err(e) => {
let _ = env.throw_new("java/lang/RuntimeException", format!("{:?}", e));
std::ptr::null_mut()
}
}
}

fn inner_take<'local>(
env: &mut JNIEnv<'local>,
java_dataset: JObject,
indices_obj: JObject, // List<Integer>
columns_obj: JObject, // List<String>
) -> Result<jbyteArray> {
let indices: Vec<i32> = env.get_integers(&indices_obj)?;
let indices_u64: Vec<u64> = indices.iter().map(|&x| x as u64).collect();
let indices_slice: &[u64] = &indices_u64;
let columns: Vec<String> = env.get_strings(&columns_obj)?;

let result = {
let dataset_guard =
unsafe { env.get_rust_field::<_, _, BlockingDataset>(java_dataset, NATIVE_DATASET) }?;
let dataset = &dataset_guard.inner;

let projection = ProjectionRequest::from_columns(columns, dataset.schema());

match RT.block_on(dataset.take(&indices_slice, projection)) {
Ok(res) => res,
Err(e) => {
return Err(e.into());
}
}
};

let mut buffer = Vec::new();
{
let mut writer = StreamWriter::try_new(&mut buffer, &result.schema())?;
writer.write(&result)?;
writer.finish()?;
}

let byte_array = env.byte_array_from_slice(&buffer)?;
Ok(**byte_array)
}

//////////////////////////////
// Schema evolution Methods //
//////////////////////////////
Expand Down
26 changes: 26 additions & 0 deletions java/core/src/main/java/com/lancedb/lance/Dataset.java
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,15 @@
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.util.Preconditions;
import org.apache.arrow.vector.ipc.ArrowReader;
import org.apache.arrow.vector.ipc.ArrowStreamReader;
import org.apache.arrow.vector.types.pojo.Schema;

import java.io.ByteArrayInputStream;
import java.io.Closeable;
import java.io.IOException;
import java.nio.channels.Channels;
import java.nio.channels.ReadableByteChannel;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -315,6 +321,26 @@ public LanceScanner newScan(ScanOptions options) {
}
}

/**
* Select rows of data by index.
*
* @param indices the indices to take
* @param columns the columns to take
* @return an ArrowReader
*/
public ArrowReader take(List<Integer> indices, List<String> columns) throws IOException {
try (LockManager.ReadLock readLock = lockManager.acquireReadLock()) {
Preconditions.checkArgument(nativeDatasetHandle != 0, "Scanner is closed");
byte[] arrowData = nativeTake(indices, columns);
try (ByteArrayInputStream byteArrayInputStream = new ByteArrayInputStream(arrowData);
ReadableByteChannel readChannel = Channels.newChannel(byteArrayInputStream)) {
return new ArrowStreamReader(readChannel, allocator);
}
}
}

private native byte[] nativeTake(List<Integer> indices, List<String> columns);

/**
* Gets the URI of the dataset.
*
Expand Down
32 changes: 29 additions & 3 deletions java/core/src/test/java/com/lancedb/lance/DatasetTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@

import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.ipc.ArrowReader;
import org.apache.arrow.vector.types.pojo.ArrowType;
import org.apache.arrow.vector.types.pojo.Field;
import org.apache.arrow.vector.types.pojo.Schema;
Expand All @@ -26,9 +28,7 @@
import java.io.IOException;
import java.net.URISyntaxException;
import java.nio.file.Path;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.*;
import java.util.stream.Collectors;

import static org.junit.jupiter.api.Assertions.*;
Expand Down Expand Up @@ -307,4 +307,30 @@ void testDropPath() {
Dataset.drop(datasetPath, new HashMap<>());
}
}

@Test
void testTake() throws IOException {
String testMethodName = new Object() {}.getClass().getEnclosingMethod().getName();
String datasetPath = tempDir.resolve(testMethodName).toString();
try (RootAllocator allocator = new RootAllocator(Long.MAX_VALUE)) {
TestUtils.SimpleTestDataset testDataset =
new TestUtils.SimpleTestDataset(allocator, datasetPath);
dataset = testDataset.createEmptyDataset();

try (Dataset dataset2 = testDataset.write(1, 5)) {
List<Integer> indices = Arrays.asList(1, 4);
List<String> columns = Arrays.asList("id", "name");
try (ArrowReader reader = dataset2.take(indices, columns)) {
reader.loadNextBatch();
VectorSchemaRoot result = reader.getVectorSchemaRoot();
assertNotNull(result);
assertEquals(indices.size(), result.getRowCount());

for (int i = 0; i < indices.size(); i++) {
assertEquals(indices.get(i), result.getVector("id").getObject(i));
}
}
}
}
}
}

0 comments on commit 8c8c64a

Please sign in to comment.