Skip to content

Commit

Permalink
Merge branch 'main' into 3293-type-check
Browse files Browse the repository at this point in the history
  • Loading branch information
eddyxu authored Dec 31, 2024
2 parents f35c051 + 8767c10 commit 0544719
Show file tree
Hide file tree
Showing 24 changed files with 421 additions and 101 deletions.
1 change: 1 addition & 0 deletions .github/workflows/ci-benchmarks.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
name: Run Regression Benchmarks

on:
workflow_dispatch:
push:
branches:
- main
Expand Down
58 changes: 56 additions & 2 deletions java/core/lance-jni/src/blocking_dataset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,16 @@ use arrow::datatypes::Schema;
use arrow::ffi::FFI_ArrowSchema;
use arrow::ffi_stream::ArrowArrayStreamReader;
use arrow::ffi_stream::FFI_ArrowArrayStream;
use arrow::ipc::writer::StreamWriter;
use arrow::record_batch::RecordBatchIterator;
use arrow_schema::DataType;
use jni::objects::{JMap, JString, JValue};
use jni::sys::jlong;
use jni::sys::{jboolean, jint};
use jni::sys::{jbyteArray, jlong};
use jni::{objects::JObject, JNIEnv};
use lance::dataset::builder::DatasetBuilder;
use lance::dataset::transaction::Operation;
use lance::dataset::{ColumnAlteration, Dataset, ReadParams, WriteParams};
use lance::dataset::{ColumnAlteration, Dataset, ProjectionRequest, ReadParams, WriteParams};
use lance::io::{ObjectStore, ObjectStoreParams};
use lance::table::format::Fragment;
use lance::table::format::Index;
Expand Down Expand Up @@ -683,6 +684,59 @@ fn inner_list_indexes<'local>(
Ok(array_list)
}

#[no_mangle]
pub extern "system" fn Java_com_lancedb_lance_Dataset_nativeTake(
mut env: JNIEnv,
java_dataset: JObject,
indices_obj: JObject, // List<Long>
columns_obj: JObject, // List<String>
) -> jbyteArray {
match inner_take(&mut env, java_dataset, indices_obj, columns_obj) {
Ok(byte_array) => byte_array,
Err(e) => {
let _ = env.throw_new("java/lang/RuntimeException", format!("{:?}", e));
std::ptr::null_mut()
}
}
}

fn inner_take(
env: &mut JNIEnv,
java_dataset: JObject,
indices_obj: JObject, // List<Long>
columns_obj: JObject, // List<String>
) -> Result<jbyteArray> {
let indices: Vec<i64> = env.get_longs(&indices_obj)?;
let indices_u64: Vec<u64> = indices.iter().map(|&x| x as u64).collect();
let indices_slice: &[u64] = &indices_u64;
let columns: Vec<String> = env.get_strings(&columns_obj)?;

let result = {
let dataset_guard =
unsafe { env.get_rust_field::<_, _, BlockingDataset>(java_dataset, NATIVE_DATASET) }?;
let dataset = &dataset_guard.inner;

let projection = ProjectionRequest::from_columns(columns, dataset.schema());

match RT.block_on(dataset.take(indices_slice, projection)) {
Ok(res) => res,
Err(e) => {
return Err(e.into());
}
}
};

let mut buffer = Vec::new();
{
let mut writer = StreamWriter::try_new(&mut buffer, &result.schema())?;
writer.write(&result)?;
writer.finish()?;
}

let byte_array = env.byte_array_from_slice(&buffer)?;
Ok(**byte_array)
}

//////////////////////////////
// Schema evolution Methods //
//////////////////////////////
Expand Down
24 changes: 24 additions & 0 deletions java/core/lance-jni/src/ffi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ pub trait JNIEnvExt {
/// Get integers from Java List<Integer> object.
fn get_integers(&mut self, obj: &JObject) -> Result<Vec<i32>>;

/// Get longs from Java List<Long> object.
fn get_longs(&mut self, obj: &JObject) -> Result<Vec<i64>>;

/// Get strings from Java List<String> object.
fn get_strings(&mut self, obj: &JObject) -> Result<Vec<String>>;

Expand Down Expand Up @@ -127,6 +130,18 @@ impl JNIEnvExt for JNIEnv<'_> {
Ok(results)
}

fn get_longs(&mut self, obj: &JObject) -> Result<Vec<i64>> {
let list = self.get_list(obj)?;
let mut iter = list.iter(self)?;
let mut results = Vec::with_capacity(list.size(self)? as usize);
while let Some(elem) = iter.next(self)? {
let long_obj = self.call_method(elem, "longValue", "()J", &[])?;
let long_value = long_obj.j()?;
results.push(long_value);
}
Ok(results)
}

fn get_strings(&mut self, obj: &JObject) -> Result<Vec<String>> {
let list = self.get_list(obj)?;
let mut iter = list.iter(self)?;
Expand Down Expand Up @@ -348,6 +363,15 @@ pub extern "system" fn Java_com_lancedb_lance_test_JniTestHelper_parseInts(
ok_or_throw_without_return!(env, env.get_integers(&list_obj));
}

#[no_mangle]
pub extern "system" fn Java_com_lancedb_lance_test_JniTestHelper_parseLongs(
mut env: JNIEnv,
_obj: JObject,
list_obj: JObject, // List<Long>
) {
ok_or_throw_without_return!(env, env.get_longs(&list_obj));
}

#[no_mangle]
pub extern "system" fn Java_com_lancedb_lance_test_JniTestHelper_parseIntsOpt(
mut env: JNIEnv,
Expand Down
2 changes: 1 addition & 1 deletion java/core/lance-jni/src/fragment.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ fn inner_count_rows_native(
"Fragment not found: {fragment_id}"
)));
};
let res = RT.block_on(fragment.count_rows())?;
let res = RT.block_on(fragment.count_rows(None))?;
Ok(res)
}

Expand Down
32 changes: 32 additions & 0 deletions java/core/src/main/java/com/lancedb/lance/Dataset.java
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,15 @@
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.util.Preconditions;
import org.apache.arrow.vector.ipc.ArrowReader;
import org.apache.arrow.vector.ipc.ArrowStreamReader;
import org.apache.arrow.vector.types.pojo.Schema;

import java.io.ByteArrayInputStream;
import java.io.Closeable;
import java.io.IOException;
import java.nio.channels.Channels;
import java.nio.channels.ReadableByteChannel;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -315,6 +321,32 @@ public LanceScanner newScan(ScanOptions options) {
}
}

/**
* Select rows of data by index.
*
* @param indices the indices to take
* @param columns the columns to take
* @return an ArrowReader
*/
public ArrowReader take(List<Long> indices, List<String> columns) throws IOException {
Preconditions.checkArgument(nativeDatasetHandle != 0, "Dataset is closed");
try (LockManager.ReadLock readLock = lockManager.acquireReadLock()) {
byte[] arrowData = nativeTake(indices, columns);
ByteArrayInputStream byteArrayInputStream = new ByteArrayInputStream(arrowData);
ReadableByteChannel readChannel = Channels.newChannel(byteArrayInputStream);
return new ArrowStreamReader(readChannel, allocator) {
@Override
public void close() throws IOException {
super.close();
readChannel.close();
byteArrayInputStream.close();
}
};
}
}

private native byte[] nativeTake(List<Long> indices, List<String> columns);

/**
* Gets the URI of the dataset.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,13 @@ public class JniTestHelper {
*/
public static native void parseInts(List<Integer> intsList);

/**
* JNI parse longs test.
*
* @param longsList the given list of longs
*/
public static native void parseLongs(List<Long> longsList);

/**
* JNI parse ints opts test.
*
Expand Down
35 changes: 32 additions & 3 deletions java/core/src/test/java/com/lancedb/lance/DatasetTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@

import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.ipc.ArrowReader;
import org.apache.arrow.vector.types.pojo.ArrowType;
import org.apache.arrow.vector.types.pojo.Field;
import org.apache.arrow.vector.types.pojo.Schema;
Expand All @@ -25,10 +27,9 @@

import java.io.IOException;
import java.net.URISyntaxException;
import java.nio.channels.ClosedChannelException;
import java.nio.file.Path;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.*;
import java.util.stream.Collectors;

import static org.junit.jupiter.api.Assertions.*;
Expand Down Expand Up @@ -307,4 +308,32 @@ void testDropPath() {
Dataset.drop(datasetPath, new HashMap<>());
}
}

@Test
void testTake() throws IOException, ClosedChannelException {
String testMethodName = new Object() {}.getClass().getEnclosingMethod().getName();
String datasetPath = tempDir.resolve(testMethodName).toString();
try (RootAllocator allocator = new RootAllocator(Long.MAX_VALUE)) {
TestUtils.SimpleTestDataset testDataset =
new TestUtils.SimpleTestDataset(allocator, datasetPath);
dataset = testDataset.createEmptyDataset();

try (Dataset dataset2 = testDataset.write(1, 5)) {
List<Long> indices = Arrays.asList(1L, 4L);
List<String> columns = Arrays.asList("id", "name");
try (ArrowReader reader = dataset2.take(indices, columns)) {
while (reader.loadNextBatch()) {
VectorSchemaRoot result = reader.getVectorSchemaRoot();
assertNotNull(result);
assertEquals(indices.size(), result.getRowCount());

for (int i = 0; i < indices.size(); i++) {
assertEquals(indices.get(i).intValue(), result.getVector("id").getObject(i));
assertNotNull(result.getVector("name").getObject(i));
}
}
}
}
}
}
}
5 changes: 5 additions & 0 deletions java/core/src/test/java/com/lancedb/lance/JNITest.java
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,11 @@ public void testInts() {
JniTestHelper.parseInts(Arrays.asList(1, 2, 3));
}

@Test
public void testLongs() {
JniTestHelper.parseLongs(Arrays.asList(1L, 2L, 3L, Long.MAX_VALUE));
}

@Test
public void testIntsOpt() {
JniTestHelper.parseIntsOpt(Optional.of(Arrays.asList(1, 2, 3)));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import java.sql.Date;
import java.sql.Timestamp;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;

Expand Down Expand Up @@ -89,7 +90,7 @@ public static boolean isFilterSupported(Filter filter) {
} else if (filter instanceof EqualNullSafe) {
return false;
} else if (filter instanceof In) {
return false;
return true;
} else if (filter instanceof LessThan) {
return true;
} else if (filter instanceof LessThanOrEqual) {
Expand Down Expand Up @@ -163,6 +164,13 @@ private static Optional<String> compileFilter(Filter filter) {
Optional<String> child = compileFilter(f.child());
if (child.isEmpty()) return child;
return Optional.of(String.format("NOT (%s)", child.get()));
} else if (filter instanceof In) {
In in = (In) filter;
String values =
Arrays.stream(in.values())
.map(FilterPushDown::compileValue)
.collect(Collectors.joining(","));
return Optional.of(String.format("%s IN (%s)", in.attribute(), values));
}

return Optional.empty();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,4 +82,26 @@ public void testCompileFiltersToSqlWhereClauseWithEmptyFilters() {
Optional<String> whereClause = FilterPushDown.compileFiltersToSqlWhereClause(filters);
assertFalse(whereClause.isPresent());
}

@Test
public void testIntegerInFilterPushDown() {
Object[] values = new Object[2];
values[0] = 500;
values[1] = 600;
Filter[] filters = new Filter[] {new GreaterThan("age", 30), new In("salary", values)};
Optional<String> whereClause = FilterPushDown.compileFiltersToSqlWhereClause(filters);
assertTrue(whereClause.isPresent());
assertEquals("(age > 30) AND (salary IN (500,600))", whereClause.get());
}

@Test
public void testStringInFilterPushDown() {
Object[] values = new Object[2];
values[0] = "500";
values[1] = "600";
Filter[] filters = new Filter[] {new GreaterThan("age", 30), new In("salary", values)};
Optional<String> whereClause = FilterPushDown.compileFiltersToSqlWhereClause(filters);
assertTrue(whereClause.isPresent());
assertEquals("(age > 30) AND (salary IN ('500','600'))", whereClause.get());
}
}
19 changes: 19 additions & 0 deletions python/python/ci_benchmarks/benchmarks/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,22 @@ def bench():
)

benchmark.pedantic(bench, rounds=1, iterations=1)


BTREE_FILTERS = ["image_widths = 3997", "image_widths >= 3990 AND image_widths <= 3997"]


@pytest.mark.parametrize("filt", BTREE_FILTERS)
def test_eda_btree_search(benchmark, filt):
dataset_uri = get_dataset_uri("image_eda")
ds = lance.dataset(dataset_uri)

def bench():
ds.to_table(
columns=[],
filter=filt,
with_row_id=True,
)

# We warmup so we can test hot index performance
benchmark.pedantic(bench, warmup_rounds=1, rounds=1, iterations=100)
Loading

0 comments on commit 0544719

Please sign in to comment.