Skip to content

Commit

Permalink
Update some docs + add comments
Browse files Browse the repository at this point in the history
  • Loading branch information
alancai98 committed Apr 6, 2023
1 parent 75eb39c commit 832a1f1
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 10 deletions.
27 changes: 22 additions & 5 deletions partiql-eval/src/eval/evaluable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -284,24 +284,28 @@ impl Evaluable for EvalJoin {
}

/// An SQL aggregation function call that has been rewritten to be evaluated with the `GROUP BY`
/// clause. The `[name]` is the string (generated in lowering step) that replaces the
/// clause. The `[name]` is the string (generated in AST lowering step) that replaces the
/// aggregation call expression. This name will be used as the field in the binding tuple output
/// by `GROUP BY`. `[expr]` corresponds to the expression within the aggregation function. And
/// `[func]` corresponds to the aggregation function that's being called (e.g. sum, count, avg).
///
/// For example, `SELECT a AS a, SUM(b) AS b FROM t GROUP BY a` is rewritten to the following form
/// `SELECT a AS a, __agg1 AS b FROM t GROUP BY a`
/// In the above example, `name` corresponds to '__agg1', expr refers to the variable reference `b`,
/// and `func` corresponds to the sum aggregation function, `[AggSum]`.
/// `SELECT a AS a, $__agg_1 AS b FROM t GROUP BY a`
/// In the above example, `name` corresponds to '$__agg_1', `expr` refers to the expression within
/// the aggregation function, `b`, and `func` corresponds to the sum aggregation function,
/// `[AggSum]`.
#[derive(Debug)]
pub struct AggregateExpression {
pub name: String,
pub expr: Box<dyn EvalExpr>,
pub func: AggFunc,
}

/// Represents an SQL aggregation function computed on a collection of input values.
pub trait AggregateFunction {
/// Provides the next value for the given `group`.
fn next_value(&mut self, input_value: &Value, group: &Tuple);
/// Returns the result of the aggregation function for a given `group`.
fn compute(&self, group: &Tuple) -> Value;
}

Expand Down Expand Up @@ -337,14 +341,19 @@ impl AggregateFunction for AggFunc {
}
}

/// Filter values based on the given condition
#[derive(Debug, Default)]
pub enum AggFilterFn {
/// Keeps only distinct values in each group
Distinct(AggFilterDistinct),
/// Keeps all values
#[default]
All,
}

impl AggFilterFn {
/// Returns true if and only if for the given `group`, `input_value` should be processed
/// by the aggregation function
fn filter_value(&mut self, input_value: Value, group: &Tuple) -> bool {
match self {
AggFilterFn::Distinct(d) => d.filter_value(input_value, group),
Expand Down Expand Up @@ -386,6 +395,7 @@ impl AggFilterDistinct {
}
}

/// Represents SQL's `AVG` aggregation function
#[derive(Debug)]
pub struct AggAvg {
avgs: HashMap<Tuple, (usize, Value)>,
Expand Down Expand Up @@ -433,6 +443,7 @@ impl AggregateFunction for AggAvg {
}
}

/// Represents SQL's `COUNT` aggregation function
#[derive(Debug)]
pub struct AggCount {
counts: HashMap<Tuple, usize>,
Expand Down Expand Up @@ -480,6 +491,7 @@ impl AggregateFunction for AggCount {
}
}

/// Represents SQL's `MAX` aggregation function
#[derive(Debug)]
pub struct AggMax {
maxes: HashMap<Tuple, Value>,
Expand Down Expand Up @@ -526,6 +538,7 @@ impl AggregateFunction for AggMax {
}
}

/// Represents SQL's `MIN` aggregation function
#[derive(Debug)]
pub struct AggMin {
mins: HashMap<Tuple, Value>,
Expand Down Expand Up @@ -572,6 +585,7 @@ impl AggregateFunction for AggMin {
}
}

/// Represents SQL's `SUM` aggregation function
#[derive(Debug)]
pub struct AggSum {
sums: HashMap<Tuple, Value>,
Expand Down Expand Up @@ -621,7 +635,7 @@ impl AggregateFunction for AggSum {
/// Represents an evaluation `GROUP BY` operator. For `GROUP BY` operational semantics, see section
/// `11` of
/// [PartiQL Specification — August 1, 2019](https://partiql.org/assets/PartiQL-Specification.pdf).
/// TODO: some docs on `aggregate_exprs`
/// `aggregate_exprs` represents the set of aggregate expressions to compute.
#[derive(Debug)]
pub struct EvalGroupBy {
pub strategy: EvalGroupingStrategy,
Expand Down Expand Up @@ -665,6 +679,7 @@ impl Evaluable for EvalGroupBy {
for v in input_value.into_iter() {
let v_as_tuple = v.coerce_to_tuple();
let group = self.eval_group(&v_as_tuple, ctx);
// Compute next aggregation result for each of the aggregation expressions
for aggregate_expr in self.aggregate_exprs.iter_mut() {
let evaluated_val =
aggregate_expr.expr.evaluate(&v_as_tuple, ctx).into_owned();
Expand All @@ -679,6 +694,8 @@ impl Evaluable for EvalGroupBy {
let bag = groups
.into_iter()
.map(|(mut k, v)| {
// Finalize aggregation computation and include result in output binding
// tuple
let mut agg_results: Vec<(&str, Value)> = vec![];
for aggregate_expr in &self.aggregate_exprs {
let agg_result = aggregate_expr.func.compute(&k);
Expand Down
24 changes: 19 additions & 5 deletions partiql-logical-planner/src/lower.rs
Original file line number Diff line number Diff line change
Expand Up @@ -880,16 +880,21 @@ impl<'ast> Visitor<'ast> for AstToLogical {
}

fn exit_call_agg(&mut self, call_agg: &'ast CallAgg) {
// TODO distinguishing between PartiQL/top-level aggregation function calls and SQL
// aggregation functions. Currently only handles SQL aggregation functions.
// Relates to the SQL aggregation functions (e.g. AVG, COUNT, SUM) -- not the `COLL_`
// functions
let env = self.exit_call();
let name = call_agg.func_name.value.to_lowercase();

let new_name = "__agg".to_owned() + &self.agg_id.id();
// Rewrites the SQL aggregation function call to be a variable reference that the `GROUP BY`
// clause will add to the binding tuples.
// E.g. SELECT a, SUM(b) FROM t GROUP BY a
// SELECT a AS a, $__agg_1 AS b FROM t GROUP BY a
let new_name = "$__agg".to_owned() + &self.agg_id.id();
let new_binding_name = BindingsName::CaseSensitive(new_name.clone());
let new_expr = ValueExpr::VarRef(new_binding_name);
self.push_vexpr(new_expr);

// Default set quantifier if the set quantifier keyword is omitted will be `ALL`
let mut setq = logical::SetQuantifier::All;

let arg = match env.last().unwrap() {
Expand Down Expand Up @@ -937,11 +942,11 @@ impl<'ast> Visitor<'ast> for AstToLogical {
};
self.aggregate_exprs.push(agg_expr);
// PartiQL permits SQL aggregations without a GROUP BY (e.g. SELECT SUM(t.a) FROM ...)
// What follows adds a GROUP BY clause with the rewrite `... GROUP BY true AS __gk`
// What follows adds a GROUP BY clause with the rewrite `... GROUP BY true AS $__gk`
if self.current_clauses_mut().group_by_clause.is_none() {
let mut exprs = HashMap::new();
exprs.insert(
"__gk".to_string(),
"$__gk".to_string(),
ValueExpr::Lit(Box::new(Value::from(true))),
);
let group_by: BindingsOp = BindingsOp::GroupBy(logical::GroupBy {
Expand Down Expand Up @@ -1202,6 +1207,15 @@ impl<'ast> Visitor<'ast> for AstToLogical {
GroupingStrategy::GroupPartial => logical::GroupingStrategy::GroupPartial,
};

// What follows is an approach to implement section 11.2.1 of the PartiQL spec
// (https://partiql.org/assets/PartiQL-Specification.pdf#subsubsection.11.2.1)
// "Grouping Attributes and Direct Use of Grouping Expressions"
// Consider the query:
// SELECT t.a + 1 AS a FROM t GROUP BY t.a + 1 AS some_alias
// Since the group by key expression (t.a + 1) is the same as the select list expression, we
// can replace the query to be `SELECT some_alias AS a FROM t GROUP BY t.a + 1 AS some_alias`
// This isn't quite correct as it doesn't deal with SELECT VALUE expressions and expressions
// that are in the `HAVING` and `ORDER BY` clauses.
let select_clause_op_id = self.current_clauses_mut().select_clause.unwrap();
let select_clause = self.plan.operator_as_mut(select_clause_op_id).unwrap();
let mut binding = HashMap::new();
Expand Down
7 changes: 7 additions & 0 deletions partiql-logical/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,7 @@ pub enum JoinKind {
Cross,
}

/// An SQL aggregation function call with its arguments
#[derive(Debug, Clone, Eq, PartialEq)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct AggregateExpression {
Expand All @@ -293,14 +294,20 @@ pub struct AggregateExpression {
pub setq: SetQuantifier,
}

/// SQL aggregate function
#[derive(Debug, Clone, Eq, PartialEq)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub enum AggFunc {
// TODO: modeling of COUNT(*)
/// Represents SQL's `AVG` aggregation function
AggAvg,
/// Represents SQL's `COUNT` aggregation function
AggCount,
/// Represents SQL's `MAX` aggregation function
AggMax,
/// Represents SQL's `MIN` aggregation function
AggMin,
/// Represents SQL's `SUM` aggregation function
AggSum,
}

Expand Down
2 changes: 2 additions & 0 deletions partiql-value/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -521,11 +521,13 @@ impl Value {
}

#[inline]
/// Returns true if and only if Value is an integer, real, or decimal
pub fn is_number(&self) -> bool {
matches!(self, Value::Integer(_) | Value::Real(_) | Value::Decimal(_))
}

#[inline]
/// Returns true if and only if Value is null or missing
pub fn is_null_or_missing(&self) -> bool {
matches!(self, Value::Missing | Value::Null)
}
Expand Down

0 comments on commit 832a1f1

Please sign in to comment.