diff --git a/src/ast.rs b/src/ast.rs index 88091ff..b9fb191 100644 --- a/src/ast.rs +++ b/src/ast.rs @@ -40,7 +40,7 @@ use crate::{ tokenize_sql, EmptyFile, Ident, Keyword, Literal, LiteralValue, PseudoKeyword, Punct, RawToken, Span, Spanned, ToTokens, Token, TokenStream, TokenWriter, }, - types::TableType, + types::{StructType, TableType}, util::{is_c_ident, AnsiIdent, AnsiString}, }; @@ -1173,6 +1173,14 @@ impl Emit for ArrayExpression { } } +/// The type of the elements in an `ARRAY` expression. +#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault, Spanned, ToTokens)] +pub struct ArrayElementType { + pub lt: Punct, + pub elem_type: DataType, + pub gt: Punct, +} + /// An `ARRAY` definition. Either a `SELECT` expression or a list of /// expressions. #[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault, Spanned, ToTokens)] @@ -1205,19 +1213,61 @@ pub struct ArraySelectExpression { } /// A struct expression. -#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault, Spanned, ToTokens)] +#[derive(Clone, Debug, Drive, DriveMut, EmitDefault, Spanned, ToTokens)] pub struct StructExpression { + /// Type information added later by inference. + #[emit(skip)] + #[to_tokens(skip)] + #[drive(skip)] + pub ty: Option, + pub struct_token: Keyword, + pub field_decls: Option, pub paren1: Punct, pub fields: SelectList, pub paren2: Punct, } -/// The type of the elements in an `ARRAY` expression. +impl Emit for StructExpression { + fn emit(&self, t: Target, f: &mut TokenWriter<'_>) -> io::Result<()> { + match t { + Target::Trino => { + f.write_token_start("CAST(")?; + self.struct_token.ident.token.with_str("ROW").emit(t, f)?; + self.paren1.emit(t, f)?; + // Loop over fields, emitting them without any aliases. + for (i, item) in self.fields.items.node_iter().enumerate() { + if i > 0 { + f.write_token_start(", ")?; + } + match item { + SelectListItem::Expression { expression, .. } => { + expression.emit(t, f)?; + } + _ => panic!("Unexpected select list item (should have been caught by type checker): {:?}", item), + } + } + self.paren2.emit(t, f)?; + f.write_token_start("AS")?; + let ty = self + .ty + .as_ref() + .expect("type should have been added by type checker"); + let data_type = DataType::try_from(ty.clone()) + .expect("should be able to print data type of ROW"); + data_type.emit(t, f)?; + f.write_token_start(")") + } + _ => self.emit_default(t, f), + } + } +} + +/// Fields declared by a `STRUCT<..>(..)` expression. #[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault, Spanned, ToTokens)] -pub struct ArrayElementType { +pub struct StructFieldDeclarations { pub lt: Punct, - pub elem_type: DataType, + pub fields: NodeVec, pub gt: Punct, } @@ -2353,6 +2403,11 @@ peg::parser! { } } + rule array_element_type() -> ArrayElementType + = lt:p("<") elem_type:data_type() gt:p(">") { + ArrayElementType { lt, elem_type, gt } + } + rule array_definition() -> ArrayDefinition = select:array_select_expression() { ArrayDefinition::Query { select: Box::new(select) } } / expressions:sep(, ",") { ArrayDefinition::Elements(expressions) } @@ -2378,18 +2433,22 @@ peg::parser! { } rule struct_expression() -> StructExpression - = struct_token:k("STRUCT") paren1:p("(") fields:select_list() paren2:p(")") { + = struct_token:k("STRUCT") field_decls:struct_field_declarations()? + paren1:p("(") fields:select_list() paren2:p(")") + { StructExpression { + ty: None, struct_token, + field_decls, paren1, fields, paren2, } } - rule array_element_type() -> ArrayElementType - = lt:p("<") elem_type:data_type() gt:p(">") { - ArrayElementType { lt, elem_type, gt } + rule struct_field_declarations() -> StructFieldDeclarations + = lt:p("<") fields:sep(, ",") gt:p(">") { + StructFieldDeclarations { lt, fields, gt } } rule count_expression() -> CountExpression diff --git a/src/infer/mod.rs b/src/infer/mod.rs index b3015c5..267cf1a 100644 --- a/src/infer/mod.rs +++ b/src/infer/mod.rs @@ -8,8 +8,11 @@ use crate::{ ast::{self, ConditionJoinOperator, Emit, Expression, Name}, errors::{Error, Result}, scope::{ColumnSet, ColumnSetScope, Scope, ScopeGet, ScopeHandle}, - tokenizer::{Ident, Literal, LiteralValue, Spanned}, - types::{ArgumentType, ColumnType, SimpleType, TableType, Type, Unnested, ValueType}, + tokenizer::{Ident, Keyword, Literal, LiteralValue, Punct, Spanned}, + types::{ + ArgumentType, ColumnType, ResolvedTypeVarsOnly, SimpleType, StructElementType, StructType, + TableType, Type, Unnested, ValueType, + }, unification::{UnificationTable, Unify}, }; @@ -649,6 +652,7 @@ impl InferTypes for ast::Expression { ast::Expression::Literal(Literal { value, .. }) => value.infer_types(&()), ast::Expression::BoolValue(_) => Ok(ArgumentType::bool()), ast::Expression::Null { .. } => Ok(ArgumentType::null()), + ast::Expression::Interval(_) => Err(nyi(self, "INTERVAL expression")), ast::Expression::ColumnName(name) => name.infer_types(scope), ast::Expression::Cast(cast) => cast.infer_types(scope), ast::Expression::Is(is) => is.infer_types(scope), @@ -665,11 +669,13 @@ impl InferTypes for ast::Expression { } ast::Expression::Parens { expression, .. } => expression.infer_types(scope), ast::Expression::Array(array) => array.infer_types(scope), + ast::Expression::Struct(struct_expr) => struct_expr.infer_types(scope), ast::Expression::Count(count) => count.infer_types(scope), + ast::Expression::CurrentDate(_) => Err(nyi(self, "CURRENT_DATE expression")), ast::Expression::ArrayAgg(array_agg) => array_agg.infer_types(scope), + ast::Expression::SpecialDateFunctionCall(_) => Err(nyi(self, "special date function")), ast::Expression::FunctionCall(fcall) => fcall.infer_types(scope), ast::Expression::Index(index) => index.infer_types(scope), - _ => Err(nyi(self, "expression")), } } } @@ -973,6 +979,83 @@ impl InferTypes for ast::ArrayDefinition { } } +impl InferTypes for ast::StructExpression { + type Scope = ColumnSetScope; + type Output = ArgumentType; + + fn infer_types(&mut self, scope: &Self::Scope) -> Result { + let ast::StructExpression { + ty, + struct_token, + field_decls, + fields, + .. + } = self; + + // Infer our struct type from the field expressions. + let mut field_types = vec![]; + for field in fields.items.node_iter_mut() { + if let ast::SelectListItem::Expression { expression, alias } = field { + let field_ty = expression.infer_types(scope)?; + let field_ty = field_ty.expect_value_type(expression)?; + let ident = alias + .clone() + .map(|a| a.ident) + .or_else(|| expression.infer_column_name()); + field_types.push(StructElementType { + name: ident, + ty: field_ty.clone(), + }); + } else { + // We could forbid this in the grammar if we were less lazy. + return Err(Error::annotated( + "struct field must be expression, not wildcard", + field.span(), + "expression required", + )); + } + } + let actual_ty = StructType { + fields: field_types, + }; + + // If we have `field_decls`, use those to build our official type. + let return_ty = if let Some(field_decls) = field_decls { + let expected_ty = + ValueType::::try_from(&ast::DataType::Struct { + struct_token: Keyword::new("STRUCT", struct_token.span()), + lt: Punct::new("<", field_decls.lt.span()), + fields: field_decls.fields.clone(), + gt: Punct::new(">", field_decls.gt.span()), + })?; + let expected_ty = expected_ty.expect_struct_type(field_decls)?; + actual_ty.expect_subtype_of(expected_ty, field_decls)?; + expected_ty.clone() + } else { + actual_ty + }; + + *ty = Some(return_ty.clone()); + if return_ty + .fields + .iter() + .any(|f| f.ty == ValueType::Simple(SimpleType::Null)) + { + return Err(Error::annotated( + format!( + "NULL column in {}, try STRUCT(..)", + return_ty + ), + struct_token.span(), + "contains a NULL field", + )); + } + let return_ty = ValueType::Simple(SimpleType::Struct(return_ty)); + return_ty.expect_inhabited(field_decls)?; + Ok(ArgumentType::Value(return_ty)) + } +} + impl InferTypes for ast::CountExpression { type Scope = ColumnSetScope; type Output = ArgumentType; diff --git a/src/types.rs b/src/types.rs index 55a1095..2e46dc5 100644 --- a/src/types.rs +++ b/src/types.rs @@ -22,7 +22,7 @@ use crate::{ errors::{format_err, Error, Result}, known_files::{FileId, KnownFiles}, scope::{ColumnSet, ColumnSetColumn, ColumnSetColumnName}, - tokenizer::{Ident, Span, Spanned}, + tokenizer::{Ident, Keyword, PseudoKeyword, Punct, Span, Spanned}, unification::{UnificationTable, Unify}, util::is_c_ident, }; @@ -364,12 +364,11 @@ impl ValueType { // an ARRAY is not a subtype of ARRAY, as (ValueType::Array(SimpleType::Bottom), ValueType::Array(_)) => true, - // TODO: Structs with anonymous fields may be subtype of structs - // with named fields. - - // TODO: Tables with unknown column names, built by `SELECT` and - // combined with `UNION`, may be subtypes of tables with known - // column names. + // Structs may be subtypes of other structs. + ( + ValueType::Simple(SimpleType::Struct(a)), + ValueType::Simple(SimpleType::Struct(b)), + ) => a.is_subtype_of(b), // Otherwise, assume it isn't a subtype. _ => false, @@ -416,6 +415,18 @@ impl ValueType { _ => Ok(()), } } + + /// Expect this value type to be a struct type. + pub fn expect_struct_type(&self, spanned: &dyn Spanned) -> Result<&StructType> { + match self { + ValueType::Simple(SimpleType::Struct(s)) => Ok(s), + _ => Err(Error::annotated( + format!("expected struct type, found {}", self), + spanned.span(), + "type mismatch", + )), + } + } } impl ValueType { @@ -516,6 +527,22 @@ impl Unify for ValueType { } } +impl TryFrom for ast::DataType { + type Error = Error; + + fn try_from(value: ValueType) -> Result { + match value { + ValueType::Simple(t) => t.try_into(), + ValueType::Array(t) => Ok(ast::DataType::Array { + array_token: Keyword::new("ARRAY", Span::Unknown), + lt: Punct::new("<", Span::Unknown), + data_type: Box::new(t.try_into()?), + gt: Punct::new(">", Span::Unknown), + }), + } + } +} + impl fmt::Display for ValueType { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { @@ -668,12 +695,78 @@ impl fmt::Display for SimpleType { } } +impl TryFrom for ast::DataType { + type Error = Error; + + fn try_from(value: SimpleType) -> Result { + let pk = |s| PseudoKeyword::new(s, Span::Unknown); + match value { + SimpleType::Bool => Ok(ast::DataType::Bool(pk("BOOL"))), + SimpleType::Bottom => Err(format_err!( + "cannot convert unknown type to a printable type" + )), + SimpleType::Bytes => Ok(ast::DataType::Bytes(pk("BYTES"))), + SimpleType::Date => Ok(ast::DataType::Date(pk("DATE"))), + SimpleType::Datepart => Err(format_err!( + "cannot convert datepart type to a printable type" + )), + SimpleType::Datetime => Ok(ast::DataType::Datetime(pk("DATETIME"))), + SimpleType::Float64 => Ok(ast::DataType::Float64(pk("FLOAT64"))), + SimpleType::Geography => Ok(ast::DataType::Geography(pk("GEOGRAPHY"))), + SimpleType::Int64 => Ok(ast::DataType::Int64(pk("INT64"))), + SimpleType::Interval => Err(format_err!( + "cannot convert interval type to a printable type" + )), + SimpleType::Null => Err(format_err!( + "cannot convert unknown type to a printable type" + )), + SimpleType::Numeric => Ok(ast::DataType::Numeric(pk("NUMERIC"))), + SimpleType::String => Ok(ast::DataType::String(pk("STRING"))), + SimpleType::Time => Ok(ast::DataType::Time(pk("TIME"))), + SimpleType::Timestamp => Ok(ast::DataType::Timestamp(pk("TIMESTAMP"))), + SimpleType::Struct(s) => s.try_into(), + SimpleType::Parameter(_) => { + unreachable!("SimpleType::Parameter should contain no values") + } + } + } +} + /// A struct type. #[derive(Clone, Debug, PartialEq, Eq)] pub struct StructType { pub fields: Vec>, } +impl StructType { + /// Is this a subtype of `other`? + pub fn is_subtype_of(&self, other: &StructType) -> bool { + // We are a subtype of `other` if we have the same fields, and each of + // our fields is a subtype of the corresponding field in `other`. + if self.fields.len() != other.fields.len() { + return false; + } + for (a, b) in self.fields.iter().zip(&other.fields) { + if !a.is_subtype_of(b) { + return false; + } + } + true + } + + /// Return an error if we are not a subtype of `other`. + pub fn expect_subtype_of(&self, other: &StructType, spanned: &dyn Spanned) -> Result<()> { + if !self.is_subtype_of(other) { + return Err(Error::annotated( + format!("expected {}, found {}", other, self), + spanned.span(), + "type mismatch", + )); + } + Ok(()) + } +} + impl Unify for StructType { type Resolved = StructType; @@ -705,6 +798,23 @@ impl Unify for StructType { } } +impl TryFrom for ast::DataType { + type Error = Error; + + fn try_from(value: StructType) -> Result { + let mut fields = ast::NodeVec::new(","); + for field in value.fields { + fields.push(field.try_into()?); + } + Ok(ast::DataType::Struct { + struct_token: Keyword::new("STRUCT", Span::Unknown), + lt: Punct::new("<", Span::Unknown), + fields, + gt: Punct::new(">", Span::Unknown), + }) + } +} + impl fmt::Display for StructType { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "STRUCT<")?; @@ -726,6 +836,36 @@ pub struct StructElementType { pub ty: ValueType, } +impl StructElementType { + /// Is this a subtype of `other`? + pub fn is_subtype_of(&self, other: &StructElementType) -> bool { + // We are a subtype of `other` if we have the same name, or if we have + // no name and `other` has a name. + // + // TODO: We may need to refine this carefully to match the expected + // behavior. Some of these combinations can't occur in the `STRUCT(..)` + // syntax, because it doesn't allow `STRUCT<..>(..) to use `AS`. + self.ty.is_subtype_of(&other.ty) + && match (&self.name, &other.name) { + (Some(a), Some(b)) => a == b, + (None, Some(_)) => true, + (None, None) => true, + (Some(_), None) => false, + } + } +} + +impl TryFrom for ast::StructField { + type Error = Error; + + fn try_from(value: StructElementType) -> Result { + Ok(ast::StructField { + name: value.name, + data_type: value.ty.try_into()?, + }) + } +} + impl fmt::Display for StructElementType { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { if let Some(name) = &self.name { diff --git a/tests/sql/data_types/structs.sql b/tests/sql/data_types/structs.sql new file mode 100644 index 0000000..cc1d757 --- /dev/null +++ b/tests/sql/data_types/structs.sql @@ -0,0 +1,43 @@ +-- pending: snowflake Need to emulate using OBJECT. May be challenging. +-- pending: sqlite3 Need to build structs from scratch + +CREATE OR REPLACE TABLE __result1 AS +WITH t AS (SELECT 1 AS a) +SELECT + -- Not allowed on Trino. + -- STRUCT() AS empty_struct, + STRUCT(1) AS anon_value, + STRUCT(1 AS a) AS named_value, + STRUCT(a) AS inferred_name, + STRUCT(1 AS a, 2 AS b) AS named_values, + STRUCT(NULL) AS anon_value_with_type, + STRUCT(1, 2) AS named_values_with_type, + STRUCT([1] AS arr) AS struct_with_array, + STRUCT(STRUCT(1 AS a) AS `inner`) AS struct_with_struct, + --STRUCT(1 AS a).a AS struct_field_access, +FROM t; + +CREATE OR REPLACE TABLE __expected1 ( + -- empty_struct STRUCT<>, + anon_value STRUCT, + named_value STRUCT, + inferred_name STRUCT, + named_values STRUCT, + anon_value_with_type STRUCT, + named_values_with_type STRUCT, + struct_with_array STRUCT>, + struct_with_struct STRUCT<`inner` STRUCT>, + --struct_field_access INT64, +); +INSERT INTO __expected1 +SELECT + -- STRUCT(), -- empty_struct + STRUCT(1), -- anon_value + STRUCT(1), -- named_value + STRUCT(1), -- inferred_name + STRUCT(1, 2), -- named_values + STRUCT(NULL), -- anon_value_with_type + STRUCT(1, 2), -- named_values_with_type + STRUCT([1]), -- struct_with_array + STRUCT(STRUCT(1)); -- struct_with_struct + --1; -- struct_field_access diff --git a/tests/sql/pending/structs.sql b/tests/sql/pending/structs.sql deleted file mode 100644 index 1452079..0000000 --- a/tests/sql/pending/structs.sql +++ /dev/null @@ -1,3 +0,0 @@ --- pending: snowflake Need to emulate using OBJECT. May be challenging. --- pending: sqlite3 Need to build structs from scratch --- pending: trino ROW may be almost a drop-in replacement, plus SELECT AS STRUCT