diff --git a/datafusion/physical-plan/src/aggregates/row_hash.rs b/datafusion/physical-plan/src/aggregates/row_hash.rs index 4a021c154331..36f055db1c68 100644 --- a/datafusion/physical-plan/src/aggregates/row_hash.rs +++ b/datafusion/physical-plan/src/aggregates/row_hash.rs @@ -608,15 +608,12 @@ impl Stream for GroupedHashAggregateStream { loop { match &self.exec_state { ExecutionState::ReadingInput => 'reading_input: { - match ready!(self.input.poll_next_unpin(cx)) { - // new batch to aggregate - Some(Ok(batch)) => { + match (ready!(self.input.poll_next_unpin(cx)), self.mode) { + // New batch to aggregate in partial aggregation operator + (Some(Ok(batch)), AggregateMode::Partial) => { let timer = elapsed_compute.timer(); let input_rows = batch.num_rows(); - // Make sure we have enough capacity for `batch`, otherwise spill - extract_ok!(self.spill_previous_if_necessary(&batch)); - // Do the grouping extract_ok!(self.group_aggregate_batch(batch)); @@ -649,11 +646,50 @@ impl Stream for GroupedHashAggregateStream { timer.done(); } - Some(Err(e)) => { + + // New batch to aggregate in terminal aggregation operator + // (Final/FinalPartitioned/Single/SinglePartitioned) + (Some(Ok(batch)), _) => { + let timer = elapsed_compute.timer(); + + // Make sure we have enough capacity for `batch`, otherwise spill + extract_ok!(self.spill_previous_if_necessary(&batch)); + + // Do the grouping + extract_ok!(self.group_aggregate_batch(batch)); + + // If we can begin emitting rows, do so, + // otherwise keep consuming input + assert!(!self.input_done); + + // If the number of group values equals or exceeds the soft limit, + // emit all groups and switch to producing output + if self.hit_soft_group_limit() { + timer.done(); + extract_ok!(self.set_input_done_and_produce_output()); + // make sure the exec_state just set is not overwritten below + break 'reading_input; + } + + if let Some(to_emit) = self.group_ordering.emit_to() { + let batch = extract_ok!(self.emit(to_emit, false)); + self.exec_state = ExecutionState::ProducingOutput(batch); + timer.done(); + // make sure the exec_state just set is not overwritten below + break 'reading_input; + } + + timer.done(); + } + + // Found error from input stream + (Some(Err(e)), _) => { // inner had error, return to caller return Poll::Ready(Some(Err(e))); } - None => { + + // Found end from input stream + (None, _) => { // inner is done, emit all rows and switch to producing output extract_ok!(self.set_input_done_and_produce_output()); } @@ -1003,16 +1039,19 @@ impl GroupedHashAggregateStream { /// Updates skip aggregation probe state. fn update_skip_aggregation_probe(&mut self, input_rows: usize) { - if let Some(probe) = self.skip_aggregation_probe.as_mut() { - // Skip aggregation probe is only supported in Partial aggregation. - // And it is not supported if stream has any spills even in Partial aggregation. - // Although currently spilling is actually not supported in Partial aggregation, - // it is possible to be supported in future, so we also add an assertion for it. - assert!( - self.mode == AggregateMode::Partial && self.spill_state.spills.is_empty() - ); - probe.update_state(input_rows, self.group_values.len()); - }; + // Skip aggregation probe is only supported and called in Partial aggregation. + // And it is not supported if stream has any spills even in Partial aggregation. + // Although currently spilling is actually not supported in Partial aggregation, + // it is possible to be supported in future, so we also add an assertion for it. + assert!(self.spill_state.spills.is_empty()); + + // As mention above, skip aggregation probe will only be called in Partial aggregation. + // And naturally, in Partial aggregation, we should ensure `skip_aggregation_probe` + // is not `None`, so it is safe to unwrap here. + self.skip_aggregation_probe + .as_mut() + .expect("skip_aggregation_probe must be some in partial aggregation") + .update_state(input_rows, self.group_values.len()); } /// In case the probe indicates that aggregation may be