From ac2337bc4763417a980053929ef9b54c00350e38 Mon Sep 17 00:00:00 2001 From: Eric Kidd Date: Mon, 30 Oct 2023 15:29:50 -0400 Subject: [PATCH] Infer more parts of SELECT --- src/infer.rs | 77 ++++++++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 66 insertions(+), 11 deletions(-) diff --git a/src/infer.rs b/src/infer.rs index f8da877..4ce1aec 100644 --- a/src/infer.rs +++ b/src/infer.rs @@ -1,5 +1,7 @@ //! Our type inference subsystem. +use std::collections::HashSet; + use crate::{ ast, errors::{Error, Result}, @@ -297,11 +299,34 @@ impl InferTypes for ast::SelectExpression { .. } = self; + // See if we have a FROM clause. + let mut from_type = None; let mut scope = scope.to_owned(); if let Some(from_clause) = from_clause { - ((), scope) = from_clause.infer_types(&scope)?; + let (new_from_type, new_scope) = from_clause.infer_types(&scope)?; + from_type = Some(new_from_type); + scope = new_scope; } + // Helper function to add columns from a table type to a list of columns. + let add_table_cols = + |cols: &mut Vec<_>, table_type: &TableType, except: &Option| { + let except = except_set(except); + for column in &table_type.columns { + if let Some(column_name) = &column.name { + let column_name = CaseInsensitiveIdent::from(column_name.clone()); + if !except.contains(&column_name) { + cols.push(ColumnType { + name: column.name.clone(), + ty: column.ty.to_owned(), + not_null: false, + }); + } + } + } + }; + + // Iterate over the select list, adding columns to the scope. let mut cols = vec![]; for item in select_list.node_iter_mut() { match item { @@ -320,7 +345,24 @@ impl InferTypes for ast::SelectExpression { not_null: false, }); } - _ => return Err(nyi(item, "select list item")), + ast::SelectListItem::Wildcard { star, except } => { + if let Some(from_type) = &from_type { + add_table_cols(&mut cols, from_type, except); + } else { + return Err(Error::annotated( + "cannot use * in SELECT without a FROM clause", + star.span(), + "no FROM clause", + )); + } + } + ast::SelectListItem::TableNameWildcard { + table_name, except, .. + } => { + let table = ident_from_table_name(table_name)?; + let table_type = scope.get_or_err(&table)?.try_as_table_type(&table)?; + add_table_cols(&mut cols, table_type, except); + } } } let table_type = TableType { columns: cols }; @@ -329,7 +371,7 @@ impl InferTypes for ast::SelectExpression { } impl InferTypes for ast::FromClause { - type Type = (); + type Type = TableType; fn infer_types(&mut self, scope: &ScopeHandle) -> Result<(Self::Type, ScopeHandle)> { let ast::FromClause { @@ -337,28 +379,30 @@ impl InferTypes for ast::FromClause { join_operations, .. } = self; - let ((), scope) = from_item.infer_types(scope)?; + let (table_type, scope) = from_item.infer_types(scope)?; if !join_operations.is_empty() { return Err(nyi(self, "join operations")); } - Ok(((), scope)) + Ok((table_type, scope)) } } impl InferTypes for ast::FromItem { - type Type = (); + /// We return a table type for use by `SELECT *`. + type Type = TableType; fn infer_types(&mut self, scope: &ScopeHandle) -> Result<(Self::Type, ScopeHandle)> { match self { ast::FromItem::TableName { table_name, alias } => { let table = ident_from_table_name(table_name)?; let table_type = scope.get_or_err(&table)?.try_as_table_type(&table)?; - - if alias.is_some() { - return Err(nyi(alias, "from with alias")); - } + let name = match alias { + Some(alias) => CaseInsensitiveIdent::from(alias.ident.clone()), + None => table, + }; let mut scope = Scope::new(scope); + scope.add(name, Type::Table(table_type.clone()))?; for column in &table_type.columns { if let Some(column_name) = &column.name { scope.add( @@ -367,7 +411,7 @@ impl InferTypes for ast::FromItem { )?; } } - Ok(((), scope.into_handle())) + Ok((table_type.clone(), scope.into_handle())) } ast::FromItem::Subquery { .. } => Err(nyi(self, "from subquery")), ast::FromItem::Unnest { .. } => Err(nyi(self, "from unnest")), @@ -782,6 +826,17 @@ fn ident_from_function_name(function_name: &ast::FunctionName) -> Result) -> HashSet { + let mut set = HashSet::new(); + if let Some(except) = except { + for ident in except.columns.node_iter() { + set.insert(ident.clone().into()); + } + } + set +} + /// Infer types a function-like expression (including primitives). fn infer_call<'args, ArgExprs>( func_name: &CaseInsensitiveIdent,