Skip to content

Commit

Permalink
Merge branch 'main' into SupportStatisticsRowNum
Browse files Browse the repository at this point in the history
  • Loading branch information
SaintBacchus authored Jan 1, 2025
2 parents 9d9fd53 + 8767c10 commit ee4cc7c
Show file tree
Hide file tree
Showing 17 changed files with 290 additions and 92 deletions.
58 changes: 56 additions & 2 deletions java/core/lance-jni/src/blocking_dataset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,16 @@ use arrow::datatypes::Schema;
use arrow::ffi::FFI_ArrowSchema;
use arrow::ffi_stream::ArrowArrayStreamReader;
use arrow::ffi_stream::FFI_ArrowArrayStream;
use arrow::ipc::writer::StreamWriter;
use arrow::record_batch::RecordBatchIterator;
use arrow_schema::DataType;
use jni::objects::{JMap, JString, JValue};
use jni::sys::jlong;
use jni::sys::{jboolean, jint};
use jni::sys::{jbyteArray, jlong};
use jni::{objects::JObject, JNIEnv};
use lance::dataset::builder::DatasetBuilder;
use lance::dataset::transaction::Operation;
use lance::dataset::{ColumnAlteration, Dataset, ReadParams, WriteParams};
use lance::dataset::{ColumnAlteration, Dataset, ProjectionRequest, ReadParams, WriteParams};
use lance::io::{ObjectStore, ObjectStoreParams};
use lance::table::format::Fragment;
use lance::table::format::Index;
Expand Down Expand Up @@ -683,6 +684,59 @@ fn inner_list_indexes<'local>(
Ok(array_list)
}

#[no_mangle]
pub extern "system" fn Java_com_lancedb_lance_Dataset_nativeTake(
mut env: JNIEnv,
java_dataset: JObject,
indices_obj: JObject, // List<Long>
columns_obj: JObject, // List<String>
) -> jbyteArray {
match inner_take(&mut env, java_dataset, indices_obj, columns_obj) {
Ok(byte_array) => byte_array,
Err(e) => {
let _ = env.throw_new("java/lang/RuntimeException", format!("{:?}", e));
std::ptr::null_mut()
}
}
}

fn inner_take(
env: &mut JNIEnv,
java_dataset: JObject,
indices_obj: JObject, // List<Long>
columns_obj: JObject, // List<String>
) -> Result<jbyteArray> {
let indices: Vec<i64> = env.get_longs(&indices_obj)?;
let indices_u64: Vec<u64> = indices.iter().map(|&x| x as u64).collect();
let indices_slice: &[u64] = &indices_u64;
let columns: Vec<String> = env.get_strings(&columns_obj)?;

let result = {
let dataset_guard =
unsafe { env.get_rust_field::<_, _, BlockingDataset>(java_dataset, NATIVE_DATASET) }?;
let dataset = &dataset_guard.inner;

let projection = ProjectionRequest::from_columns(columns, dataset.schema());

match RT.block_on(dataset.take(indices_slice, projection)) {
Ok(res) => res,
Err(e) => {
return Err(e.into());
}
}
};

let mut buffer = Vec::new();
{
let mut writer = StreamWriter::try_new(&mut buffer, &result.schema())?;
writer.write(&result)?;
writer.finish()?;
}

let byte_array = env.byte_array_from_slice(&buffer)?;
Ok(**byte_array)
}

//////////////////////////////
// Schema evolution Methods //
//////////////////////////////
Expand Down
24 changes: 24 additions & 0 deletions java/core/lance-jni/src/ffi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ pub trait JNIEnvExt {
/// Get integers from Java List<Integer> object.
fn get_integers(&mut self, obj: &JObject) -> Result<Vec<i32>>;

/// Get longs from Java List<Long> object.
fn get_longs(&mut self, obj: &JObject) -> Result<Vec<i64>>;

/// Get strings from Java List<String> object.
fn get_strings(&mut self, obj: &JObject) -> Result<Vec<String>>;

Expand Down Expand Up @@ -127,6 +130,18 @@ impl JNIEnvExt for JNIEnv<'_> {
Ok(results)
}

fn get_longs(&mut self, obj: &JObject) -> Result<Vec<i64>> {
let list = self.get_list(obj)?;
let mut iter = list.iter(self)?;
let mut results = Vec::with_capacity(list.size(self)? as usize);
while let Some(elem) = iter.next(self)? {
let long_obj = self.call_method(elem, "longValue", "()J", &[])?;
let long_value = long_obj.j()?;
results.push(long_value);
}
Ok(results)
}

fn get_strings(&mut self, obj: &JObject) -> Result<Vec<String>> {
let list = self.get_list(obj)?;
let mut iter = list.iter(self)?;
Expand Down Expand Up @@ -348,6 +363,15 @@ pub extern "system" fn Java_com_lancedb_lance_test_JniTestHelper_parseInts(
ok_or_throw_without_return!(env, env.get_integers(&list_obj));
}

#[no_mangle]
pub extern "system" fn Java_com_lancedb_lance_test_JniTestHelper_parseLongs(
mut env: JNIEnv,
_obj: JObject,
list_obj: JObject, // List<Long>
) {
ok_or_throw_without_return!(env, env.get_longs(&list_obj));
}

#[no_mangle]
pub extern "system" fn Java_com_lancedb_lance_test_JniTestHelper_parseIntsOpt(
mut env: JNIEnv,
Expand Down
2 changes: 1 addition & 1 deletion java/core/lance-jni/src/fragment.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ fn inner_count_rows_native(
"Fragment not found: {fragment_id}"
)));
};
let res = RT.block_on(fragment.count_rows())?;
let res = RT.block_on(fragment.count_rows(None))?;
Ok(res)
}

Expand Down
32 changes: 32 additions & 0 deletions java/core/src/main/java/com/lancedb/lance/Dataset.java
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,15 @@
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.util.Preconditions;
import org.apache.arrow.vector.ipc.ArrowReader;
import org.apache.arrow.vector.ipc.ArrowStreamReader;
import org.apache.arrow.vector.types.pojo.Schema;

import java.io.ByteArrayInputStream;
import java.io.Closeable;
import java.io.IOException;
import java.nio.channels.Channels;
import java.nio.channels.ReadableByteChannel;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -315,6 +321,32 @@ public LanceScanner newScan(ScanOptions options) {
}
}

/**
* Select rows of data by index.
*
* @param indices the indices to take
* @param columns the columns to take
* @return an ArrowReader
*/
public ArrowReader take(List<Long> indices, List<String> columns) throws IOException {
Preconditions.checkArgument(nativeDatasetHandle != 0, "Dataset is closed");
try (LockManager.ReadLock readLock = lockManager.acquireReadLock()) {
byte[] arrowData = nativeTake(indices, columns);
ByteArrayInputStream byteArrayInputStream = new ByteArrayInputStream(arrowData);
ReadableByteChannel readChannel = Channels.newChannel(byteArrayInputStream);
return new ArrowStreamReader(readChannel, allocator) {
@Override
public void close() throws IOException {
super.close();
readChannel.close();
byteArrayInputStream.close();
}
};
}
}

private native byte[] nativeTake(List<Long> indices, List<String> columns);

/**
* Gets the URI of the dataset.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,13 @@ public class JniTestHelper {
*/
public static native void parseInts(List<Integer> intsList);

/**
* JNI parse longs test.
*
* @param longsList the given list of longs
*/
public static native void parseLongs(List<Long> longsList);

/**
* JNI parse ints opts test.
*
Expand Down
35 changes: 32 additions & 3 deletions java/core/src/test/java/com/lancedb/lance/DatasetTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@

import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.ipc.ArrowReader;
import org.apache.arrow.vector.types.pojo.ArrowType;
import org.apache.arrow.vector.types.pojo.Field;
import org.apache.arrow.vector.types.pojo.Schema;
Expand All @@ -25,10 +27,9 @@

import java.io.IOException;
import java.net.URISyntaxException;
import java.nio.channels.ClosedChannelException;
import java.nio.file.Path;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.*;
import java.util.stream.Collectors;

import static org.junit.jupiter.api.Assertions.*;
Expand Down Expand Up @@ -307,4 +308,32 @@ void testDropPath() {
Dataset.drop(datasetPath, new HashMap<>());
}
}

@Test
void testTake() throws IOException, ClosedChannelException {
String testMethodName = new Object() {}.getClass().getEnclosingMethod().getName();
String datasetPath = tempDir.resolve(testMethodName).toString();
try (RootAllocator allocator = new RootAllocator(Long.MAX_VALUE)) {
TestUtils.SimpleTestDataset testDataset =
new TestUtils.SimpleTestDataset(allocator, datasetPath);
dataset = testDataset.createEmptyDataset();

try (Dataset dataset2 = testDataset.write(1, 5)) {
List<Long> indices = Arrays.asList(1L, 4L);
List<String> columns = Arrays.asList("id", "name");
try (ArrowReader reader = dataset2.take(indices, columns)) {
while (reader.loadNextBatch()) {
VectorSchemaRoot result = reader.getVectorSchemaRoot();
assertNotNull(result);
assertEquals(indices.size(), result.getRowCount());

for (int i = 0; i < indices.size(); i++) {
assertEquals(indices.get(i).intValue(), result.getVector("id").getObject(i));
assertNotNull(result.getVector("name").getObject(i));
}
}
}
}
}
}
}
5 changes: 5 additions & 0 deletions java/core/src/test/java/com/lancedb/lance/JNITest.java
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,11 @@ public void testInts() {
JniTestHelper.parseInts(Arrays.asList(1, 2, 3));
}

@Test
public void testLongs() {
JniTestHelper.parseLongs(Arrays.asList(1L, 2L, 3L, Long.MAX_VALUE));
}

@Test
public void testIntsOpt() {
JniTestHelper.parseIntsOpt(Optional.of(Arrays.asList(1, 2, 3)));
Expand Down
33 changes: 17 additions & 16 deletions python/python/lance/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,13 +507,13 @@ def to_table(
batch_size: Optional[int] = None,
batch_readahead: Optional[int] = None,
fragment_readahead: Optional[int] = None,
scan_in_order: bool = True,
scan_in_order: Optional[bool] = None,
*,
prefilter: bool = False,
with_row_id: bool = False,
with_row_address: bool = False,
use_stats: bool = True,
fast_search: bool = False,
prefilter: Optional[bool] = None,
with_row_id: Optional[bool] = None,
with_row_address: Optional[bool] = None,
use_stats: Optional[bool] = None,
fast_search: Optional[bool] = None,
full_text_query: Optional[Union[str, dict]] = None,
io_buffer_size: Optional[int] = None,
late_materialization: Optional[bool | List[str]] = None,
Expand Down Expand Up @@ -558,24 +558,25 @@ def to_table(
The number of batches to read ahead.
fragment_readahead: int, optional
The number of fragments to read ahead.
scan_in_order: bool, default True
scan_in_order: bool, optional, default True
Whether to read the fragments and batches in order. If false,
throughput may be higher, but batches will be returned out of order
and memory use might increase.
prefilter: bool, default False
prefilter: bool, optional, default False
Run filter before the vector search.
late_materialization: bool or List[str], default None
Allows custom control over late materialization. See
``ScannerBuilder.late_materialization`` for more information.
use_scalar_index: bool, default True
Allows custom control over scalar index usage. See
``ScannerBuilder.use_scalar_index`` for more information.
with_row_id: bool, default False
with_row_id: bool, optional, default False
Return row ID.
with_row_address: bool, default False
with_row_address: bool, optional, default False
Return row address
use_stats: bool, default True
use_stats: bool, optional, default True
Use stats pushdown during filters.
fast_search: bool, optional, default False
full_text_query: str or dict, optional
query string to search for, the results will be ranked by BM25.
e.g. "hello world", would match documents contains "hello" or "world".
Expand Down Expand Up @@ -687,12 +688,12 @@ def to_batches(
batch_size: Optional[int] = None,
batch_readahead: Optional[int] = None,
fragment_readahead: Optional[int] = None,
scan_in_order: bool = True,
scan_in_order: Optional[bool] = None,
*,
prefilter: bool = False,
with_row_id: bool = False,
with_row_address: bool = False,
use_stats: bool = True,
prefilter: Optional[bool] = None,
with_row_id: Optional[bool] = None,
with_row_address: Optional[bool] = None,
use_stats: Optional[bool] = None,
full_text_query: Optional[Union[str, dict]] = None,
io_buffer_size: Optional[int] = None,
late_materialization: Optional[bool | List[str]] = None,
Expand Down
6 changes: 3 additions & 3 deletions python/python/lance/fragment.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ def __init__(
if fragment_id is None:
raise ValueError("Either fragment or fragment_id must be specified")
fragment = dataset.get_fragment(fragment_id)._fragment
self._fragment = fragment
self._fragment: _Fragment = fragment
if self._fragment is None:
raise ValueError(f"Fragment id does not exist: {fragment_id}")

Expand Down Expand Up @@ -367,8 +367,8 @@ def count_rows(
self, filter: Optional[Union[pa.compute.Expression, str]] = None
) -> int:
if filter is not None:
raise ValueError("Does not support filter at the moment")
return self._fragment.count_rows()
return self.scanner(filter=filter).count_rows()
return self._fragment.count_rows(filter)

@property
def num_deletions(self) -> int:
Expand Down
7 changes: 7 additions & 0 deletions python/python/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2806,3 +2806,10 @@ def test_dataset_drop(tmp_path: Path):
assert Path(tmp_path).exists()
lance.LanceDataset.drop(tmp_path)
assert not Path(tmp_path).exists()


def test_dataset_schema(tmp_path: Path):
table = pa.table({"x": [0]})
ds = lance.write_dataset(table, str(tmp_path)) # noqa: F841
ds._default_scan_options = {"with_row_id": True}
assert ds.schema == ds.to_table().schema
13 changes: 13 additions & 0 deletions python/python/tests/test_fragment.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import lance
import pandas as pd
import pyarrow as pa
import pyarrow.compute as pc
import pytest
from helper import ProgressForTest
from lance import (
Expand Down Expand Up @@ -422,3 +423,15 @@ def test_fragment_merge(tmp_path):
tmp_path, merge, read_version=dataset.latest_version
)
assert [f.name for f in dataset.schema] == ["a", "b", "c", "d"]


def test_fragment_count_rows(tmp_path: Path):
data = pa.table({"a": range(800), "b": range(800)})
ds = write_dataset(data, tmp_path)

fragments = ds.get_fragments()
assert len(fragments) == 1

assert fragments[0].count_rows() == 800
assert fragments[0].count_rows("a < 200") == 200
assert fragments[0].count_rows(pc.field("a") < 200) == 200
Loading

0 comments on commit ee4cc7c

Please sign in to comment.