From 3f9d17fb38edafb23c1908af27b06c78d66680cd Mon Sep 17 00:00:00 2001 From: Josh Pschorr Date: Fri, 11 Oct 2024 10:19:49 -0700 Subject: [PATCH] Add validation for comparison operations (#505) --- CHANGELOG.md | 1 + partiql-eval/src/eval/eval_expr_wrapper.rs | 33 +++++- partiql-eval/src/eval/expr/functions.rs | 11 +- partiql-eval/src/eval/expr/operators.rs | 82 +++++++++++++-- partiql-eval/src/lib.rs | 3 + partiql-value/src/lib.rs | 12 ++- partiql/tests/common.rs | 115 +++++++++++++++++++++ partiql/tests/comparisons.rs | 111 ++++++++++++++++++++ partiql/tests/extension_error.rs | 41 +++----- partiql/tests/pretty.rs | 10 +- partiql/tests/tuple_ops.rs | 105 ++----------------- partiql/tests/user_context.rs | 26 ++--- 12 files changed, 387 insertions(+), 163 deletions(-) create mode 100644 partiql/tests/common.rs create mode 100644 partiql/tests/comparisons.rs diff --git a/CHANGELOG.md b/CHANGELOG.md index 9008f486..c8385b24 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] ### Changed - *BREAKING* partiql-parser: Added a source location to `ParseError::UnexpectedEndOfInput` +- partiql-eval: Fixed behavior of comparison and `BETWEEN` operations w.r.t. type mismatches ### Added - partiql-value: Pretty-printing of `Value` via `ToPretty` trait diff --git a/partiql-eval/src/eval/eval_expr_wrapper.rs b/partiql-eval/src/eval/eval_expr_wrapper.rs index b88d3cc6..b90568b0 100644 --- a/partiql-eval/src/eval/eval_expr_wrapper.rs +++ b/partiql-eval/src/eval/eval_expr_wrapper.rs @@ -94,6 +94,12 @@ pub(crate) enum ArgCheckControlFlow { Propagate(R), } +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub(crate) struct ArgValidateError { + pub(crate) message: String, + pub(crate) propagate: Value, +} + /// A type which performs argument checking during evaluation. pub(crate) trait ArgChecker: Debug { /// Check an argument against an expected type. @@ -101,6 +107,11 @@ pub(crate) trait ArgChecker: Debug { typ: &PartiqlShape, arg: Cow<'a, Value>, ) -> ArgCheckControlFlow>; + + /// Validate all arguments. + fn validate_args(args: Vec>) -> Result>, ArgValidateError> { + Ok(args) + } } /// How to handle argument mismatch and `MISSING` propagation @@ -273,7 +284,12 @@ impl, ArgC: ArgChecker ControlFlow::Break(Missing) }; - match evaluate_args::<{ STRICT }, ArgC, _>(&self.args, |n| &self.types[n], bindings, ctx) { + match evaluate_and_validate_args::<{ STRICT }, ArgC, _>( + &self.args, + |n| &self.types[n], + bindings, + ctx, + ) { ControlFlow::Continue(result) => match result.try_into() { Ok(a) => ControlFlow::Continue(a), Err(args) => err_arg_count_mismatch(args), @@ -283,7 +299,7 @@ impl, ArgC: ArgChecker } } -pub(crate) fn evaluate_args< +pub(crate) fn evaluate_and_validate_args< 'a, 'c, 't, @@ -352,7 +368,18 @@ where ControlFlow::Break(v) } else { // If `propagate` is `None`, then return result - ControlFlow::Continue(result) + match ArgC::validate_args(result) { + Ok(result) => ControlFlow::Continue(result), + Err(err) => { + if STRICT { + ctx.add_error(EvaluationError::IllegalState(format!( + "Arguments failed validation: {}", + err.message + ))) + } + ControlFlow::Break(err.propagate) + } + } } } diff --git a/partiql-eval/src/eval/expr/functions.rs b/partiql-eval/src/eval/expr/functions.rs index 1d6e77e4..485ddac0 100644 --- a/partiql-eval/src/eval/expr/functions.rs +++ b/partiql-eval/src/eval/expr/functions.rs @@ -1,4 +1,6 @@ -use crate::eval::eval_expr_wrapper::{evaluate_args, DefaultArgChecker, PropagateMissing}; +use crate::eval::eval_expr_wrapper::{ + evaluate_and_validate_args, DefaultArgChecker, PropagateMissing, +}; use crate::eval::expr::{BindError, BindEvalExpr, EvalExpr}; use crate::eval::EvalContext; @@ -41,7 +43,12 @@ impl EvalExpr for EvalExprFnScalar { { type Check = DefaultArgChecker>; let typ = PartiqlShapeBuilder::init_or_get().new_struct(StructType::new_any()); - match evaluate_args::<{ STRICT }, Check, _>(&self.args, |_| &typ, bindings, ctx) { + match evaluate_and_validate_args::<{ STRICT }, Check, _>( + &self.args, + |_| &typ, + bindings, + ctx, + ) { ControlFlow::Break(v) => Cow::Owned(v), ControlFlow::Continue(args) => match self.plan.evaluate(&args, ctx.as_session()) { Ok(v) => v, diff --git a/partiql-eval/src/eval/expr/operators.rs b/partiql-eval/src/eval/expr/operators.rs index 693bf695..6cecac0a 100644 --- a/partiql-eval/src/eval/expr/operators.rs +++ b/partiql-eval/src/eval/expr/operators.rs @@ -1,7 +1,7 @@ use crate::eval::eval_expr_wrapper::{ - ArgCheckControlFlow, ArgChecker, ArgShortCircuit, BinaryValueExpr, DefaultArgChecker, - ExecuteEvalExpr, NullArgChecker, PropagateMissing, PropagateNull, TernaryValueExpr, - UnaryValueExpr, + ArgCheckControlFlow, ArgChecker, ArgShortCircuit, ArgValidateError, BinaryValueExpr, + DefaultArgChecker, ExecuteEvalExpr, NullArgChecker, PropagateMissing, PropagateNull, + TernaryValueExpr, UnaryValueExpr, }; use crate::eval::expr::{BindError, BindEvalExpr, EvalExpr}; @@ -12,7 +12,7 @@ use partiql_types::{ Static, StructType, }; use partiql_value::Value::{Boolean, Missing, Null}; -use partiql_value::{BinaryAnd, EqualityValue, NullableEq, NullableOrd, Tuple, Value}; +use partiql_value::{BinaryAnd, Comparable, EqualityValue, NullableEq, NullableOrd, Tuple, Value}; use std::borrow::{Borrow, Cow}; use std::fmt::{Debug, Formatter}; @@ -146,6 +146,34 @@ impl ArgChecker } } +#[derive(Debug)] +pub(crate) struct ComparisonArgChecker { + check: PhantomData>, +} + +impl ArgChecker + for ComparisonArgChecker +{ + #[inline] + fn arg_check<'a>( + typ: &PartiqlShape, + arg: Cow<'a, Value>, + ) -> ArgCheckControlFlow> { + DefaultArgChecker::<{ STRICT }, OnMissing>::arg_check(typ, arg) + } + + fn validate_args(args: Vec>) -> Result>, ArgValidateError> { + if args.len() == 2 && args[0].is_comparable_to(&args[1]) { + Ok(args) + } else { + Err(ArgValidateError { + message: "data-type mismatch".to_string(), + propagate: OnMissing::propagate(), + }) + } + } +} + impl BindEvalExpr for EvalOpBinary { #[inline] fn bind( @@ -157,6 +185,7 @@ impl BindEvalExpr for EvalOpBinary { type InCheck = DefaultArgChecker>; type Check = DefaultArgChecker>; type EqCheck = DefaultArgChecker>; + type CompCheck = ComparisonArgChecker>; type MathCheck = DefaultArgChecker>; macro_rules! create { @@ -177,6 +206,12 @@ impl BindEvalExpr for EvalOpBinary { }; } + macro_rules! comparison { + ($f:expr) => { + create!(CompCheck, [type_dynamic!(), type_dynamic!()], $f) + }; + } + macro_rules! math { ($f:expr) => {{ let nums = PartiqlShapeBuilder::init_or_get().any_of(type_numeric!()); @@ -195,10 +230,10 @@ impl BindEvalExpr for EvalOpBinary { let wrap = EqualityValue::; NullableEq::neq(&wrap(lhs), &wrap(rhs)) }), - EvalOpBinary::Gt => equality!(NullableOrd::gt), - EvalOpBinary::Gteq => equality!(NullableOrd::gteq), - EvalOpBinary::Lt => equality!(NullableOrd::lt), - EvalOpBinary::Lteq => equality!(NullableOrd::lteq), + EvalOpBinary::Gt => comparison!(NullableOrd::gt), + EvalOpBinary::Gteq => comparison!(NullableOrd::gteq), + EvalOpBinary::Lt => comparison!(NullableOrd::lt), + EvalOpBinary::Lteq => comparison!(NullableOrd::lteq), EvalOpBinary::Add => math!(|lhs, rhs| lhs + rhs), EvalOpBinary::Sub => math!(|lhs, rhs| lhs - rhs), EvalOpBinary::Mul => math!(|lhs, rhs| lhs * rhs), @@ -275,6 +310,35 @@ impl BindEvalExpr for EvalOpBinary { } } +#[derive(Debug)] +pub(crate) struct BetweenArgChecker { + check: PhantomData, +} + +impl ArgChecker for BetweenArgChecker { + #[inline] + fn arg_check<'a>( + typ: &PartiqlShape, + arg: Cow<'a, Value>, + ) -> ArgCheckControlFlow> { + NullArgChecker::arg_check(typ, arg) + } + + fn validate_args(args: Vec>) -> Result>, ArgValidateError> { + if args.len() == 3 + && args[0].is_comparable_to(&args[1]) + && args[0].is_comparable_to(&args[2]) + { + Ok(args) + } else { + Err(ArgValidateError { + message: "data-type mismatch".to_string(), + propagate: Value::Missing, + }) + } + } +} + /// Represents an evaluation `PartiQL` `BETWEEN` operator, e.g. `x BETWEEN 10 AND 20`. #[derive(Debug, Default, Clone)] pub(crate) struct EvalBetweenExpr {} @@ -285,7 +349,7 @@ impl BindEvalExpr for EvalBetweenExpr { args: Vec>, ) -> Result, BindError> { let types = [type_dynamic!(), type_dynamic!(), type_dynamic!()]; - TernaryValueExpr::create_checked::<{ STRICT }, NullArgChecker, _>( + TernaryValueExpr::create_checked::<{ STRICT }, BetweenArgChecker<{ STRICT }>, _>( types, args, |value, from, to| value.gteq(from).and(&value.lteq(to)), diff --git a/partiql-eval/src/lib.rs b/partiql-eval/src/lib.rs index 846a16db..69240b44 100644 --- a/partiql-eval/src/lib.rs +++ b/partiql-eval/src/lib.rs @@ -741,6 +741,9 @@ mod tests { // left part of AND evaluates to false eval_between_op(Value::from(1), Value::from(2), Null, Value::from(false)); eval_between_op(Value::from(1), Value::from(2), Missing, Value::from(false)); + // right part of AND evaluates to false + eval_between_op(Value::from(2), Null, Value::from(1), Value::from(false)); + eval_between_op(Value::from(2), Missing, Value::from(1), Value::from(false)); } #[test] diff --git a/partiql-value/src/lib.rs b/partiql-value/src/lib.rs index e90ab87d..938769c9 100644 --- a/partiql-value/src/lib.rs +++ b/partiql-value/src/lib.rs @@ -6,7 +6,7 @@ use std::cmp::Ordering; use std::borrow::Cow; -use std::fmt::{Debug, Formatter}; +use std::fmt::{Debug, Display, Formatter}; use std::hash::Hash; use std::iter::Once; @@ -27,6 +27,7 @@ pub use list::*; pub use pretty::*; pub use tuple::*; +use partiql_common::pretty::ToPretty; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; @@ -60,6 +61,15 @@ pub enum Value { // TODO: add other supported PartiQL values -- sexp } +impl Display for Value { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self.to_pretty_string(f.width().unwrap_or(80)) { + Ok(pretty) => f.write_str(&pretty), + Err(_) => f.write_str(""), + } + } +} + impl ops::Add for &Value { type Output = Value; diff --git a/partiql/tests/common.rs b/partiql/tests/common.rs new file mode 100644 index 00000000..b34b18f4 --- /dev/null +++ b/partiql/tests/common.rs @@ -0,0 +1,115 @@ +use partiql_ast_passes::error::AstTransformationError; +use partiql_catalog::catalog::{Catalog, PartiqlCatalog}; +use partiql_catalog::context::SystemContext; +use partiql_eval as eval; +use partiql_eval::env::basic::MapBindings; +use partiql_eval::error::{EvalErr, PlanErr}; +use partiql_eval::eval::{BasicContext, EvalPlan, EvalResult, Evaluated}; +use partiql_eval::plan::EvaluationMode; +use partiql_logical as logical; +use partiql_parser::{Parsed, ParserError, ParserResult}; +use partiql_value::{DateTime, Value}; +use std::error::Error; +use thiserror::Error; + +#[derive(Error, Debug)] +pub enum TestError<'a> { + #[error("Parse error: {0:?}")] + Parse(ParserError<'a>), + #[error("Lower error: {0:?}")] + Lower(AstTransformationError), + #[error("Plan error: {0:?}")] + Plan(PlanErr), + #[error("Evaluation error: {0:?}")] + Eval(EvalErr), + #[error("Other: {0:?}")] + Other(Box), +} + +impl<'a> From> for TestError<'a> { + fn from(err: ParserError<'a>) -> Self { + TestError::Parse(err) + } +} + +impl From for TestError<'_> { + fn from(err: AstTransformationError) -> Self { + TestError::Lower(err) + } +} + +impl From for TestError<'_> { + fn from(err: PlanErr) -> Self { + TestError::Plan(err) + } +} + +impl From for TestError<'_> { + fn from(err: EvalErr) -> Self { + TestError::Eval(err) + } +} + +impl From> for TestError<'_> { + fn from(err: Box) -> Self { + TestError::Other(err) + } +} + +#[track_caller] +#[inline] +pub fn parse(statement: &str) -> ParserResult<'_> { + partiql_parser::Parser::default().parse(statement) +} + +#[track_caller] +#[inline] +pub fn lower( + catalog: &dyn Catalog, + parsed: &Parsed<'_>, +) -> Result, AstTransformationError> { + let planner = partiql_logical_planner::LogicalPlanner::new(catalog); + planner.lower(parsed) +} + +#[track_caller] +#[inline] +pub fn compile( + mode: EvaluationMode, + catalog: &dyn Catalog, + logical: logical::LogicalPlan, +) -> Result { + let mut planner = eval::plan::EvaluatorPlanner::new(mode, catalog); + planner.compile(&logical) +} + +#[track_caller] +#[inline] +pub fn evaluate(mut plan: EvalPlan, bindings: MapBindings) -> EvalResult { + let sys = SystemContext { + now: DateTime::from_system_now_utc(), + }; + let ctx = BasicContext::new(bindings, sys); + plan.execute_mut(&ctx) +} + +#[track_caller] +#[inline] +pub fn eval_query_with_catalog<'a>( + statement: &'a str, + catalog: &dyn Catalog, + mode: EvaluationMode, +) -> Result> { + let parsed = parse(statement)?; + let lowered = lower(catalog, &parsed)?; + let bindings = Default::default(); + let plan = compile(mode, catalog, lowered)?; + Ok(evaluate(plan, bindings)?) +} + +#[track_caller] +#[inline] +pub fn eval_query(statement: &str, mode: EvaluationMode) -> Result> { + let catalog = PartiqlCatalog::default(); + eval_query_with_catalog(statement, &catalog, mode) +} diff --git a/partiql/tests/comparisons.rs b/partiql/tests/comparisons.rs new file mode 100644 index 00000000..67716dd9 --- /dev/null +++ b/partiql/tests/comparisons.rs @@ -0,0 +1,111 @@ +use crate::common::{eval_query, TestError}; +use assert_matches::assert_matches; +use itertools::Itertools; +use partiql_eval::eval::Evaluated; +use partiql_eval::plan::EvaluationMode; +use partiql_value::{Comparable, Value}; + +mod common; + +#[track_caller] +#[inline] +pub fn eval_modes(statement: &str) -> (Result, Result) { + let permissive = eval_query(statement, EvaluationMode::Permissive); + let strict = eval_query(statement, EvaluationMode::Strict); + (permissive, strict) +} + +#[track_caller] +#[inline] +pub fn eval_fail(statement: &str) { + let (permissive, strict) = eval_modes(statement); + + assert_matches!(permissive, Ok(_)); + let permissive = permissive.unwrap().result; + assert_matches!(permissive, Value::Missing); + + assert_matches!(strict, Err(_)); + let err = strict.unwrap_err(); + assert_matches!(err, TestError::Eval(_)); +} + +#[track_caller] +#[inline] +pub fn eval_success(statement: &str) { + let (permissive, strict) = eval_modes(statement); + + assert_matches!(permissive, Ok(_)); + assert_matches!(strict, Ok(_)); + assert_eq!(permissive.unwrap().result, strict.unwrap().result); +} + +#[track_caller] +#[inline] +pub fn eval_op(op: &str) { + let vals = op_values(); + let pairs = vals.clone().into_iter().cartesian_product(vals); + for (l, r) in pairs { + let statement = format!("{l} {op} {r}"); + if l.is_comparable_to(&r) { + println!("`{statement}` should compare"); + eval_success(&statement); + } else { + println!("`{statement}` should error"); + eval_fail(&statement); + } + } +} + +fn op_values() -> [Value; 4] { + [ + Value::Integer(1), + Value::Real(3.14.into()), + Value::Boolean(true), + Value::String("foo".to_string().into()), + /* TODO currently DateTimes can be printed but not yet parsed + Value::DateTime(Box::new(DateTime::TimestampWithTz( + time::OffsetDateTime::now_utc(), + ))), + */ + ] +} + +#[test] +fn lt() { + eval_op("<") +} + +#[test] +fn gt() { + eval_op(">") +} + +#[test] +fn lte() { + eval_op("<=") +} + +#[test] +fn gte() { + eval_op(">=") +} + +#[test] +fn between() { + let vals = op_values(); + let pairs = vals.clone().into_iter().cartesian_product(vals.clone()); + let trios = pairs + .into_iter() + .cartesian_product(vals) + .map(|((l, m), r)| (l, m, r)); + for (l, m, r) in trios { + let statement = format!("{l} BETWEEN {m} AND {r}"); + if l.is_comparable_to(&r) && l.is_comparable_to(&m) { + println!("`{statement}` should compare"); + eval_success(&statement); + } else { + println!("`{statement}` should error"); + eval_fail(&statement); + } + } +} diff --git a/partiql/tests/extension_error.rs b/partiql/tests/extension_error.rs index 3b71f856..d32c4af6 100644 --- a/partiql/tests/extension_error.rs +++ b/partiql/tests/extension_error.rs @@ -16,11 +16,12 @@ use partiql_eval::env::basic::MapBindings; use partiql_eval::error::{EvalErr, EvaluationError}; use partiql_eval::eval::{BasicContext, Evaluated}; use partiql_eval::plan::EvaluationMode; -use partiql_parser::{Parsed, ParserResult}; use partiql_value::{bag, tuple, DateTime, Value}; +use crate::common::{lower, parse, TestError}; use partiql_logical as logical; +mod common; #[derive(Debug)] pub struct UserCtxTestExtension {} @@ -115,21 +116,6 @@ impl Iterator for TestDataGen { Some(Err(Box::new(UserCtxError::Runtime))) } } -#[track_caller] -#[inline] -pub(crate) fn parse(statement: &str) -> ParserResult { - partiql_parser::Parser::default().parse(statement) -} - -#[track_caller] -#[inline] -pub(crate) fn lower( - catalog: &dyn Catalog, - parsed: &Parsed<'_>, -) -> partiql_logical::LogicalPlan { - let planner = partiql_logical_planner::LogicalPlanner::new(catalog); - planner.lower(parsed).expect("lower") -} #[track_caller] #[inline] @@ -156,7 +142,7 @@ pub(crate) fn evaluate( } #[test] -fn test_context_bad_args_permissive() { +fn test_context_bad_args_permissive() -> Result<(), TestError<'static>> { let query = "SELECT foo, bar from test_user_context(9) as data"; let mut catalog = PartiqlCatalog::default(); @@ -164,7 +150,7 @@ fn test_context_bad_args_permissive() { ext.load(&mut catalog).expect("extension load to succeed"); let parsed = parse(query); - let lowered = lower(&catalog, &parsed.expect("parse")); + let lowered = lower(&catalog, &parsed.expect("parse"))?; let bindings = Default::default(); let ctx: [(String, &dyn Any); 0] = []; @@ -178,9 +164,11 @@ fn test_context_bad_args_permissive() { assert!(out.is_ok()); assert_eq!(out.unwrap().result, bag!(tuple!()).into()); + + Ok(()) } #[test] -fn test_context_bad_args_strict() { +fn test_context_bad_args_strict() -> Result<(), TestError<'static>> { use assert_matches::assert_matches; let query = "SELECT foo, bar from test_user_context(9) as data"; @@ -189,7 +177,7 @@ fn test_context_bad_args_strict() { ext.load(&mut catalog).expect("extension load to succeed"); let parsed = parse(query); - let lowered = lower(&catalog, &parsed.expect("parse")); + let lowered = lower(&catalog, &parsed.expect("parse"))?; let bindings = Default::default(); let ctx: [(String, &dyn Any); 0] = []; @@ -202,10 +190,12 @@ fn test_context_bad_args_strict() { assert_matches!(err, EvaluationError::ExtensionResultError(err) => { assert_eq!(err.to_string(), "bad arguments") }); + + Ok(()) } #[test] -fn test_context_runtime_permissive() { +fn test_context_runtime_permissive() -> Result<(), TestError<'static>> { let query = "SELECT foo, bar from test_user_context('counter') as data"; let mut catalog = PartiqlCatalog::default(); @@ -213,7 +203,7 @@ fn test_context_runtime_permissive() { ext.load(&mut catalog).expect("extension load to succeed"); let parsed = parse(query); - let lowered = lower(&catalog, &parsed.expect("parse")); + let lowered = lower(&catalog, &parsed.expect("parse"))?; let bindings = Default::default(); let ctx: [(String, &dyn Any); 0] = []; @@ -227,10 +217,11 @@ fn test_context_runtime_permissive() { assert!(out.is_ok()); assert_eq!(out.unwrap().result, bag!(tuple!()).into()); + Ok(()) } #[test] -fn test_context_runtime_strict() { +fn test_context_runtime_strict() -> Result<(), TestError<'static>> { use assert_matches::assert_matches; let query = "SELECT foo, bar from test_user_context('counter') as data"; @@ -239,7 +230,7 @@ fn test_context_runtime_strict() { ext.load(&mut catalog).expect("extension load to succeed"); let parsed = parse(query); - let lowered = lower(&catalog, &parsed.expect("parse")); + let lowered = lower(&catalog, &parsed.expect("parse"))?; let bindings = Default::default(); let ctx: [(String, &dyn Any); 0] = []; @@ -252,4 +243,6 @@ fn test_context_runtime_strict() { assert_matches!(err, EvaluationError::ExtensionResultError(err) => { assert_eq!(err.to_string(), "runtime error") }); + + Ok(()) } diff --git a/partiql/tests/pretty.rs b/partiql/tests/pretty.rs index e4f431f5..86d30e3a 100644 --- a/partiql/tests/pretty.rs +++ b/partiql/tests/pretty.rs @@ -1,21 +1,17 @@ +use crate::common::parse; use itertools::Itertools; use partiql_ast::ast::{AstNode, TopLevelQuery}; use partiql_common::pretty::ToPretty; -use partiql_parser::ParserResult; use partiql_value::{bag, list, tuple, DateTime, Value}; use rust_decimal::prelude::FromPrimitive; use time::macros::{date, datetime, offset, time}; -#[track_caller] -#[inline] -fn parse(statement: &str) -> ParserResult<'_> { - partiql_parser::Parser::default().parse(statement) -} +mod common; #[track_caller] #[inline] fn pretty_print_test(name: &str, statement: &str) { - let res = parse(statement); + let res = common::parse(statement); assert!(res.is_ok()); let res = res.unwrap(); diff --git a/partiql/tests/tuple_ops.rs b/partiql/tests/tuple_ops.rs index 37112401..97af119e 100644 --- a/partiql/tests/tuple_ops.rs +++ b/partiql/tests/tuple_ops.rs @@ -1,113 +1,22 @@ +use crate::common::{eval_query_with_catalog, TestError}; use assert_matches::assert_matches; -use partiql_ast_passes::error::AstTransformationError; -use partiql_catalog::catalog::{Catalog, PartiqlCatalog}; -use partiql_catalog::context::SystemContext; +use partiql_catalog::catalog::PartiqlCatalog; use partiql_catalog::extension::Extension; -use partiql_eval as eval; -use partiql_eval::env::basic::MapBindings; -use partiql_eval::error::{EvalErr, PlanErr}; -use partiql_eval::eval::{BasicContext, EvalPlan, EvalResult, Evaluated}; +use partiql_eval::eval::Evaluated; use partiql_eval::plan::EvaluationMode; use partiql_extension_value_functions::PartiqlValueFnExtension; -use partiql_logical as logical; -use partiql_parser::{Parsed, ParserError, ParserResult}; -use partiql_value::{DateTime, Value}; -use std::error::Error; -use thiserror::Error; +use partiql_value::Value; -#[derive(Error, Debug)] -enum TestError<'a> { - #[error("Parse error: {0:?}")] - Parse(ParserError<'a>), - #[error("Lower error: {0:?}")] - Lower(AstTransformationError), - #[error("Plan error: {0:?}")] - Plan(PlanErr), - #[error("Evaluation error: {0:?}")] - Eval(EvalErr), - #[error("Other: {0:?}")] - Other(Box), -} - -impl<'a> From> for TestError<'a> { - fn from(err: ParserError<'a>) -> Self { - TestError::Parse(err) - } -} - -impl From for TestError<'_> { - fn from(err: AstTransformationError) -> Self { - TestError::Lower(err) - } -} - -impl From for TestError<'_> { - fn from(err: PlanErr) -> Self { - TestError::Plan(err) - } -} - -impl From for TestError<'_> { - fn from(err: EvalErr) -> Self { - TestError::Eval(err) - } -} - -impl From> for TestError<'_> { - fn from(err: Box) -> Self { - TestError::Other(err) - } -} - -#[track_caller] -#[inline] -fn parse(statement: &str) -> ParserResult<'_> { - partiql_parser::Parser::default().parse(statement) -} - -#[track_caller] -#[inline] -fn lower( - catalog: &dyn Catalog, - parsed: &Parsed<'_>, -) -> Result, AstTransformationError> { - let planner = partiql_logical_planner::LogicalPlanner::new(catalog); - planner.lower(parsed) -} - -#[track_caller] -#[inline] -fn compile( - mode: EvaluationMode, - catalog: &dyn Catalog, - logical: logical::LogicalPlan, -) -> Result { - let mut planner = eval::plan::EvaluatorPlanner::new(mode, catalog); - planner.compile(&logical) -} - -#[track_caller] -#[inline] -fn evaluate(mut plan: EvalPlan, bindings: MapBindings) -> EvalResult { - let sys = SystemContext { - now: DateTime::from_system_now_utc(), - }; - let ctx = BasicContext::new(bindings, sys); - plan.execute_mut(&ctx) -} +mod common; #[track_caller] #[inline] -fn eval(statement: &str, mode: EvaluationMode) -> Result> { +pub fn eval(statement: &str, mode: EvaluationMode) -> Result> { let mut catalog = PartiqlCatalog::default(); let ext = PartiqlValueFnExtension::default(); ext.load(&mut catalog)?; - let parsed = parse(statement)?; - let lowered = lower(&catalog, &parsed)?; - let bindings = Default::default(); - let plan = compile(mode, &catalog, lowered)?; - Ok(evaluate(plan, bindings)?) + eval_query_with_catalog(statement, &catalog, mode) } #[test] diff --git a/partiql/tests/user_context.rs b/partiql/tests/user_context.rs index 079acfb4..6a87dffa 100644 --- a/partiql/tests/user_context.rs +++ b/partiql/tests/user_context.rs @@ -16,11 +16,12 @@ use partiql_catalog::table_fn::{ use partiql_eval::env::basic::MapBindings; use partiql_eval::eval::BasicContext; use partiql_eval::plan::EvaluationMode; -use partiql_parser::{Parsed, ParserResult}; use partiql_value::{bag, tuple, DateTime, Value}; +use crate::common::{lower, parse, TestError}; use partiql_logical as logical; +mod common; #[derive(Debug)] pub struct UserCtxTestExtension {} @@ -141,22 +142,6 @@ pub struct Counter { data: RefCell, } -#[track_caller] -#[inline] -pub(crate) fn parse(statement: &str) -> ParserResult { - partiql_parser::Parser::default().parse(statement) -} - -#[track_caller] -#[inline] -pub(crate) fn lower( - catalog: &dyn Catalog, - parsed: &Parsed<'_>, -) -> partiql_logical::LogicalPlan { - let planner = partiql_logical_planner::LogicalPlanner::new(catalog); - planner.lower(parsed).expect("lower") -} - #[track_caller] #[inline] pub(crate) fn evaluate( @@ -183,8 +168,9 @@ pub(crate) fn evaluate( Value::Missing } } + #[test] -fn test_context() { +fn test_context() -> Result<(), TestError<'static>> { let expected: Value = bag![ tuple![("foo", 1), ("bar", "id_1")], tuple![("foo", 0), ("bar", "id_2")], @@ -201,7 +187,7 @@ fn test_context() { ext.load(&mut catalog).expect("extension load to succeed"); let parsed = parse(query); - let lowered = lower(&catalog, &parsed.expect("parse")); + let lowered = lower(&catalog, &parsed.expect("parse"))?; let bindings = Default::default(); let counter = Counter { @@ -213,4 +199,6 @@ fn test_context() { assert!(out.is_bag()); assert_eq!(&out, &expected); assert_eq!(*counter.data.borrow(), 0); + + Ok(()) }