Skip to content

Commit

Permalink
Re-run type inference as needed
Browse files Browse the repository at this point in the history
sql_quote and other transformations can introduce untyped AST nodes.

Re-run inference as necessary, in a fashion inspired by LLVM on-demand
re-analysis.

Co-authored-by: Dave Shirley <[email protected]>
  • Loading branch information
emk and dave-shirley-faraday committed Nov 7, 2023
1 parent b1d9eee commit 9c14b66
Show file tree
Hide file tree
Showing 21 changed files with 129 additions and 52 deletions.
14 changes: 13 additions & 1 deletion src/drivers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
use std::{borrow::Cow, collections::VecDeque, fmt, str::FromStr};

use async_trait::async_trait;
use tracing::debug;
use tracing::{debug, trace};

use crate::{
ast::{self, Emit, Target},
errors::{format_err, Error, Result},
infer::InferTypes,
scope::Scope,
transforms::{Transform, TransformExtra},
};

Expand Down Expand Up @@ -146,6 +148,8 @@ pub trait Driver: Send + Sync + 'static {
/// us to do less database-specific work in [`Emit::emit`], and more in the
/// database drivers themselves. This can't change lexical syntax, but it
/// can change the structure of the AST.
///
/// Before calling this, we assume that you have already run type inference.
fn rewrite_ast<'ast>(&self, ast: &'ast ast::SqlProgram) -> Result<RewrittenAst<'ast>> {
let transforms = self.transforms();
if transforms.is_empty() {
Expand All @@ -154,10 +158,18 @@ pub trait Driver: Send + Sync + 'static {
ast: Cow::Borrowed(ast),
});
} else {
// Start out assuming that we have types.
let mut has_types = true;
let mut rewritten = ast.clone();
let mut extra = TransformExtra::default();
for transform in transforms {
trace!(transform = %transform.name(), input = %rewritten.emit_to_string(Target::BigQuery), "transforming");
if transform.requires_types() && !has_types {
let scope = Scope::root();
rewritten.infer_types(&scope)?;
}
extra.extend(transform.transform(&mut rewritten)?);
has_types = false;
}
Ok(RewrittenAst {
extra,
Expand Down
4 changes: 2 additions & 2 deletions src/drivers/sqlite3/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -188,10 +188,10 @@ impl Driver for SQLite3Driver {
/// executing it.
fn transforms(&self) -> Vec<Box<dyn Transform>> {
vec![
Box::new(transforms::BoolToInt),
Box::new(transforms::CountifToCase),
Box::new(transforms::QualifyToSubquery),
Box::<transforms::ExpandExcept>::default(),
Box::new(transforms::BoolToInt),
Box::new(transforms::CountifToCase),
Box::new(transforms::IfToCase),
Box::new(transforms::IndexFromZero),
Box::new(transforms::OrReplaceToDropIfExists),
Expand Down
6 changes: 3 additions & 3 deletions src/drivers/trino/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -173,12 +173,12 @@ impl Driver for TrinoDriver {
fn transforms(&self) -> Vec<Box<dyn Transform>> {
vec![
Box::new(transforms::ArraySelectToSubquery),
Box::new(transforms::QualifyToSubquery),
Box::<transforms::ExpandExcept>::default(),
Box::new(transforms::InUnnestToInSelect),
Box::new(transforms::CountifToCase),
Box::new(transforms::IndexFromOne),
Box::new(transforms::InUnnestToInSelect),
Box::new(transforms::IsBoolToCase),
Box::new(transforms::QualifyToSubquery),
Box::<transforms::ExpandExcept>::default(),
Box::new(transforms::OrReplaceToDropIfExists),
Box::new(transforms::RenameFunctions::new(
&FUNCTION_NAMES,
Expand Down
18 changes: 17 additions & 1 deletion src/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1149,8 +1149,16 @@ fn contains_aggregate(scope: &ColumnSetScope, node: &ast::NodeVec<ast::SelectLis
}

/// A struct that we use to walk an AST looking for aggregate functions.
///
/// TODO: We probably need to be a lot more careful about sub-queries here.
///
/// TODO: We may want to have `trait ContainsAggregate` at some point.
#[derive(Debug, Visitor)]
#[visitor(ast::FunctionCall(enter))]
#[visitor(
ast::ArrayAggExpression(enter),
ast::CountExpression(enter),
ast::FunctionCall(enter)
)]
struct ContainsAggregate<'scope> {
scope: &'scope ColumnSetScope,
contains_aggregate: bool,
Expand All @@ -1164,6 +1172,14 @@ impl<'scope> ContainsAggregate<'scope> {
}
}

fn enter_array_agg_expression(&mut self, _array_agg: &ast::ArrayAggExpression) {
self.contains_aggregate = true;
}

fn enter_count_expression(&mut self, _count: &ast::CountExpression) {
self.contains_aggregate = true;
}

fn enter_function_call(&mut self, fcall: &ast::FunctionCall) {
if self.contains_aggregate {
return;
Expand Down
8 changes: 8 additions & 0 deletions src/scope.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,14 @@ pub struct Scope {
impl Scope {
/// Create a new scope with no parent.
pub fn root() -> ScopeHandle {
// Only build the root scope once, because it's moderately expensive.
static ROOT: once_cell::sync::Lazy<ScopeHandle> =
once_cell::sync::Lazy::new(Scope::build_root);
ROOT.clone()
}

/// Helper function for `root`.
fn build_root() -> ScopeHandle {
let mut scope = Self {
parent: None,
names: BTreeMap::new(),
Expand Down
8 changes: 8 additions & 0 deletions src/transforms/array_select_to_subquery.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,14 @@ impl ArraySelectToSubquery {
}

impl Transform for ArraySelectToSubquery {
fn name(&self) -> &'static str {
"ArraySelectToSubquery"
}

fn requires_types(&self) -> bool {
true
}

fn transform(mut self: Box<Self>, sql_program: &mut ast::SqlProgram) -> Result<TransformExtra> {
sql_program.drive_mut(self.as_mut());
Ok(TransformExtra::default())
Expand Down
4 changes: 4 additions & 0 deletions src/transforms/bool_to_int.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ impl BoolToInt {
}

impl Transform for BoolToInt {
fn name(&self) -> &'static str {
"BoolToInt"
}

fn transform(mut self: Box<Self>, sql_program: &mut ast::SqlProgram) -> Result<TransformExtra> {
sql_program.drive_mut(self.as_mut());
Ok(TransformExtra::default())
Expand Down
4 changes: 4 additions & 0 deletions src/transforms/clean_up_temp_manually.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@ pub struct CleanUpTempManually {
}

impl Transform for CleanUpTempManually {
fn name(&self) -> &'static str {
"CleanUpTempManually"
}

fn transform(self: Box<Self>, sql_program: &mut ast::SqlProgram) -> Result<TransformExtra> {
let mut native_teardown_sql = vec![];

Expand Down
4 changes: 4 additions & 0 deletions src/transforms/countif_to_case.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ impl CountifToCase {
}

impl Transform for CountifToCase {
fn name(&self) -> &'static str {
"CountifToCase"
}

fn transform(mut self: Box<Self>, sql_program: &mut ast::SqlProgram) -> Result<TransformExtra> {
sql_program.drive_mut(self.as_mut());
Ok(TransformExtra::default())
Expand Down
8 changes: 8 additions & 0 deletions src/transforms/expand_except.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,14 @@ impl ExpandExcept {
}

impl Transform for ExpandExcept {
fn name(&self) -> &'static str {
"ExpandExcept"
}

fn requires_types(&self) -> bool {
true
}

fn transform(mut self: Box<Self>, sql_program: &mut ast::SqlProgram) -> Result<TransformExtra> {
sql_program.drive_mut(self.as_mut());
if let Some(error) = self.error {
Expand Down
4 changes: 4 additions & 0 deletions src/transforms/if_to_case.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ impl IfToCase {
}

impl Transform for IfToCase {
fn name(&self) -> &'static str {
"IfToCase"
}

fn transform(mut self: Box<Self>, sql_program: &mut ast::SqlProgram) -> Result<TransformExtra> {
sql_program.drive_mut(self.as_mut());
Ok(TransformExtra::default())
Expand Down
4 changes: 4 additions & 0 deletions src/transforms/in_unnest_to_in_select.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ impl InUnnestToInSelect {
}

impl Transform for InUnnestToInSelect {
fn name(&self) -> &'static str {
"InUnnestToInSelect"
}

fn transform(mut self: Box<Self>, sql_program: &mut ast::SqlProgram) -> Result<TransformExtra> {
sql_program.drive_mut(self.as_mut());
Ok(TransformExtra::default())
Expand Down
4 changes: 4 additions & 0 deletions src/transforms/index_from_one.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ impl IndexFromOne {
}

impl Transform for IndexFromOne {
fn name(&self) -> &'static str {
"IndexFromOne"
}

fn transform(mut self: Box<Self>, sql_program: &mut ast::SqlProgram) -> Result<TransformExtra> {
sql_program.drive_mut(self.as_mut());
Ok(TransformExtra::default())
Expand Down
4 changes: 4 additions & 0 deletions src/transforms/index_from_zero.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ impl IndexFromZero {
}

impl Transform for IndexFromZero {
fn name(&self) -> &'static str {
"IndexFromZero"
}

fn transform(mut self: Box<Self>, sql_program: &mut ast::SqlProgram) -> Result<TransformExtra> {
sql_program.drive_mut(self.as_mut());
Ok(TransformExtra::default())
Expand Down
4 changes: 4 additions & 0 deletions src/transforms/is_bool_to_case.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ impl IsBoolToCase {
}

impl Transform for IsBoolToCase {
fn name(&self) -> &'static str {
"IsBoolToCase"
}

fn transform(mut self: Box<Self>, sql_program: &mut ast::SqlProgram) -> Result<TransformExtra> {
sql_program.drive_mut(self.as_mut());
Ok(TransformExtra::default())
Expand Down
12 changes: 12 additions & 0 deletions src/transforms/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,18 @@ mod wrap_nested_queries;

/// A transform that modifies an [`SqlProgram`].
pub trait Transform {
/// A human-readable name for this transform.
fn name(&self) -> &'static str;

/// Does this transform require currently valid type information?
///
/// This is sort of analogous to an LLVM analysis pass that can be
/// invalidated and recreated, except we can only perform a single type of
/// analysis.
fn requires_types(&self) -> bool {
false
}

/// Apply this transform to an [`SqlProgram`].
///
/// Returns a list of extra SQL statements that need to be executed before
Expand Down
4 changes: 4 additions & 0 deletions src/transforms/or_replace_to_drop_if_exists.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@ use super::{Transform, TransformExtra};
pub struct OrReplaceToDropIfExists;

impl Transform for OrReplaceToDropIfExists {
fn name(&self) -> &'static str {
"OrReplaceToDropIfExists"
}

fn transform(self: Box<Self>, sql_program: &mut ast::SqlProgram) -> Result<TransformExtra> {
let old_statements = sql_program.statements.take();
for mut node_or_sep in old_statements {
Expand Down
55 changes: 10 additions & 45 deletions src/transforms/qualify_to_subquery.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,9 @@ use derive_visitor::{DriveMut, VisitorMut};
use joinery_macros::sql_quote;

use crate::{
ast::{
self, FromClause, FromItem, Qualify, QueryExpression, QueryStatement, SelectExpression,
SelectList, SelectListItem,
},
ast::{self, Qualify, SelectExpression},
errors::Result,
tokenizer::{Ident, Spanned},
types::{ArgumentType, ColumnType},
unique_names::unique_name,
};

Expand All @@ -22,7 +18,6 @@ pub struct QualifyToSubquery;
impl QualifyToSubquery {
fn enter_select_expression(&mut self, expr: &mut SelectExpression) {
if let SelectExpression {
ty,
select_list,
qualify:
Some(Qualify {
Expand All @@ -42,17 +37,6 @@ impl QualifyToSubquery {
.expect("should be valid SQL");
select_list.items.push(select_list_item);

// Since we want to use `EXCEPT`, we'll need to patch up type
// inference later.
let mut wildcard_ty = ty
.clone()
.expect("should have run type inference before QUALIFY extraction");
wildcard_ty.columns.push(ColumnType {
name: Some(new_name.clone()),
ty: ArgumentType::bool(),
not_null: false,
});

// Remove the QUALIFY from the original SELECT. This has to happen after
// the manipulate above, because we need to drop a bunch of `&mut` before
// we can change `expr`.
Expand All @@ -64,48 +48,29 @@ impl QualifyToSubquery {
// Rewrite the original SELECT expression to be a subquery. The
// `#nested_expr` used here is a placeholder; we'll patch it up
// later with one that retains type information.
let mut new_expr = sql_quote! {
let new_expr = sql_quote! {
SELECT * EXCEPT (#new_name)
FROM (#nested_expr)
WHERE #new_name
}
.try_into_select_expression()
.expect("should be valid SQL");

// Patch up the type inference.
//
// TODO: Maybe we should just re-run type inference after this pass?
let SelectExpression {
select_list: SelectList { items },
from_clause:
Some(FromClause {
from_item: FromItem::Subquery { query, .. },
..
}),
..
} = &mut new_expr
else {
panic!("did not get expected SELECT expression");
};
let QueryStatement {
query_expression: QueryExpression::SelectExpression(new_nested_expr),
} = query.as_mut()
else {
panic!("did not get expected FROM clause");
};
*new_nested_expr = nested_expr;
let Some(SelectListItem::Wildcard { ty, .. }) = items.node_iter_mut().next() else {
panic!("did not get expected SELECT list item");
};
*ty = Some(wildcard_ty);

// Install our new SELECT expression.
*expr = new_expr;
}
}
}

impl Transform for QualifyToSubquery {
fn name(&self) -> &'static str {
"QualifyToSubquery"
}

fn requires_types(&self) -> bool {
true
}

fn transform(mut self: Box<Self>, sql_program: &mut ast::SqlProgram) -> Result<TransformExtra> {
sql_program.drive_mut(self.as_mut());
Ok(TransformExtra::default())
Expand Down
4 changes: 4 additions & 0 deletions src/transforms/rename_functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,10 @@ impl RenameFunctions {
}

impl Transform for RenameFunctions {
fn name(&self) -> &'static str {
"RenameFunctions"
}

fn transform(mut self: Box<Self>, sql_program: &mut ast::SqlProgram) -> Result<TransformExtra> {
// Walk the AST, renaming functions and collecting UDFs.
sql_program.drive_mut(self.as_mut());
Expand Down
4 changes: 4 additions & 0 deletions src/transforms/standardize_current_date.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@ impl StandardizeCurrentDate {
}

impl Transform for StandardizeCurrentDate {
fn name(&self) -> &'static str {
"StandardizeCurrentDate"
}

fn transform(mut self: Box<Self>, sql_program: &mut ast::SqlProgram) -> Result<TransformExtra> {
sql_program.drive_mut(self.as_mut());
Ok(TransformExtra::default())
Expand Down
4 changes: 4 additions & 0 deletions src/transforms/wrap_nested_queries.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ impl WrapNestedQueries {
}

impl Transform for WrapNestedQueries {
fn name(&self) -> &'static str {
"WrapNestedQueries"
}

fn transform(mut self: Box<Self>, sql_program: &mut ast::SqlProgram) -> Result<TransformExtra> {
sql_program.drive_mut(self.as_mut());
Ok(TransformExtra::default())
Expand Down

0 comments on commit 9c14b66

Please sign in to comment.