diff --git a/partiql-logical-planner/src/lower.rs b/partiql-logical-planner/src/lower.rs index 83982a8f..5820fcc4 100644 --- a/partiql-logical-planner/src/lower.rs +++ b/partiql-logical-planner/src/lower.rs @@ -32,10 +32,12 @@ use partiql_catalog::call_defs::{CallArgument, CallDef}; use partiql_ast_passes::error::{AstTransformError, AstTransformationError}; +use partiql_ast_passes::name_resolver::{NameLookup, NameRef}; use partiql_catalog::Catalog; use partiql_extension_ion::decode::{IonDecoderBuilder, IonDecoderConfig}; use partiql_extension_ion::Encoding; use partiql_logical::AggFunc::{AggAny, AggAvg, AggCount, AggEvery, AggMax, AggMin, AggSum}; +use partiql_logical::ValueExpr::DynamicLookup; use std::sync::atomic::{AtomicU32, Ordering}; type FnvIndexMap = IndexMap; @@ -161,6 +163,8 @@ pub struct AstToLogical<'a> { from_lets: HashSet, + projection_renames: Vec>>, + aliases: FnvIndexMap, // generator of 'fresh' ids @@ -230,6 +234,8 @@ impl<'a> AstToLogical<'a> { from_lets: Default::default(), + projection_renames: Default::default(), + aliases: Default::default(), // generator of 'fresh' ids @@ -283,6 +289,30 @@ impl<'a> AstToLogical<'a> { } fn resolve_varref(&self, varref: &ast::VarRef) -> logical::ValueExpr { + fn binding_to_symprim<'a>(binding: &'a BindingsName<'a>) -> SymbolPrimitive { + match binding { + BindingsName::CaseSensitive(n) => SymbolPrimitive { + value: n.as_ref().to_owned(), + case: CaseSensitivity::CaseSensitive, + }, + BindingsName::CaseInsensitive(n) => SymbolPrimitive { + value: n.as_ref().to_owned(), + case: CaseSensitivity::CaseInsensitive, + }, + } + } + + fn binding_to_static<'a>(binding: &'a BindingsName<'a>) -> BindingsName<'static> { + match binding { + BindingsName::CaseSensitive(n) => { + BindingsName::CaseSensitive(Cow::Owned(n.as_ref().to_string())) + } + BindingsName::CaseInsensitive(n) => { + BindingsName::CaseInsensitive(Cow::Owned(n.as_ref().to_string())) + } + } + } + // Convert a `SymbolPrimitive` into a `BindingsName` fn symprim_to_binding(sym: &SymbolPrimitive) -> BindingsName<'static> { match sym.case { @@ -306,7 +336,7 @@ impl<'a> AstToLogical<'a> { if let Some(key_schema) = self.key_registry.schema.get(id) { let key_schema: &name_resolver::KeySchema = key_schema; - let name_ref: &name_resolver::NameRef = key_schema + let name_ref: &NameRef = &key_schema .consume .iter() .find(|name_ref| name_ref.sym == varref.name) @@ -314,6 +344,27 @@ impl<'a> AstToLogical<'a> { let var_binding = symprim_to_binding(&name_ref.sym); let mut lookups = vec![]; + + if matches!(self.current_ctx(), Some(QueryContext::Order)) { + let renames = self.projection_renames.last().unwrap(); + let binding = renames + .iter() + .find(|(k, _)| { + let SymbolPrimitive { value, case } = &name_ref.sym; + match case { + CaseSensitivity::CaseSensitive => value == *k, + CaseSensitivity::CaseInsensitive => unicase::eq(value, *k), + } + }) + .map(|(k, v)| binding_to_static(v)) + .unwrap_or_else(|| symprim_to_binding(&name_ref.sym)); + + lookups.push(DynamicLookup(Box::new(vec![ValueExpr::VarRef( + binding, + VarRefType::Local, + )]))); + } + for lookup in &name_ref.lookup { match lookup { name_resolver::NameLookup::Global => { @@ -442,11 +493,13 @@ impl<'a> AstToLogical<'a> { fn enter_q(&mut self) { self.q_stack.push(Default::default()); self.ctx_stack.push(QueryContext::Query); + self.projection_renames.push(Default::default()); } #[inline] fn exit_q(&mut self) -> QueryClauses { - self.ctx_stack.pop(); + self.projection_renames.pop().expect("q level"); + self.ctx_stack.pop().expect("q level"); self.q_stack.pop().expect("q level") } @@ -856,6 +909,15 @@ impl<'a, 'ast> Visitor<'ast> for AstToLogical<'a> { "".to_string() } }; + + if !alias.is_empty() { + if let ValueExpr::VarRef(name, vrtype) = &value { + self.projection_renames + .last_mut() + .expect("renames") + .insert(alias.clone(), name.clone()); + } + } exprs.push((alias, value)); } @@ -1286,16 +1348,16 @@ impl<'a, 'ast> Visitor<'ast> for AstToLogical<'a> { Traverse::Continue } - fn enter_var_ref(&mut self, _var_ref: &'ast VarRef) -> Traverse { + fn enter_var_ref(&mut self, var_ref: &'ast VarRef) -> Traverse { let is_path = matches!(self.current_ctx(), Some(QueryContext::Path)); if !is_path { - let options = self.resolve_varref(_var_ref); + let options = self.resolve_varref(var_ref); self.push_vexpr(options); } else { let VarRef { name: SymbolPrimitive { value, case }, qualifier: _, - } = _var_ref; + } = var_ref; let name = match case { CaseSensitivity::CaseSensitive => { BindingsName::CaseSensitive(Cow::Owned(value.clone())) diff --git a/partiql/Cargo.toml b/partiql/Cargo.toml index 7d048628..a3866b66 100644 --- a/partiql/Cargo.toml +++ b/partiql/Cargo.toml @@ -25,6 +25,7 @@ bench = false [dev-dependencies] partiql-parser = { path = "../partiql-parser" } partiql-ast = { path = "../partiql-ast" } +partiql-ast-passes = { path = "../partiql-ast-passes" } partiql-catalog = { path = "../partiql-catalog"} partiql-value = { path = "../partiql-value" } partiql-logical = { path = "../partiql-logical" }