diff --git a/java/core/src/main/java/com/lancedb/lance/Dataset.java b/java/core/src/main/java/com/lancedb/lance/Dataset.java index c8208dfca5..6f14fee1a9 100644 --- a/java/core/src/main/java/com/lancedb/lance/Dataset.java +++ b/java/core/src/main/java/com/lancedb/lance/Dataset.java @@ -332,10 +332,16 @@ public ArrowReader take(List indices, List columns) throws IOEx try (LockManager.ReadLock readLock = lockManager.acquireReadLock()) { Preconditions.checkArgument(nativeDatasetHandle != 0, "Scanner is closed"); byte[] arrowData = nativeTake(indices, columns); - try (ByteArrayInputStream byteArrayInputStream = new ByteArrayInputStream(arrowData); - ReadableByteChannel readChannel = Channels.newChannel(byteArrayInputStream)) { - return new ArrowStreamReader(readChannel, allocator); - } + ByteArrayInputStream byteArrayInputStream = new ByteArrayInputStream(arrowData); + ReadableByteChannel readChannel = Channels.newChannel(byteArrayInputStream); + return new ArrowStreamReader(readChannel, allocator) { + @Override + public void close() throws IOException { + super.close(); + readChannel.close(); + byteArrayInputStream.close(); + } + }; } } diff --git a/java/core/src/test/java/com/lancedb/lance/DatasetTest.java b/java/core/src/test/java/com/lancedb/lance/DatasetTest.java index b1b44862fe..7bba99c8b0 100644 --- a/java/core/src/test/java/com/lancedb/lance/DatasetTest.java +++ b/java/core/src/test/java/com/lancedb/lance/DatasetTest.java @@ -27,6 +27,7 @@ import java.io.IOException; import java.net.URISyntaxException; +import java.nio.channels.ClosedChannelException; import java.nio.file.Path; import java.util.*; import java.util.stream.Collectors; @@ -309,7 +310,7 @@ void testDropPath() { } @Test - void testTake() throws IOException { + void testTake() throws IOException, ClosedChannelException { String testMethodName = new Object() {}.getClass().getEnclosingMethod().getName(); String datasetPath = tempDir.resolve(testMethodName).toString(); try (RootAllocator allocator = new RootAllocator(Long.MAX_VALUE)) { @@ -321,14 +322,15 @@ void testTake() throws IOException { List indices = Arrays.asList(1, 4); List columns = Arrays.asList("id", "name"); try (ArrowReader reader = dataset2.take(indices, columns)) { - reader.loadNextBatch(); - VectorSchemaRoot result = reader.getVectorSchemaRoot(); - assertNotNull(result); - assertEquals(indices.size(), result.getRowCount()); - - for (int i = 0; i < indices.size(); i++) { - assertEquals(indices.get(i), result.getVector("id").getObject(i)); - assertNotNull(result.getVector("name").getObject(i)); + while (reader.loadNextBatch()) { + VectorSchemaRoot result = reader.getVectorSchemaRoot(); + assertNotNull(result); + assertEquals(indices.size(), result.getRowCount()); + + for (int i = 0; i < indices.size(); i++) { + assertEquals(indices.get(i), result.getVector("id").getObject(i)); + assertNotNull(result.getVector("name").getObject(i)); + } } } }