Skip to content

Commit

Permalink
fix: add support for list fields with names other than 'item' (#2580)
Browse files Browse the repository at this point in the history
In the future we can maybe add support for a few more things as well:

* User's should be able to read back into whatever schema they want
(this is present in the rust already but missing from python because
projection is missing from python, so this is mostly just tests)
* Perhaps add some kind of read option to "normalize" a schema so we
always read back the field as "item". Right now, if no schema is
provided at read time, we mirror exactly the write time schema, this
will cause non-item field names to propagate which is maybe not the best
choice. This could also be a write time option to normalize the schema
on write.
  • Loading branch information
westonpace authored Jul 9, 2024
1 parent ab7349e commit ebf7c5d
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 1 deletion.
33 changes: 33 additions & 0 deletions python/python/tests/test_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright The Lance Authors

import pyarrow as pa
import pyarrow.parquet as pq
import pytest
from lance.file import LanceFileReader, LanceFileWriter

Expand Down Expand Up @@ -156,3 +157,35 @@ def test_metadata(tmp_path):
assert page.buffers[0].size == 24

assert len(page.encoding) > 0


def test_round_trip_parquet(tmp_path):
pq_path = tmp_path / "foo.parquet"
table = pa.table({"int": [1, 2], "list_str": [["x", "yz", "abc"], ["foo", "bar"]]})
pq.write_table(table, str(pq_path))
table = pq.read_table(str(pq_path))

lance_path = tmp_path / "foo.lance"
with LanceFileWriter(str(lance_path)) as writer:
writer.write_batch(table)

reader = LanceFileReader(str(lance_path))
round_tripped = reader.read_all().to_table()
assert round_tripped == table


def test_list_field_name(tmp_path):
weird_field = pa.field("why does this name even exist", pa.string())
weird_string_type = pa.list_(weird_field)
schema = pa.schema([pa.field("list_str", weird_string_type)])
table = pa.table({"list_str": [["x", "yz", "abc"], ["foo", "bar"]]}, schema=schema)

path = tmp_path / "foo.lance"
with LanceFileWriter(str(path)) as writer:
writer.write_batch(table)

reader = LanceFileReader(str(path))
round_tripped = reader.read_all().to_table()

assert round_tripped == table
assert round_tripped.schema.field("list_str").type == weird_string_type
2 changes: 2 additions & 0 deletions rust/lance-encoding/src/decoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -637,6 +637,7 @@ impl FieldDecoderStrategy for CoreFieldDecoderStrategy {
file_buffers: buffers,
positions_and_sizes: &offsets_column.buffer_offsets_and_sizes,
};
let item_field_name = items_field.name().clone();
let (chain, items_scheduler) = chain.new_child(
/*child_idx=*/ 0,
&field.children[0],
Expand Down Expand Up @@ -688,6 +689,7 @@ impl FieldDecoderStrategy for CoreFieldDecoderStrategy {
Ok(Arc::new(ListFieldScheduler::new(
inner,
items_scheduler,
item_field_name.clone(),
items_type,
offset_type,
null_offset_adjustments,
Expand Down
13 changes: 12 additions & 1 deletion rust/lance-encoding/src/encodings/logical/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,7 @@ impl<'a> SchedulingJob for ListFieldSchedulingJob<'a> {
item_decoder: None,
rows_drained: 0,
lists_available: 0,
item_field_name: self.scheduler.item_field_name.clone(),
num_rows,
unloaded: Some(indirect_fut),
items_type: self.scheduler.items_type.clone(),
Expand Down Expand Up @@ -491,6 +492,7 @@ impl<'a> SchedulingJob for ListFieldSchedulingJob<'a> {
pub struct ListFieldScheduler {
offsets_scheduler: Arc<dyn FieldScheduler>,
items_scheduler: Arc<dyn FieldScheduler>,
item_field_name: String,
items_type: DataType,
offset_type: DataType,
list_type: DataType,
Expand All @@ -512,6 +514,7 @@ impl ListFieldScheduler {
pub fn new(
offsets_scheduler: Arc<dyn FieldScheduler>,
items_scheduler: Arc<dyn FieldScheduler>,
item_field_name: String,
items_type: DataType,
// Should be int32 or int64
offset_type: DataType,
Expand All @@ -529,6 +532,7 @@ impl ListFieldScheduler {
Self {
offsets_scheduler,
items_scheduler,
item_field_name,
items_type,
offset_type,
offset_page_info,
Expand Down Expand Up @@ -573,6 +577,7 @@ struct ListPageDecoder {
lists_available: u64,
num_rows: u64,
rows_drained: u64,
item_field_name: String,
items_type: DataType,
offset_type: DataType,
data_type: DataType,
Expand All @@ -583,6 +588,7 @@ struct ListDecodeTask {
validity: BooleanBuffer,
// Will be None if there are no items (all empty / null lists)
items: Option<Box<dyn DecodeArrayTask>>,
item_field_name: String,
items_type: DataType,
offset_type: DataType,
}
Expand All @@ -601,7 +607,11 @@ impl DecodeArrayTask for ListDecodeTask {

// TODO: we default to nullable true here, should probably use the nullability given to
// us from the input schema
let item_field = Arc::new(Field::new("item", self.items_type.clone(), true));
let item_field = Arc::new(Field::new(
self.item_field_name,
self.items_type.clone(),
true,
));

// The offsets are already decoded but they need to be shifted back to 0 and cast
// to the appropriate type
Expand Down Expand Up @@ -756,6 +766,7 @@ impl LogicalPageDecoder for ListPageDecoder {
task: Box::new(ListDecodeTask {
offsets,
validity,
item_field_name: self.item_field_name.clone(),
items: item_decode,
items_type: self.items_type.clone(),
offset_type: self.offset_type.clone(),
Expand Down

0 comments on commit ebf7c5d

Please sign in to comment.