diff --git a/datafusion/common/src/scalar/mod.rs b/datafusion/common/src/scalar/mod.rs index f431e6264367..d23c4f321aaa 100644 --- a/datafusion/common/src/scalar/mod.rs +++ b/datafusion/common/src/scalar/mod.rs @@ -1746,7 +1746,7 @@ impl ScalarValue { } /// Converts `Vec` where each element has type corresponding to - /// `data_type`, to a [`ListArray`]. + /// `data_type`, to a single element [`ListArray`]. /// /// Example /// ``` diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index 30c184a28e33..382793a40f3e 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -29,8 +29,7 @@ use datafusion_common::tree_node::{ TreeNodeVisitor, }; use datafusion_common::{ - internal_datafusion_err, internal_err, Column, DFField, DFSchema, DFSchemaRef, - DataFusionError, Result, + internal_err, Column, DFField, DFSchema, DFSchemaRef, DataFusionError, Result, }; use datafusion_expr::expr::Alias; use datafusion_expr::logical_plan::{ @@ -44,8 +43,36 @@ use datafusion_expr::{col, Expr, ExprSchemable}; /// - DataType of this expression. type ExprSet = HashMap; -/// Identifier type. Current implementation use describe of a expression (type String) as -/// Identifier. +/// An ordered map of Identifiers assigned by `ExprIdentifierVisitor` in an +/// initial expression walk. +/// +/// Used by `CommonSubexprRewriter`, which rewrites the expressions to remove +/// common subexpressions. +/// +/// Elements in this array are created on the walk down the expression tree +/// during `f_down`. Thus element 0 is the root of the expression tree. The +/// tuple contains: +/// - series_number. +/// - Incremented during `f_up`, start from 1. +/// - Thus, items with higher idx have the lower series_number. +/// - [`Identifier`] +/// - Identifier of the expression. If empty (`""`), expr should not be considered for common elimination. +/// +/// # Example +/// An expression like `(a + b)` would have the following `IdArray`: +/// ```text +/// [ +/// (3, "a + b"), +/// (2, "a"), +/// (1, "b") +/// ] +/// ``` +type IdArray = Vec<(usize, Identifier)>; + +/// Identifier for each subexpression. +/// +/// Note that the current implementation uses the `Display` of an expression +/// (a `String`) as `Identifier`. /// /// A Identifier should (ideally) be able to "hash", "accumulate", "equal" and "have no /// collision (as low as possible)" @@ -328,8 +355,9 @@ impl CommonSubexprEliminate { agg_exprs.push(expr.alias(&name)); proj_exprs.push(Expr::Column(Column::from_name(name))); } else { - let id = - ExprIdentifierVisitor::<'static>::desc_expr(&expr_rewritten); + let id = ExprIdentifierVisitor::<'static>::expr_identifier( + &expr_rewritten, + ); let out_name = expr_rewritten.to_field(&new_input_schema)?.qualified_name(); agg_exprs.push(expr_rewritten.alias(&id)); @@ -597,15 +625,15 @@ impl ExprMask { /// This visitor implementation use a stack `visit_stack` to track traversal, which /// lets us know when a sub-tree's visiting is finished. When `pre_visit` is called /// (traversing to a new node), an `EnterMark` and an `ExprItem` will be pushed into stack. -/// And try to pop out a `EnterMark` on leaving a node (`post_visit()`). All `ExprItem` +/// And try to pop out a `EnterMark` on leaving a node (`f_up()`). All `ExprItem` /// before the first `EnterMark` is considered to be sub-tree of the leaving node. /// /// This visitor also records identifier in `id_array`. Makes the following traverse /// pass can get the identifier of a node without recalculate it. We assign each node /// in the expr tree a series number, start from 1, maintained by `series_number`. -/// Series number represents the order we left (`post_visit`) a node. Has the property +/// Series number represents the order we left (`f_up()`) a node. Has the property /// that child node's series number always smaller than parent's. While `id_array` is -/// organized in the order we enter (`pre_visit`) a node. `node_count` helps us to +/// organized in the order we enter (`f_down()`) a node. `node_count` helps us to /// get the index of `id_array` for each node. /// /// `Expr` without sub-expr (column, literal etc.) will not have identifier @@ -614,15 +642,15 @@ struct ExprIdentifierVisitor<'a> { // param expr_set: &'a mut ExprSet, /// series number (usize) and identifier. - id_array: &'a mut Vec<(usize, Identifier)>, + id_array: &'a mut IdArray, /// input schema for the node that we're optimizing, so we can determine the correct datatype /// for each subexpression input_schema: DFSchemaRef, // inner states visit_stack: Vec, - /// increased in pre_visit, start from 0. + /// increased in fn_down, start from 0. node_count: usize, - /// increased in post_visit, start from 1. + /// increased in fn_up, start from 1. series_number: usize, /// which expression should be skipped? expr_mask: ExprMask, @@ -633,31 +661,33 @@ enum VisitRecord { /// `usize` is the monotone increasing series number assigned in pre_visit(). /// Starts from 0. Is used to index the identifier array `id_array` in post_visit(). EnterMark(usize), + /// the node's children were skipped => jump to f_up on same node + JumpMark(usize), /// Accumulated identifier of sub expression. ExprItem(Identifier), } impl ExprIdentifierVisitor<'_> { - fn desc_expr(expr: &Expr) -> String { + fn expr_identifier(expr: &Expr) -> Identifier { format!("{expr}") } /// Find the first `EnterMark` in the stack, and accumulates every `ExprItem` /// before it. - fn pop_enter_mark(&mut self) -> Option<(usize, Identifier)> { + fn pop_enter_mark(&mut self) -> (usize, Identifier) { let mut desc = String::new(); while let Some(item) = self.visit_stack.pop() { match item { - VisitRecord::EnterMark(idx) => { - return Some((idx, desc)); + VisitRecord::EnterMark(idx) | VisitRecord::JumpMark(idx) => { + return (idx, desc); } - VisitRecord::ExprItem(s) => { - desc.push_str(&s); + VisitRecord::ExprItem(id) => { + desc.push_str(&id); } } } - None + unreachable!("Enter mark should paired with node number"); } } @@ -665,34 +695,39 @@ impl TreeNodeVisitor for ExprIdentifierVisitor<'_> { type Node = Expr; fn f_down(&mut self, expr: &Expr) -> Result { + // put placeholder, sets the proper array length + self.id_array.push((0, "".to_string())); + // related to https://github.com/apache/arrow-datafusion/issues/8814 // If the expr contain volatile expression or is a short-circuit expression, skip it. if expr.short_circuits() || is_volatile_expression(expr)? { - return Ok(TreeNodeRecursion::Jump); + self.visit_stack + .push(VisitRecord::JumpMark(self.node_count)); + return Ok(TreeNodeRecursion::Jump); // go to f_up } + self.visit_stack .push(VisitRecord::EnterMark(self.node_count)); self.node_count += 1; - // put placeholder - self.id_array.push((0, "".to_string())); + Ok(TreeNodeRecursion::Continue) } fn f_up(&mut self, expr: &Expr) -> Result { self.series_number += 1; - let Some((idx, sub_expr_desc)) = self.pop_enter_mark() else { - return Ok(TreeNodeRecursion::Continue); - }; + let (idx, sub_expr_identifier) = self.pop_enter_mark(); + // skip exprs should not be recognize. if self.expr_mask.ignores(expr) { - self.id_array[idx].0 = self.series_number; - let desc = Self::desc_expr(expr); - self.visit_stack.push(VisitRecord::ExprItem(desc)); + let curr_expr_identifier = Self::expr_identifier(expr); + self.visit_stack + .push(VisitRecord::ExprItem(curr_expr_identifier)); + self.id_array[idx].0 = self.series_number; // leave Identifer as empty "", since will not use as common expr return Ok(TreeNodeRecursion::Continue); } - let mut desc = Self::desc_expr(expr); - desc.push_str(&sub_expr_desc); + let mut desc = Self::expr_identifier(expr); + desc.push_str(&sub_expr_identifier); self.id_array[idx] = (self.series_number, desc.clone()); self.visit_stack.push(VisitRecord::ExprItem(desc.clone())); @@ -733,7 +768,7 @@ fn expr_to_identifier( /// evaluate result of replaced expression. struct CommonSubexprRewriter<'a> { expr_set: &'a ExprSet, - id_array: &'a [(usize, Identifier)], + id_array: &'a IdArray, /// Which identifier is replaced. affected_id: &'a mut BTreeSet, @@ -755,20 +790,26 @@ impl TreeNodeRewriter for CommonSubexprRewriter<'_> { if expr.short_circuits() || is_volatile_expression(&expr)? { return Ok(Transformed::new(expr, false, TreeNodeRecursion::Jump)); } + + let (series_number, curr_id) = &self.id_array[self.curr_index]; + + // halting conditions if self.curr_index >= self.id_array.len() - || self.max_series_number > self.id_array[self.curr_index].0 + || self.max_series_number > *series_number { return Ok(Transformed::new(expr, false, TreeNodeRecursion::Jump)); } - let curr_id = &self.id_array[self.curr_index].1; // skip `Expr`s without identifier (empty identifier). if curr_id.is_empty() { - self.curr_index += 1; + self.curr_index += 1; // incr idx for id_array, when not jumping return Ok(Transformed::no(expr)); } + + // lookup previously visited expression match self.expr_set.get(curr_id) { Some((_, counter, _)) => { + // if has a commonly used (a.k.a. 1+ use) expr if *counter > 1 { self.affected_id.insert(curr_id.clone()); @@ -781,23 +822,10 @@ impl TreeNodeRewriter for CommonSubexprRewriter<'_> { )); } - let (series_number, id) = &self.id_array[self.curr_index]; + // incr idx for id_array, when not jumping self.curr_index += 1; - // Skip sub-node of a replaced tree, or without identifier, or is not repeated expr. - let expr_set_item = self.expr_set.get(id).ok_or_else(|| { - internal_datafusion_err!("expr_set invalid state") - })?; - if *series_number < self.max_series_number - || id.is_empty() - || expr_set_item.1 <= 1 - { - return Ok(Transformed::new( - expr, - false, - TreeNodeRecursion::Jump, - )); - } + // series_number was the inverse number ordering (when doing f_up) self.max_series_number = *series_number; // step index to skip all sub-node (which has smaller series number). while self.curr_index < self.id_array.len() @@ -811,7 +839,7 @@ impl TreeNodeRewriter for CommonSubexprRewriter<'_> { // `projection_push_down` optimizer use "expr name" to eliminate useless // projections. Ok(Transformed::new( - col(id).alias(expr_name), + col(curr_id).alias(expr_name), true, TreeNodeRecursion::Jump, )) @@ -827,7 +855,7 @@ impl TreeNodeRewriter for CommonSubexprRewriter<'_> { fn replace_common_expr( expr: Expr, - id_array: &[(usize, Identifier)], + id_array: &IdArray, expr_set: &ExprSet, affected_id: &mut BTreeSet, ) -> Result { diff --git a/datafusion/physical-expr/src/aggregate/count_distinct/mod.rs b/datafusion/physical-expr/src/aggregate/count_distinct/mod.rs index 71782fcc5f9b..fb5e7710496c 100644 --- a/datafusion/physical-expr/src/aggregate/count_distinct/mod.rs +++ b/datafusion/physical-expr/src/aggregate/count_distinct/mod.rs @@ -47,7 +47,7 @@ use crate::binary_map::OutputType; use crate::expressions::format_state_name; use crate::{AggregateExpr, PhysicalExpr}; -/// Expression for a COUNT(DISTINCT) aggregation. +/// Expression for a `COUNT(DISTINCT)` aggregation. #[derive(Debug)] pub struct DistinctCount { /// Column name @@ -100,6 +100,7 @@ impl AggregateExpr for DistinctCount { use TimeUnit::*; Ok(match &self.state_data_type { + // try and use a specialized accumulator if possible, otherwise fall back to generic accumulator Int8 => Box::new(PrimitiveDistinctCountAccumulator::::new()), Int16 => Box::new(PrimitiveDistinctCountAccumulator::::new()), Int32 => Box::new(PrimitiveDistinctCountAccumulator::::new()), @@ -157,6 +158,7 @@ impl AggregateExpr for DistinctCount { OutputType::Binary, )), + // Use the generic accumulator based on `ScalarValue` for all other types _ => Box::new(DistinctCountAccumulator { values: HashSet::default(), state_data_type: self.state_data_type.clone(), @@ -183,7 +185,11 @@ impl PartialEq for DistinctCount { } /// General purpose distinct accumulator that works for any DataType by using -/// [`ScalarValue`]. Some types have specialized accumulators that are (much) +/// [`ScalarValue`]. +/// +/// It stores intermediate results as a `ListArray` +/// +/// Note that many types have specialized accumulators that are (much) /// more efficient such as [`PrimitiveDistinctCountAccumulator`] and /// [`BytesDistinctCountAccumulator`] #[derive(Debug)] @@ -193,8 +199,9 @@ struct DistinctCountAccumulator { } impl DistinctCountAccumulator { - // calculating the size for fixed length values, taking first batch size * number of batches - // This method is faster than .full_size(), however it is not suitable for variable length values like strings or complex types + // calculating the size for fixed length values, taking first batch size * + // number of batches This method is faster than .full_size(), however it is + // not suitable for variable length values like strings or complex types fn fixed_size(&self) -> usize { std::mem::size_of_val(self) + (std::mem::size_of::() * self.values.capacity()) @@ -207,7 +214,8 @@ impl DistinctCountAccumulator { + std::mem::size_of::() } - // calculates the size as accurate as possible, call to this method is expensive + // calculates the size as accurately as possible. Note that calling this + // method is expensive fn full_size(&self) -> usize { std::mem::size_of_val(self) + (std::mem::size_of::() * self.values.capacity()) @@ -221,6 +229,7 @@ impl DistinctCountAccumulator { } impl Accumulator for DistinctCountAccumulator { + /// Returns the distinct values seen so far as (one element) ListArray. fn state(&mut self) -> Result> { let scalars = self.values.iter().cloned().collect::>(); let arr = ScalarValue::new_list(scalars.as_slice(), &self.state_data_type); @@ -246,6 +255,11 @@ impl Accumulator for DistinctCountAccumulator { }) } + /// Merges multiple sets of distinct values into the current set. + /// + /// The input to this function is a `ListArray` with **multiple** rows, + /// where each row contains the values from a partial aggregate's phase (e.g. + /// the result of calling `Self::state` on multiple accumulators). fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { if states.is_empty() { return Ok(()); @@ -253,8 +267,12 @@ impl Accumulator for DistinctCountAccumulator { assert_eq!(states.len(), 1, "array_agg states must be singleton!"); let array = &states[0]; let list_array = array.as_list::(); - let inner_array = list_array.value(0); - self.update_batch(&[inner_array]) + for inner_array in list_array.iter() { + let inner_array = inner_array + .expect("counts are always non null, so are intermediate results"); + self.update_batch(&[inner_array])?; + } + Ok(()) } fn evaluate(&mut self) -> Result { diff --git a/datafusion/sqllogictest/test_files/dictionary.slt b/datafusion/sqllogictest/test_files/dictionary.slt index 002aade2528e..af7bf5cb16e8 100644 --- a/datafusion/sqllogictest/test_files/dictionary.slt +++ b/datafusion/sqllogictest/test_files/dictionary.slt @@ -280,3 +280,70 @@ ORDER BY 2023-12-20T01:20:00 1000 f2 foo 2023-12-20T01:30:00 1000 f1 32.0 2023-12-20T01:30:00 1000 f2 foo + +# Cleanup +statement ok +drop view m1; + +statement ok +drop view m2; + +###### +# Create a table using UNION ALL to get 2 partitions (very important) +###### +statement ok +create table m3_source as + select * from (values('foo', 'bar', 1)) + UNION ALL + select * from (values('foo', 'baz', 1)); + +###### +# Now, create a table with the same data, but column2 has type `Dictionary(Int32)` to trigger the fallback code +###### +statement ok +create table m3 as + select + column1, + arrow_cast(column2, 'Dictionary(Int32, Utf8)') as "column2", + column3 +from m3_source; + +# there are two values in column2 +query T?I rowsort +SELECT * +FROM m3; +---- +foo bar 1 +foo baz 1 + +# There is 1 distinct value in column1 +query I +SELECT count(distinct column1) +FROM m3 +GROUP BY column3; +---- +1 + +# There are 2 distinct values in column2 +query I +SELECT count(distinct column2) +FROM m3 +GROUP BY column3; +---- +2 + +# Should still get the same results when querying in the same query +query II +SELECT count(distinct column1), count(distinct column2) +FROM m3 +GROUP BY column3; +---- +1 2 + + +# Cleanup +statement ok +drop table m3; + +statement ok +drop table m3_source; diff --git a/datafusion/sqllogictest/test_files/expr.slt b/datafusion/sqllogictest/test_files/expr.slt index 9e4e3aa8185d..e0290efddd35 100644 --- a/datafusion/sqllogictest/test_files/expr.slt +++ b/datafusion/sqllogictest/test_files/expr.slt @@ -1910,3 +1910,39 @@ false true false true NULL NULL NULL NULL false false true true false false true false + + +############# +## Common Subexpr Eliminate Tests +############# + +statement ok +CREATE TABLE doubles ( + f64 DOUBLE +) as VALUES + (10.1) +; + +# common subexpr with alias +query RRR rowsort +select f64, round(1.0 / f64) as i64_1, acos(round(1.0 / f64)) from doubles; +---- +10.1 0 1.570796326795 + +# common subexpr with coalesce (short-circuited) +query RRR rowsort +select f64, coalesce(1.0 / f64, 0.0), acos(coalesce(1.0 / f64, 0.0)) from doubles; +---- +10.1 0.09900990099 1.471623942989 + +# common subexpr with coalesce (short-circuited) and alias +query RRR rowsort +select f64, coalesce(1.0 / f64, 0.0) as f64_1, acos(coalesce(1.0 / f64, 0.0)) from doubles; +---- +10.1 0.09900990099 1.471623942989 + +# common subexpr with case (short-circuited) +query RRR rowsort +select f64, case when f64 > 0 then 1.0 / f64 else null end, acos(case when f64 > 0 then 1.0 / f64 else null end) from doubles; +---- +10.1 0.09900990099 1.471623942989