Skip to content

Commit

Permalink
Allow ORDER BY to 'see' projection names (#443)
Browse files Browse the repository at this point in the history
  • Loading branch information
jpschorr authored Feb 7, 2024
1 parent 3374e62 commit 05d3fc8
Show file tree
Hide file tree
Showing 5 changed files with 173 additions and 11 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Add `partiql-extension-visualize` for visualizing AST and logical plan

### Fixed
- Fixed `ORDER BY`'s ability to see into projection aliases

## [0.6.0] - 2023-10-31
### Changed
Expand Down
2 changes: 1 addition & 1 deletion partiql-eval/src/eval/expr/path.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::env::Bindings;

pub use core::borrow::{Borrow, BorrowMut};
pub use core::borrow::Borrow;

use crate::eval::expr::{BindError, BindEvalExpr, EvalExpr};
use crate::eval::EvalContext;
Expand Down
64 changes: 57 additions & 7 deletions partiql-logical-planner/src/lower.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,12 @@ use partiql_catalog::call_defs::{CallArgument, CallDef};

use partiql_ast_passes::error::{AstTransformError, AstTransformationError};

use partiql_ast_passes::name_resolver::NameRef;
use partiql_catalog::Catalog;
use partiql_extension_ion::decode::{IonDecoderBuilder, IonDecoderConfig};
use partiql_extension_ion::Encoding;
use partiql_logical::AggFunc::{AggAny, AggAvg, AggCount, AggEvery, AggMax, AggMin, AggSum};
use partiql_logical::ValueExpr::DynamicLookup;
use std::sync::atomic::{AtomicU32, Ordering};

type FnvIndexMap<K, V> = IndexMap<K, V, FnvBuildHasher>;
Expand Down Expand Up @@ -161,6 +163,8 @@ pub struct AstToLogical<'a> {

from_lets: HashSet<ast::NodeId>,

projection_renames: Vec<FnvIndexMap<String, BindingsName<'a>>>,

aliases: FnvIndexMap<NodeId, SymbolPrimitive>,

// generator of 'fresh' ids
Expand Down Expand Up @@ -230,6 +234,8 @@ impl<'a> AstToLogical<'a> {

from_lets: Default::default(),

projection_renames: Default::default(),

aliases: Default::default(),

// generator of 'fresh' ids
Expand Down Expand Up @@ -283,6 +289,17 @@ impl<'a> AstToLogical<'a> {
}

fn resolve_varref(&self, varref: &ast::VarRef) -> logical::ValueExpr {
fn binding_to_static<'a>(binding: &'a BindingsName<'a>) -> BindingsName<'static> {
match binding {
BindingsName::CaseSensitive(n) => {
BindingsName::CaseSensitive(Cow::Owned(n.as_ref().to_string()))
}
BindingsName::CaseInsensitive(n) => {
BindingsName::CaseInsensitive(Cow::Owned(n.as_ref().to_string()))
}
}
}

// Convert a `SymbolPrimitive` into a `BindingsName`
fn symprim_to_binding(sym: &SymbolPrimitive) -> BindingsName<'static> {
match sym.case {
Expand All @@ -306,14 +323,36 @@ impl<'a> AstToLogical<'a> {
if let Some(key_schema) = self.key_registry.schema.get(id) {
let key_schema: &name_resolver::KeySchema = key_schema;

let name_ref: &name_resolver::NameRef = key_schema
let name_ref: &NameRef = key_schema
.consume
.iter()
.find(|name_ref| name_ref.sym == varref.name)
.expect("NameRef");

let var_binding = symprim_to_binding(&name_ref.sym);
let mut lookups = vec![];

if matches!(self.current_ctx(), Some(QueryContext::Order)) {
if let Some(renames) = self.projection_renames.last() {
let binding = renames
.iter()
.find(|(k, _)| {
let SymbolPrimitive { value, case } = &name_ref.sym;
match case {
CaseSensitivity::CaseSensitive => value == *k,
CaseSensitivity::CaseInsensitive => unicase::eq(value, *k),
}
})
.map(|(_k, v)| binding_to_static(v))
.unwrap_or_else(|| symprim_to_binding(&name_ref.sym));

lookups.push(DynamicLookup(Box::new(vec![ValueExpr::VarRef(
binding,
VarRefType::Local,
)])));
}
}

for lookup in &name_ref.lookup {
match lookup {
name_resolver::NameLookup::Global => {
Expand Down Expand Up @@ -442,11 +481,13 @@ impl<'a> AstToLogical<'a> {
fn enter_q(&mut self) {
self.q_stack.push(Default::default());
self.ctx_stack.push(QueryContext::Query);
self.projection_renames.push(Default::default());
}

#[inline]
fn exit_q(&mut self) -> QueryClauses {
self.ctx_stack.pop();
self.projection_renames.pop().expect("q level");
self.ctx_stack.pop().expect("q level");
self.q_stack.pop().expect("q level")
}

Expand Down Expand Up @@ -856,6 +897,15 @@ impl<'a, 'ast> Visitor<'ast> for AstToLogical<'a> {
"".to_string()
}
};

if !alias.is_empty() {
if let ValueExpr::VarRef(name, _vrtype) = &value {
self.projection_renames
.last_mut()
.expect("renames")
.insert(alias.clone(), name.clone());
}
}
exprs.push((alias, value));
}

Expand Down Expand Up @@ -1286,16 +1336,16 @@ impl<'a, 'ast> Visitor<'ast> for AstToLogical<'a> {
Traverse::Continue
}

fn enter_var_ref(&mut self, _var_ref: &'ast VarRef) -> Traverse {
fn enter_var_ref(&mut self, var_ref: &'ast VarRef) -> Traverse {
let is_path = matches!(self.current_ctx(), Some(QueryContext::Path));
if !is_path {
let options = self.resolve_varref(_var_ref);
let options = self.resolve_varref(var_ref);
self.push_vexpr(options);
} else {
let VarRef {
name: SymbolPrimitive { value, case },
qualifier: _,
} = _var_ref;
} = var_ref;
let name = match case {
CaseSensitivity::CaseSensitive => {
BindingsName::CaseSensitive(Cow::Owned(value.clone()))
Expand Down Expand Up @@ -1963,7 +2013,7 @@ mod tests {
let lowering_errs = logical.expect_err("Expect errs").errors;
assert_eq!(lowering_errs.len(), 2);
assert_eq!(
lowering_errs.get(0),
lowering_errs.first(),
Some(&AstTransformError::UnsupportedFunction("foo".to_string()))
);
assert_eq!(
Expand All @@ -1985,7 +2035,7 @@ mod tests {
let lowering_errs = logical.expect_err("Expect errs").errors;
assert_eq!(lowering_errs.len(), 2);
assert_eq!(
lowering_errs.get(0),
lowering_errs.first(),
Some(&AstTransformError::InvalidNumberOfArguments(
"abs".to_string()
))
Expand Down
8 changes: 6 additions & 2 deletions partiql/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,18 @@ bench = false
[dev-dependencies]
partiql-parser = { path = "../partiql-parser" }
partiql-ast = { path = "../partiql-ast" }
partiql-ast-passes = { path = "../partiql-ast-passes" }
partiql-catalog = { path = "../partiql-catalog"}
partiql-value = { path = "../partiql-value" }
partiql-logical = { path = "../partiql-logical" }
partiql-logical-planner = { path = "../partiql-logical-planner" }
partiql-eval = { path = "../partiql-eval" }

itertools = "0.10"
criterion = "0.4"

thiserror = "1.0"

itertools = "0.12"
criterion = "0.5"
rand = "0.8"

[[bench]]
Expand Down
109 changes: 108 additions & 1 deletion partiql/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,112 @@
#[cfg(test)]
mod tests {
use partiql_ast_passes::error::AstTransformationError;
use partiql_catalog::{Catalog, PartiqlCatalog};
use partiql_eval as eval;
use partiql_eval::env::basic::MapBindings;
use partiql_eval::error::{EvalErr, PlanErr};
use partiql_eval::eval::{EvalPlan, EvalResult, Evaluated};
use partiql_eval::plan::EvaluationMode;
use partiql_logical as logical;
use partiql_parser::{Parsed, ParserError, ParserResult};
use partiql_value::Value;
use thiserror::Error;

#[derive(Error, Debug)]
enum TestError<'a> {
#[error("Parse error: {0:?}")]
Parse(ParserError<'a>),
#[error("Lower error: {0:?}")]
Lower(AstTransformationError),
#[error("Plan error: {0:?}")]
Plan(PlanErr),
#[error("Evaluation error: {0:?}")]
Eval(EvalErr),
}

impl<'a> From<ParserError<'a>> for TestError<'a> {
fn from(err: ParserError<'a>) -> Self {
TestError::Parse(err)
}
}

impl From<AstTransformationError> for TestError<'_> {
fn from(err: AstTransformationError) -> Self {
TestError::Lower(err)
}
}

impl From<PlanErr> for TestError<'_> {
fn from(err: PlanErr) -> Self {
TestError::Plan(err)
}
}

impl From<EvalErr> for TestError<'_> {
fn from(err: EvalErr) -> Self {
TestError::Eval(err)
}
}

#[track_caller]
#[inline]
fn parse(statement: &str) -> ParserResult {
partiql_parser::Parser::default().parse(statement)
}

#[track_caller]
#[inline]
fn lower(
catalog: &dyn Catalog,
parsed: &Parsed,
) -> Result<logical::LogicalPlan<logical::BindingsOp>, AstTransformationError> {
let planner = partiql_logical_planner::LogicalPlanner::new(catalog);
planner.lower(parsed)
}

#[track_caller]
#[inline]
fn compile(
mode: EvaluationMode,
catalog: &dyn Catalog,
logical: logical::LogicalPlan<logical::BindingsOp>,
) -> Result<EvalPlan, PlanErr> {
let mut planner = eval::plan::EvaluatorPlanner::new(mode, catalog);
planner.compile(&logical)
}

#[track_caller]
#[inline]
fn evaluate(mut plan: EvalPlan, bindings: MapBindings<Value>) -> EvalResult {
plan.execute_mut(bindings)
}

#[track_caller]
#[inline]
fn eval(statement: &str, mode: EvaluationMode) -> Result<Evaluated, TestError<'_>> {
let catalog = PartiqlCatalog::default();

let parsed = parse(statement)?;
let lowered = lower(&catalog, &parsed)?;
let bindings = Default::default();
let plan = compile(mode, &catalog, lowered)?;
Ok(evaluate(plan, bindings)?)
}

#[test]
fn todo() {}
fn order_by_count() {
let query = "select foo, count(1) as n from
<<
{ 'foo': 'foo' },
{ 'foo': 'bar' },
{ 'foo': 'qux' },
{ 'foo': 'bar' },
{ 'foo': 'baz' },
{ 'foo': 'bar' },
{ 'foo': 'baz' }
>> group by foo order by n desc";

let res = eval(query, EvaluationMode::Permissive);
assert!(res.is_ok());
}
}

1 comment on commit 05d3fc8

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PartiQL (rust) Benchmark

Benchmark suite Current: 05d3fc8 Previous: 3374e62 Ratio
arith_agg-avg 762950 ns/iter (± 5108) 763583 ns/iter (± 3390) 1.00
arith_agg-avg_distinct 848481 ns/iter (± 10112) 848959 ns/iter (± 3747) 1.00
arith_agg-count 806134 ns/iter (± 14060) 804209 ns/iter (± 29915) 1.00
arith_agg-count_distinct 841558 ns/iter (± 3359) 841379 ns/iter (± 2663) 1.00
arith_agg-min 814702 ns/iter (± 5147) 811458 ns/iter (± 12162) 1.00
arith_agg-min_distinct 845150 ns/iter (± 23113) 848553 ns/iter (± 2675) 1.00
arith_agg-max 819597 ns/iter (± 5641) 822782 ns/iter (± 2645) 1.00
arith_agg-max_distinct 857070 ns/iter (± 8162) 859345 ns/iter (± 2350) 1.00
arith_agg-sum 811700 ns/iter (± 4935) 811349 ns/iter (± 3125) 1.00
arith_agg-sum_distinct 845651 ns/iter (± 3480) 849049 ns/iter (± 4051) 1.00
arith_agg-avg-count-min-max-sum 964272 ns/iter (± 8929) 959853 ns/iter (± 3200) 1.00
arith_agg-avg-count-min-max-sum-group_by 1208187 ns/iter (± 11106) 1257292 ns/iter (± 8831) 0.96
arith_agg-avg-count-min-max-sum-group_by-group_as 1810359 ns/iter (± 20019) 1809821 ns/iter (± 24568) 1.00
arith_agg-avg_distinct-count_distinct-min_distinct-max_distinct-sum_distinct 1245624 ns/iter (± 10761) 1226070 ns/iter (± 12413) 1.02
arith_agg-avg_distinct-count_distinct-min_distinct-max_distinct-sum_distinct-group_by 1537157 ns/iter (± 21402) 1562560 ns/iter (± 8900) 0.98
arith_agg-avg_distinct-count_distinct-min_distinct-max_distinct-sum_distinct-group_by-group_as 2109521 ns/iter (± 15943) 2142416 ns/iter (± 9792) 0.98
parse-1 4302 ns/iter (± 69) 4249 ns/iter (± 15) 1.01
parse-15 39811 ns/iter (± 222) 38868 ns/iter (± 402) 1.02
parse-30 79992 ns/iter (± 499) 75578 ns/iter (± 10744) 1.06
compile-1 4462 ns/iter (± 21) 4355 ns/iter (± 26) 1.02
compile-15 32043 ns/iter (± 180) 32542 ns/iter (± 178) 0.98
compile-30 64769 ns/iter (± 271) 66497 ns/iter (± 420) 0.97
plan-1 65654 ns/iter (± 458) 65369 ns/iter (± 217) 1.00
plan-15 1020884 ns/iter (± 27825) 1016409 ns/iter (± 38498) 1.00
plan-30 2042760 ns/iter (± 17083) 2052512 ns/iter (± 13028) 1.00
eval-1 13009758 ns/iter (± 289446) 13391126 ns/iter (± 150215) 0.97
eval-15 86227676 ns/iter (± 929942) 85921655 ns/iter (± 2375249) 1.00
eval-30 165768763 ns/iter (± 809202) 164438426 ns/iter (± 2221754) 1.01
join 9723 ns/iter (± 96) 9623 ns/iter (± 492) 1.01
simple 2498 ns/iter (± 10) 2451 ns/iter (± 17) 1.02
simple-no 435 ns/iter (± 1) 437 ns/iter (± 3) 1.00
numbers 57 ns/iter (± 0) 58 ns/iter (± 3) 0.98
parse-simple 618 ns/iter (± 2) 624 ns/iter (± 2) 0.99
parse-ion 1890 ns/iter (± 8) 1802 ns/iter (± 16) 1.05
parse-group 5822 ns/iter (± 23) 5628 ns/iter (± 11) 1.03
parse-complex 14946 ns/iter (± 184) 14538 ns/iter (± 46) 1.03
parse-complex-fexpr 21988 ns/iter (± 77) 21356 ns/iter (± 70) 1.03

This comment was automatically generated by workflow using github-action-benchmark.

Please sign in to comment.