diff --git a/partiql-catalog/src/lib.rs b/partiql-catalog/src/lib.rs index a88ca976..25ddba7d 100644 --- a/partiql-catalog/src/lib.rs +++ b/partiql-catalog/src/lib.rs @@ -3,7 +3,7 @@ use crate::call_defs::CallDef; -use partiql_types::PartiqlType; +use partiql_types::PartiqlShape; use partiql_value::Value; use std::borrow::Cow; @@ -128,12 +128,12 @@ pub trait Catalog: Debug { pub struct TypeEnvEntry<'a> { name: UniCase, aliases: Vec<&'a str>, - ty: PartiqlType, + ty: PartiqlShape, } impl<'a> TypeEnvEntry<'a> { #[must_use] - pub fn new(name: &str, aliases: &[&'a str], ty: PartiqlType) -> Self { + pub fn new(name: &str, aliases: &[&'a str], ty: PartiqlShape) -> Self { TypeEnvEntry { name: UniCase::from(name.to_string()), aliases: aliases.to_vec(), @@ -145,7 +145,7 @@ impl<'a> TypeEnvEntry<'a> { #[derive(Debug, Clone)] pub struct TypeEntry { id: ObjectId, - ty: PartiqlType, + ty: PartiqlShape, } impl TypeEntry { @@ -155,7 +155,7 @@ impl TypeEntry { } #[must_use] - pub fn ty(&self) -> &PartiqlType { + pub fn ty(&self) -> &PartiqlShape { &self.ty } } @@ -197,7 +197,7 @@ impl<'a> FunctionEntry<'a> { #[derive(Debug)] pub struct PartiqlCatalog { functions: CatalogEntrySet, - types: CatalogEntrySet, + types: CatalogEntrySet, id: CatalogId, } diff --git a/partiql-eval/src/eval/eval_expr_wrapper.rs b/partiql-eval/src/eval/eval_expr_wrapper.rs index 542936aa..7784e3c5 100644 --- a/partiql-eval/src/eval/eval_expr_wrapper.rs +++ b/partiql-eval/src/eval/eval_expr_wrapper.rs @@ -4,7 +4,7 @@ use crate::eval::expr::{BindError, EvalExpr}; use crate::eval::EvalContext; use itertools::Itertools; -use partiql_types::{PartiqlType, TypeKind, TYPE_ANY}; +use partiql_types::{PartiqlShape, StaticTypeVariant, TYPE_DYNAMIC}; use partiql_value::Value::{Missing, Null}; use partiql_value::{Tuple, Value}; @@ -18,36 +18,47 @@ use std::ops::ControlFlow; // TODO replace with type system's subsumption once it is in place #[inline] -pub(crate) fn subsumes(typ: &PartiqlType, value: &Value) -> bool { - match (typ.kind(), value) { +pub(crate) fn subsumes(typ: &PartiqlShape, value: &Value) -> bool { + match (typ, value) { (_, Value::Null) => true, (_, Value::Missing) => true, - (TypeKind::Any, _) => true, - (TypeKind::AnyOf(anyof), val) => anyof.types().any(|typ| subsumes(typ, val)), - ( - TypeKind::Int | TypeKind::Int8 | TypeKind::Int16 | TypeKind::Int32 | TypeKind::Int64, - Value::Integer(_), - ) => true, - (TypeKind::Bool, Value::Boolean(_)) => true, - (TypeKind::Decimal | TypeKind::DecimalP(_, _), Value::Decimal(_)) => true, - (TypeKind::Float32 | TypeKind::Float64, Value::Real(_)) => true, - ( - TypeKind::String | TypeKind::StringFixed(_) | TypeKind::StringVarying(_), - Value::String(_), - ) => true, - (TypeKind::Struct(_), Value::Tuple(_)) => true, - (TypeKind::Bag(b_type), Value::Bag(b_values)) => { - let bag_element_type = b_type.element_type(); - let mut b_values = b_values.iter(); - b_values.all(|b_value| subsumes(bag_element_type, b_value)) - } - (TypeKind::DateTime, Value::DateTime(_)) => true, + (PartiqlShape::Dynamic, _) => true, + (PartiqlShape::AnyOf(anyof), val) => anyof.types().any(|typ| subsumes(typ, val)), + (PartiqlShape::Static(s), val) => match (s.ty(), val) { + ( + StaticTypeVariant::Int + | StaticTypeVariant::Int8 + | StaticTypeVariant::Int16 + | StaticTypeVariant::Int32 + | StaticTypeVariant::Int64, + Value::Integer(_), + ) => true, + (StaticTypeVariant::Bool, Value::Boolean(_)) => true, + (StaticTypeVariant::Decimal | StaticTypeVariant::DecimalP(_, _), Value::Decimal(_)) => { + true + } + (StaticTypeVariant::Float32 | StaticTypeVariant::Float64, Value::Real(_)) => true, + ( + StaticTypeVariant::String + | StaticTypeVariant::StringFixed(_) + | StaticTypeVariant::StringVarying(_), + Value::String(_), + ) => true, + (StaticTypeVariant::Struct(_), Value::Tuple(_)) => true, + (StaticTypeVariant::Bag(b_type), Value::Bag(b_values)) => { + let bag_element_type = b_type.element_type(); + let mut b_values = b_values.iter(); + b_values.all(|b_value| subsumes(bag_element_type, b_value)) + } + (StaticTypeVariant::DateTime, Value::DateTime(_)) => true, - (TypeKind::Array(a_type), Value::List(l_values)) => { - let array_element_type = a_type.element_type(); - let mut l_values = l_values.iter(); - l_values.all(|l_value| subsumes(array_element_type, l_value)) - } + (StaticTypeVariant::Array(a_type), Value::List(l_values)) => { + let array_element_type = a_type.element_type(); + let mut l_values = l_values.iter(); + l_values.all(|l_value| subsumes(array_element_type, l_value)) + } + _ => false, + }, _ => false, } } @@ -95,7 +106,7 @@ pub(crate) enum ArgCheckControlFlow { pub(crate) trait ArgChecker: Debug { /// Check an argument against an expected type. fn arg_check<'a>( - typ: &PartiqlType, + typ: &PartiqlShape, arg: Cow<'a, Value>, ) -> ArgCheckControlFlow>; } @@ -158,7 +169,7 @@ impl ArgChecker for DefaultArgChecker { fn arg_check<'a>( - typ: &PartiqlType, + typ: &PartiqlShape, arg: Cow<'a, Value>, ) -> ArgCheckControlFlow> { let err = || { @@ -189,7 +200,7 @@ pub(crate) struct NullArgChecker {} impl ArgChecker for NullArgChecker { fn arg_check<'a>( - _typ: &PartiqlType, + _typ: &PartiqlShape, arg: Cow<'a, Value>, ) -> ArgCheckControlFlow> { ArgCheckControlFlow::Continue(arg) @@ -209,7 +220,7 @@ pub(crate) struct ArgCheckEvalExpr< ArgC: ArgChecker, > { /// The expected type of expression's positional arguments - pub(crate) types: [PartiqlType; N], + pub(crate) types: [PartiqlShape; N], /// The expression's positional arguments pub(crate) args: [Box; N], /// the expression @@ -237,7 +248,7 @@ impl, ArgC: ArgChecker impl, ArgC: ArgChecker> ArgCheckEvalExpr { - pub fn new(types: [PartiqlType; N], args: [Box; N], expr: E) -> Self { + pub fn new(types: [PartiqlShape; N], args: [Box; N], expr: E) -> Self { Self { types, args, @@ -298,11 +309,7 @@ impl, ArgC: ArgChecker ArgCheckControlFlow::ShortCircuit(v) => return ControlFlow::Break(v), ArgCheckControlFlow::ErrorOrShortCircuit(v) => { if STRICT { - let signature = self - .types - .iter() - .map(|typ| format!("{}", typ.kind())) - .join(","); + let signature = self.types.iter().map(|typ| format!("{}", typ)).join(","); let before = (0..i).map(|_| "_"); let arg = "MISSING"; // TODO display actual argument? let after = (i + 1..N).map(|_| "_"); @@ -368,7 +375,7 @@ impl EvalExprWrapper { #[inline] pub(crate) fn create_checked( ident: E, - types: [PartiqlType; N], + types: [PartiqlShape; N], args: Vec>, f: F, ) -> Result, BindError> @@ -414,13 +421,13 @@ impl UnaryValueExpr { where F: 'static + Fn(&Value) -> Value, { - Self::create_typed::([TYPE_ANY; 1], args, f) + Self::create_typed::([TYPE_DYNAMIC; 1], args, f) } #[allow(dead_code)] #[inline] pub(crate) fn create_typed( - types: [PartiqlType; 1], + types: [PartiqlShape; 1], args: Vec>, f: F, ) -> Result, BindError> @@ -434,7 +441,7 @@ impl UnaryValueExpr { #[allow(dead_code)] #[inline] pub(crate) fn create_checked( - types: [PartiqlType; 1], + types: [PartiqlShape; 1], args: Vec>, f: F, ) -> Result, BindError> @@ -478,13 +485,13 @@ impl BinaryValueExpr { where F: 'static + Fn(&Value, &Value) -> Value, { - Self::create_typed::([TYPE_ANY; 2], args, f) + Self::create_typed::([TYPE_DYNAMIC; 2], args, f) } #[allow(dead_code)] #[inline] pub(crate) fn create_typed( - types: [PartiqlType; 2], + types: [PartiqlShape; 2], args: Vec>, f: F, ) -> Result, BindError> @@ -498,7 +505,7 @@ impl BinaryValueExpr { #[allow(dead_code)] #[inline] pub(crate) fn create_checked( - types: [PartiqlType; 2], + types: [PartiqlShape; 2], args: Vec>, f: F, ) -> Result, BindError> @@ -542,13 +549,13 @@ impl TernaryValueExpr { where F: 'static + Fn(&Value, &Value, &Value) -> Value, { - Self::create_typed::([TYPE_ANY; 3], args, f) + Self::create_typed::([TYPE_DYNAMIC; 3], args, f) } #[allow(dead_code)] #[inline] pub(crate) fn create_typed( - types: [PartiqlType; 3], + types: [PartiqlShape; 3], args: Vec>, f: F, ) -> Result, BindError> @@ -562,7 +569,7 @@ impl TernaryValueExpr { #[allow(dead_code)] #[inline] pub(crate) fn create_checked( - types: [PartiqlType; 3], + types: [PartiqlShape; 3], args: Vec>, f: F, ) -> Result, BindError> @@ -611,13 +618,13 @@ impl QuaternaryValueExpr { where F: 'static + Fn(&Value, &Value, &Value, &Value) -> Value, { - Self::create_typed::([TYPE_ANY; 4], args, f) + Self::create_typed::([TYPE_DYNAMIC; 4], args, f) } #[allow(dead_code)] #[inline] pub(crate) fn create_typed( - types: [PartiqlType; 4], + types: [PartiqlShape; 4], args: Vec>, f: F, ) -> Result, BindError> @@ -631,7 +638,7 @@ impl QuaternaryValueExpr { #[allow(dead_code)] #[inline] pub(crate) fn create_checked( - types: [PartiqlType; 4], + types: [PartiqlShape; 4], args: Vec>, f: F, ) -> Result, BindError> diff --git a/partiql-eval/src/eval/expr/coll.rs b/partiql-eval/src/eval/expr/coll.rs index 5068e469..b18e3c8b 100644 --- a/partiql-eval/src/eval/expr/coll.rs +++ b/partiql-eval/src/eval/expr/coll.rs @@ -4,7 +4,9 @@ use crate::eval::expr::{BindError, BindEvalExpr, EvalExpr}; use itertools::{Itertools, Unique}; -use partiql_types::{ArrayType, BagType, PartiqlType, TypeKind, TYPE_BOOL, TYPE_NUMERIC_TYPES}; +use partiql_types::{ + ArrayType, BagType, PartiqlShape, StaticTypeVariant, TYPE_BOOL, TYPE_NUMERIC_TYPES, +}; use partiql_value::Value::{Missing, Null}; use partiql_value::{BinaryAnd, BinaryOr, Value, ValueIter}; @@ -38,7 +40,7 @@ impl BindEvalExpr for EvalCollFn { args: Vec>, ) -> Result, BindError> { fn create( - types: [PartiqlType; 1], + types: [PartiqlShape; 1], args: Vec>, f: F, ) -> Result, BindError> @@ -49,21 +51,23 @@ impl BindEvalExpr for EvalCollFn { value.sequence_iter().map_or(Missing, &f) }) } - let boolean_elems = [PartiqlType::any_of([ - PartiqlType::new(TypeKind::Array(ArrayType::new(Box::new(TYPE_BOOL)))), - PartiqlType::new(TypeKind::Bag(BagType::new(Box::new(TYPE_BOOL)))), + let boolean_elems = [PartiqlShape::any_of([ + PartiqlShape::new(StaticTypeVariant::Array(ArrayType::new(Box::new( + TYPE_BOOL, + )))), + PartiqlShape::new(StaticTypeVariant::Bag(BagType::new(Box::new(TYPE_BOOL)))), ])]; - let numeric_elems = [PartiqlType::any_of([ - PartiqlType::new(TypeKind::Array(ArrayType::new(Box::new( - PartiqlType::any_of(TYPE_NUMERIC_TYPES), + let numeric_elems = [PartiqlShape::any_of([ + PartiqlShape::new(StaticTypeVariant::Array(ArrayType::new(Box::new( + PartiqlShape::any_of(TYPE_NUMERIC_TYPES), + )))), + PartiqlShape::new(StaticTypeVariant::Bag(BagType::new(Box::new( + PartiqlShape::any_of(TYPE_NUMERIC_TYPES), )))), - PartiqlType::new(TypeKind::Bag(BagType::new(Box::new(PartiqlType::any_of( - TYPE_NUMERIC_TYPES, - ))))), ])]; - let any_elems = [PartiqlType::any_of([ - PartiqlType::new(TypeKind::Array(ArrayType::new_any())), - PartiqlType::new(TypeKind::Bag(BagType::new_any())), + let any_elems = [PartiqlShape::any_of([ + PartiqlShape::new(StaticTypeVariant::Array(ArrayType::new_any())), + PartiqlShape::new(StaticTypeVariant::Bag(BagType::new_any())), ])]; match *self { diff --git a/partiql-eval/src/eval/expr/operators.rs b/partiql-eval/src/eval/expr/operators.rs index 461e5762..c514aff0 100644 --- a/partiql-eval/src/eval/expr/operators.rs +++ b/partiql-eval/src/eval/expr/operators.rs @@ -8,7 +8,8 @@ use crate::eval::expr::{BindError, BindEvalExpr, EvalExpr}; use crate::eval::EvalContext; use partiql_types::{ - ArrayType, BagType, PartiqlType, StructType, TypeKind, TYPE_ANY, TYPE_BOOL, TYPE_NUMERIC_TYPES, + ArrayType, BagType, PartiqlShape, StaticTypeVariant, StructType, TYPE_BOOL, TYPE_DYNAMIC, + TYPE_NUMERIC_TYPES, }; use partiql_value::Value::{Boolean, Missing, Null}; use partiql_value::{BinaryAnd, EqualityValue, NullableEq, NullableOrd, Tuple, Value}; @@ -79,7 +80,7 @@ impl BindEvalExpr for EvalOpUnary { &self, args: Vec>, ) -> Result, BindError> { - let any_num = PartiqlType::any_of(TYPE_NUMERIC_TYPES); + let any_num = PartiqlShape::any_of(TYPE_NUMERIC_TYPES); let unop = |types, f: fn(&Value) -> Value| { UnaryValueExpr::create_typed::<{ STRICT }, _>(types, args, f) @@ -133,7 +134,7 @@ impl ArgChecker for BoolShortCircuitArgChecker { fn arg_check<'a>( - _typ: &PartiqlType, + _typ: &PartiqlShape, arg: Cow<'a, Value>, ) -> ArgCheckControlFlow> { match arg.borrow() { @@ -172,13 +173,13 @@ impl BindEvalExpr for EvalOpBinary { macro_rules! equality { ($f:expr) => { - create!(EqCheck, [TYPE_ANY, TYPE_ANY], $f) + create!(EqCheck, [TYPE_DYNAMIC, TYPE_DYNAMIC], $f) }; } macro_rules! math { ($f:expr) => {{ - let nums = PartiqlType::any_of(TYPE_NUMERIC_TYPES); + let nums = PartiqlShape::any_of(TYPE_NUMERIC_TYPES); create!(MathCheck, [nums.clone(), nums], $f) }}; } @@ -208,10 +209,10 @@ impl BindEvalExpr for EvalOpBinary { create!( InCheck, [ - TYPE_ANY, - PartiqlType::any_of([ - PartiqlType::new(TypeKind::Array(ArrayType::new_any())), - PartiqlType::new(TypeKind::Bag(BagType::new_any())), + TYPE_DYNAMIC, + PartiqlShape::any_of([ + PartiqlShape::new(StaticTypeVariant::Array(ArrayType::new_any())), + PartiqlShape::new(StaticTypeVariant::Bag(BagType::new_any())), ]) ], |lhs, rhs| { @@ -249,7 +250,7 @@ impl BindEvalExpr for EvalOpBinary { ) } EvalOpBinary::Concat => { - create!(Check, [TYPE_ANY, TYPE_ANY], |lhs, rhs| { + create!(Check, [TYPE_DYNAMIC, TYPE_DYNAMIC], |lhs, rhs| { // TODO non-naive concat (i.e., don't just use debug print for non-strings). let lhs = if let Value::String(s) = lhs { s.as_ref().clone() @@ -277,7 +278,7 @@ impl BindEvalExpr for EvalBetweenExpr { &self, args: Vec>, ) -> Result, BindError> { - let types = [TYPE_ANY, TYPE_ANY, TYPE_ANY]; + let types = [TYPE_DYNAMIC, TYPE_DYNAMIC, TYPE_DYNAMIC]; TernaryValueExpr::create_checked::<{ STRICT }, NullArgChecker, _>( types, args, @@ -315,7 +316,7 @@ impl BindEvalExpr for EvalFnAbs { &self, args: Vec>, ) -> Result, BindError> { - let nums = PartiqlType::any_of(TYPE_NUMERIC_TYPES); + let nums = PartiqlShape::any_of(TYPE_NUMERIC_TYPES); UnaryValueExpr::create_typed::<{ STRICT }, _>([nums], args, |v| { match NullableOrd::lt(v, &Value::from(0)) { Null => Null, @@ -336,10 +337,10 @@ impl BindEvalExpr for EvalFnCardinality { &self, args: Vec>, ) -> Result, BindError> { - let collections = PartiqlType::any_of([ - PartiqlType::new(TypeKind::Array(ArrayType::new_any())), - PartiqlType::new(TypeKind::Bag(BagType::new_any())), - PartiqlType::new(TypeKind::Struct(StructType::new_any())), + let collections = PartiqlShape::any_of([ + PartiqlShape::new(StaticTypeVariant::Array(ArrayType::new_any())), + PartiqlShape::new(StaticTypeVariant::Bag(BagType::new_any())), + PartiqlShape::new(StaticTypeVariant::Struct(StructType::new_any())), ]); UnaryValueExpr::create_typed::<{ STRICT }, _>([collections], args, |v| match v { diff --git a/partiql-logical-planner/src/lower.rs b/partiql-logical-planner/src/lower.rs index 93ce0324..d2fe5203 100644 --- a/partiql-logical-planner/src/lower.rs +++ b/partiql-logical-planner/src/lower.rs @@ -2000,7 +2000,7 @@ mod tests { use partiql_catalog::{PartiqlCatalog, TypeEnvEntry}; use partiql_logical::BindingsOp::Project; use partiql_logical::ValueExpr; - use partiql_types::any; + use partiql_types::dynamic; #[test] fn test_plan_non_existent_fns() { @@ -2100,7 +2100,7 @@ mod tests { expected_logical.add_flow_with_branch_num(project, sink, 0); let mut catalog = PartiqlCatalog::default(); - let _oid = catalog.add_type_entry(TypeEnvEntry::new("customers", &[], any!())); + let _oid = catalog.add_type_entry(TypeEnvEntry::new("customers", &[], dynamic!())); let statement = "SELECT c.id AS my_id, customers.name AS my_name FROM customers AS c"; let parsed = partiql_parser::Parser::default() .parse(statement) diff --git a/partiql-logical-planner/src/typer.rs b/partiql-logical-planner/src/typer.rs index d1f8d16f..716682f9 100644 --- a/partiql-logical-planner/src/typer.rs +++ b/partiql-logical-planner/src/typer.rs @@ -4,8 +4,8 @@ use partiql_ast::ast::{CaseSensitivity, SymbolPrimitive}; use partiql_catalog::Catalog; use partiql_logical::{BindingsOp, LogicalPlan, OpId, PathComponent, ValueExpr, VarRefType}; use partiql_types::{ - any, undefined, ArrayType, BagType, PartiqlType, StructConstraint, StructField, StructType, - TypeKind, + dynamic, undefined, ArrayType, BagType, PartiqlShape, ShapeResultError, StaticTypeVariant, + StructConstraint, StructField, StructType, }; use partiql_value::{BindingsName, Value}; use petgraph::algo::toposort; @@ -33,7 +33,16 @@ const OUTPUT_SCHEMA_KEY: &str = "_output_schema"; #[derive(Debug, Clone, Eq, PartialEq, Hash)] pub struct TypeErr { pub errors: Vec, - pub output: Option, + pub output: Option, +} + +impl From for TypeErr { + fn from(value: ShapeResultError) -> Self { + TypeErr { + errors: vec![TypingError::InvalidType(value)], + output: None, + } + } } #[derive(Error, Debug, Clone, PartialEq, Eq, Hash)] @@ -49,6 +58,9 @@ pub enum TypingError { /// Represents an error in type checking #[error("TypeCheck: {0}")] TypeCheck(String), + + #[error("TypeCheck: {0}")] + InvalidType(#[from] ShapeResultError), } #[derive(Debug, Clone)] @@ -73,7 +85,7 @@ enum LookupOrder { struct TypeEnvContext { env: LocalTypeEnv, /// Represents the type that is used for creating the `env` in the [TypeEnvContext] - derived_type: PartiqlType, + derived_type: PartiqlShape, } #[allow(dead_code)] @@ -86,7 +98,7 @@ impl TypeEnvContext { &self.env } - fn derived_type(&self) -> &PartiqlType { + fn derived_type(&self) -> &PartiqlShape { &self.derived_type } } @@ -95,13 +107,13 @@ impl Default for TypeEnvContext { fn default() -> Self { TypeEnvContext { env: LocalTypeEnv::new(), - derived_type: any!(), + derived_type: dynamic!(), } } } -impl From<(&LocalTypeEnv, &PartiqlType)> for TypeEnvContext { - fn from(value: (&LocalTypeEnv, &PartiqlType)) -> Self { +impl From<(&LocalTypeEnv, &PartiqlShape)> for TypeEnvContext { + fn from(value: (&LocalTypeEnv, &PartiqlShape)) -> Self { TypeEnvContext { env: value.0.clone(), derived_type: value.1.clone(), @@ -110,7 +122,7 @@ impl From<(&LocalTypeEnv, &PartiqlType)> for TypeEnvContext { } /// Represents a Local Type Environment as opposed to the Global Type Environment in the Catalog. -type LocalTypeEnv = IndexMap; +type LocalTypeEnv = IndexMap; #[derive(Debug, Clone)] pub struct PlanTyper<'c> { @@ -120,7 +132,7 @@ pub struct PlanTyper<'c> { errors: Vec, type_env_stack: Vec, current_bindings_op: Option, - output: Option, + output: Option, } #[allow(dead_code)] @@ -152,7 +164,7 @@ impl<'c> PlanTyper<'c> { } /// Returns the typing result for the Typer - pub fn type_plan(&mut self) -> Result { + pub fn type_plan(&mut self) -> Result { let ops = self.sort()?; for idx in ops { @@ -188,11 +200,10 @@ impl<'c> PlanTyper<'c> { } self.type_vexpr(expr, LookupOrder::Delegate); - if !as_key.is_empty() { let type_ctx = &self.local_type_ctx(); for (_name, ty) in type_ctx.env() { - if let TypeKind::Struct(_s) = ty.kind() { + if let Ok(_s) = ty.expect_struct() { self.type_env_stack.push(ty_ctx![( &ty_env![(string_to_sym(as_key.as_str()), ty.clone())], ty @@ -207,16 +218,16 @@ impl<'c> PlanTyper<'c> { StructField::new(k.as_str(), self.get_singleton_type_from_env()) }); - let ty = PartiqlType::new_struct(StructType::new(BTreeSet::from([ + let ty = PartiqlShape::new_struct(StructType::new(BTreeSet::from([ StructConstraint::Fields(fields.collect()), ]))); let derived_type_ctx = self.local_type_ctx(); let derived_type = &self.derived_type(&derived_type_ctx); let schema = if derived_type.is_ordered_collection() { - PartiqlType::new_array(ArrayType::new(Box::new(ty))) + PartiqlShape::new_array(ArrayType::new(Box::new(ty))) } else if derived_type.is_unordered_collection() { - PartiqlType::new_bag(BagType::new(Box::new(ty))) + PartiqlShape::new_bag(BagType::new(Box::new(ty))) } else { self.errors.push(TypingError::IllegalState(format!( "Expecting Collection for the output Schema but found {:?}", @@ -317,25 +328,28 @@ impl<'c> PlanTyper<'c> { } } ValueExpr::Lit(v) => { - let kind = match **v { - Value::Null => TypeKind::Undefined, - Value::Missing => TypeKind::Undefined, - Value::Integer(_) => TypeKind::Int, - Value::Decimal(_) => TypeKind::Decimal, - Value::Boolean(_) => TypeKind::Bool, - Value::String(_) => TypeKind::String, - Value::Tuple(_) => TypeKind::Struct(StructType::new_any()), - Value::List(_) => TypeKind::Array(ArrayType::new_any()), - Value::Bag(_) => TypeKind::Bag(BagType::new_any()), + let ty = match **v { + Value::Null => PartiqlShape::Undefined, + Value::Missing => PartiqlShape::Undefined, + Value::Integer(_) => PartiqlShape::new(StaticTypeVariant::Int), + Value::Decimal(_) => PartiqlShape::new(StaticTypeVariant::Decimal), + Value::Boolean(_) => PartiqlShape::new(StaticTypeVariant::Bool), + Value::String(_) => PartiqlShape::new(StaticTypeVariant::String), + Value::Tuple(_) => { + PartiqlShape::new(StaticTypeVariant::Struct(StructType::new_any())) + } + Value::List(_) => { + PartiqlShape::new(StaticTypeVariant::Array(ArrayType::new_any())) + } + Value::Bag(_) => PartiqlShape::new(StaticTypeVariant::Bag(BagType::new_any())), _ => { self.errors.push(TypingError::NotYetImplemented( "Unsupported Literal".to_string(), )); - TypeKind::Undefined + PartiqlShape::Undefined } }; - let ty = PartiqlType::new(kind); let new_type_env = IndexMap::from([(string_to_sym("_1"), ty.clone())]); self.type_env_stack.push(ty_ctx![(&new_type_env, &ty)]); } @@ -398,25 +412,31 @@ impl<'c> PlanTyper<'c> { Ok(graph) } - fn element_type<'a>(&'a mut self, ty: &'a PartiqlType) -> PartiqlType { - match ty.kind() { - TypeKind::Bag(b) => b.element_type().clone(), - TypeKind::Array(a) => a.element_type().clone(), - TypeKind::Any => any!(), - _ => ty.clone(), + fn element_type<'a>(&'a mut self, ty: &'a PartiqlShape) -> PartiqlShape { + match ty { + PartiqlShape::Dynamic => dynamic!(), + PartiqlShape::Static(s) => match s.ty() { + StaticTypeVariant::Bag(b) => b.element_type().clone(), + StaticTypeVariant::Array(a) => a.element_type().clone(), + _ => ty.clone(), + }, + undefined!() => { + todo!("Undefined type in catalog") + } + PartiqlShape::AnyOf(_any_of) => ty.clone(), } } - fn retrieve_type_from_local_ctx(&mut self, key: &SymbolPrimitive) -> Option { + fn retrieve_type_from_local_ctx(&mut self, key: &SymbolPrimitive) -> Option { let type_ctx = self.local_type_ctx(); let env = type_ctx.env().clone(); let derived_type = self.derived_type(&type_ctx); if let Some(ty) = env.get(key) { Some(ty.clone()) - } else if let TypeKind::Struct(s) = derived_type.kind() { + } else if let Ok(s) = derived_type.expect_struct() { if s.is_partial() { - Some(any!()) + Some(dynamic!()) } else { match &self.typing_mode { TypingMode::Permissive => Some(undefined!()), @@ -429,8 +449,8 @@ impl<'c> PlanTyper<'c> { } } } - } else if derived_type.is_any() { - Some(any!()) + } else if derived_type.is_dynamic() { + Some(dynamic!()) } else { self.errors.push(TypingError::IllegalState(format!( "Illegal Derive Type {:?}", @@ -440,7 +460,7 @@ impl<'c> PlanTyper<'c> { } } - fn derived_type(&mut self, ty_ctx: &TypeEnvContext) -> PartiqlType { + fn derived_type(&mut self, ty_ctx: &TypeEnvContext) -> PartiqlShape { let ty = ty_ctx.derived_type(); ty.clone() } @@ -473,7 +493,7 @@ impl<'c> PlanTyper<'c> { } } - fn resolve_global_then_local(&mut self, key: &SymbolPrimitive) -> PartiqlType { + fn resolve_global_then_local(&mut self, key: &SymbolPrimitive) -> PartiqlShape { let ty = self.resolve_global(key); match ty.is_undefined() { true => self.resolve_local(key), @@ -481,7 +501,7 @@ impl<'c> PlanTyper<'c> { } } - fn resolve_local_then_global(&mut self, key: &SymbolPrimitive) -> PartiqlType { + fn resolve_local_then_global(&mut self, key: &SymbolPrimitive) -> PartiqlShape { let ty = self.resolve_local(key); match ty.is_undefined() { true => self.resolve_global(key), @@ -489,7 +509,7 @@ impl<'c> PlanTyper<'c> { } } - fn resolve_global(&mut self, key: &SymbolPrimitive) -> PartiqlType { + fn resolve_global(&mut self, key: &SymbolPrimitive) -> PartiqlShape { if let Some(type_entry) = self.catalog.resolve_type(key.value.as_str()) { let ty = self.element_type(type_entry.ty()); ty @@ -498,7 +518,7 @@ impl<'c> PlanTyper<'c> { } } - fn resolve_local(&mut self, key: &SymbolPrimitive) -> PartiqlType { + fn resolve_local(&mut self, key: &SymbolPrimitive) -> PartiqlShape { for type_ctx in self.type_env_stack.iter().rev() { if let Some(ty) = type_ctx.env().get(key) { return ty.clone(); @@ -520,7 +540,7 @@ impl<'c> PlanTyper<'c> { // A helper function to extract one type out of the environment when we expect it. // E.g., in projections, when we expect to infer one type from the project list items. - fn get_singleton_type_from_env(&mut self) -> PartiqlType { + fn get_singleton_type_from_env(&mut self) -> PartiqlShape { let ctx = self.local_type_ctx(); let env = ctx.env(); if env.len() != 1 { @@ -534,13 +554,13 @@ impl<'c> PlanTyper<'c> { } } - fn type_varef(&mut self, key: &SymbolPrimitive, ty: &PartiqlType) { + fn type_varef(&mut self, key: &SymbolPrimitive, ty: &PartiqlShape) { if ty.is_undefined() { self.type_with_undefined(key); } else { let mut new_type_env = LocalTypeEnv::new(); - if let TypeKind::Struct(s) = ty.kind() { - for b in to_bindings(s) { + if let Ok(s) = ty.expect_struct() { + for b in to_bindings(&s) { new_type_env.insert(b.0, b.1); } @@ -562,7 +582,7 @@ fn string_to_sym(name: &str) -> SymbolPrimitive { } } -fn to_bindings(s: &StructType) -> Vec<(SymbolPrimitive, PartiqlType)> { +fn to_bindings(s: &StructType) -> Vec<(SymbolPrimitive, PartiqlShape)> { s.fields() .into_iter() .map(|field| { @@ -596,7 +616,7 @@ mod tests { [ StructField::new("id", int!()), StructField::new("name", str!()), - StructField::new("age", any!()), + StructField::new("age", dynamic!()), ] .into(), ), @@ -616,7 +636,7 @@ mod tests { [ StructField::new("id", int!()), StructField::new("name", str!()), - StructField::new("age", any!()), + StructField::new("age", dynamic!()), ] .into(), ), @@ -636,17 +656,18 @@ mod tests { [ StructField::new("id", int!()), StructField::new("name", str!()), - StructField::new("age", any!()), + StructField::new("age", dynamic!()), ] .into(), ), vec![ StructField::new("id", int!()), StructField::new("name", str!()), - StructField::new("age", any!()), + StructField::new("age", dynamic!()), ], ) .expect("Type"); + // Closed Schema with `Permissive` typing mode and `age` non-existent projection. assert_query_typing( TypingMode::Permissive, @@ -704,7 +725,7 @@ mod tests { StructField::new("id", int!()), StructField::new("name", str!()), StructField::new("age", int!()), - StructField::new("bar", any!()), + StructField::new("bar", dynamic!()), ], ) .expect("Type"); @@ -743,7 +764,7 @@ mod tests { [ StructField::new("id", int!()), StructField::new("name", str!()), - StructField::new("age", any!()), + StructField::new("age", dynamic!()), ] .into(), ), @@ -758,7 +779,7 @@ mod tests { #[test] fn simple_sfw_err() { // Closed Schema with `Strict` typing mode and `age` non-existent projection. - let err1 = r#"No Typing Information for SymbolPrimitive { value: "age", case: CaseInsensitive } in closed Schema PartiqlType(Struct(StructType { constraints: {Open(false), Fields({StructField { name: "id", ty: PartiqlType(Int) }, StructField { name: "name", ty: PartiqlType(String) }})} }))"#; + let err1 = r#"No Typing Information for SymbolPrimitive { value: "age", case: CaseInsensitive } in closed Schema Static(StaticType { ty: Struct(StructType { constraints: {Open(false), Fields({StructField { name: "id", ty: Static(StaticType { ty: Int, nullable: true }) }, StructField { name: "name", ty: Static(StaticType { ty: String, nullable: true }) }})} }), nullable: true })"#; assert_err( assert_query_typing( @@ -774,10 +795,7 @@ mod tests { ), vec![], ), - vec![ - TypingError::TypeCheck(err1.to_string()), - // TypingError::IllegalState(err2.to_string()), - ], + vec![TypingError::TypeCheck(err1.to_string())], Some(bag![r#struct![BTreeSet::from([StructConstraint::Fields( [ StructField::new("id", int!()), @@ -795,8 +813,8 @@ mod tests { StructConstraint::Open(false) ])]; - let err1 = r#"No Typing Information for SymbolPrimitive { value: "details", case: CaseInsensitive } in closed Schema PartiqlType(Struct(StructType { constraints: {Open(false), Fields({StructField { name: "age", ty: PartiqlType(Int) }})} }))"#; - let err2 = r"Illegal Derive Type PartiqlType(Undefined)"; + let err1 = r#"No Typing Information for SymbolPrimitive { value: "details", case: CaseInsensitive } in closed Schema Static(StaticType { ty: Struct(StructType { constraints: {Open(false), Fields({StructField { name: "age", ty: Static(StaticType { ty: Int, nullable: true }) }})} }), nullable: true })"#; + let err2 = r"Illegal Derive Type Undefined"; assert_err( assert_query_typing( @@ -831,7 +849,7 @@ mod tests { fn assert_err( result: Result<(), TypeErr>, expected_errors: Vec, - output: Option, + output: Option, ) { match result { Ok(()) => { @@ -849,7 +867,7 @@ mod tests { }; } - fn create_customer_schema(is_open: bool, fields: BTreeSet) -> PartiqlType { + fn create_customer_schema(is_open: bool, fields: BTreeSet) -> PartiqlShape { bag![r#struct![BTreeSet::from([ StructConstraint::Fields(fields), StructConstraint::Open(is_open) @@ -859,41 +877,38 @@ mod tests { fn assert_query_typing( mode: TypingMode, query: &str, - schema: PartiqlType, + schema: PartiqlShape, expected_fields: Vec, ) -> Result<(), TypeErr> { let expected_fields: BTreeSet<_> = expected_fields.into_iter().collect(); - let actual = type_query(mode, query, TypeEnvEntry::new("customers", &[], schema)); - - match actual { - Ok(actual) => match &actual.kind() { - TypeKind::Bag(b) => { - if let TypeKind::Struct(s) = b.element_type().kind() { - let fields = s.fields(); - let f: Vec<_> = expected_fields - .iter() - .filter(|f| !fields.contains(f)) - .collect(); - assert!(f.is_empty()); - assert_eq!(expected_fields.len(), fields.len()); - println!("query: {query:?}"); - println!("actual: {actual:?}"); - Ok(()) - } else { - Err(TypeErr { - errors: vec![TypingError::TypeCheck( - "[Struct] type expected".to_string(), - )], - output: None, - }) - } + let actual = type_query(mode, query, TypeEnvEntry::new("customers", &[], schema))? + .expect_static()?; + + match &actual.ty() { + StaticTypeVariant::Bag(b) => { + if let Ok(s) = b.element_type().expect_struct() { + let fields = s.fields(); + + let f: Vec<_> = expected_fields + .iter() + .filter(|f| !fields.contains(f)) + .collect(); + assert!(f.is_empty()); + assert_eq!(expected_fields.len(), fields.len()); + println!("query: {query:?}"); + println!("actual: {actual:?}"); + Ok(()) + } else { + Err(TypeErr { + errors: vec![TypingError::TypeCheck("[Struct] type expected".to_string())], + output: None, + }) } - _ => Err(TypeErr { - errors: vec![TypingError::TypeCheck("[Bag] type expected".to_string())], - output: None, - }), - }, - Err(e) => Err(e), + } + _ => Err(TypeErr { + errors: vec![TypingError::TypeCheck("[Bag] type expected".to_string())], + output: None, + }), } } @@ -901,7 +916,7 @@ mod tests { mode: TypingMode, query: &str, type_env_entry: TypeEnvEntry<'_>, - ) -> Result { + ) -> Result { let mut catalog = PartiqlCatalog::default(); let _oid = catalog.add_type_entry(type_env_entry); diff --git a/partiql-types/Cargo.toml b/partiql-types/Cargo.toml index 60d70ddb..468d7583 100644 --- a/partiql-types/Cargo.toml +++ b/partiql-types/Cargo.toml @@ -26,5 +26,8 @@ ordered-float = "3.*" itertools = "0.10.*" unicase = "2.6" +miette = { version ="7.2.*", features = ["fancy"] } +thiserror = "1.*" + [dev-dependencies] criterion = "0.4" diff --git a/partiql-types/src/lib.rs b/partiql-types/src/lib.rs index 56e989c5..8f1c9759 100644 --- a/partiql-types/src/lib.rs +++ b/partiql-types/src/lib.rs @@ -2,60 +2,73 @@ #![deny(clippy::all)] use itertools::Itertools; +use miette::Diagnostic; use std::collections::BTreeSet; use std::fmt::{Debug, Display, Formatter}; use std::hash::Hash; +use thiserror::Error; + +#[derive(Debug, Clone, Eq, PartialEq, Hash, Error, Diagnostic)] +#[error("ShapeResult Error")] +#[non_exhaustive] +pub enum ShapeResultError { + #[error("Unexpected type `{0:?}` for static type bool")] + UnexpectedType(String), +} + +/// Result of attempts to encode to Ion. +pub type ShapeResult = Result; pub trait Type {} -impl Type for PartiqlType {} +impl Type for StaticType {} #[macro_export] -macro_rules! any { +macro_rules! dynamic { () => { - $crate::PartiqlType::new($crate::TypeKind::Any) + $crate::PartiqlShape::Dynamic }; } #[macro_export] macro_rules! int { () => { - $crate::PartiqlType::new($crate::TypeKind::Int) + $crate::PartiqlShape::new($crate::StaticTypeVariant::Int) }; } #[macro_export] macro_rules! int8 { () => { - $crate::PartiqlType::new($crate::TypeKind::Int8) + $crate::PartiqlShape::new($crate::StaticTypeVariant::Int8) }; } #[macro_export] macro_rules! int16 { () => { - $crate::PartiqlType::new($crate::TypeKind::Int16) + $crate::PartiqlShape::new($crate::StaticTypeVariant::Int16) }; } #[macro_export] macro_rules! int32 { () => { - $crate::PartiqlType::new($crate::TypeKind::Int32) + $crate::PartiqlShape::new($crate::StaticTypeVariant::Int32) }; } #[macro_export] macro_rules! int64 { () => { - $crate::PartiqlType::new($crate::TypeKind::Int64) + $crate::PartiqlShape::new($crate::StaticTypeVariant::Int64) }; } #[macro_export] macro_rules! dec { () => { - $crate::PartiqlType::new($crate::TypeKind::Decimal) + $crate::PartiqlShape::new($crate::StaticTypeVariant::Decimal) }; } @@ -64,31 +77,31 @@ macro_rules! dec { #[macro_export] macro_rules! f32 { () => { - $crate::PartiqlType::new($crate::TypeKind::Float32) + $crate::PartiqlShape::new($crate::StaticTypeVariant::Float32) }; } #[macro_export] macro_rules! f64 { () => { - $crate::PartiqlType::new($crate::TypeKind::Float64) + $crate::PartiqlShape::new($crate::StaticTypeVariant::Float64) }; } #[macro_export] macro_rules! str { () => { - $crate::PartiqlType::new($crate::TypeKind::String) + $crate::PartiqlShape::new($crate::StaticTypeVariant::String) }; } #[macro_export] macro_rules! r#struct { () => { - $crate::PartiqlType::new_struct(StructType::new_any()) + $crate::PartiqlShape::new_struct(StructType::new_any()) }; ($elem:expr) => { - $crate::PartiqlType::new_struct(StructType::new($elem)) + $crate::PartiqlShape::new_struct(StructType::new($elem)) }; } @@ -102,39 +115,50 @@ macro_rules! struct_fields { #[macro_export] macro_rules! r#bag { () => { - $crate::PartiqlType::new_bag(BagType::new_any()); + $crate::PartiqlShape::new_bag(BagType::new_any()); }; ($elem:expr) => { - $crate::PartiqlType::new_bag(BagType::new(Box::new($elem))) + $crate::PartiqlShape::new_bag(BagType::new(Box::new($elem))) }; } #[macro_export] macro_rules! r#array { () => { - $crate::PartiqlType::new_array(ArrayType::new_any()); + $crate::PartiqlShape::new_array(ArrayType::new_any()); }; ($elem:expr) => { - $crate::PartiqlType::new_array(ArrayType::new(Box::new($elem))) + $crate::PartiqlShape::new_array(ArrayType::new(Box::new($elem))) }; } #[macro_export] macro_rules! undefined { () => { - $crate::PartiqlType::new($crate::TypeKind::Undefined) + $crate::PartiqlShape::Undefined }; } +/// Represents a PartiQL Shape #[derive(Debug, Clone, Eq, PartialEq, Hash, Ord, PartialOrd)] -pub struct PartiqlType(TypeKind); +// With this implementation `Dynamic` and `AnyOf` cannot have `nullability`; this does not mean their +// `null` value at runtime cannot belong to their domain. +// TODO adopt the correct model Pending PartiQL Types semantics finalization: https://github.com/partiql/partiql-lang/issues/18 +pub enum PartiqlShape { + Dynamic, + AnyOf(AnyOf), + Static(StaticType), + Undefined, +} #[derive(Debug, Clone, Eq, PartialEq, Hash, Ord, PartialOrd)] -#[non_exhaustive] -pub enum TypeKind { - Any, - AnyOf(AnyOf), +pub struct StaticType { + ty: StaticTypeVariant, + nullable: bool, +} +#[derive(Debug, Clone, Eq, PartialEq, Hash, Ord, PartialOrd)] +pub enum StaticTypeVariant { // Scalar Types Int, Int8, @@ -158,144 +182,197 @@ pub enum TypeKind { Struct(StructType), Bag(BagType), Array(ArrayType), - // Serves as Bottom Type - Undefined, // TODO Add BitString, ByteString, Blob, Clob, and Graph types } -impl Display for TypeKind { +impl StaticType { + #[must_use] + pub fn new(&self, ty: StaticTypeVariant) -> StaticType { + StaticType { ty, nullable: true } + } + + #[must_use] + pub fn new_non_nullable(&self, ty: StaticTypeVariant) -> StaticType { + StaticType { + ty, + nullable: false, + } + } + + #[must_use] + pub fn ty(&self) -> StaticTypeVariant { + self.ty.clone() + } + + #[must_use] + pub fn is_nullable(&self) -> bool { + self.nullable + } +} + +impl Display for StaticType { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + let nullable = if self.nullable { + "nullable" + } else { + "non_nullable" + }; + write!(f, "({}, {})", self.ty, nullable) + } +} + +impl Display for StaticTypeVariant { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { let x = match self { - TypeKind::Any => "Any".to_string(), - TypeKind::AnyOf(anyof) => { - format!( - "AnyOf({})", - anyof.types.iter().map(PartiqlType::kind).join(",") - ) - } - TypeKind::Int => "Int".to_string(), - TypeKind::Int8 => "Int8".to_string(), - TypeKind::Int16 => "Int16".to_string(), - TypeKind::Int32 => "Int32".to_string(), - TypeKind::Int64 => "Int64".to_string(), - TypeKind::Bool => "Bool".to_string(), - TypeKind::Decimal => "Decimal".to_string(), - TypeKind::DecimalP(_, _) => { + StaticTypeVariant::Int => "Int".to_string(), + StaticTypeVariant::Int8 => "Int8".to_string(), + StaticTypeVariant::Int16 => "Int16".to_string(), + StaticTypeVariant::Int32 => "Int32".to_string(), + StaticTypeVariant::Int64 => "Int64".to_string(), + StaticTypeVariant::Bool => "Bool".to_string(), + StaticTypeVariant::Decimal => "Decimal".to_string(), + StaticTypeVariant::DecimalP(_, _) => { todo!() } - TypeKind::Float32 => "Float32".to_string(), - TypeKind::Float64 => "Float64".to_string(), - TypeKind::String => "String".to_string(), - TypeKind::StringFixed(_) => { + StaticTypeVariant::Float32 => "Float32".to_string(), + StaticTypeVariant::Float64 => "Float64".to_string(), + StaticTypeVariant::String => "String".to_string(), + StaticTypeVariant::StringFixed(_) => { todo!() } - TypeKind::StringVarying(_) => { + StaticTypeVariant::StringVarying(_) => { todo!() } - TypeKind::DateTime => "DateTime".to_string(), - TypeKind::Struct(_) => "Struct".to_string(), - TypeKind::Bag(_) => "Bag".to_string(), - TypeKind::Array(_) => "Array".to_string(), - TypeKind::Undefined => "Undefined".to_string(), + StaticTypeVariant::DateTime => "DateTime".to_string(), + StaticTypeVariant::Struct(_) => "Struct".to_string(), + StaticTypeVariant::Bag(_) => "Bag".to_string(), + StaticTypeVariant::Array(_) => "Array".to_string(), }; write!(f, "{x}") } } -pub const TYPE_ANY: PartiqlType = PartiqlType::new(TypeKind::Any); -pub const TYPE_BOOL: PartiqlType = PartiqlType::new(TypeKind::Bool); -pub const TYPE_INT: PartiqlType = PartiqlType::new(TypeKind::Int); -pub const TYPE_INT8: PartiqlType = PartiqlType::new(TypeKind::Int8); -pub const TYPE_INT16: PartiqlType = PartiqlType::new(TypeKind::Int16); -pub const TYPE_INT32: PartiqlType = PartiqlType::new(TypeKind::Int32); -pub const TYPE_INT64: PartiqlType = PartiqlType::new(TypeKind::Int64); -pub const TYPE_REAL: PartiqlType = PartiqlType::new(TypeKind::Float32); -pub const TYPE_DOUBLE: PartiqlType = PartiqlType::new(TypeKind::Float64); -pub const TYPE_DECIMAL: PartiqlType = PartiqlType::new(TypeKind::Decimal); -pub const TYPE_STRING: PartiqlType = PartiqlType::new(TypeKind::String); -pub const TYPE_DATETIME: PartiqlType = PartiqlType::new(TypeKind::DateTime); -pub const TYPE_NUMERIC_TYPES: [PartiqlType; 4] = [TYPE_INT, TYPE_REAL, TYPE_DOUBLE, TYPE_DECIMAL]; +pub const TYPE_DYNAMIC: PartiqlShape = PartiqlShape::Dynamic; +pub const TYPE_BOOL: PartiqlShape = PartiqlShape::new(StaticTypeVariant::Bool); +pub const TYPE_INT: PartiqlShape = PartiqlShape::new(StaticTypeVariant::Int); +pub const TYPE_INT8: PartiqlShape = PartiqlShape::new(StaticTypeVariant::Int8); +pub const TYPE_INT16: PartiqlShape = PartiqlShape::new(StaticTypeVariant::Int16); +pub const TYPE_INT32: PartiqlShape = PartiqlShape::new(StaticTypeVariant::Int32); +pub const TYPE_INT64: PartiqlShape = PartiqlShape::new(StaticTypeVariant::Int64); +pub const TYPE_REAL: PartiqlShape = PartiqlShape::new(StaticTypeVariant::Float32); +pub const TYPE_DOUBLE: PartiqlShape = PartiqlShape::new(StaticTypeVariant::Float64); +pub const TYPE_DECIMAL: PartiqlShape = PartiqlShape::new(StaticTypeVariant::Decimal); +pub const TYPE_STRING: PartiqlShape = PartiqlShape::new(StaticTypeVariant::String); +pub const TYPE_DATETIME: PartiqlShape = PartiqlShape::new(StaticTypeVariant::DateTime); +pub const TYPE_NUMERIC_TYPES: [PartiqlShape; 4] = [TYPE_INT, TYPE_REAL, TYPE_DOUBLE, TYPE_DECIMAL]; #[allow(dead_code)] -impl PartiqlType { +impl PartiqlShape { #[must_use] - pub const fn new(kind: TypeKind) -> PartiqlType { - PartiqlType(kind) + pub const fn new(ty: StaticTypeVariant) -> PartiqlShape { + PartiqlShape::Static(StaticType { ty, nullable: true }) } #[must_use] - pub fn new_any() -> PartiqlType { - PartiqlType(TypeKind::Any) + pub const fn new_non_nullable(ty: StaticTypeVariant) -> PartiqlShape { + PartiqlShape::Static(StaticType { + ty, + nullable: false, + }) } #[must_use] - pub fn new_struct(s: StructType) -> PartiqlType { - PartiqlType(TypeKind::Struct(s)) + pub fn new_dynamic() -> PartiqlShape { + PartiqlShape::Dynamic } #[must_use] - pub fn new_bag(b: BagType) -> PartiqlType { - PartiqlType(TypeKind::Bag(b)) + pub fn new_struct(s: StructType) -> PartiqlShape { + PartiqlShape::new(StaticTypeVariant::Struct(s)) } #[must_use] - pub fn new_array(a: ArrayType) -> PartiqlType { - PartiqlType(TypeKind::Array(a)) + pub fn new_bag(b: BagType) -> PartiqlShape { + PartiqlShape::new(StaticTypeVariant::Bag(b)) } - pub fn any_of(types: I) -> PartiqlType + #[must_use] + pub fn new_array(a: ArrayType) -> PartiqlShape { + PartiqlShape::new(StaticTypeVariant::Array(a)) + } + + pub fn any_of(types: I) -> PartiqlShape where - I: IntoIterator, + I: IntoIterator, { let any_of = AnyOf::from_iter(types); match any_of.types.len() { - 0 => TYPE_ANY, + 0 => TYPE_DYNAMIC, 1 => { let AnyOf { types } = any_of; types.into_iter().next().unwrap() } - _ => PartiqlType(TypeKind::AnyOf(any_of)), + // TODO figure out what does it mean for a Union to be nullable or not + _ => PartiqlShape::AnyOf(any_of), } } #[must_use] - pub fn union_with(self, other: PartiqlType) -> PartiqlType { - match (self.0, other.0) { - (TypeKind::Any, _) | (_, TypeKind::Any) => PartiqlType::new(TypeKind::Any), - (TypeKind::AnyOf(lhs), TypeKind::AnyOf(rhs)) => { - PartiqlType::any_of(lhs.types.into_iter().chain(rhs.types)) + pub fn union_with(self, other: PartiqlShape) -> PartiqlShape { + match (self, other) { + (PartiqlShape::Dynamic, _) | (_, PartiqlShape::Dynamic) => PartiqlShape::new_dynamic(), + (PartiqlShape::AnyOf(lhs), PartiqlShape::AnyOf(rhs)) => { + PartiqlShape::any_of(lhs.types.into_iter().chain(rhs.types)) } - (TypeKind::AnyOf(anyof), other) | (other, TypeKind::AnyOf(anyof)) => { + (PartiqlShape::AnyOf(anyof), other) | (other, PartiqlShape::AnyOf(anyof)) => { let mut types = anyof.types; - types.insert(PartiqlType::new(other)); - PartiqlType::any_of(types) + types.insert(other); + PartiqlShape::any_of(types) } (l, r) => { - let types = [PartiqlType::new(l), PartiqlType::new(r)]; - PartiqlType::any_of(types) + let types = [l, r]; + PartiqlShape::any_of(types) } } } #[must_use] pub fn is_string(&self) -> bool { - matches!(&self, PartiqlType(TypeKind::String)) - } - - #[must_use] - pub fn kind(&self) -> &TypeKind { - &self.0 + matches!( + &self, + PartiqlShape::Static(StaticType { + ty: StaticTypeVariant::String, + nullable: true + }) + ) } #[must_use] pub fn is_struct(&self) -> bool { - matches!(*self, PartiqlType(TypeKind::Struct(_))) + matches!( + *self, + PartiqlShape::Static(StaticType { + ty: StaticTypeVariant::Struct(_), + nullable: true + }) + ) } #[must_use] pub fn is_collection(&self) -> bool { - matches!(*self, PartiqlType(TypeKind::Bag(_))) - || matches!(*self, PartiqlType(TypeKind::Array(_))) + matches!( + *self, + PartiqlShape::Static(StaticType { + ty: StaticTypeVariant::Bag(_), + nullable: true + }) + ) || matches!( + *self, + PartiqlShape::Static(StaticType { + ty: StaticTypeVariant::Array(_), + nullable: true + }) + ) } #[must_use] @@ -306,49 +383,132 @@ impl PartiqlType { #[must_use] pub fn is_ordered_collection(&self) -> bool { // TODO Add Sexp when added - matches!(*self, PartiqlType(TypeKind::Array(_))) + matches!( + *self, + PartiqlShape::Static(StaticType { + ty: StaticTypeVariant::Array(_), + nullable: true + }) + ) } #[must_use] pub fn is_bag(&self) -> bool { - matches!(*self, PartiqlType(TypeKind::Bag(_))) + matches!( + *self, + PartiqlShape::Static(StaticType { + ty: StaticTypeVariant::Bag(_), + nullable: true + }) + ) } #[must_use] pub fn is_array(&self) -> bool { - matches!(*self, PartiqlType(TypeKind::Array(_))) + matches!( + *self, + PartiqlShape::Static(StaticType { + ty: StaticTypeVariant::Array(_), + nullable: true + }) + ) } #[must_use] - pub fn is_any(&self) -> bool { - matches!(*self, PartiqlType(TypeKind::Any)) + pub fn is_dynamic(&self) -> bool { + matches!(*self, PartiqlShape::Dynamic) } #[must_use] pub fn is_undefined(&self) -> bool { - matches!(*self, PartiqlType(TypeKind::Undefined)) + matches!(*self, PartiqlShape::Undefined) + } + + pub fn expect_bool(&self) -> ShapeResult { + if let PartiqlShape::Static(StaticType { + ty: StaticTypeVariant::Bool, + nullable: n, + }) = self + { + Ok(StaticType { + ty: StaticTypeVariant::Bool, + nullable: *n, + }) + } else { + Err(ShapeResultError::UnexpectedType(format!("{self}"))) + } + } + + pub fn expect_struct(&self) -> ShapeResult { + if let PartiqlShape::Static(StaticType { + ty: StaticTypeVariant::Struct(stct), + .. + }) = self + { + Ok(stct.clone()) + } else { + Err(ShapeResultError::UnexpectedType(format!("{self}"))) + } + } + + pub fn expect_static(&self) -> ShapeResult { + if let PartiqlShape::Static(s) = self { + Ok(s.clone()) + } else { + Err(ShapeResultError::UnexpectedType(format!("{self}"))) + } + } + + pub fn expect_dynamic_type(&self) -> ShapeResult { + if let PartiqlShape::Dynamic = self { + Ok(PartiqlShape::Dynamic) + } else { + Err(ShapeResultError::UnexpectedType(format!("{self}"))) + } + } + + pub fn expect_undefined(&self) -> ShapeResult { + if let PartiqlShape::Undefined = self { + Ok(PartiqlShape::Undefined) + } else { + Err(ShapeResultError::UnexpectedType(format!("{self}"))) + } + } +} + +impl Display for PartiqlShape { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + let x = match self { + PartiqlShape::Dynamic => "Dynamic".to_string(), + PartiqlShape::AnyOf(anyof) => { + format!("AnyOf({})", anyof.types.iter().cloned().join(",")) + } + PartiqlShape::Static(s) => format!("{s}"), + PartiqlShape::Undefined => "Undefined".to_string(), + }; + write!(f, "{x}") } } #[derive(Hash, Eq, PartialEq, Debug, Clone, Ord, PartialOrd)] #[allow(dead_code)] pub struct AnyOf { - types: BTreeSet, + types: BTreeSet, } impl AnyOf { #[must_use] - pub const fn new(types: BTreeSet) -> Self { + pub const fn new(types: BTreeSet) -> Self { AnyOf { types } } - pub fn types(&self) -> impl Iterator { + pub fn types(&self) -> impl Iterator { self.types.iter() } } -impl FromIterator for AnyOf { - fn from_iter>(iter: T) -> Self { +impl FromIterator for AnyOf { + fn from_iter>(iter: T) -> Self { AnyOf { types: iter.into_iter().collect(), } @@ -413,12 +573,12 @@ pub enum StructConstraint { #[allow(dead_code)] pub struct StructField { name: String, - ty: PartiqlType, + ty: PartiqlShape, } impl StructField { #[must_use] - pub fn new(name: &str, ty: PartiqlType) -> Self { + pub fn new(name: &str, ty: PartiqlShape) -> Self { StructField { name: name.to_string(), ty, @@ -431,13 +591,13 @@ impl StructField { } #[must_use] - pub fn ty(&self) -> &PartiqlType { + pub fn ty(&self) -> &PartiqlShape { &self.ty } } -impl From<(&str, PartiqlType)> for StructField { - fn from(value: (&str, PartiqlType)) -> Self { +impl From<(&str, PartiqlShape)> for StructField { + fn from(value: (&str, PartiqlShape)) -> Self { StructField { name: value.0.to_string(), ty: value.1, @@ -448,22 +608,22 @@ impl From<(&str, PartiqlType)> for StructField { #[derive(Debug, Clone, Eq, PartialEq, Hash, Ord, PartialOrd)] #[allow(dead_code)] pub struct BagType { - element_type: Box, + element_type: Box, } impl BagType { #[must_use] pub fn new_any() -> Self { - BagType::new(Box::new(PartiqlType(TypeKind::Any))) + BagType::new(Box::new(PartiqlShape::Dynamic)) } #[must_use] - pub fn new(typ: Box) -> Self { + pub fn new(typ: Box) -> Self { BagType { element_type: typ } } #[must_use] - pub fn element_type(&self) -> &PartiqlType { + pub fn element_type(&self) -> &PartiqlShape { &self.element_type } } @@ -471,7 +631,7 @@ impl BagType { #[derive(Debug, Clone, Eq, PartialEq, Hash, Ord, PartialOrd)] #[allow(dead_code)] pub struct ArrayType { - element_type: Box, + element_type: Box, // TODO Add Array constraint once we have Schema Specification: // https://github.com/partiql/partiql-spec/issues/49 } @@ -479,44 +639,44 @@ pub struct ArrayType { impl ArrayType { #[must_use] pub fn new_any() -> Self { - ArrayType::new(Box::new(PartiqlType(TypeKind::Any))) + ArrayType::new(Box::new(PartiqlShape::Dynamic)) } #[must_use] - pub fn new(typ: Box) -> Self { + pub fn new(typ: Box) -> Self { ArrayType { element_type: typ } } #[must_use] - pub fn element_type(&self) -> &PartiqlType { + pub fn element_type(&self) -> &PartiqlShape { &self.element_type } } #[cfg(test)] mod tests { - use crate::{PartiqlType, TYPE_INT, TYPE_REAL}; + use crate::{PartiqlShape, TYPE_INT, TYPE_REAL}; #[test] fn union() { let expect_int = TYPE_INT; assert_eq!(expect_int, TYPE_INT.union_with(TYPE_INT)); - let expect_nums = PartiqlType::any_of([TYPE_INT, TYPE_REAL]); + let expect_nums = PartiqlShape::any_of([TYPE_INT, TYPE_REAL]); assert_eq!(expect_nums, TYPE_INT.union_with(TYPE_REAL)); assert_eq!( expect_nums, - PartiqlType::any_of([ + PartiqlShape::any_of([ TYPE_INT.union_with(TYPE_REAL), TYPE_INT.union_with(TYPE_REAL) ]) ); assert_eq!( expect_nums, - PartiqlType::any_of([ + PartiqlShape::any_of([ TYPE_INT.union_with(TYPE_REAL), TYPE_INT.union_with(TYPE_REAL), - PartiqlType::any_of([ + PartiqlShape::any_of([ TYPE_INT.union_with(TYPE_REAL), TYPE_INT.union_with(TYPE_REAL) ])