Skip to content

Commit

Permalink
Remove RowAccumulators
Browse files Browse the repository at this point in the history
  • Loading branch information
alamb committed Jul 14, 2023
1 parent c6891cb commit 50ea550
Show file tree
Hide file tree
Showing 8 changed files with 1 addition and 1,284 deletions.
134 changes: 0 additions & 134 deletions datafusion/physical-expr/src/aggregate/average.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,6 @@ use std::convert::TryFrom;
use std::sync::Arc;

use crate::aggregate::groups_accumulator::accumulate::NullState;
use crate::aggregate::row_accumulator::{
is_row_accumulator_support_dtype, RowAccumulator,
};
use crate::aggregate::sum;
use crate::aggregate::sum::sum_batch;
use crate::aggregate::utils::calculate_result_decimal_for_avg;
Expand All @@ -46,7 +43,6 @@ use arrow_array::{
use datafusion_common::{downcast_value, ScalarValue};
use datafusion_common::{DataFusionError, Result};
use datafusion_expr::Accumulator;
use datafusion_row::accessor::RowAccessor;

use super::groups_accumulator::EmitTo;
use super::utils::{adjust_output_array, Decimal128Averager};
Expand Down Expand Up @@ -139,21 +135,6 @@ impl AggregateExpr for Avg {
&self.name
}

fn row_accumulator_supported(&self) -> bool {
is_row_accumulator_support_dtype(&self.sum_data_type)
}

fn create_row_accumulator(
&self,
start_index: usize,
) -> Result<Box<dyn RowAccumulator>> {
Ok(Box::new(AvgRowAccumulator::new(
start_index,
&self.sum_data_type,
&self.rt_data_type,
)))
}

fn reverse_expr(&self) -> Option<Arc<dyn AggregateExpr>> {
Some(Arc::new(self.clone()))
}
Expand Down Expand Up @@ -321,121 +302,6 @@ impl Accumulator for AvgAccumulator {
}
}

#[derive(Debug)]
struct AvgRowAccumulator {
state_index: usize,
sum_datatype: DataType,
return_data_type: DataType,
}

impl AvgRowAccumulator {
pub fn new(
start_index: usize,
sum_datatype: &DataType,
return_data_type: &DataType,
) -> Self {
Self {
state_index: start_index,
sum_datatype: sum_datatype.clone(),
return_data_type: return_data_type.clone(),
}
}
}

impl RowAccumulator for AvgRowAccumulator {
fn update_batch(
&mut self,
values: &[ArrayRef],
accessor: &mut RowAccessor,
) -> Result<()> {
let values = &values[0];
// count
let delta = (values.len() - values.null_count()) as u64;
accessor.add_u64(self.state_index(), delta);

// sum
sum::add_to_row(
self.state_index() + 1,
accessor,
&sum::sum_batch(values, &self.sum_datatype)?,
)
}

fn update_scalar_values(
&mut self,
values: &[ScalarValue],
accessor: &mut RowAccessor,
) -> Result<()> {
let value = &values[0];
sum::update_avg_to_row(self.state_index(), accessor, value)
}

fn update_scalar(
&mut self,
value: &ScalarValue,
accessor: &mut RowAccessor,
) -> Result<()> {
sum::update_avg_to_row(self.state_index(), accessor, value)
}

fn merge_batch(
&mut self,
states: &[ArrayRef],
accessor: &mut RowAccessor,
) -> Result<()> {
let counts = downcast_value!(states[0], UInt64Array);
// count
let delta = compute::sum(counts).unwrap_or(0);
accessor.add_u64(self.state_index(), delta);

// sum
let difference = sum::sum_batch(&states[1], &self.sum_datatype)?;
sum::add_to_row(self.state_index() + 1, accessor, &difference)
}

fn evaluate(&self, accessor: &RowAccessor) -> Result<ScalarValue> {
match self.sum_datatype {
DataType::Decimal128(p, s) => {
match accessor.get_u64_opt(self.state_index()) {
None => Ok(ScalarValue::Decimal128(None, p, s)),
Some(0) => Ok(ScalarValue::Decimal128(None, p, s)),
Some(n) => {
// now the sum_type and return type is not the same, need to convert the sum type to return type
accessor.get_i128_opt(self.state_index() + 1).map_or_else(
|| Ok(ScalarValue::Decimal128(None, p, s)),
|f| {
calculate_result_decimal_for_avg(
f,
n as i128,
s,
&self.return_data_type,
)
},
)
}
}
}
DataType::Float64 => Ok(match accessor.get_u64_opt(self.state_index()) {
None => ScalarValue::Float64(None),
Some(0) => ScalarValue::Float64(None),
Some(n) => ScalarValue::Float64(
accessor
.get_f64_opt(self.state_index() + 1)
.map(|f| f / n as f64),
),
}),
_ => Err(DataFusionError::Internal(
"Sum should be f64 or decimal128 on average".to_string(),
)),
}
}

#[inline(always)]
fn state_index(&self) -> usize {
self.state_index
}
}

/// An accumulator to compute the average of `[PrimitiveArray<T>]`.
/// Stores values as native types, and does overflow checking
///
Expand Down
Loading

0 comments on commit 50ea550

Please sign in to comment.