Skip to content

Commit

Permalink
Infer struct types and output for Trino
Browse files Browse the repository at this point in the history
Co-authored-by: Dave Shirley <[email protected]>
  • Loading branch information
emk and dave-shirley-faraday committed Nov 9, 2023
1 parent 72cded9 commit 2520445
Show file tree
Hide file tree
Showing 5 changed files with 344 additions and 22 deletions.
77 changes: 68 additions & 9 deletions src/ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ use crate::{
tokenize_sql, EmptyFile, Ident, Keyword, Literal, LiteralValue, PseudoKeyword, Punct,
RawToken, Span, Spanned, ToTokens, Token, TokenStream, TokenWriter,
},
types::TableType,
types::{StructType, TableType},
util::{is_c_ident, AnsiIdent, AnsiString},
};

Expand Down Expand Up @@ -1173,6 +1173,14 @@ impl Emit for ArrayExpression {
}
}

/// The type of the elements in an `ARRAY` expression.
#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault, Spanned, ToTokens)]
pub struct ArrayElementType {
pub lt: Punct,
pub elem_type: DataType,
pub gt: Punct,
}

/// An `ARRAY` definition. Either a `SELECT` expression or a list of
/// expressions.
#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault, Spanned, ToTokens)]
Expand Down Expand Up @@ -1205,19 +1213,61 @@ pub struct ArraySelectExpression {
}

/// A struct expression.
#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault, Spanned, ToTokens)]
#[derive(Clone, Debug, Drive, DriveMut, EmitDefault, Spanned, ToTokens)]
pub struct StructExpression {
/// Type information added later by inference.
#[emit(skip)]
#[to_tokens(skip)]
#[drive(skip)]
pub ty: Option<StructType>,

pub struct_token: Keyword,
pub field_decls: Option<StructFieldDeclarations>,
pub paren1: Punct,
pub fields: SelectList,
pub paren2: Punct,
}

/// The type of the elements in an `ARRAY` expression.
impl Emit for StructExpression {
fn emit(&self, t: Target, f: &mut TokenWriter<'_>) -> io::Result<()> {
match t {
Target::Trino => {
f.write_token_start("CAST(")?;
self.struct_token.ident.token.with_str("ROW").emit(t, f)?;
self.paren1.emit(t, f)?;
// Loop over fields, emitting them without any aliases.
for (i, item) in self.fields.items.node_iter().enumerate() {
if i > 0 {
f.write_token_start(", ")?;
}
match item {
SelectListItem::Expression { expression, .. } => {
expression.emit(t, f)?;
}
_ => panic!("Unexpected select list item (should have been caught by type checker): {:?}", item),
}
}
self.paren2.emit(t, f)?;
f.write_token_start("AS")?;
let ty = self
.ty
.as_ref()
.expect("type should have been added by type checker");
let data_type = DataType::try_from(ty.clone())
.expect("should be able to print data type of ROW");
data_type.emit(t, f)?;
f.write_token_start(")")
}
_ => self.emit_default(t, f),
}
}
}

/// Fields declared by a `STRUCT<..>(..)` expression.
#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault, Spanned, ToTokens)]
pub struct ArrayElementType {
pub struct StructFieldDeclarations {
pub lt: Punct,
pub elem_type: DataType,
pub fields: NodeVec<StructField>,
pub gt: Punct,
}

Expand Down Expand Up @@ -2353,6 +2403,11 @@ peg::parser! {
}
}

rule array_element_type() -> ArrayElementType
= lt:p("<") elem_type:data_type() gt:p(">") {
ArrayElementType { lt, elem_type, gt }
}

rule array_definition() -> ArrayDefinition
= select:array_select_expression() { ArrayDefinition::Query { select: Box::new(select) } }
/ expressions:sep(<expression()>, ",") { ArrayDefinition::Elements(expressions) }
Expand All @@ -2378,18 +2433,22 @@ peg::parser! {
}

rule struct_expression() -> StructExpression
= struct_token:k("STRUCT") paren1:p("(") fields:select_list() paren2:p(")") {
= struct_token:k("STRUCT") field_decls:struct_field_declarations()?
paren1:p("(") fields:select_list() paren2:p(")")
{
StructExpression {
ty: None,
struct_token,
field_decls,
paren1,
fields,
paren2,
}
}

rule array_element_type() -> ArrayElementType
= lt:p("<") elem_type:data_type() gt:p(">") {
ArrayElementType { lt, elem_type, gt }
rule struct_field_declarations() -> StructFieldDeclarations
= lt:p("<") fields:sep(<struct_field()>, ",") gt:p(">") {
StructFieldDeclarations { lt, fields, gt }
}

rule count_expression() -> CountExpression
Expand Down
89 changes: 86 additions & 3 deletions src/infer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,11 @@ use crate::{
ast::{self, ConditionJoinOperator, Emit, Expression, Name},
errors::{Error, Result},
scope::{ColumnSet, ColumnSetScope, Scope, ScopeGet, ScopeHandle},
tokenizer::{Ident, Literal, LiteralValue, Spanned},
types::{ArgumentType, ColumnType, SimpleType, TableType, Type, Unnested, ValueType},
tokenizer::{Ident, Keyword, Literal, LiteralValue, Punct, Spanned},
types::{
ArgumentType, ColumnType, ResolvedTypeVarsOnly, SimpleType, StructElementType, StructType,
TableType, Type, Unnested, ValueType,
},
unification::{UnificationTable, Unify},
};

Expand Down Expand Up @@ -649,6 +652,7 @@ impl InferTypes for ast::Expression {
ast::Expression::Literal(Literal { value, .. }) => value.infer_types(&()),
ast::Expression::BoolValue(_) => Ok(ArgumentType::bool()),
ast::Expression::Null { .. } => Ok(ArgumentType::null()),
ast::Expression::Interval(_) => Err(nyi(self, "INTERVAL expression")),
ast::Expression::ColumnName(name) => name.infer_types(scope),
ast::Expression::Cast(cast) => cast.infer_types(scope),
ast::Expression::Is(is) => is.infer_types(scope),
Expand All @@ -665,11 +669,13 @@ impl InferTypes for ast::Expression {
}
ast::Expression::Parens { expression, .. } => expression.infer_types(scope),
ast::Expression::Array(array) => array.infer_types(scope),
ast::Expression::Struct(struct_expr) => struct_expr.infer_types(scope),
ast::Expression::Count(count) => count.infer_types(scope),
ast::Expression::CurrentDate(_) => Err(nyi(self, "CURRENT_DATE expression")),
ast::Expression::ArrayAgg(array_agg) => array_agg.infer_types(scope),
ast::Expression::SpecialDateFunctionCall(_) => Err(nyi(self, "special date function")),
ast::Expression::FunctionCall(fcall) => fcall.infer_types(scope),
ast::Expression::Index(index) => index.infer_types(scope),
_ => Err(nyi(self, "expression")),
}
}
}
Expand Down Expand Up @@ -973,6 +979,83 @@ impl InferTypes for ast::ArrayDefinition {
}
}

impl InferTypes for ast::StructExpression {
type Scope = ColumnSetScope;
type Output = ArgumentType;

fn infer_types(&mut self, scope: &Self::Scope) -> Result<Self::Output> {
let ast::StructExpression {
ty,
struct_token,
field_decls,
fields,
..
} = self;

// Infer our struct type from the field expressions.
let mut field_types = vec![];
for field in fields.items.node_iter_mut() {
if let ast::SelectListItem::Expression { expression, alias } = field {
let field_ty = expression.infer_types(scope)?;
let field_ty = field_ty.expect_value_type(expression)?;
let ident = alias
.clone()
.map(|a| a.ident)
.or_else(|| expression.infer_column_name());
field_types.push(StructElementType {
name: ident,
ty: field_ty.clone(),
});
} else {
// We could forbid this in the grammar if we were less lazy.
return Err(Error::annotated(
"struct field must be expression, not wildcard",
field.span(),
"expression required",
));
}
}
let actual_ty = StructType {
fields: field_types,
};

// If we have `field_decls`, use those to build our official type.
let return_ty = if let Some(field_decls) = field_decls {
let expected_ty =
ValueType::<ResolvedTypeVarsOnly>::try_from(&ast::DataType::Struct {
struct_token: Keyword::new("STRUCT", struct_token.span()),
lt: Punct::new("<", field_decls.lt.span()),
fields: field_decls.fields.clone(),
gt: Punct::new(">", field_decls.gt.span()),
})?;
let expected_ty = expected_ty.expect_struct_type(field_decls)?;
actual_ty.expect_subtype_of(expected_ty, field_decls)?;
expected_ty.clone()
} else {
actual_ty
};

*ty = Some(return_ty.clone());
if return_ty
.fields
.iter()
.any(|f| f.ty == ValueType::Simple(SimpleType::Null))
{
return Err(Error::annotated(
format!(
"NULL column in {}, try STRUCT<col1 type1, ..>(..)",
return_ty
),
struct_token.span(),
"contains a NULL field",
));
}
let return_ty = ValueType::Simple(SimpleType::Struct(return_ty));
return_ty.expect_inhabited(field_decls)?;
Ok(ArgumentType::Value(return_ty))
}
}

impl InferTypes for ast::CountExpression {
type Scope = ColumnSetScope;
type Output = ArgumentType;
Expand Down
Loading

0 comments on commit 2520445

Please sign in to comment.