Skip to content

Commit

Permalink
trino: Rewrite APPROX_QUANTILES
Browse files Browse the repository at this point in the history
This adds some new infrastucture that allows us to rewrite a function call
into a more complex function call expression. Normally, we are only
likely to need this for aggregrate functions, because renaming + UDFs
will work for most other cases.
  • Loading branch information
emk committed Feb 27, 2024
1 parent e89df84 commit 0dc9f03
Show file tree
Hide file tree
Showing 5 changed files with 100 additions and 48 deletions.
11 changes: 5 additions & 6 deletions src/drivers/snowflake/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use tracing::{debug, instrument};
use crate::{
ast::Target,
errors::{format_err, Context, Error, Result},
transforms::{self, Transform, Udf},
transforms::{self, RenameFunctionsBuilder, Transform, Udf},
util::AnsiIdent,
};

Expand Down Expand Up @@ -249,16 +249,15 @@ impl Driver for SnowflakeDriver {
}

fn transforms(&self) -> Vec<Box<dyn Transform>> {
let rename_functions = RenameFunctionsBuilder::new(&FUNCTION_NAMES)
.udf_table(&UDFS, &format_udf)
.build();
vec![
Box::new(transforms::CountifToCase),
Box::new(transforms::IfToCase),
Box::new(transforms::IndexFromZero),
Box::new(transforms::IsBoolToCase),
Box::new(transforms::RenameFunctions::new(
&FUNCTION_NAMES,
&UDFS,
&format_udf,
)),
Box::new(rename_functions),
]
}

Expand Down
42 changes: 22 additions & 20 deletions src/drivers/trino/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,18 @@ use std::{fmt, str::FromStr, sync::Arc};

use async_trait::async_trait;
use codespan_reporting::{diagnostic::Diagnostic, files::Files};
use joinery_macros::sql_quote;
use once_cell::sync::Lazy;
use prusto::{error::Error as PrustoError, Client, ClientBuilder, Presto, QueryError, Row};
use regex::Regex;
use tracing::debug;

use crate::{
ast::Target,
ast::{FunctionCall, Target},
errors::{format_err, Context, Error, Result, SourceError},
known_files::KnownFiles,
transforms::{self, Transform, Udf},
tokenizer::TokenStream,
transforms::{self, RenameFunctionsBuilder, Transform},
util::AnsiIdent,
};

Expand Down Expand Up @@ -54,20 +56,21 @@ static FUNCTION_NAMES: phf::Map<&'static str, &'static str> = phf::phf_map! {
"TO_HEX" => "memory.joinery_compat.TO_HEX_COMPAT",
};

/// A `phf_map!` of BigQuery function names to UDFs.
///
/// TODO: I'm not even sure there's a way to define SQL UDFs in Trino.
static UDFS: phf::Map<&'static str, &'static Udf> = phf::phf_map! {};

/// Format a UDF.
///
/// TODO: I'm not even sure there's a way to define SQL UDFs in Trino.
fn format_udf(udf: &Udf) -> String {
format!(
"CREATE OR REPLACE TEMP FUNCTION {} AS $$\n{}\n$$\n",
udf.decl, udf.sql
)
/// Rewrite `APPROX_QUANTILES` to `APPROX_PERCENTILE`. We need to implement this
/// as a function call rewriter, because it's an aggregate function, and we
/// can't define those as an SQL UDF.
fn rewrite_approx_quantiles(call: &FunctionCall) -> TokenStream {
let mut args = call.args.node_iter();
let value_arg = args.next().expect("should be enforced by type checker");
let quantiles_arg = args.next().expect("should be enforced by type checker");
sql_quote! {
APPROX_PERCENTILE(
#value_arg,
PERCENTILE_ARRAY_FOR_QUANTILES(#quantiles_arg)
)
}
}

/// A locator for a Trino database. May or may not also work for Presto.
#[derive(Debug)]
pub struct TrinoLocator {
Expand Down Expand Up @@ -179,6 +182,9 @@ impl Driver for TrinoDriver {
}

fn transforms(&self) -> Vec<Box<dyn Transform>> {
let rename_functions = RenameFunctionsBuilder::new(&FUNCTION_NAMES)
.rewrite_function_call("APPROX_QUANTILES", &rewrite_approx_quantiles)
.build();
vec![
Box::new(transforms::QualifyToSubquery),
Box::<transforms::ExpandExcept>::default(),
Expand All @@ -187,11 +193,7 @@ impl Driver for TrinoDriver {
Box::new(transforms::IndexFromOne),
Box::new(transforms::IsBoolToCase),
Box::new(transforms::OrReplaceToDropIfExists),
Box::new(transforms::RenameFunctions::new(
&FUNCTION_NAMES,
&UDFS,
&format_udf,
)),
Box::new(rename_functions),
Box::new(transforms::SpecialDateFunctionsToTrino),
Box::new(transforms::StandardizeCurrentTimeUnit::no_parens()),
Box::new(transforms::CleanUpTempManually {
Expand Down
2 changes: 1 addition & 1 deletion src/transforms/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ pub use self::{
is_bool_to_case::IsBoolToCase,
or_replace_to_drop_if_exists::OrReplaceToDropIfExists,
qualify_to_subquery::QualifyToSubquery,
rename_functions::{RenameFunctions, Udf},
rename_functions::{RenameFunctionsBuilder, Udf},
special_date_functions_to_trino::SpecialDateFunctionsToTrino,
standardize_current_time_unit::StandardizeCurrentTimeUnit,
wrap_nested_queries::WrapNestedQueries,
Expand Down
88 changes: 72 additions & 16 deletions src/transforms/rename_functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,77 @@ use derive_visitor::{DriveMut, VisitorMut};
use crate::{
ast::{self, CurrentTimeUnit, FunctionCall, Name},
errors::Result,
tokenizer::{Ident, Spanned},
tokenizer::{Ident, Spanned, TokenStream},
};

use super::{Transform, TransformExtra};

/// A Snowflake UDF (user-defined function).
/// A custom UDF (user-defined function).
pub struct Udf {
pub decl: &'static str,
pub sql: &'static str,
}

/// Given a `FunctionCall`, return a rewritten `TokenStream` that implements the
/// function in the target dialect. This is basically a proc macro for SQL
/// function calls.
pub type FunctionCallRewriter = &'static dyn Fn(&FunctionCall) -> TokenStream;

/// A builder type for [`RenameFunctions`].
pub struct RenameFunctionsBuilder {
function_table: &'static phf::Map<&'static str, &'static str>,
udf_table: &'static phf::Map<&'static str, &'static Udf>,
format_udf: &'static dyn Fn(&Udf) -> String,
function_call_rewriters: HashMap<String, FunctionCallRewriter>,
}

impl RenameFunctionsBuilder {
/// Create a new `RenameFunctionsBuilder`.
pub fn new(function_table: &'static phf::Map<&'static str, &'static str>) -> Self {
static EMPTY_UDFS: phf::Map<&'static str, &'static Udf> = phf::phf_map! {};
Self {
function_table,
udf_table: &EMPTY_UDFS,
format_udf: &|_| "".to_string(),
function_call_rewriters: HashMap::new(),
}
}

/// Set the UDF table and formatter, for databases that support
/// `WITH FUNCTION`-like syntax.
pub fn udf_table(
mut self,
udf_table: &'static phf::Map<&'static str, &'static Udf>,
format_udf: &'static dyn Fn(&Udf) -> String,
) -> Self {
self.udf_table = udf_table;
self.format_udf = format_udf;
self
}

/// Add a function call rewriter.
pub fn rewrite_function_call(
mut self,
name: &'static str,
rewriter: FunctionCallRewriter,
) -> Self {
self.function_call_rewriters
.insert(name.to_ascii_uppercase(), rewriter);
self
}

/// Build a `RenameFunctions` transform.
pub fn build(self) -> RenameFunctions {
RenameFunctions {
function_table: self.function_table,
udf_table: self.udf_table,
format_udf: self.format_udf,
udfs: HashMap::new(),
function_call_rewriters: self.function_call_rewriters,
}
}
}

#[derive(VisitorMut)]
#[visitor(CurrentTimeUnit(enter), FunctionCall(enter))]
pub struct RenameFunctions {
Expand All @@ -32,23 +92,12 @@ pub struct RenameFunctions {

// UDFs that we need to create, if we haven't already.
udfs: HashMap<String, &'static Udf>,

// Function call rewriters.
function_call_rewriters: HashMap<String, FunctionCallRewriter>,
}

impl RenameFunctions {
/// Create a new `RenameFunctions` visitor.
pub fn new(
function_table: &'static phf::Map<&'static str, &'static str>,
udf_table: &'static phf::Map<&'static str, &'static Udf>,
format_udf: &'static dyn Fn(&Udf) -> String,
) -> Self {
Self {
function_table,
udf_table,
format_udf,
udfs: HashMap::new(),
}
}

/// Allow renaming CURRENT_DATETIME, etc.
fn enter_current_time_unit(&mut self, current_time_unit: &mut ast::CurrentTimeUnit) {
let ident = &current_time_unit.current_time_unit_token.ident;
Expand All @@ -70,6 +119,13 @@ impl RenameFunctions {
// We'll need a UDF, so add it to our list it if isn't already
// there.
self.udfs.insert(name, udf);
} else if let Some(rewriter) = self.function_call_rewriters.get(&name) {
// Rewrite the function call.
let token_stream = rewriter(function_call);
let new_function_call = token_stream
.try_into_function_call()
.expect("could not parse rewritten function call");
*function_call = new_function_call;
}
}
}
Expand Down
5 changes: 0 additions & 5 deletions tests/sql/functions/aggregate/approx_quantiles.sql
Original file line number Diff line number Diff line change
@@ -1,10 +1,5 @@
-- pending: snowflake Use APPROX_PERCENTILE instead of APPROX_QUANTILES (complicated)
-- pending: sqlite3 No APPROX_QUANTILES function
-- pending: trino Use APPROX_PERCENTILE instead of APPROX_QUANTILES (complicated)

-- For Trino, we can specify a list of percentiles to get values at. So we can
-- get use things like `APPROX_PERCENTILE(x, ARRAY[0.0, 0.25, 0.5, 0.75, 1.0])`
-- to get the quartiles.

CREATE TEMP TABLE quantile_data (x INT64);
INSERT INTO quantile_data VALUES (1), (2), (3), (4), (5);
Expand Down

0 comments on commit 0dc9f03

Please sign in to comment.