Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Branch for upgrade to DataFusion March 5 Upgrade #18

Open
wants to merge 2 commits into
base: alamb/march_5_base
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion datafusion/common/src/scalar/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1746,7 +1746,7 @@ impl ScalarValue {
}

/// Converts `Vec<ScalarValue>` where each element has type corresponding to
/// `data_type`, to a [`ListArray`].
/// `data_type`, to a single element [`ListArray`].
///
/// Example
/// ```
Expand Down
130 changes: 79 additions & 51 deletions datafusion/optimizer/src/common_subexpr_eliminate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,7 @@ use datafusion_common::tree_node::{
TreeNodeVisitor,
};
use datafusion_common::{
internal_datafusion_err, internal_err, Column, DFField, DFSchema, DFSchemaRef,
DataFusionError, Result,
internal_err, Column, DFField, DFSchema, DFSchemaRef, DataFusionError, Result,
};
use datafusion_expr::expr::Alias;
use datafusion_expr::logical_plan::{
Expand All @@ -44,8 +43,36 @@ use datafusion_expr::{col, Expr, ExprSchemable};
/// - DataType of this expression.
type ExprSet = HashMap<Identifier, (Expr, usize, DataType)>;

/// Identifier type. Current implementation use describe of a expression (type String) as
/// Identifier.
/// An ordered map of Identifiers assigned by `ExprIdentifierVisitor` in an
/// initial expression walk.
///
/// Used by `CommonSubexprRewriter`, which rewrites the expressions to remove
/// common subexpressions.
///
/// Elements in this array are created on the walk down the expression tree
/// during `f_down`. Thus element 0 is the root of the expression tree. The
/// tuple contains:
/// - series_number.
/// - Incremented during `f_up`, start from 1.
/// - Thus, items with higher idx have the lower series_number.
/// - [`Identifier`]
/// - Identifier of the expression. If empty (`""`), expr should not be considered for common elimination.
///
/// # Example
/// An expression like `(a + b)` would have the following `IdArray`:
/// ```text
/// [
/// (3, "a + b"),
/// (2, "a"),
/// (1, "b")
/// ]
/// ```
type IdArray = Vec<(usize, Identifier)>;

/// Identifier for each subexpression.
///
/// Note that the current implementation uses the `Display` of an expression
/// (a `String`) as `Identifier`.
///
/// A Identifier should (ideally) be able to "hash", "accumulate", "equal" and "have no
/// collision (as low as possible)"
Expand Down Expand Up @@ -328,8 +355,9 @@ impl CommonSubexprEliminate {
agg_exprs.push(expr.alias(&name));
proj_exprs.push(Expr::Column(Column::from_name(name)));
} else {
let id =
ExprIdentifierVisitor::<'static>::desc_expr(&expr_rewritten);
let id = ExprIdentifierVisitor::<'static>::expr_identifier(
&expr_rewritten,
);
let out_name =
expr_rewritten.to_field(&new_input_schema)?.qualified_name();
agg_exprs.push(expr_rewritten.alias(&id));
Expand Down Expand Up @@ -597,15 +625,15 @@ impl ExprMask {
/// This visitor implementation use a stack `visit_stack` to track traversal, which
/// lets us know when a sub-tree's visiting is finished. When `pre_visit` is called
/// (traversing to a new node), an `EnterMark` and an `ExprItem` will be pushed into stack.
/// And try to pop out a `EnterMark` on leaving a node (`post_visit()`). All `ExprItem`
/// And try to pop out a `EnterMark` on leaving a node (`f_up()`). All `ExprItem`
/// before the first `EnterMark` is considered to be sub-tree of the leaving node.
///
/// This visitor also records identifier in `id_array`. Makes the following traverse
/// pass can get the identifier of a node without recalculate it. We assign each node
/// in the expr tree a series number, start from 1, maintained by `series_number`.
/// Series number represents the order we left (`post_visit`) a node. Has the property
/// Series number represents the order we left (`f_up()`) a node. Has the property
/// that child node's series number always smaller than parent's. While `id_array` is
/// organized in the order we enter (`pre_visit`) a node. `node_count` helps us to
/// organized in the order we enter (`f_down()`) a node. `node_count` helps us to
/// get the index of `id_array` for each node.
///
/// `Expr` without sub-expr (column, literal etc.) will not have identifier
Expand All @@ -614,15 +642,15 @@ struct ExprIdentifierVisitor<'a> {
// param
expr_set: &'a mut ExprSet,
/// series number (usize) and identifier.
id_array: &'a mut Vec<(usize, Identifier)>,
id_array: &'a mut IdArray,
/// input schema for the node that we're optimizing, so we can determine the correct datatype
/// for each subexpression
input_schema: DFSchemaRef,
// inner states
visit_stack: Vec<VisitRecord>,
/// increased in pre_visit, start from 0.
/// increased in fn_down, start from 0.
node_count: usize,
/// increased in post_visit, start from 1.
/// increased in fn_up, start from 1.
series_number: usize,
/// which expression should be skipped?
expr_mask: ExprMask,
Expand All @@ -633,66 +661,73 @@ enum VisitRecord {
/// `usize` is the monotone increasing series number assigned in pre_visit().
/// Starts from 0. Is used to index the identifier array `id_array` in post_visit().
EnterMark(usize),
/// the node's children were skipped => jump to f_up on same node
JumpMark(usize),
/// Accumulated identifier of sub expression.
ExprItem(Identifier),
}

impl ExprIdentifierVisitor<'_> {
fn desc_expr(expr: &Expr) -> String {
fn expr_identifier(expr: &Expr) -> Identifier {
format!("{expr}")
}

/// Find the first `EnterMark` in the stack, and accumulates every `ExprItem`
/// before it.
fn pop_enter_mark(&mut self) -> Option<(usize, Identifier)> {
fn pop_enter_mark(&mut self) -> (usize, Identifier) {
let mut desc = String::new();

while let Some(item) = self.visit_stack.pop() {
match item {
VisitRecord::EnterMark(idx) => {
return Some((idx, desc));
VisitRecord::EnterMark(idx) | VisitRecord::JumpMark(idx) => {
return (idx, desc);
}
VisitRecord::ExprItem(s) => {
desc.push_str(&s);
VisitRecord::ExprItem(id) => {
desc.push_str(&id);
}
}
}
None
unreachable!("Enter mark should paired with node number");
}
}

impl TreeNodeVisitor for ExprIdentifierVisitor<'_> {
type Node = Expr;

fn f_down(&mut self, expr: &Expr) -> Result<TreeNodeRecursion> {
// put placeholder, sets the proper array length
self.id_array.push((0, "".to_string()));

// related to https://github.com/apache/arrow-datafusion/issues/8814
// If the expr contain volatile expression or is a short-circuit expression, skip it.
if expr.short_circuits() || is_volatile_expression(expr)? {
return Ok(TreeNodeRecursion::Jump);
self.visit_stack
.push(VisitRecord::JumpMark(self.node_count));
return Ok(TreeNodeRecursion::Jump); // go to f_up
}

self.visit_stack
.push(VisitRecord::EnterMark(self.node_count));
self.node_count += 1;
// put placeholder
self.id_array.push((0, "".to_string()));

Ok(TreeNodeRecursion::Continue)
}

fn f_up(&mut self, expr: &Expr) -> Result<TreeNodeRecursion> {
self.series_number += 1;

let Some((idx, sub_expr_desc)) = self.pop_enter_mark() else {
return Ok(TreeNodeRecursion::Continue);
};
let (idx, sub_expr_identifier) = self.pop_enter_mark();

// skip exprs should not be recognize.
if self.expr_mask.ignores(expr) {
self.id_array[idx].0 = self.series_number;
let desc = Self::desc_expr(expr);
self.visit_stack.push(VisitRecord::ExprItem(desc));
let curr_expr_identifier = Self::expr_identifier(expr);
self.visit_stack
.push(VisitRecord::ExprItem(curr_expr_identifier));
self.id_array[idx].0 = self.series_number; // leave Identifer as empty "", since will not use as common expr
return Ok(TreeNodeRecursion::Continue);
}
let mut desc = Self::desc_expr(expr);
desc.push_str(&sub_expr_desc);
let mut desc = Self::expr_identifier(expr);
desc.push_str(&sub_expr_identifier);

self.id_array[idx] = (self.series_number, desc.clone());
self.visit_stack.push(VisitRecord::ExprItem(desc.clone()));
Expand Down Expand Up @@ -733,7 +768,7 @@ fn expr_to_identifier(
/// evaluate result of replaced expression.
struct CommonSubexprRewriter<'a> {
expr_set: &'a ExprSet,
id_array: &'a [(usize, Identifier)],
id_array: &'a IdArray,
/// Which identifier is replaced.
affected_id: &'a mut BTreeSet<Identifier>,

Expand All @@ -755,20 +790,26 @@ impl TreeNodeRewriter for CommonSubexprRewriter<'_> {
if expr.short_circuits() || is_volatile_expression(&expr)? {
return Ok(Transformed::new(expr, false, TreeNodeRecursion::Jump));
}

let (series_number, curr_id) = &self.id_array[self.curr_index];

// halting conditions
if self.curr_index >= self.id_array.len()
|| self.max_series_number > self.id_array[self.curr_index].0
|| self.max_series_number > *series_number
{
return Ok(Transformed::new(expr, false, TreeNodeRecursion::Jump));
}

let curr_id = &self.id_array[self.curr_index].1;
// skip `Expr`s without identifier (empty identifier).
if curr_id.is_empty() {
self.curr_index += 1;
self.curr_index += 1; // incr idx for id_array, when not jumping
return Ok(Transformed::no(expr));
}

// lookup previously visited expression
match self.expr_set.get(curr_id) {
Some((_, counter, _)) => {
// if has a commonly used (a.k.a. 1+ use) expr
if *counter > 1 {
self.affected_id.insert(curr_id.clone());

Expand All @@ -781,23 +822,10 @@ impl TreeNodeRewriter for CommonSubexprRewriter<'_> {
));
}

let (series_number, id) = &self.id_array[self.curr_index];
// incr idx for id_array, when not jumping
self.curr_index += 1;
// Skip sub-node of a replaced tree, or without identifier, or is not repeated expr.
let expr_set_item = self.expr_set.get(id).ok_or_else(|| {
internal_datafusion_err!("expr_set invalid state")
})?;
if *series_number < self.max_series_number
|| id.is_empty()
|| expr_set_item.1 <= 1
{
return Ok(Transformed::new(
expr,
false,
TreeNodeRecursion::Jump,
));
}

// series_number was the inverse number ordering (when doing f_up)
self.max_series_number = *series_number;
// step index to skip all sub-node (which has smaller series number).
while self.curr_index < self.id_array.len()
Expand All @@ -811,7 +839,7 @@ impl TreeNodeRewriter for CommonSubexprRewriter<'_> {
// `projection_push_down` optimizer use "expr name" to eliminate useless
// projections.
Ok(Transformed::new(
col(id).alias(expr_name),
col(curr_id).alias(expr_name),
true,
TreeNodeRecursion::Jump,
))
Expand All @@ -827,7 +855,7 @@ impl TreeNodeRewriter for CommonSubexprRewriter<'_> {

fn replace_common_expr(
expr: Expr,
id_array: &[(usize, Identifier)],
id_array: &IdArray,
expr_set: &ExprSet,
affected_id: &mut BTreeSet<Identifier>,
) -> Result<Expr> {
Expand Down
32 changes: 25 additions & 7 deletions datafusion/physical-expr/src/aggregate/count_distinct/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ use crate::binary_map::OutputType;
use crate::expressions::format_state_name;
use crate::{AggregateExpr, PhysicalExpr};

/// Expression for a COUNT(DISTINCT) aggregation.
/// Expression for a `COUNT(DISTINCT)` aggregation.
#[derive(Debug)]
pub struct DistinctCount {
/// Column name
Expand Down Expand Up @@ -100,6 +100,7 @@ impl AggregateExpr for DistinctCount {
use TimeUnit::*;

Ok(match &self.state_data_type {
// try and use a specialized accumulator if possible, otherwise fall back to generic accumulator
Int8 => Box::new(PrimitiveDistinctCountAccumulator::<Int8Type>::new()),
Int16 => Box::new(PrimitiveDistinctCountAccumulator::<Int16Type>::new()),
Int32 => Box::new(PrimitiveDistinctCountAccumulator::<Int32Type>::new()),
Expand Down Expand Up @@ -157,6 +158,7 @@ impl AggregateExpr for DistinctCount {
OutputType::Binary,
)),

// Use the generic accumulator based on `ScalarValue` for all other types
_ => Box::new(DistinctCountAccumulator {
values: HashSet::default(),
state_data_type: self.state_data_type.clone(),
Expand All @@ -183,7 +185,11 @@ impl PartialEq<dyn Any> for DistinctCount {
}

/// General purpose distinct accumulator that works for any DataType by using
/// [`ScalarValue`]. Some types have specialized accumulators that are (much)
/// [`ScalarValue`].
///
/// It stores intermediate results as a `ListArray`
///
/// Note that many types have specialized accumulators that are (much)
/// more efficient such as [`PrimitiveDistinctCountAccumulator`] and
/// [`BytesDistinctCountAccumulator`]
#[derive(Debug)]
Expand All @@ -193,8 +199,9 @@ struct DistinctCountAccumulator {
}

impl DistinctCountAccumulator {
// calculating the size for fixed length values, taking first batch size * number of batches
// This method is faster than .full_size(), however it is not suitable for variable length values like strings or complex types
// calculating the size for fixed length values, taking first batch size *
// number of batches This method is faster than .full_size(), however it is
// not suitable for variable length values like strings or complex types
fn fixed_size(&self) -> usize {
std::mem::size_of_val(self)
+ (std::mem::size_of::<ScalarValue>() * self.values.capacity())
Expand All @@ -207,7 +214,8 @@ impl DistinctCountAccumulator {
+ std::mem::size_of::<DataType>()
}

// calculates the size as accurate as possible, call to this method is expensive
// calculates the size as accurately as possible. Note that calling this
// method is expensive
fn full_size(&self) -> usize {
std::mem::size_of_val(self)
+ (std::mem::size_of::<ScalarValue>() * self.values.capacity())
Expand All @@ -221,6 +229,7 @@ impl DistinctCountAccumulator {
}

impl Accumulator for DistinctCountAccumulator {
/// Returns the distinct values seen so far as (one element) ListArray.
fn state(&mut self) -> Result<Vec<ScalarValue>> {
let scalars = self.values.iter().cloned().collect::<Vec<_>>();
let arr = ScalarValue::new_list(scalars.as_slice(), &self.state_data_type);
Expand All @@ -246,15 +255,24 @@ impl Accumulator for DistinctCountAccumulator {
})
}

/// Merges multiple sets of distinct values into the current set.
///
/// The input to this function is a `ListArray` with **multiple** rows,
/// where each row contains the values from a partial aggregate's phase (e.g.
/// the result of calling `Self::state` on multiple accumulators).
fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
if states.is_empty() {
return Ok(());
}
assert_eq!(states.len(), 1, "array_agg states must be singleton!");
let array = &states[0];
let list_array = array.as_list::<i32>();
let inner_array = list_array.value(0);
self.update_batch(&[inner_array])
for inner_array in list_array.iter() {
let inner_array = inner_array
.expect("counts are always non null, so are intermediate results");
self.update_batch(&[inner_array])?;
}
Ok(())
}

fn evaluate(&mut self) -> Result<ScalarValue> {
Expand Down
Loading
Loading