Skip to content

Commit

Permalink
Add validation for comparison operations (#505)
Browse files Browse the repository at this point in the history
  • Loading branch information
jpschorr authored Oct 11, 2024
1 parent 36b90c5 commit 3f9d17f
Show file tree
Hide file tree
Showing 12 changed files with 387 additions and 163 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [Unreleased]
### Changed
- *BREAKING* partiql-parser: Added a source location to `ParseError::UnexpectedEndOfInput`
- partiql-eval: Fixed behavior of comparison and `BETWEEN` operations w.r.t. type mismatches

### Added
- partiql-value: Pretty-printing of `Value` via `ToPretty` trait
Expand Down
33 changes: 30 additions & 3 deletions partiql-eval/src/eval/eval_expr_wrapper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,13 +94,24 @@ pub(crate) enum ArgCheckControlFlow<B, C, R = B> {
Propagate(R),
}

#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub(crate) struct ArgValidateError {
pub(crate) message: String,
pub(crate) propagate: Value,
}

/// A type which performs argument checking during evaluation.
pub(crate) trait ArgChecker: Debug {
/// Check an argument against an expected type.
fn arg_check<'a>(
typ: &PartiqlShape,
arg: Cow<'a, Value>,
) -> ArgCheckControlFlow<Value, Cow<'a, Value>>;

/// Validate all arguments.
fn validate_args(args: Vec<Cow<'_, Value>>) -> Result<Vec<Cow<'_, Value>>, ArgValidateError> {
Ok(args)
}
}

/// How to handle argument mismatch and `MISSING` propagation
Expand Down Expand Up @@ -273,7 +284,12 @@ impl<const STRICT: bool, const N: usize, E: ExecuteEvalExpr<N>, ArgC: ArgChecker
ControlFlow::Break(Missing)
};

match evaluate_args::<{ STRICT }, ArgC, _>(&self.args, |n| &self.types[n], bindings, ctx) {
match evaluate_and_validate_args::<{ STRICT }, ArgC, _>(
&self.args,
|n| &self.types[n],
bindings,
ctx,
) {
ControlFlow::Continue(result) => match result.try_into() {
Ok(a) => ControlFlow::Continue(a),
Err(args) => err_arg_count_mismatch(args),
Expand All @@ -283,7 +299,7 @@ impl<const STRICT: bool, const N: usize, E: ExecuteEvalExpr<N>, ArgC: ArgChecker
}
}

pub(crate) fn evaluate_args<
pub(crate) fn evaluate_and_validate_args<
'a,
'c,
't,
Expand Down Expand Up @@ -352,7 +368,18 @@ where
ControlFlow::Break(v)
} else {
// If `propagate` is `None`, then return result
ControlFlow::Continue(result)
match ArgC::validate_args(result) {
Ok(result) => ControlFlow::Continue(result),
Err(err) => {
if STRICT {
ctx.add_error(EvaluationError::IllegalState(format!(
"Arguments failed validation: {}",
err.message
)))
}
ControlFlow::Break(err.propagate)
}
}
}
}

Expand Down
11 changes: 9 additions & 2 deletions partiql-eval/src/eval/expr/functions.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
use crate::eval::eval_expr_wrapper::{evaluate_args, DefaultArgChecker, PropagateMissing};
use crate::eval::eval_expr_wrapper::{
evaluate_and_validate_args, DefaultArgChecker, PropagateMissing,
};

use crate::eval::expr::{BindError, BindEvalExpr, EvalExpr};
use crate::eval::EvalContext;
Expand Down Expand Up @@ -41,7 +43,12 @@ impl<const STRICT: bool> EvalExpr for EvalExprFnScalar<STRICT> {
{
type Check<const STRICT: bool> = DefaultArgChecker<STRICT, PropagateMissing<true>>;
let typ = PartiqlShapeBuilder::init_or_get().new_struct(StructType::new_any());
match evaluate_args::<{ STRICT }, Check<STRICT>, _>(&self.args, |_| &typ, bindings, ctx) {
match evaluate_and_validate_args::<{ STRICT }, Check<STRICT>, _>(
&self.args,
|_| &typ,
bindings,
ctx,
) {
ControlFlow::Break(v) => Cow::Owned(v),
ControlFlow::Continue(args) => match self.plan.evaluate(&args, ctx.as_session()) {
Ok(v) => v,
Expand Down
82 changes: 73 additions & 9 deletions partiql-eval/src/eval/expr/operators.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::eval::eval_expr_wrapper::{
ArgCheckControlFlow, ArgChecker, ArgShortCircuit, BinaryValueExpr, DefaultArgChecker,
ExecuteEvalExpr, NullArgChecker, PropagateMissing, PropagateNull, TernaryValueExpr,
UnaryValueExpr,
ArgCheckControlFlow, ArgChecker, ArgShortCircuit, ArgValidateError, BinaryValueExpr,
DefaultArgChecker, ExecuteEvalExpr, NullArgChecker, PropagateMissing, PropagateNull,
TernaryValueExpr, UnaryValueExpr,
};

use crate::eval::expr::{BindError, BindEvalExpr, EvalExpr};
Expand All @@ -12,7 +12,7 @@ use partiql_types::{
Static, StructType,
};
use partiql_value::Value::{Boolean, Missing, Null};
use partiql_value::{BinaryAnd, EqualityValue, NullableEq, NullableOrd, Tuple, Value};
use partiql_value::{BinaryAnd, Comparable, EqualityValue, NullableEq, NullableOrd, Tuple, Value};

use std::borrow::{Borrow, Cow};
use std::fmt::{Debug, Formatter};
Expand Down Expand Up @@ -146,6 +146,34 @@ impl<const TARGET: bool, OnMissing: ArgShortCircuit> ArgChecker
}
}

#[derive(Debug)]
pub(crate) struct ComparisonArgChecker<const STRICT: bool, OnMissing: ArgShortCircuit> {
check: PhantomData<DefaultArgChecker<STRICT, OnMissing>>,
}

impl<const STRICT: bool, OnMissing: ArgShortCircuit> ArgChecker
for ComparisonArgChecker<STRICT, OnMissing>
{
#[inline]
fn arg_check<'a>(
typ: &PartiqlShape,
arg: Cow<'a, Value>,
) -> ArgCheckControlFlow<Value, Cow<'a, Value>> {
DefaultArgChecker::<{ STRICT }, OnMissing>::arg_check(typ, arg)
}

fn validate_args(args: Vec<Cow<'_, Value>>) -> Result<Vec<Cow<'_, Value>>, ArgValidateError> {
if args.len() == 2 && args[0].is_comparable_to(&args[1]) {
Ok(args)
} else {
Err(ArgValidateError {
message: "data-type mismatch".to_string(),
propagate: OnMissing::propagate(),
})
}
}
}

impl BindEvalExpr for EvalOpBinary {
#[inline]
fn bind<const STRICT: bool>(
Expand All @@ -157,6 +185,7 @@ impl BindEvalExpr for EvalOpBinary {
type InCheck<const STRICT: bool> = DefaultArgChecker<STRICT, PropagateNull<false>>;
type Check<const STRICT: bool> = DefaultArgChecker<STRICT, PropagateMissing<true>>;
type EqCheck<const STRICT: bool> = DefaultArgChecker<STRICT, PropagateMissing<false>>;
type CompCheck<const STRICT: bool> = ComparisonArgChecker<STRICT, PropagateMissing<true>>;
type MathCheck<const STRICT: bool> = DefaultArgChecker<STRICT, PropagateMissing<true>>;

macro_rules! create {
Expand All @@ -177,6 +206,12 @@ impl BindEvalExpr for EvalOpBinary {
};
}

macro_rules! comparison {
($f:expr) => {
create!(CompCheck<STRICT>, [type_dynamic!(), type_dynamic!()], $f)
};
}

macro_rules! math {
($f:expr) => {{
let nums = PartiqlShapeBuilder::init_or_get().any_of(type_numeric!());
Expand All @@ -195,10 +230,10 @@ impl BindEvalExpr for EvalOpBinary {
let wrap = EqualityValue::<false, Value>;
NullableEq::neq(&wrap(lhs), &wrap(rhs))
}),
EvalOpBinary::Gt => equality!(NullableOrd::gt),
EvalOpBinary::Gteq => equality!(NullableOrd::gteq),
EvalOpBinary::Lt => equality!(NullableOrd::lt),
EvalOpBinary::Lteq => equality!(NullableOrd::lteq),
EvalOpBinary::Gt => comparison!(NullableOrd::gt),
EvalOpBinary::Gteq => comparison!(NullableOrd::gteq),
EvalOpBinary::Lt => comparison!(NullableOrd::lt),
EvalOpBinary::Lteq => comparison!(NullableOrd::lteq),
EvalOpBinary::Add => math!(|lhs, rhs| lhs + rhs),
EvalOpBinary::Sub => math!(|lhs, rhs| lhs - rhs),
EvalOpBinary::Mul => math!(|lhs, rhs| lhs * rhs),
Expand Down Expand Up @@ -275,6 +310,35 @@ impl BindEvalExpr for EvalOpBinary {
}
}

#[derive(Debug)]
pub(crate) struct BetweenArgChecker<const STRICT: bool> {
check: PhantomData<NullArgChecker>,
}

impl<const STRICT: bool> ArgChecker for BetweenArgChecker<STRICT> {
#[inline]
fn arg_check<'a>(
typ: &PartiqlShape,
arg: Cow<'a, Value>,
) -> ArgCheckControlFlow<Value, Cow<'a, Value>> {
NullArgChecker::arg_check(typ, arg)
}

fn validate_args(args: Vec<Cow<'_, Value>>) -> Result<Vec<Cow<'_, Value>>, ArgValidateError> {
if args.len() == 3
&& args[0].is_comparable_to(&args[1])
&& args[0].is_comparable_to(&args[2])
{
Ok(args)
} else {
Err(ArgValidateError {
message: "data-type mismatch".to_string(),
propagate: Value::Missing,
})
}
}
}

/// Represents an evaluation `PartiQL` `BETWEEN` operator, e.g. `x BETWEEN 10 AND 20`.
#[derive(Debug, Default, Clone)]
pub(crate) struct EvalBetweenExpr {}
Expand All @@ -285,7 +349,7 @@ impl BindEvalExpr for EvalBetweenExpr {
args: Vec<Box<dyn EvalExpr>>,
) -> Result<Box<dyn EvalExpr>, BindError> {
let types = [type_dynamic!(), type_dynamic!(), type_dynamic!()];
TernaryValueExpr::create_checked::<{ STRICT }, NullArgChecker, _>(
TernaryValueExpr::create_checked::<{ STRICT }, BetweenArgChecker<{ STRICT }>, _>(
types,
args,
|value, from, to| value.gteq(from).and(&value.lteq(to)),
Expand Down
3 changes: 3 additions & 0 deletions partiql-eval/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -741,6 +741,9 @@ mod tests {
// left part of AND evaluates to false
eval_between_op(Value::from(1), Value::from(2), Null, Value::from(false));
eval_between_op(Value::from(1), Value::from(2), Missing, Value::from(false));
// right part of AND evaluates to false
eval_between_op(Value::from(2), Null, Value::from(1), Value::from(false));
eval_between_op(Value::from(2), Missing, Value::from(1), Value::from(false));
}

#[test]
Expand Down
12 changes: 11 additions & 1 deletion partiql-value/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use std::cmp::Ordering;

use std::borrow::Cow;

use std::fmt::{Debug, Formatter};
use std::fmt::{Debug, Display, Formatter};
use std::hash::Hash;

use std::iter::Once;
Expand All @@ -27,6 +27,7 @@ pub use list::*;
pub use pretty::*;
pub use tuple::*;

use partiql_common::pretty::ToPretty;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};

Expand Down Expand Up @@ -60,6 +61,15 @@ pub enum Value {
// TODO: add other supported PartiQL values -- sexp
}

impl Display for Value {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self.to_pretty_string(f.width().unwrap_or(80)) {
Ok(pretty) => f.write_str(&pretty),
Err(_) => f.write_str("<internal value error occurred>"),
}
}
}

impl ops::Add for &Value {
type Output = Value;

Expand Down
115 changes: 115 additions & 0 deletions partiql/tests/common.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
use partiql_ast_passes::error::AstTransformationError;
use partiql_catalog::catalog::{Catalog, PartiqlCatalog};
use partiql_catalog::context::SystemContext;
use partiql_eval as eval;
use partiql_eval::env::basic::MapBindings;
use partiql_eval::error::{EvalErr, PlanErr};
use partiql_eval::eval::{BasicContext, EvalPlan, EvalResult, Evaluated};
use partiql_eval::plan::EvaluationMode;
use partiql_logical as logical;
use partiql_parser::{Parsed, ParserError, ParserResult};
use partiql_value::{DateTime, Value};
use std::error::Error;
use thiserror::Error;

#[derive(Error, Debug)]
pub enum TestError<'a> {
#[error("Parse error: {0:?}")]
Parse(ParserError<'a>),
#[error("Lower error: {0:?}")]
Lower(AstTransformationError),
#[error("Plan error: {0:?}")]
Plan(PlanErr),
#[error("Evaluation error: {0:?}")]
Eval(EvalErr),
#[error("Other: {0:?}")]
Other(Box<dyn Error>),
}

impl<'a> From<ParserError<'a>> for TestError<'a> {
fn from(err: ParserError<'a>) -> Self {
TestError::Parse(err)
}
}

impl From<AstTransformationError> for TestError<'_> {
fn from(err: AstTransformationError) -> Self {
TestError::Lower(err)
}
}

impl From<PlanErr> for TestError<'_> {
fn from(err: PlanErr) -> Self {
TestError::Plan(err)
}
}

impl From<EvalErr> for TestError<'_> {
fn from(err: EvalErr) -> Self {
TestError::Eval(err)
}
}

impl From<Box<dyn Error>> for TestError<'_> {
fn from(err: Box<dyn Error>) -> Self {
TestError::Other(err)
}
}

#[track_caller]
#[inline]
pub fn parse(statement: &str) -> ParserResult<'_> {
partiql_parser::Parser::default().parse(statement)
}

#[track_caller]
#[inline]
pub fn lower(
catalog: &dyn Catalog,
parsed: &Parsed<'_>,
) -> Result<logical::LogicalPlan<logical::BindingsOp>, AstTransformationError> {
let planner = partiql_logical_planner::LogicalPlanner::new(catalog);
planner.lower(parsed)
}

#[track_caller]
#[inline]
pub fn compile(
mode: EvaluationMode,
catalog: &dyn Catalog,
logical: logical::LogicalPlan<logical::BindingsOp>,
) -> Result<EvalPlan, PlanErr> {
let mut planner = eval::plan::EvaluatorPlanner::new(mode, catalog);
planner.compile(&logical)
}

#[track_caller]
#[inline]
pub fn evaluate(mut plan: EvalPlan, bindings: MapBindings<Value>) -> EvalResult {
let sys = SystemContext {
now: DateTime::from_system_now_utc(),
};
let ctx = BasicContext::new(bindings, sys);
plan.execute_mut(&ctx)
}

#[track_caller]
#[inline]
pub fn eval_query_with_catalog<'a>(
statement: &'a str,
catalog: &dyn Catalog,
mode: EvaluationMode,
) -> Result<Evaluated, TestError<'a>> {
let parsed = parse(statement)?;
let lowered = lower(catalog, &parsed)?;
let bindings = Default::default();
let plan = compile(mode, catalog, lowered)?;
Ok(evaluate(plan, bindings)?)
}

#[track_caller]
#[inline]
pub fn eval_query(statement: &str, mode: EvaluationMode) -> Result<Evaluated, TestError<'_>> {
let catalog = PartiqlCatalog::default();
eval_query_with_catalog(statement, &catalog, mode)
}
Loading

1 comment on commit 3f9d17f

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PartiQL (rust) Benchmark

Benchmark suite Current: 3f9d17f Previous: 36b90c5 Ratio
arith_agg-avg 765428 ns/iter (± 17728) 757550 ns/iter (± 14135) 1.01
arith_agg-avg_distinct 858820 ns/iter (± 20625) 846820 ns/iter (± 4815) 1.01
arith_agg-count 814540 ns/iter (± 16330) 804760 ns/iter (± 17690) 1.01
arith_agg-count_distinct 852711 ns/iter (± 3888) 838944 ns/iter (± 3186) 1.02
arith_agg-min 819503 ns/iter (± 5160) 810775 ns/iter (± 2512) 1.01
arith_agg-min_distinct 856574 ns/iter (± 4490) 843615 ns/iter (± 11284) 1.02
arith_agg-max 827442 ns/iter (± 1701) 817235 ns/iter (± 5860) 1.01
arith_agg-max_distinct 864722 ns/iter (± 1752) 854241 ns/iter (± 29133) 1.01
arith_agg-sum 821912 ns/iter (± 11050) 810277 ns/iter (± 3458) 1.01
arith_agg-sum_distinct 857937 ns/iter (± 2288) 844957 ns/iter (± 23645) 1.02
arith_agg-avg-count-min-max-sum 968498 ns/iter (± 45038) 961188 ns/iter (± 5157) 1.01
arith_agg-avg-count-min-max-sum-group_by 1221287 ns/iter (± 4423) 1204105 ns/iter (± 13924) 1.01
arith_agg-avg-count-min-max-sum-group_by-group_as 1838260 ns/iter (± 4472) 1832649 ns/iter (± 9959) 1.00
arith_agg-avg_distinct-count_distinct-min_distinct-max_distinct-sum_distinct 1266617 ns/iter (± 22698) 1262413 ns/iter (± 14368) 1.00
arith_agg-avg_distinct-count_distinct-min_distinct-max_distinct-sum_distinct-group_by 1556324 ns/iter (± 8387) 1534439 ns/iter (± 28757) 1.01
arith_agg-avg_distinct-count_distinct-min_distinct-max_distinct-sum_distinct-group_by-group_as 2188553 ns/iter (± 5721) 2116623 ns/iter (± 6642) 1.03
parse-1 6102 ns/iter (± 28) 6042 ns/iter (± 15) 1.01
parse-15 50272 ns/iter (± 88) 49793 ns/iter (± 149) 1.01
parse-30 99654 ns/iter (± 3473) 93069 ns/iter (± 355) 1.07
compile-1 4225 ns/iter (± 22) 4169 ns/iter (± 45) 1.01
compile-15 32652 ns/iter (± 260) 31690 ns/iter (± 115) 1.03
compile-30 67918 ns/iter (± 193) 66238 ns/iter (± 456) 1.03
plan-1 69724 ns/iter (± 423) 66874 ns/iter (± 402) 1.04
plan-15 1094861 ns/iter (± 9322) 1055270 ns/iter (± 20019) 1.04
plan-30 2191116 ns/iter (± 13765) 2112655 ns/iter (± 10183) 1.04
eval-1 13197485 ns/iter (± 249211) 13247704 ns/iter (± 391970) 1.00
eval-15 89259168 ns/iter (± 934982) 95034828 ns/iter (± 947484) 0.94
eval-30 171668148 ns/iter (± 423756) 182929663 ns/iter (± 493150) 0.94
join 10085 ns/iter (± 421) 9839 ns/iter (± 47) 1.03
simple 2543 ns/iter (± 9) 2590 ns/iter (± 13) 0.98
simple-no 475 ns/iter (± 1) 474 ns/iter (± 2) 1.00
numbers 48 ns/iter (± 0) 48 ns/iter (± 0) 1
parse-simple 937 ns/iter (± 4) 820 ns/iter (± 31) 1.14
parse-ion 2733 ns/iter (± 8) 2667 ns/iter (± 7) 1.02
parse-group 7922 ns/iter (± 23) 8040 ns/iter (± 43) 0.99
parse-complex 20662 ns/iter (± 208) 21267 ns/iter (± 231) 0.97
parse-complex-fexpr 28614 ns/iter (± 227) 28958 ns/iter (± 102) 0.99

This comment was automatically generated by workflow using github-action-benchmark.

Please sign in to comment.