Skip to content

Commit

Permalink
Upgrade sqlparser-rs to 0.51.0, support new interval logic from `sqlp…
Browse files Browse the repository at this point in the history
…arse-rs` (#12222)

* support new interval logic from sqlparse-rs

* uprev sqlparser-rs branch

* use sqlparser 51

* better extract logic and interval testing

* revert unnecessary changes

* revert unnecessary changes, more

* cleanup

* fix last failing test :fingerscrossed:
  • Loading branch information
samuelcolvin authored Sep 15, 2024
1 parent f48e0b2 commit 3f0fb4a
Show file tree
Hide file tree
Showing 10 changed files with 215 additions and 222 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ rand = "0.8"
regex = "1.8"
rstest = "0.22.0"
serde_json = "1"
sqlparser = { version = "0.50.0", features = ["visitor"] }
sqlparser = { version = "0.51.0", features = ["visitor"] }
tempfile = "3"
thiserror = "1.0.44"
tokio = { version = "1.36", features = ["macros", "rt", "sync"] }
Expand Down
4 changes: 2 additions & 2 deletions datafusion-cli/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

44 changes: 28 additions & 16 deletions datafusion/functions/src/datetime/date_part.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@
// under the License.

use std::any::Any;
use std::str::FromStr;
use std::sync::Arc;

use arrow::array::{Array, ArrayRef, Float64Array};
use arrow::compute::kernels::cast_utils::IntervalUnit;
use arrow::compute::{binary, cast, date_part, DatePart};
use arrow::datatypes::DataType::{
Date32, Date64, Float64, Time32, Time64, Timestamp, Utf8, Utf8View,
Expand Down Expand Up @@ -161,22 +163,32 @@ impl ScalarUDFImpl for DatePartFunc {
return exec_err!("Date part '{part}' not supported");
}

let arr = match part_trim.to_lowercase().as_str() {
"year" => date_part_f64(array.as_ref(), DatePart::Year)?,
"quarter" => date_part_f64(array.as_ref(), DatePart::Quarter)?,
"month" => date_part_f64(array.as_ref(), DatePart::Month)?,
"week" => date_part_f64(array.as_ref(), DatePart::Week)?,
"day" => date_part_f64(array.as_ref(), DatePart::Day)?,
"doy" => date_part_f64(array.as_ref(), DatePart::DayOfYear)?,
"dow" => date_part_f64(array.as_ref(), DatePart::DayOfWeekSunday0)?,
"hour" => date_part_f64(array.as_ref(), DatePart::Hour)?,
"minute" => date_part_f64(array.as_ref(), DatePart::Minute)?,
"second" => seconds(array.as_ref(), Second)?,
"millisecond" => seconds(array.as_ref(), Millisecond)?,
"microsecond" => seconds(array.as_ref(), Microsecond)?,
"nanosecond" => seconds(array.as_ref(), Nanosecond)?,
"epoch" => epoch(array.as_ref())?,
_ => return exec_err!("Date part '{part}' not supported"),
// using IntervalUnit here means we hand off all the work of supporting plurals (like "seconds")
// and synonyms ( like "ms,msec,msecond,millisecond") to Arrow
let arr = if let Ok(interval_unit) = IntervalUnit::from_str(part_trim) {
match interval_unit {
IntervalUnit::Year => date_part_f64(array.as_ref(), DatePart::Year)?,
IntervalUnit::Month => date_part_f64(array.as_ref(), DatePart::Month)?,
IntervalUnit::Week => date_part_f64(array.as_ref(), DatePart::Week)?,
IntervalUnit::Day => date_part_f64(array.as_ref(), DatePart::Day)?,
IntervalUnit::Hour => date_part_f64(array.as_ref(), DatePart::Hour)?,
IntervalUnit::Minute => date_part_f64(array.as_ref(), DatePart::Minute)?,
IntervalUnit::Second => seconds(array.as_ref(), Second)?,
IntervalUnit::Millisecond => seconds(array.as_ref(), Millisecond)?,
IntervalUnit::Microsecond => seconds(array.as_ref(), Microsecond)?,
IntervalUnit::Nanosecond => seconds(array.as_ref(), Nanosecond)?,
// century and decade are not supported by `DatePart`, although they are supported in postgres
_ => return exec_err!("Date part '{part}' not supported"),
}
} else {
// special cases that can be extracted (in postgres) but are not interval units
match part_trim.to_lowercase().as_str() {
"qtr" | "quarter" => date_part_f64(array.as_ref(), DatePart::Quarter)?,
"doy" => date_part_f64(array.as_ref(), DatePart::DayOfYear)?,
"dow" => date_part_f64(array.as_ref(), DatePart::DayOfWeekSunday0)?,
"epoch" => epoch(array.as_ref())?,
_ => return exec_err!("Date part '{part}' not supported"),
}
};

Ok(if is_scalar {
Expand Down
4 changes: 1 addition & 3 deletions datafusion/sql/src/expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -201,9 +201,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
}

SQLExpr::Array(arr) => self.sql_array_literal(arr.elem, schema),
SQLExpr::Interval(interval) => {
self.sql_interval_to_expr(false, interval, schema, planner_context)
}
SQLExpr::Interval(interval) => self.sql_interval_to_expr(false, interval),
SQLExpr::Identifier(id) => {
self.sql_identifier_to_expr(id, schema, planner_context)
}
Expand Down
2 changes: 1 addition & 1 deletion datafusion/sql/src/expr/unary_op.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
self.parse_sql_number(&n, true)
}
SQLExpr::Interval(interval) => {
self.sql_interval_to_expr(true, interval, schema, planner_context)
self.sql_interval_to_expr(true, interval)
}
// not a literal, apply negative operator on expression
_ => Ok(Expr::Negative(Box::new(self.sql_expr_to_logical_expr(
Expand Down
195 changes: 72 additions & 123 deletions datafusion/sql/src/expr/value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ use datafusion_expr::expr::{BinaryExpr, Placeholder};
use datafusion_expr::planner::PlannerResult;
use datafusion_expr::{lit, Expr, Operator};
use log::debug;
use sqlparser::ast::{BinaryOperator, Expr as SQLExpr, Interval, Value};
use sqlparser::ast::{BinaryOperator, Expr as SQLExpr, Interval, UnaryOperator, Value};
use sqlparser::parser::ParserError::ParserError;
use std::borrow::Cow;

Expand Down Expand Up @@ -168,12 +168,11 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {

/// Convert a SQL interval expression to a DataFusion logical plan
/// expression
#[allow(clippy::only_used_in_recursion)]
pub(super) fn sql_interval_to_expr(
&self,
negative: bool,
interval: Interval,
schema: &DFSchema,
planner_context: &mut PlannerContext,
) -> Result<Expr> {
if interval.leading_precision.is_some() {
return not_impl_err!(
Expand All @@ -196,127 +195,42 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
);
}

// Only handle string exprs for now
let value = match *interval.value {
SQLExpr::Value(
Value::SingleQuotedString(s) | Value::DoubleQuotedString(s),
) => {
if negative {
format!("-{s}")
} else {
s
}
}
// Support expressions like `interval '1 month' + date/timestamp`.
// Such expressions are parsed like this by sqlparser-rs
//
// Interval
// BinaryOp
// Value(StringLiteral)
// Cast
// Value(StringLiteral)
//
// This code rewrites them to the following:
//
// BinaryOp
// Interval
// Value(StringLiteral)
// Cast
// Value(StringLiteral)
SQLExpr::BinaryOp { left, op, right } => {
let df_op = match op {
BinaryOperator::Plus => Operator::Plus,
BinaryOperator::Minus => Operator::Minus,
BinaryOperator::Eq => Operator::Eq,
BinaryOperator::NotEq => Operator::NotEq,
BinaryOperator::Gt => Operator::Gt,
BinaryOperator::GtEq => Operator::GtEq,
BinaryOperator::Lt => Operator::Lt,
BinaryOperator::LtEq => Operator::LtEq,
_ => {
return not_impl_err!("Unsupported interval operator: {op:?}");
}
};
match (
interval.leading_field.as_ref(),
left.as_ref(),
right.as_ref(),
) {
(_, _, SQLExpr::Value(_)) => {
let left_expr = self.sql_interval_to_expr(
negative,
Interval {
value: left,
leading_field: interval.leading_field.clone(),
leading_precision: None,
last_field: None,
fractional_seconds_precision: None,
},
schema,
planner_context,
)?;
let right_expr = self.sql_interval_to_expr(
false,
Interval {
value: right,
leading_field: interval.leading_field,
leading_precision: None,
last_field: None,
fractional_seconds_precision: None,
},
schema,
planner_context,
)?;
return Ok(Expr::BinaryExpr(BinaryExpr::new(
Box::new(left_expr),
df_op,
Box::new(right_expr),
)));
}
// In this case, the left node is part of the interval
// expr and the right node is an independent expr.
//
// Leading field is not supported when the right operand
// is not a value.
(None, _, _) => {
let left_expr = self.sql_interval_to_expr(
negative,
Interval {
value: left,
leading_field: None,
leading_precision: None,
last_field: None,
fractional_seconds_precision: None,
},
schema,
planner_context,
)?;
let right_expr = self.sql_expr_to_logical_expr(
*right,
schema,
planner_context,
)?;
return Ok(Expr::BinaryExpr(BinaryExpr::new(
Box::new(left_expr),
df_op,
Box::new(right_expr),
)));
}
_ => {
let value = SQLExpr::BinaryOp { left, op, right };
return not_impl_err!(
"Unsupported interval argument. Expected string literal, got: {value:?}"
);
}
if let SQLExpr::BinaryOp { left, op, right } = *interval.value {
let df_op = match op {
BinaryOperator::Plus => Operator::Plus,
BinaryOperator::Minus => Operator::Minus,
_ => {
return not_impl_err!("Unsupported interval operator: {op:?}");
}
}
_ => {
return not_impl_err!(
"Unsupported interval argument. Expected string literal, got: {:?}",
interval.value
);
}
};
};
let left_expr = self.sql_interval_to_expr(
negative,
Interval {
value: left,
leading_field: interval.leading_field.clone(),
leading_precision: None,
last_field: None,
fractional_seconds_precision: None,
},
)?;
let right_expr = self.sql_interval_to_expr(
false,
Interval {
value: right,
leading_field: interval.leading_field,
leading_precision: None,
last_field: None,
fractional_seconds_precision: None,
},
)?;
return Ok(Expr::BinaryExpr(BinaryExpr::new(
Box::new(left_expr),
df_op,
Box::new(right_expr),
)));
}

let value = interval_literal(*interval.value, negative)?;

let value = if has_units(&value) {
// If the interval already contains a unit
Expand All @@ -343,6 +257,41 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
}
}

fn interval_literal(interval_value: SQLExpr, negative: bool) -> Result<String> {
let s = match interval_value {
SQLExpr::Value(Value::SingleQuotedString(s) | Value::DoubleQuotedString(s)) => s,
SQLExpr::Value(Value::Number(ref v, long)) => {
if long {
return not_impl_err!(
"Unsupported interval argument. Long number not supported: {interval_value:?}"
);
} else {
v.to_string()
}
}
SQLExpr::UnaryOp { op, expr } => {
let negative = match op {
UnaryOperator::Minus => !negative,
UnaryOperator::Plus => negative,
_ => {
return not_impl_err!(
"Unsupported SQL unary operator in interval {op:?}"
);
}
};
interval_literal(*expr, negative)?
}
_ => {
return not_impl_err!("Unsupported interval argument. Expected string literal or number, got: {interval_value:?}");
}
};
if negative {
Ok(format!("-{s}"))
} else {
Ok(s)
}
}

// TODO make interval parsing better in arrow-rs / expose `IntervalType`
fn has_units(val: &str) -> bool {
let val = val.to_lowercase();
Expand Down
18 changes: 12 additions & 6 deletions datafusion/sql/tests/cases/plan_to_sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -495,11 +495,17 @@ fn test_table_references_in_plan_to_sql() {
assert_eq!(format!("{}", sql), expected_sql)
}

test("catalog.schema.table", "SELECT catalog.\"schema\".\"table\".id, catalog.\"schema\".\"table\".\"value\" FROM catalog.\"schema\".\"table\"");
test("schema.table", "SELECT \"schema\".\"table\".id, \"schema\".\"table\".\"value\" FROM \"schema\".\"table\"");
test(
"catalog.schema.table",
r#"SELECT "catalog"."schema"."table".id, "catalog"."schema"."table"."value" FROM "catalog"."schema"."table""#,
);
test(
"schema.table",
r#"SELECT "schema"."table".id, "schema"."table"."value" FROM "schema"."table""#,
);
test(
"table",
"SELECT \"table\".id, \"table\".\"value\" FROM \"table\"",
r#"SELECT "table".id, "table"."value" FROM "table""#,
);
}

Expand All @@ -521,10 +527,10 @@ fn test_table_scan_with_no_projection_in_plan_to_sql() {

test(
"catalog.schema.table",
"SELECT * FROM catalog.\"schema\".\"table\"",
r#"SELECT * FROM "catalog"."schema"."table""#,
);
test("schema.table", "SELECT * FROM \"schema\".\"table\"");
test("table", "SELECT * FROM \"table\"");
test("schema.table", r#"SELECT * FROM "schema"."table""#);
test("table", r#"SELECT * FROM "table""#);
}

#[test]
Expand Down
15 changes: 15 additions & 0 deletions datafusion/sqllogictest/test_files/expr.slt
Original file line number Diff line number Diff line change
Expand Up @@ -1355,6 +1355,16 @@ SELECT date_part('second', arrow_cast('23:32:50.123456789'::time, 'Time64(Nanose
----
50.123456789

query R
select extract(second from '2024-08-09T12:13:14')
----
14

query R
select extract(seconds from '2024-08-09T12:13:14')
----
14

query R
SELECT extract(second from arrow_cast('23:32:50.123456789'::time, 'Time64(Nanosecond)'))
----
Expand All @@ -1381,6 +1391,11 @@ SELECT extract(microsecond from arrow_cast('23:32:50.123456789'::time, 'Time64(N
----
50123456.789000005

query R
SELECT extract(us from arrow_cast('23:32:50.123456789'::time, 'Time64(Nanosecond)'))
----
50123456.789000005

query R
SELECT date_part('nanosecond', arrow_cast('23:32:50.123456789'::time, 'Time64(Nanosecond)'))
----
Expand Down
Loading

0 comments on commit 3f0fb4a

Please sign in to comment.