Skip to content

Commit

Permalink
Combine TableName, FunctionName, etc. into Name
Browse files Browse the repository at this point in the history
We can use a single Name type for all of this. We need this before
overhauling scopes.
  • Loading branch information
emk committed Nov 1, 2023
1 parent 1cc83df commit c7bce75
Show file tree
Hide file tree
Showing 8 changed files with 235 additions and 450 deletions.
406 changes: 150 additions & 256 deletions src/ast.rs

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions src/cmd/sql_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -223,14 +223,14 @@ impl PendingTestInfo {
/// Tables output by a test suite. This normally stores `ast::TableName`s, but
/// we use `Option<TableName>` while extracting the table names from the AST.
#[derive(Clone, Debug, Default)]
struct OutputTablePair<Name: Clone + fmt::Debug = ast::TableName> {
struct OutputTablePair<Name: Clone + fmt::Debug = ast::Name> {
result: Name,
expected: Name,
}

/// Find the names of all tables that are output by this query.
fn find_output_tables(ast: &ast::SqlProgram) -> Result<Vec<OutputTablePair>> {
let mut tables = Vec::<OutputTablePair<Option<ast::TableName>>>::default();
let mut tables = Vec::<OutputTablePair<Option<ast::Name>>>::default();

for s in ast.statements.node_iter() {
let name = match s {
Expand Down
127 changes: 50 additions & 77 deletions src/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
use std::collections::HashSet;

use crate::{
ast,
ast::{self, Name},
errors::{Error, Result},
scope::{CaseInsensitiveIdent, Scope, ScopeHandle},
scope::{Scope, ScopeHandle},
tokenizer::{Ident, Literal, LiteralValue, Spanned},
types::{ArgumentType, ColumnType, SimpleType, TableType, Type, ValueType},
unification::{UnificationTable, Unify},
Expand Down Expand Up @@ -126,7 +126,7 @@ impl InferTypes for ast::CreateTableStatement {
}
.name_anonymous_columns(table_name.span());
table_type.expect_creatable(table_name)?;
scope.add(ident_from_table_name(table_name)?, Type::Table(table_type))?;
scope.add(table_name.clone(), Type::Table(table_type))?;
Ok(((), scope.into_handle()))
}
ast::CreateTableStatement {
Expand All @@ -141,7 +141,7 @@ impl InferTypes for ast::CreateTableStatement {
let ty = ty.name_anonymous_columns(table_name.span());
ty.expect_creatable(table_name)?;
let mut scope = Scope::new(scope);
scope.add(ident_from_table_name(table_name)?, Type::Table(ty))?;
scope.add(table_name.clone(), Type::Table(ty))?;
Ok(((), scope.into_handle()))
}
}
Expand All @@ -152,10 +152,8 @@ impl InferTypes for ast::DropTableStatement {
type Type = ();

fn infer_types(&mut self, scope: &ScopeHandle) -> Result<(Self::Type, ScopeHandle)> {
let ast::DropTableStatement { table_name, .. } = self;
let mut scope = Scope::new(scope);
let table = ident_from_table_name(table_name)?;
scope.hide(&table)?;
scope.hide(&self.table_name)?;
Ok(((), scope.into_handle()))
}
}
Expand All @@ -164,15 +162,11 @@ impl InferTypes for ast::InsertIntoStatement {
type Type = ();

fn infer_types(&mut self, scope: &ScopeHandle) -> Result<(Self::Type, ScopeHandle)> {
let ast::InsertIntoStatement {
table_name,
inserted_data,
..
} = self;
let table = ident_from_table_name(table_name)?;
let table_type = scope.get_or_err(&table)?.try_as_table_type(&table)?;
let table_type = scope
.get_or_err(&self.table_name)?
.try_as_table_type(&self.table_name)?;

match inserted_data {
match &mut self.inserted_data {
ast::InsertedData::Values { rows, .. } => {
for row in rows.node_iter_mut() {
let (ty, _scope) = row.infer_types(scope)?;
Expand Down Expand Up @@ -314,8 +308,8 @@ impl InferTypes for ast::SelectExpression {
let except = except_set(except);
for column in &table_type.columns {
if let Some(column_name) = &column.name {
let column_name = CaseInsensitiveIdent::from(column_name.clone());
if !except.contains(&column_name) {
let name = Name::from(column_name.clone());
if !except.contains(&name) {
cols.push(ColumnType {
name: column.name.clone(),
ty: column.ty.to_owned(),
Expand Down Expand Up @@ -359,8 +353,9 @@ impl InferTypes for ast::SelectExpression {
ast::SelectListItem::TableNameWildcard {
table_name, except, ..
} => {
let table = ident_from_table_name(table_name)?;
let table_type = scope.get_or_err(&table)?.try_as_table_type(&table)?;
let table_type = scope
.get_or_err(table_name)?
.try_as_table_type(table_name)?;
add_table_cols(&mut cols, table_type, except);
}
}
Expand Down Expand Up @@ -394,11 +389,12 @@ impl InferTypes for ast::FromItem {
fn infer_types(&mut self, scope: &ScopeHandle) -> Result<(Self::Type, ScopeHandle)> {
match self {
ast::FromItem::TableName { table_name, alias } => {
let table = ident_from_table_name(table_name)?;
let table_type = scope.get_or_err(&table)?.try_as_table_type(&table)?;
let table_type = scope
.get_or_err(table_name)?
.try_as_table_type(table_name)?;
let name = match alias {
Some(alias) => CaseInsensitiveIdent::from(alias.ident.clone()),
None => table,
Some(alias) => alias.ident.clone().into(),
None => table_name.clone(),
};

let mut scope = Scope::new(scope);
Expand Down Expand Up @@ -427,8 +423,7 @@ impl InferTypes for ast::Expression {
ast::Expression::Literal(Literal { value, .. }) => value.infer_types(scope),
ast::Expression::BoolValue(_) => Ok((ArgumentType::bool(), scope.clone())),
ast::Expression::Null { .. } => Ok((ArgumentType::null(), scope.clone())),
ast::Expression::ColumnName(ident) => ident.infer_types(scope),
ast::Expression::TableAndColumnName(name) => name.infer_types(scope),
ast::Expression::ColumnName(name) => name.infer_types(scope),
ast::Expression::Cast(cast) => cast.infer_types(scope),
ast::Expression::Is(is) => is.infer_types(scope),
ast::Expression::In(in_expr) => in_expr.infer_types(scope),
Expand Down Expand Up @@ -472,20 +467,21 @@ impl InferTypes for Ident {
}
}

impl InferTypes for ast::TableAndColumnName {
impl InferTypes for ast::Name {
type Type = ArgumentType;

fn infer_types(&mut self, scope: &ScopeHandle) -> Result<(Self::Type, ScopeHandle)> {
let ast::TableAndColumnName {
table_name,
column_name,
..
} = self;

let table = ident_from_table_name(table_name)?;
let table_type = scope.get_or_err(&table)?.try_as_table_type(&table)?;
let column_type = table_type.column_by_name_or_err(column_name)?;
Ok((column_type.ty.to_owned(), scope.clone()))
let (table_name, column_name) = self.split_table_and_column();
if let Some(table_name) = table_name {
let table_type = scope
.get_or_err(&table_name)?
.try_as_table_type(&table_name)?;
let column_type = table_type.column_by_name_or_err(&column_name)?;
Ok((column_type.ty.to_owned(), scope.clone()))
} else {
let ty = scope.get_or_err(self)?.try_as_argument_type(self)?;
Ok((ty.to_owned(), scope.clone()))
}
}
}

Expand All @@ -511,7 +507,7 @@ impl InferTypes for ast::IsExpression {
fn infer_types(&mut self, scope: &ScopeHandle) -> Result<(Self::Type, ScopeHandle)> {
// We need to do this manually because our second argument isn't an
// expression.
let func_name = &CaseInsensitiveIdent::new("%IS", self.is_token.span());
let func_name = &Name::new("%IS", self.is_token.span());
let func_ty = scope
.get_or_err(func_name)?
.try_as_function_type(func_name)?;
Expand Down Expand Up @@ -548,7 +544,7 @@ impl InferTypes for ast::InExpression {
fn infer_types(&mut self, scope: &ScopeHandle) -> Result<(Self::Type, ScopeHandle)> {
// We need to do this manually because our second argument isn't an
// expression.
let func_name = &CaseInsensitiveIdent::new("%IN", self.in_token.span());
let func_name = &Name::new("%IN", self.in_token.span());
let func_ty = scope
.get_or_err(func_name)?
.try_as_function_type(func_name)?;
Expand Down Expand Up @@ -600,7 +596,7 @@ impl InferTypes for ast::BetweenExpression {
type Type = ArgumentType;

fn infer_types(&mut self, scope: &ScopeHandle) -> Result<(Self::Type, ScopeHandle)> {
let func_name = &CaseInsensitiveIdent::new("%BETWEEN", self.between_token.span());
let func_name = &Name::new("%BETWEEN", self.between_token.span());
let args = [
self.left.as_mut(),
self.middle.as_mut(),
Expand All @@ -614,7 +610,7 @@ impl InferTypes for ast::KeywordBinopExpression {
type Type = ArgumentType;

fn infer_types(&mut self, scope: &ScopeHandle) -> Result<(Self::Type, ScopeHandle)> {
let func_name = &CaseInsensitiveIdent::new(
let func_name = &Name::new(
&format!("%{}", self.op_keyword.ident.token.as_str()),
self.op_keyword.span(),
);
Expand All @@ -627,7 +623,7 @@ impl InferTypes for ast::NotExpression {
type Type = ArgumentType;

fn infer_types(&mut self, scope: &ScopeHandle) -> Result<(Self::Type, ScopeHandle)> {
let func_name = &CaseInsensitiveIdent::new("%NOT", self.not_token.span());
let func_name = &Name::new("%NOT", self.not_token.span());
let args = [self.expression.as_mut()];
infer_call(func_name, args, scope)
}
Expand All @@ -642,11 +638,7 @@ impl InferTypes for ast::IfExpression {
self.then_expression.as_mut(),
self.else_expression.as_mut(),
];
infer_call(
&CaseInsensitiveIdent::new("%IF", self.if_token.span()),
args,
scope,
)
infer_call(&Name::new("%IF", self.if_token.span()), args, scope)
}
}

Expand Down Expand Up @@ -697,7 +689,7 @@ impl InferTypes for ast::BinopExpression {
type Type = ArgumentType;

fn infer_types(&mut self, scope: &ScopeHandle) -> Result<(Self::Type, ScopeHandle)> {
let prim_name = CaseInsensitiveIdent::new(
let prim_name = Name::new(
&format!("%{}", self.op_token.token.as_str()),
self.op_token.span(),
);
Expand Down Expand Up @@ -740,7 +732,7 @@ impl InferTypes for ast::ArrayDefinition {
ast::ArrayDefinition::Elements(exprs) => {
// We can use infer_call if we're careful.
let span = exprs.items.span();
let func_name = &CaseInsensitiveIdent::new("%ARRAY", span);
let func_name = &Name::new("%ARRAY", span);
let (elem_ty, _) = infer_call(func_name, exprs.node_iter_mut(), scope)?;
let elem_ty = elem_ty.expect_array_type_returning_elem_type(self)?;
let elem_ty = ArgumentType::Value(ValueType::Simple(elem_ty.clone()));
Expand All @@ -754,19 +746,18 @@ impl InferTypes for ast::FunctionCall {
type Type = ArgumentType;

fn infer_types(&mut self, scope: &ScopeHandle) -> Result<(Self::Type, ScopeHandle)> {
let name = ident_from_function_name(&self.name)?;
if self.over_clause.is_some() {
return Err(nyi(&self.over_clause, "over clause"));
}
infer_call(&name, self.args.node_iter_mut(), scope)
infer_call(&self.name, self.args.node_iter_mut(), scope)
}
}

impl InferTypes for ast::IndexExpression {
type Type = ArgumentType;

fn infer_types(&mut self, scope: &ScopeHandle) -> Result<(Self::Type, ScopeHandle)> {
let func_name = &CaseInsensitiveIdent::new("%[]", self.bracket1.span());
let func_name = &Name::new("%[]", self.bracket1.span());
let index_expr = match &mut self.index {
ast::IndexOffset::Simple(expression)
| ast::IndexOffset::Offset { expression, .. }
Expand Down Expand Up @@ -795,10 +786,10 @@ impl<T: InferColumnName> InferColumnName for Option<T> {
impl InferColumnName for ast::Expression {
fn infer_column_name(&mut self) -> Option<Ident> {
match self {
ast::Expression::ColumnName(ident) => Some(ident.clone()),
ast::Expression::TableAndColumnName(ast::TableAndColumnName {
column_name, ..
}) => Some(column_name.clone()),
ast::Expression::ColumnName(name) => {
let (_table, col) = name.split_table_and_column();
Some(col)
}
_ => None,
}
}
Expand All @@ -810,24 +801,8 @@ impl InferColumnName for ast::Alias {
}
}

/// Convert a table name to an identifier.
fn ident_from_table_name(table_name: &ast::TableName) -> Result<CaseInsensitiveIdent> {
match table_name {
ast::TableName::Table { table, .. } => Ok(table.clone().into()),
_ => Err(nyi(table_name, "dotted name")),
}
}

/// Convert a function name to an identifier.
fn ident_from_function_name(function_name: &ast::FunctionName) -> Result<CaseInsensitiveIdent> {
match function_name {
ast::FunctionName::Function { function, .. } => Ok(function.clone().into()),
_ => Err(nyi(function_name, "dotted name")),
}
}

/// Build a set from an optional [`ast::Except`] clause.
fn except_set(except: &Option<ast::Except>) -> HashSet<CaseInsensitiveIdent> {
fn except_set(except: &Option<ast::Except>) -> HashSet<Name> {
let mut set = HashSet::new();
if let Some(except) = except {
for ident in except.columns.node_iter() {
Expand All @@ -839,7 +814,7 @@ fn except_set(except: &Option<ast::Except>) -> HashSet<CaseInsensitiveIdent> {

/// Infer types a function-like expression (including primitives).
fn infer_call<'args, ArgExprs>(
func_name: &CaseInsensitiveIdent,
func_name: &Name,
args: ArgExprs,
scope: &ScopeHandle,
) -> Result<(ArgumentType, ScopeHandle)>
Expand Down Expand Up @@ -873,7 +848,7 @@ mod tests {
use crate::{
ast::parse_sql,
known_files::KnownFiles,
scope::{CaseInsensitiveIdent, Scope, ScopeValue},
scope::{Scope, ScopeValue},
tokenizer::Span,
types::tests::ty,
};
Expand Down Expand Up @@ -901,9 +876,7 @@ mod tests {
}

fn lookup(scope: &ScopeHandle, name: &str) -> Option<ScopeValue> {
scope
.get(&CaseInsensitiveIdent::new(name, Span::Unknown))
.cloned()
scope.get(&Name::new(name, Span::Unknown)).cloned()
}

macro_rules! assert_defines {
Expand Down
Loading

0 comments on commit c7bce75

Please sign in to comment.