Skip to content

Commit

Permalink
Cleanup synchronous take
Browse files Browse the repository at this point in the history
  • Loading branch information
westonpace committed Jan 16, 2025
1 parent db18d27 commit 2b886d9
Show file tree
Hide file tree
Showing 7 changed files with 163 additions and 54 deletions.
122 changes: 90 additions & 32 deletions rust/lance-encoding/src/decoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1472,10 +1472,51 @@ impl BatchDecodeStream {
}
}

// Utility types to smooth out the differences between the 2.0 and 2.1 decoders so that
// we can have a single implementation of the batch decode iterator
enum RootDecoderMessage {
LoadedPage(LoadedPage),
LegacyPage(DecoderReady),
}
trait RootDecoderType {
fn accept_message(&mut self, message: RootDecoderMessage) -> Result<()>;
fn drain_batch(&mut self, num_rows: u64) -> Result<NextDecodeTask>;
fn wait(&mut self, loaded_need: u64, runtime: &tokio::runtime::Runtime) -> Result<()>;
}
impl RootDecoderType for StructuralStructDecoder {
fn accept_message(&mut self, message: RootDecoderMessage) -> Result<()> {
let RootDecoderMessage::LoadedPage(loaded_page) = message else {
unreachable!()
};
self.accept_page(loaded_page)
}
fn drain_batch(&mut self, num_rows: u64) -> Result<NextDecodeTask> {
self.drain_batch_task(num_rows)
}
fn wait(&mut self, _: u64, _: &tokio::runtime::Runtime) -> Result<()> {
// Waiting happens elsewhere (not as part of the decoder)
Ok(())
}
}
impl RootDecoderType for SimpleStructDecoder {
fn accept_message(&mut self, message: RootDecoderMessage) -> Result<()> {
let RootDecoderMessage::LegacyPage(legacy_page) = message else {
unreachable!()
};
self.accept_child(legacy_page)
}
fn drain_batch(&mut self, num_rows: u64) -> Result<NextDecodeTask> {
self.drain(num_rows)
}
fn wait(&mut self, loaded_need: u64, runtime: &tokio::runtime::Runtime) -> Result<()> {
runtime.block_on(self.wait_for_loaded(loaded_need))
}
}

/// A blocking back decoder that performs synchronous decoding
pub struct BatchDecodeIterator {
struct BatchDecodeIterator<T: RootDecoderType> {
messages: VecDeque<Result<DecoderMessage>>,
root_decoder: StructuralStructDecoder,
root_decoder: T,
rows_remaining: u64,
rows_per_batch: u32,
rows_scheduled: u64,
Expand All @@ -1488,13 +1529,13 @@ pub struct BatchDecodeIterator {
schema: Arc<ArrowSchema>,
}

impl BatchDecodeIterator {
impl<T: RootDecoderType> BatchDecodeIterator<T> {
/// Create a new instance of a batch decode iterator
pub fn new(
messages: VecDeque<Result<DecoderMessage>>,
rows_per_batch: u32,
num_rows: u64,
root_decoder: StructuralStructDecoder,
root_decoder: T,
schema: Arc<ArrowSchema>,
) -> Self {
Self {
Expand Down Expand Up @@ -1536,11 +1577,24 @@ impl BatchDecodeIterator {
let message = self.messages.pop_front().unwrap()?;
self.rows_scheduled = message.scheduled_so_far;
for decoder_message in message.decoders {
let unloaded_page = decoder_message.into_structural();
let loaded_page = self.wait_for_page(unloaded_page)?;
self.root_decoder.accept_page(loaded_page)?;
match decoder_message {
MessageType::UnloadedPage(unloaded_page) => {
let loaded_page = self.wait_for_page(unloaded_page)?;
self.root_decoder
.accept_message(RootDecoderMessage::LoadedPage(loaded_page))?;
}
MessageType::DecoderReady(decoder_ready) => {
// The root decoder we can ignore
if !decoder_ready.path.is_empty() {
self.root_decoder
.accept_message(RootDecoderMessage::LegacyPage(decoder_ready))?;
}
}
}
}
}
self.root_decoder
.wait(self.rows_scheduled, &self.wait_for_io_runtime)?;
Ok(self.rows_scheduled)
}

Expand Down Expand Up @@ -1578,12 +1632,8 @@ impl BatchDecodeIterator {
return Ok(None);
}

let next_task = self.root_decoder.drain(to_take)?;
let next_task = NextDecodeTask {
has_more: self.rows_remaining > 0,
num_rows: to_take,
task: Box::new(next_task),
};
let next_task = self.root_decoder.drain_batch(to_take)?;

self.rows_drained += to_take;

let batch = next_task.into_batch(self.emitted_batch_size_warning.clone())?;
Expand All @@ -1592,7 +1642,7 @@ impl BatchDecodeIterator {
}
}

impl Iterator for BatchDecodeIterator {
impl<T: RootDecoderType> Iterator for BatchDecodeIterator<T> {
type Item = ArrowResult<RecordBatch>;

fn next(&mut self) -> Option<Self::Item> {
Expand All @@ -1602,7 +1652,7 @@ impl Iterator for BatchDecodeIterator {
}
}

impl RecordBatchReader for BatchDecodeIterator {
impl<T: RootDecoderType> RecordBatchReader for BatchDecodeIterator<T> {
fn schema(&self) -> Arc<ArrowSchema> {
self.schema.clone()
}
Expand Down Expand Up @@ -1711,12 +1761,7 @@ impl StructuralBatchDecodeStream {
return Ok(None);
}

let next_task = self.root_decoder.drain(to_take)?;
let next_task = NextDecodeTask {
has_more: self.rows_remaining > 0,
num_rows: to_take,
task: Box::new(next_task),
};
let next_task = self.root_decoder.drain_batch_task(to_take)?;
self.rows_drained += to_take;
Ok(Some(next_task))
}
Expand Down Expand Up @@ -1825,19 +1870,32 @@ pub fn create_decode_iterator(
num_rows: u64,
batch_size: u32,
should_validate: bool,
is_structural: bool,
messages: VecDeque<Result<DecoderMessage>>,
) -> Box<dyn RecordBatchReader> {
let arrow_schema = Arc::new(ArrowSchema::from(schema));
let root_fields = arrow_schema.fields.clone();
let simple_struct_decoder =
StructuralStructDecoder::new(root_fields, should_validate, /*is_root=*/ true);
Box::new(BatchDecodeIterator::new(
messages,
batch_size,
num_rows,
simple_struct_decoder,
arrow_schema,
))
if is_structural {
let simple_struct_decoder =
StructuralStructDecoder::new(root_fields, should_validate, /*is_root=*/ true);
Box::new(BatchDecodeIterator::new(
messages,
batch_size,
num_rows,
simple_struct_decoder,
arrow_schema,
))
} else {
let root_decoder = SimpleStructDecoder::new(root_fields, num_rows);
let _legacy_iterator = Box::new(BatchDecodeIterator::new(
messages,
batch_size,
num_rows,
root_decoder,
arrow_schema,
));
todo!("The synchronous path for 2.0 is not quite working yet")
}
}

fn create_scheduler_decoder(
Expand Down Expand Up @@ -1968,6 +2026,7 @@ pub fn schedule_and_decode_blocking(
}

let num_rows = requested_rows.num_rows();
let is_structural = column_infos[0].is_structural();

let (tx, mut rx) = mpsc::unbounded_channel();

Expand Down Expand Up @@ -2010,6 +2069,7 @@ pub fn schedule_and_decode_blocking(
num_rows,
config.batch_size,
config.should_validate,
is_structural,
messages.into(),
);

Expand Down Expand Up @@ -2408,8 +2468,6 @@ pub struct NextDecodeTask {
pub task: Box<dyn DecodeArrayTask>,
/// The number of rows that will be created
pub num_rows: u64,
/// Whether or not the decoder that created this still has more rows to decode
pub has_more: bool,
}

impl NextDecodeTask {
Expand Down
1 change: 0 additions & 1 deletion rust/lance-encoding/src/encodings/logical/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,6 @@ impl LogicalPageDecoder for BinaryPageDecoder {
fn drain(&mut self, num_rows: u64) -> Result<NextDecodeTask> {
let inner_task = self.inner.drain(num_rows)?;
Ok(NextDecodeTask {
has_more: inner_task.has_more,
num_rows: inner_task.num_rows,
task: Box::new(BinaryArrayDecoder {
inner: inner_task.task,
Expand Down
1 change: 0 additions & 1 deletion rust/lance-encoding/src/encodings/logical/blob.rs
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,6 @@ impl LogicalPageDecoder for BlobFieldDecoder {
let validity = self.drain_validity(num_rows as usize)?;
self.rows_drained += num_rows;
Ok(NextDecodeTask {
has_more: self.rows_drained < self.num_rows,
num_rows,
task: Box::new(BlobArrayDecodeTask::new(bytes, validity)),
})
Expand Down
2 changes: 0 additions & 2 deletions rust/lance-encoding/src/encodings/logical/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -786,9 +786,7 @@ impl LogicalPageDecoder for ListPageDecoder {
};

self.rows_drained += num_rows;
let has_more = self.rows_left() > 0;
Ok(NextDecodeTask {
has_more,
num_rows,
task: Box::new(ListDecodeTask {
offsets,
Expand Down
1 change: 0 additions & 1 deletion rust/lance-encoding/src/encodings/logical/primitive.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2548,7 +2548,6 @@ impl LogicalPageDecoder for PrimitiveFieldDecoder {
Ok(NextDecodeTask {
task,
num_rows: rows_to_take,
has_more: self.rows_drained != self.num_rows,
})
}

Expand Down
12 changes: 9 additions & 3 deletions rust/lance-encoding/src/encodings/logical/struct.rs
Original file line number Diff line number Diff line change
Expand Up @@ -631,6 +631,14 @@ impl StructuralStructDecoder {
_ => Box::new(StructuralPrimitiveFieldDecoder::new(field, should_validate)),
}
}

pub fn drain_batch_task(&mut self, num_rows: u64) -> Result<NextDecodeTask> {
let array_drain = self.drain(num_rows)?;
Ok(NextDecodeTask {
num_rows,
task: Box::new(array_drain),
})
}
}

impl StructuralFieldDecoder for StructuralStructDecoder {
Expand Down Expand Up @@ -725,6 +733,7 @@ impl SimpleStructDecoder {
}

async fn do_wait_for_loaded(&mut self, loaded_need: u64) -> Result<()> {
println!("Wait for loaded_need: {}", loaded_need);
let mut wait_orders = self
.children
.iter_mut()
Expand Down Expand Up @@ -787,16 +796,13 @@ impl LogicalPageDecoder for SimpleStructDecoder {
.map(|child| child.drain(num_rows))
.collect::<Result<Vec<_>>>()?;
let num_rows = child_tasks[0].num_rows;
let has_more = child_tasks[0].has_more;
debug_assert!(child_tasks.iter().all(|task| task.num_rows == num_rows));
debug_assert!(child_tasks.iter().all(|task| task.has_more == has_more));
Ok(NextDecodeTask {
task: Box::new(SimpleStructDecodeTask {
children: child_tasks,
child_fields: self.child_fields.clone(),
}),
num_rows,
has_more,
})
}

Expand Down
Loading

0 comments on commit 2b886d9

Please sign in to comment.