-
Notifications
You must be signed in to change notification settings - Fork 1.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement GROUPING aggregate function (following Postgres behavior.) #12565
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,20 +19,34 @@ | |
|
||
use std::any::Any; | ||
use std::fmt; | ||
use std::sync::Arc; | ||
|
||
use arrow::array::ArrayRef; | ||
use arrow::array::AsArray; | ||
use arrow::array::BooleanArray; | ||
use arrow::array::UInt32Array; | ||
use arrow::datatypes::DataType; | ||
use arrow::datatypes::Field; | ||
use arrow::datatypes::UInt32Type; | ||
use datafusion_common::internal_datafusion_err; | ||
use datafusion_common::internal_err; | ||
use datafusion_common::plan_err; | ||
use datafusion_common::{not_impl_err, Result}; | ||
use datafusion_expr::function::AccumulatorArgs; | ||
use datafusion_expr::function::StateFieldsArgs; | ||
use datafusion_expr::utils::format_state_name; | ||
use datafusion_expr::EmitTo; | ||
use datafusion_expr::GroupsAccumulator; | ||
use datafusion_expr::{Accumulator, AggregateUDFImpl, Signature, Volatility}; | ||
use datafusion_functions_aggregate_common::aggregate::groups_accumulator::accumulate::accumulate; | ||
use datafusion_physical_expr::expressions::Column; | ||
use datafusion_physical_expr::PhysicalExpr; | ||
|
||
make_udaf_expr_and_func!( | ||
Grouping, | ||
grouping, | ||
expression, | ||
"Returns 1 if the data is aggregated across the specified column or 0 for not aggregated in the result set.", | ||
"Returns a bitmap where bit i is 1 if this row is aggregated across the ith argument to GROUPING and 0 otherwise.", | ||
grouping_udaf | ||
); | ||
|
||
|
@@ -59,9 +73,55 @@ impl Grouping { | |
/// Create a new GROUPING aggregate function. | ||
pub fn new() -> Self { | ||
Self { | ||
signature: Signature::any(1, Volatility::Immutable), | ||
signature: Signature::variadic_any(Volatility::Immutable), | ||
} | ||
} | ||
|
||
/// Create an accumulator for GROUPING(grouping_args) in a GROUP BY over group_exprs | ||
/// A special creation function is necessary because GROUPING has unusual input requirements. | ||
pub fn create_grouping_accumulator( | ||
&self, | ||
grouping_args: &[Arc<dyn PhysicalExpr>], | ||
group_exprs: &[(Arc<dyn PhysicalExpr>, String)], | ||
) -> Result<Box<dyn GroupsAccumulator>> { | ||
if grouping_args.len() > 32 { | ||
return plan_err!( | ||
"GROUPING is supported for up to 32 columns. Consider another \ | ||
GROUPING statement if you need to aggregate over more columns." | ||
); | ||
} | ||
// The PhysicalExprs of grouping_exprs must be Column PhysicalExpr. Because if | ||
// the group by PhysicalExpr in SQL is non-Column PhysicalExpr, then there is | ||
// a ProjectionExec before AggregateExec to convert the non-column PhysicalExpr | ||
// to Column PhysicalExpr. | ||
let column_index = | ||
|expr: &Arc<dyn PhysicalExpr>| match expr.as_any().downcast_ref::<Column>() { | ||
Some(column) => Ok(column.index()), | ||
None => internal_err!("Grouping doesn't support expr: {}", expr), | ||
}; | ||
Comment on lines
+93
to
+101
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is only true when one enabled the optimizer rule There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we look for equal PhysicalExprs? The Postgres docs imply they do ~text comparison but I'm not sure how accessible that info is at this layer. |
||
let group_by_columns: Result<Vec<_>> = | ||
group_exprs.iter().map(|(e, _)| column_index(e)).collect(); | ||
let group_by_columns = group_by_columns?; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this can be 1 liner? |
||
|
||
let arg_columns: Result<Vec<_>> = | ||
grouping_args.iter().map(column_index).collect(); | ||
let expr_indices: Result<Vec<_>> = arg_columns? | ||
.iter() | ||
.map(|arg| { | ||
group_by_columns | ||
.iter() | ||
.position(|gb| arg == gb) | ||
.ok_or_else(|| { | ||
internal_datafusion_err!("Invalid grouping set indices.") | ||
}) | ||
}) | ||
.collect(); | ||
|
||
Ok(Box::new(GroupingAccumulator { | ||
grouping_ids: vec![], | ||
expr_indices: expr_indices?, | ||
})) | ||
} | ||
} | ||
|
||
impl AggregateUDFImpl for Grouping { | ||
|
@@ -78,20 +138,145 @@ impl AggregateUDFImpl for Grouping { | |
} | ||
|
||
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> { | ||
Ok(DataType::Int32) | ||
Ok(DataType::UInt32) | ||
} | ||
|
||
fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> { | ||
Ok(vec![Field::new( | ||
format_state_name(args.name, "grouping"), | ||
DataType::Int32, | ||
DataType::UInt32, | ||
true, | ||
)]) | ||
} | ||
|
||
fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> { | ||
not_impl_err!( | ||
"physical plan is not yet implemented for GROUPING aggregate function" | ||
) | ||
not_impl_err!("The GROUPING function requires a GROUP BY context.") | ||
} | ||
|
||
fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool { | ||
false | ||
} | ||
|
||
fn create_groups_accumulator( | ||
&self, | ||
_args: AccumulatorArgs, | ||
) -> Result<Box<dyn GroupsAccumulator>> { | ||
// Use `create_grouping_accumulator` instead. | ||
not_impl_err!("GROUPING is not supported when invoked this way.") | ||
} | ||
} | ||
|
||
struct GroupingAccumulator { | ||
// Grouping ID value for each group | ||
grouping_ids: Vec<u32>, | ||
// Indices of GROUPING arguments as they appear in the GROUPING SET | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can we have more details or example on indices? |
||
expr_indices: Vec<usize>, | ||
} | ||
|
||
impl GroupingAccumulator { | ||
fn mask_to_id(&self, mask: &[bool]) -> Result<u32> { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please add more description on this method, how it changes the mask |
||
let mut id: u32 = 0; | ||
// rightmost entry is the LSB | ||
for (i, &idx) in self.expr_indices.iter().rev().enumerate() { | ||
match mask.get(idx) { | ||
Some(true) => id |= 1 << i, | ||
Some(false) => {} | ||
None => { | ||
return internal_err!( | ||
"Index out of bounds while calculating GROUPING id." | ||
) | ||
} | ||
} | ||
} | ||
Ok(id) | ||
} | ||
} | ||
|
||
impl GroupsAccumulator for GroupingAccumulator { | ||
fn update_batch( | ||
&mut self, | ||
_values: &[ArrayRef], | ||
_group_indices: &[usize], | ||
_opt_filter: Option<&BooleanArray>, | ||
_total_num_groups: usize, | ||
) -> Result<()> { | ||
// No-op since GROUPING doesn't care about values | ||
Ok(()) | ||
} | ||
|
||
fn merge_batch( | ||
&mut self, | ||
values: &[ArrayRef], | ||
group_indices: &[usize], | ||
_opt_filter: Option<&BooleanArray>, | ||
total_num_groups: usize, | ||
) -> Result<()> { | ||
assert_eq!(values.len(), 1, "single argument to merge_batch"); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. so we always expect only 1 array ? |
||
self.grouping_ids.resize(total_num_groups, 0); | ||
let other_ids = values[0].as_primitive::<UInt32Type>(); | ||
accumulate(group_indices, other_ids, None, |group_index, group_id| { | ||
self.grouping_ids[group_index] |= group_id; | ||
}); | ||
Ok(()) | ||
} | ||
|
||
fn update_groupings( | ||
&mut self, | ||
group_indices: &[usize], | ||
group_mask: &[bool], | ||
total_num_groups: usize, | ||
) -> Result<()> { | ||
self.grouping_ids.resize(total_num_groups, 0); | ||
let group_id = self.mask_to_id(group_mask)?; | ||
for &group_idx in group_indices { | ||
self.grouping_ids[group_idx] = group_id; | ||
} | ||
Ok(()) | ||
} | ||
|
||
fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef> { | ||
let values = emit_to.take_needed(&mut self.grouping_ids); | ||
let values = UInt32Array::new(values.into(), None); | ||
Ok(Arc::new(values)) | ||
} | ||
|
||
fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> { | ||
self.evaluate(emit_to).map(|arr| vec![arr]) | ||
} | ||
|
||
fn size(&self) -> usize { | ||
self.grouping_ids.capacity() * std::mem::size_of::<u32>() | ||
} | ||
} | ||
|
||
#[cfg(test)] | ||
mod tests { | ||
use crate::grouping::GroupingAccumulator; | ||
|
||
#[test] | ||
fn test_group_ids() { | ||
let grouping = GroupingAccumulator { | ||
grouping_ids: vec![], | ||
expr_indices: vec![0, 1, 3, 2], | ||
}; | ||
let cases = vec![ | ||
(0b0000, vec![false, false, false, false]), | ||
(0b1000, vec![true, false, false, false]), | ||
(0b0100, vec![false, true, false, false]), | ||
(0b1010, vec![true, false, false, true]), | ||
(0b1001, vec![true, false, true, false]), | ||
]; | ||
for (expected, input) in cases { | ||
assert_eq!(expected, grouping.mask_to_id(&input).unwrap()); | ||
} | ||
} | ||
#[test] | ||
fn test_bad_index() { | ||
let grouping = GroupingAccumulator { | ||
grouping_ids: vec![], | ||
expr_indices: vec![5], | ||
}; | ||
let res = grouping.mask_to_id(&[false]); | ||
assert!(res.is_err()) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. you may want to check the error message as well |
||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -238,6 +238,13 @@ impl PartialEq for PhysicalGroupBy { | |
} | ||
} | ||
|
||
pub(crate) struct PhysicalGroupingSet { | ||
/// Exprs/columns over which the grouping set is aggregated | ||
values: Vec<ArrayRef>, | ||
/// True if the corresponding value is null in this grouping set | ||
mask: Vec<bool>, | ||
} | ||
|
||
enum StreamType { | ||
AggregateStream(AggregateStream), | ||
GroupedHash(GroupedHashAggregateStream), | ||
|
@@ -1140,13 +1147,13 @@ fn evaluate_optional( | |
/// - `batch`: the `RecordBatch` to evaluate against | ||
/// | ||
/// Returns: A Vec of Vecs of Array of results | ||
/// The outer Vec appears to be for grouping sets | ||
/// The outer Vec contains the grouping sets defined by `group_by.groups` | ||
/// The inner Vec contains the results per expression | ||
/// The inner-inner Array contains the results per row | ||
pub(crate) fn evaluate_group_by( | ||
group_by: &PhysicalGroupBy, | ||
batch: &RecordBatch, | ||
) -> Result<Vec<Vec<ArrayRef>>> { | ||
) -> Result<Vec<PhysicalGroupingSet>> { | ||
let exprs: Vec<ArrayRef> = group_by | ||
.expr | ||
.iter() | ||
|
@@ -1169,7 +1176,7 @@ pub(crate) fn evaluate_group_by( | |
.groups | ||
.iter() | ||
.map(|group| { | ||
group | ||
let v = group | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. lets have more meaningful name? |
||
.iter() | ||
.enumerate() | ||
.map(|(idx, is_null)| { | ||
|
@@ -1179,7 +1186,11 @@ pub(crate) fn evaluate_group_by( | |
Arc::clone(&exprs[idx]) | ||
} | ||
}) | ||
.collect() | ||
.collect(); | ||
PhysicalGroupingSet { | ||
values: v, | ||
mask: group.clone(), | ||
} | ||
}) | ||
.collect()) | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lets have it as a const