Skip to content

Commit

Permalink
refactor!: eradicate entire use of Expression
Browse files Browse the repository at this point in the history
  • Loading branch information
varshith257 committed Jan 12, 2025
1 parent b24d761 commit 1bc4808
Show file tree
Hide file tree
Showing 24 changed files with 533 additions and 204 deletions.
29 changes: 29 additions & 0 deletions crates/proof-of-sql-parser/src/sqlparser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ use alloc::{
vec,
};
use core::fmt::Display;
use serde::{Deserialize, Serialize};
use sqlparser::ast::{
BinaryOperator, DataType, Expr, Function, FunctionArg, FunctionArgExpr, GroupByExpr, Ident,
ObjectName, Offset, OffsetRows, OrderByExpr, Query, Select, SelectItem, SetExpr, TableFactor,
Expand All @@ -33,6 +34,34 @@ fn id(id: Identifier) -> Expr {
Expr::Identifier(id.into())
}

#[must_use]
/// New `AliasedResultExpr` using sqlparser types
/// Represents an aliased SQL expression, e.g., `a + 1 AS alias`.
#[derive(Serialize, Deserialize, Debug, PartialEq, Eq, Clone)]
pub struct SqlAliasedResultExpr {
/// The SQL expression being aliased (e.g., `a + 1`).
pub expr: Box<Expr>,
/// The alias for the expression (e.g., `alias` in `a + 1 AS alias`).
pub alias: Ident,
}

impl SqlAliasedResultExpr {
/// Creates a new `SqlAliasedResultExpr`.
pub fn new(expr: Box<Expr>, alias: Ident) -> Self {
Self { expr, alias }
}

/// Try to get the identifier of the expression if it is a column
/// Otherwise, return None
#[must_use]
pub fn try_as_identifier(&self) -> Option<&Ident> {
match self.expr.as_ref() {
Expr::Identifier(identifier) => Some(identifier),
_ => None,
}
}
}

/// Provides an extension for the `TimezoneInfo` type for offsets.
pub trait TimezoneInfoExt {
/// Retrieve the offset in seconds for `TimezoneInfo`.
Expand Down
169 changes: 169 additions & 0 deletions crates/proof-of-sql/src/base/database/expr_utility.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
use proof_of_sql_parser::{intermediate_ast::Literal, sqlparser::SqlAliasedResultExpr};
use sqlparser::ast::{
BinaryOperator, Expr, Function, FunctionArg, FunctionArgExpr, Ident, ObjectName, UnaryOperator,
};

/// Compute the sum of an expression
#[must_use]
pub fn sum(expr: Expr) -> Expr {
Expr::Function(Function {
name: ObjectName(vec![Ident::new("SUM")]),
args: vec![FunctionArg::Unnamed(FunctionArgExpr::Expr(*Box::new(expr)))],
filter: None,
null_treatment: None,
over: None,
distinct: false,
special: false,
order_by: vec![],
})
}

/// Get column from name
///
/// # Panics
///
/// This function will panic if the name cannot be parsed into a valid column expression as valid [Identifier]s.
#[must_use]
pub fn col(name: &str) -> Expr {
Expr::Identifier(name.into())
}

/// Compute the maximum of an expression
#[must_use]
pub fn max(expr: Expr) -> Expr {
Expr::Function(Function {
name: ObjectName(vec![Ident::new("MAX")]),
args: vec![FunctionArg::Unnamed(FunctionArgExpr::Expr(*Box::new(expr)))],
filter: None,
null_treatment: None,
over: None,
distinct: false,
special: false,
order_by: vec![],
})
}

/// Construct a new `Expr` A + B
#[must_use]
pub fn add(left: Expr, right: Expr) -> Expr {
Expr::BinaryOp {
op: BinaryOperator::Plus,
left: Box::new(left),
right: Box::new(right),
}
}

/// Construct a new `Expr` A - B
#[must_use]
pub fn sub(left: Expr, right: Expr) -> Expr {
Expr::BinaryOp {
op: BinaryOperator::Minus,
left: Box::new(left),
right: Box::new(right),
}
}

/// Get literal from value
pub fn lit<L>(literal: L) -> Expr
where
L: Into<Literal>,
{
Expr::from(literal.into())
}

/// Count the amount of non-null entries of an expression
#[must_use]
pub fn count(expr: Expr) -> Expr {
Expr::Function(Function {
name: ObjectName(vec![Ident::new("COUNT")]),
args: vec![FunctionArg::Unnamed(FunctionArgExpr::Expr(*Box::new(expr)))],
filter: None,
null_treatment: None,
over: None,
distinct: false,
special: false,
order_by: vec![],
})
}

/// Count the rows
#[must_use]
pub fn count_all() -> Expr {
count(Expr::Wildcard)
}

/// Construct a new `Expr` representing A * B
#[must_use]
pub fn mul(left: Expr, right: Expr) -> Expr {
Expr::BinaryOp {
left: Box::new(left),
op: BinaryOperator::Multiply,
right: Box::new(right),
}
}

/// Compute the minimum of an expression
#[must_use]
pub fn min(expr: Expr) -> Expr {
Expr::Function(Function {
name: ObjectName(vec![Ident::new("MIN")]),
args: vec![FunctionArg::Unnamed(FunctionArgExpr::Expr(*Box::new(expr)))],
filter: None,
null_treatment: None,
over: None,
distinct: false,
special: false,
order_by: vec![],
})
}

/// Construct a new `Expr` for NOT P
#[must_use]
pub fn not(expr: Expr) -> Expr {
Expr::UnaryOp {
op: UnaryOperator::Not,
expr: Box::new(expr),
}
}

/// Construct a new `Expr` for A >= B
#[must_use]
pub fn ge(left: Expr, right: Expr) -> Expr {
Expr::BinaryOp {
left: Box::new(left),
op: BinaryOperator::GtEq,
right: Box::new(right),
}
}

/// Construct a new `Expr` for A == B
#[must_use]
pub fn equal(left: Expr, right: Expr) -> Expr {
Expr::BinaryOp {
left: Box::new(left),
op: BinaryOperator::Eq,
right: Box::new(right),
}
}

/// Construct a new `Expr` for P OR Q
#[must_use]
pub fn or(left: Expr, right: Expr) -> Expr {
Expr::BinaryOp {
left: Box::new(left),
op: BinaryOperator::Or,
right: Box::new(right),
}
}

/// An expression with an alias, i.e., EXPR AS ALIAS
///
/// # Panics
///
/// This function will panic if the `alias` cannot be parsed as a valid [Identifier].
pub fn aliased_expr(expr: Expr, alias: &str) -> SqlAliasedResultExpr {
SqlAliasedResultExpr {
expr: Box::new(expr),
alias: Ident::new(alias),
}
}
42 changes: 24 additions & 18 deletions crates/proof-of-sql/src/base/database/expression_evaluation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,8 @@ impl<S: Scalar> OwnedTable<S> {
match expr {
Expr::Identifier(ident) => self.evaluate_column(ident),
Expr::Value(_) | Expr::TypedString { .. } => self.evaluate_literal(expr),
Expr::BinaryOp { op, left, right } => {
self.evaluate_binary_expr(&(*op).clone().into(), left, right)
}
Expr::UnaryOp { op, expr } => self.evaluate_unary_expr((*op).into(), expr),
Expr::BinaryOp { op, left, right } => self.evaluate_binary_expr(op, left, right),
Expr::UnaryOp { op, expr } => self.evaluate_unary_expr(*op, expr),
_ => Err(ExpressionEvaluationError::Unsupported {
expression: format!("Expression {expr:?} is not supported yet"),
}),
Expand All @@ -36,7 +34,13 @@ impl<S: Scalar> OwnedTable<S> {
})?
.clone())
}

/// Evaluates a literal expression and returns its corresponding column representation.
///
/// # Panics
///
/// This function will panic if:
/// - `BigDecimal::parse_bytes` fails to parse a valid decimal string.
/// - `Precision::try_from` fails due to invalid precision or scale values.
fn evaluate_literal(&self, value: &Expr) -> ExpressionEvaluationResult<OwnedColumn<S>> {
let len = self.num_rows();
match value {
Expand All @@ -47,8 +51,8 @@ impl<S: Scalar> OwnedTable<S> {
.map_err(|_| DecimalError::InvalidDecimal {
error: format!("Invalid number: {n}"),
})?;
if num >= i64::MIN as i128 && num <= i64::MAX as i128 {
Ok(OwnedColumn::BigInt(vec![num as i64; len]))
if num >= i128::from(i64::MIN) && num <= i128::from(i64::MAX) {
Ok(OwnedColumn::BigInt(vec![num.try_into().unwrap(); len]))
} else {
Ok(OwnedColumn::Int128(vec![num; len]))
}
Expand All @@ -57,18 +61,20 @@ impl<S: Scalar> OwnedTable<S> {
Ok(OwnedColumn::VarChar(vec![s.clone(); len]))
}
Expr::TypedString { data_type, value } => match data_type {
DataType::Decimal(ExactNumberInfo::PrecisionAndScale(precision, scale)) => {
DataType::Decimal(ExactNumberInfo::PrecisionAndScale(raw_precision, raw_scale)) => {
let decimal = BigDecimal::parse_bytes(value.as_bytes(), 10).unwrap();
let scalar = try_convert_intermediate_decimal_to_scalar(
&decimal,
Precision::try_from(*precision as u64)?,
*scale as i8,
)?;
Ok(OwnedColumn::Decimal75(
Precision::try_from(*precision as u64)?,
*scale as i8,
vec![scalar; len],
))
let precision = Precision::try_from(*raw_precision).map_err(|_| {
DecimalError::InvalidPrecision {
error: raw_precision.to_string(),
}
})?;
let scale =
i8::try_from(*raw_scale).map_err(|_| DecimalError::InvalidScale {
scale: raw_scale.to_string(),
})?;
let scalar =
try_convert_intermediate_decimal_to_scalar(&decimal, precision, scale)?;
Ok(OwnedColumn::Decimal75(precision, scale, vec![scalar; len]))
}
DataType::Timestamp(Some(time_unit), time_zone) => {
let time_unit = PoSQLTimeUnit::from_precision(*time_unit).map_err(|err| {
Expand Down
3 changes: 3 additions & 0 deletions crates/proof-of-sql/src/base/database/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ mod slice_operation;

mod slice_decimal_operation;

/// util functions for `Expr` tests
pub mod expr_utility;

mod column_type_operation;
pub use column_type_operation::{
try_add_subtract_column_types, try_divide_column_types, try_multiply_column_types,
Expand Down
2 changes: 2 additions & 0 deletions crates/proof-of-sql/src/base/math/big_decimal_ext.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use bigdecimal::BigDecimal;
use num_bigint::BigInt;

pub trait BigDecimalExt {
#[allow(dead_code)]
fn precision(&self) -> u64;
fn scale(&self) -> i64;
fn try_into_bigint_with_precision_and_scale(
Expand All @@ -14,6 +15,7 @@ pub trait BigDecimalExt {
impl BigDecimalExt for BigDecimal {
/// Get the precision of the fixed-point representation of this intermediate decimal.
#[must_use]
#[allow(dead_code)]
fn precision(&self) -> u64 {
self.normalized().digits()
}
Expand Down
1 change: 1 addition & 0 deletions crates/proof-of-sql/src/base/math/i256.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use crate::base::scalar::Scalar;
use alloc::vec::Vec;
use ark_ff::BigInteger;
use serde::{Deserialize, Serialize};
use std::fmt;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use alloc::string::String;
use snafu::Snafu;

/// Errors that can occur during proof plan serialization.
Expand Down
26 changes: 8 additions & 18 deletions crates/proof-of-sql/src/sql/parse/dyn_proof_expr_builder.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,9 @@
use super::ConversionError;
use crate::{
base::{database::ColumnRef, map::IndexMap, math::i256::I256},
sql::{
// parse::{
// dyn_proof_expr_builder::DecimalError::{InvalidPrecision, InvalidScale},
// ConversionError::DecimalConversionError,
// },
proof_exprs::{ColumnExpr, DynProofExpr, ProofExpr},
},
sql::proof_exprs::{ColumnExpr, DynProofExpr, ProofExpr},
};
use alloc::{boxed::Box, format, string::ToString};
use alloc::{boxed::Box, format, string::ToString, vec::Vec};
use proof_of_sql_parser::posql_time::PoSQLTimeUnit;
use sqlparser::ast::{
BinaryOperator, DataType, ExactNumberInfo, Expr, FunctionArg, FunctionArgExpr, Ident,
Expand Down Expand Up @@ -56,10 +50,10 @@ impl DynProofExprBuilder<'_> {
}
Expr::UnaryOp { op, expr } => self.visit_unary_expr(*op, expr.as_ref()),
Expr::Function(function) => {
if let Some(first_arg) = function.args.get(0) {
if let FunctionArg::Unnamed(FunctionArgExpr::Expr(inner_expr)) = first_arg {
return self.visit_aggregate_expr(function.name.to_string(), inner_expr);
}
if let Some(FunctionArg::Unnamed(FunctionArgExpr::Expr(inner_expr))) =
function.args.first()
{
return self.visit_aggregate_expr(&function.name.to_string(), inner_expr);
}
Err(ConversionError::Unprovable {
error: format!("Function {function:?} has unsupported arguments"),
Expand Down Expand Up @@ -242,19 +236,15 @@ impl DynProofExprBuilder<'_> {
}
}

fn visit_aggregate_expr(
&self,
op: String,
expr: &Expr,
) -> Result<DynProofExpr, ConversionError> {
fn visit_aggregate_expr(&self, op: &str, expr: &Expr) -> Result<DynProofExpr, ConversionError> {
if self.in_agg_scope {
return Err(ConversionError::InvalidExpression {
expression: "nested aggregations are invalid".to_string(),
});
}
let expr = DynProofExprBuilder::new_agg(self.column_mapping).visit_expr(expr)?;

match (op.as_str(), expr.data_type().is_numeric()) {
match (op, expr.data_type().is_numeric()) {
("COUNT", _) | ("SUM", true) => Ok(DynProofExpr::new_aggregate(op, expr)?),
("SUM", false) => Err(ConversionError::InvalidExpression {
expression: format!(
Expand Down
Loading

0 comments on commit 1bc4808

Please sign in to comment.