From 3f0fb4a9edb5bd20756f82c2cb6b2c8d137de80c Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Sun, 15 Sep 2024 13:00:44 +0100 Subject: [PATCH] Upgrade sqlparser-rs to 0.51.0, support new interval logic from `sqlparse-rs` (#12222) * support new interval logic from sqlparse-rs * uprev sqlparser-rs branch * use sqlparser 51 * better extract logic and interval testing * revert unnecessary changes * revert unnecessary changes, more * cleanup * fix last failing test :fingerscrossed: --- Cargo.toml | 2 +- datafusion-cli/Cargo.lock | 4 +- .../functions/src/datetime/date_part.rs | 44 ++-- datafusion/sql/src/expr/mod.rs | 4 +- datafusion/sql/src/expr/unary_op.rs | 2 +- datafusion/sql/src/expr/value.rs | 195 +++++++----------- datafusion/sql/tests/cases/plan_to_sql.rs | 18 +- datafusion/sqllogictest/test_files/expr.slt | 15 ++ .../sqllogictest/test_files/interval.slt | 82 ++------ .../test_files/interval_mysql.slt | 71 +++++++ 10 files changed, 215 insertions(+), 222 deletions(-) create mode 100644 datafusion/sqllogictest/test_files/interval_mysql.slt diff --git a/Cargo.toml b/Cargo.toml index d26734f08a3c..1332a05095de 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -137,7 +137,7 @@ rand = "0.8" regex = "1.8" rstest = "0.22.0" serde_json = "1" -sqlparser = { version = "0.50.0", features = ["visitor"] } +sqlparser = { version = "0.51.0", features = ["visitor"] } tempfile = "3" thiserror = "1.0.44" tokio = { version = "1.36", features = ["macros", "rt", "sync"] } diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index cdf0c661c258..65ea0a756b0d 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -3543,9 +3543,9 @@ checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67" [[package]] name = "sqlparser" -version = "0.50.0" +version = "0.51.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b2e5b515a2bd5168426033e9efbfd05500114833916f1d5c268f938b4ee130ac" +checksum = "5fe11944a61da0da3f592e19a45ebe5ab92dc14a779907ff1f08fbb797bfefc7" dependencies = [ "log", "sqlparser_derive", diff --git a/datafusion/functions/src/datetime/date_part.rs b/datafusion/functions/src/datetime/date_part.rs index e24b11aeb71f..8ee82d872651 100644 --- a/datafusion/functions/src/datetime/date_part.rs +++ b/datafusion/functions/src/datetime/date_part.rs @@ -16,9 +16,11 @@ // under the License. use std::any::Any; +use std::str::FromStr; use std::sync::Arc; use arrow::array::{Array, ArrayRef, Float64Array}; +use arrow::compute::kernels::cast_utils::IntervalUnit; use arrow::compute::{binary, cast, date_part, DatePart}; use arrow::datatypes::DataType::{ Date32, Date64, Float64, Time32, Time64, Timestamp, Utf8, Utf8View, @@ -161,22 +163,32 @@ impl ScalarUDFImpl for DatePartFunc { return exec_err!("Date part '{part}' not supported"); } - let arr = match part_trim.to_lowercase().as_str() { - "year" => date_part_f64(array.as_ref(), DatePart::Year)?, - "quarter" => date_part_f64(array.as_ref(), DatePart::Quarter)?, - "month" => date_part_f64(array.as_ref(), DatePart::Month)?, - "week" => date_part_f64(array.as_ref(), DatePart::Week)?, - "day" => date_part_f64(array.as_ref(), DatePart::Day)?, - "doy" => date_part_f64(array.as_ref(), DatePart::DayOfYear)?, - "dow" => date_part_f64(array.as_ref(), DatePart::DayOfWeekSunday0)?, - "hour" => date_part_f64(array.as_ref(), DatePart::Hour)?, - "minute" => date_part_f64(array.as_ref(), DatePart::Minute)?, - "second" => seconds(array.as_ref(), Second)?, - "millisecond" => seconds(array.as_ref(), Millisecond)?, - "microsecond" => seconds(array.as_ref(), Microsecond)?, - "nanosecond" => seconds(array.as_ref(), Nanosecond)?, - "epoch" => epoch(array.as_ref())?, - _ => return exec_err!("Date part '{part}' not supported"), + // using IntervalUnit here means we hand off all the work of supporting plurals (like "seconds") + // and synonyms ( like "ms,msec,msecond,millisecond") to Arrow + let arr = if let Ok(interval_unit) = IntervalUnit::from_str(part_trim) { + match interval_unit { + IntervalUnit::Year => date_part_f64(array.as_ref(), DatePart::Year)?, + IntervalUnit::Month => date_part_f64(array.as_ref(), DatePart::Month)?, + IntervalUnit::Week => date_part_f64(array.as_ref(), DatePart::Week)?, + IntervalUnit::Day => date_part_f64(array.as_ref(), DatePart::Day)?, + IntervalUnit::Hour => date_part_f64(array.as_ref(), DatePart::Hour)?, + IntervalUnit::Minute => date_part_f64(array.as_ref(), DatePart::Minute)?, + IntervalUnit::Second => seconds(array.as_ref(), Second)?, + IntervalUnit::Millisecond => seconds(array.as_ref(), Millisecond)?, + IntervalUnit::Microsecond => seconds(array.as_ref(), Microsecond)?, + IntervalUnit::Nanosecond => seconds(array.as_ref(), Nanosecond)?, + // century and decade are not supported by `DatePart`, although they are supported in postgres + _ => return exec_err!("Date part '{part}' not supported"), + } + } else { + // special cases that can be extracted (in postgres) but are not interval units + match part_trim.to_lowercase().as_str() { + "qtr" | "quarter" => date_part_f64(array.as_ref(), DatePart::Quarter)?, + "doy" => date_part_f64(array.as_ref(), DatePart::DayOfYear)?, + "dow" => date_part_f64(array.as_ref(), DatePart::DayOfWeekSunday0)?, + "epoch" => epoch(array.as_ref())?, + _ => return exec_err!("Date part '{part}' not supported"), + } }; Ok(if is_scalar { diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs index c79c6358be36..34e119c45fdf 100644 --- a/datafusion/sql/src/expr/mod.rs +++ b/datafusion/sql/src/expr/mod.rs @@ -201,9 +201,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } SQLExpr::Array(arr) => self.sql_array_literal(arr.elem, schema), - SQLExpr::Interval(interval) => { - self.sql_interval_to_expr(false, interval, schema, planner_context) - } + SQLExpr::Interval(interval) => self.sql_interval_to_expr(false, interval), SQLExpr::Identifier(id) => { self.sql_identifier_to_expr(id, schema, planner_context) } diff --git a/datafusion/sql/src/expr/unary_op.rs b/datafusion/sql/src/expr/unary_op.rs index 9fcee7a06124..2a341fb7c446 100644 --- a/datafusion/sql/src/expr/unary_op.rs +++ b/datafusion/sql/src/expr/unary_op.rs @@ -43,7 +43,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { self.parse_sql_number(&n, true) } SQLExpr::Interval(interval) => { - self.sql_interval_to_expr(true, interval, schema, planner_context) + self.sql_interval_to_expr(true, interval) } // not a literal, apply negative operator on expression _ => Ok(Expr::Negative(Box::new(self.sql_expr_to_logical_expr( diff --git a/datafusion/sql/src/expr/value.rs b/datafusion/sql/src/expr/value.rs index afcd182fa343..64575645bc44 100644 --- a/datafusion/sql/src/expr/value.rs +++ b/datafusion/sql/src/expr/value.rs @@ -26,7 +26,7 @@ use datafusion_expr::expr::{BinaryExpr, Placeholder}; use datafusion_expr::planner::PlannerResult; use datafusion_expr::{lit, Expr, Operator}; use log::debug; -use sqlparser::ast::{BinaryOperator, Expr as SQLExpr, Interval, Value}; +use sqlparser::ast::{BinaryOperator, Expr as SQLExpr, Interval, UnaryOperator, Value}; use sqlparser::parser::ParserError::ParserError; use std::borrow::Cow; @@ -168,12 +168,11 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { /// Convert a SQL interval expression to a DataFusion logical plan /// expression + #[allow(clippy::only_used_in_recursion)] pub(super) fn sql_interval_to_expr( &self, negative: bool, interval: Interval, - schema: &DFSchema, - planner_context: &mut PlannerContext, ) -> Result { if interval.leading_precision.is_some() { return not_impl_err!( @@ -196,127 +195,42 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { ); } - // Only handle string exprs for now - let value = match *interval.value { - SQLExpr::Value( - Value::SingleQuotedString(s) | Value::DoubleQuotedString(s), - ) => { - if negative { - format!("-{s}") - } else { - s - } - } - // Support expressions like `interval '1 month' + date/timestamp`. - // Such expressions are parsed like this by sqlparser-rs - // - // Interval - // BinaryOp - // Value(StringLiteral) - // Cast - // Value(StringLiteral) - // - // This code rewrites them to the following: - // - // BinaryOp - // Interval - // Value(StringLiteral) - // Cast - // Value(StringLiteral) - SQLExpr::BinaryOp { left, op, right } => { - let df_op = match op { - BinaryOperator::Plus => Operator::Plus, - BinaryOperator::Minus => Operator::Minus, - BinaryOperator::Eq => Operator::Eq, - BinaryOperator::NotEq => Operator::NotEq, - BinaryOperator::Gt => Operator::Gt, - BinaryOperator::GtEq => Operator::GtEq, - BinaryOperator::Lt => Operator::Lt, - BinaryOperator::LtEq => Operator::LtEq, - _ => { - return not_impl_err!("Unsupported interval operator: {op:?}"); - } - }; - match ( - interval.leading_field.as_ref(), - left.as_ref(), - right.as_ref(), - ) { - (_, _, SQLExpr::Value(_)) => { - let left_expr = self.sql_interval_to_expr( - negative, - Interval { - value: left, - leading_field: interval.leading_field.clone(), - leading_precision: None, - last_field: None, - fractional_seconds_precision: None, - }, - schema, - planner_context, - )?; - let right_expr = self.sql_interval_to_expr( - false, - Interval { - value: right, - leading_field: interval.leading_field, - leading_precision: None, - last_field: None, - fractional_seconds_precision: None, - }, - schema, - planner_context, - )?; - return Ok(Expr::BinaryExpr(BinaryExpr::new( - Box::new(left_expr), - df_op, - Box::new(right_expr), - ))); - } - // In this case, the left node is part of the interval - // expr and the right node is an independent expr. - // - // Leading field is not supported when the right operand - // is not a value. - (None, _, _) => { - let left_expr = self.sql_interval_to_expr( - negative, - Interval { - value: left, - leading_field: None, - leading_precision: None, - last_field: None, - fractional_seconds_precision: None, - }, - schema, - planner_context, - )?; - let right_expr = self.sql_expr_to_logical_expr( - *right, - schema, - planner_context, - )?; - return Ok(Expr::BinaryExpr(BinaryExpr::new( - Box::new(left_expr), - df_op, - Box::new(right_expr), - ))); - } - _ => { - let value = SQLExpr::BinaryOp { left, op, right }; - return not_impl_err!( - "Unsupported interval argument. Expected string literal, got: {value:?}" - ); - } + if let SQLExpr::BinaryOp { left, op, right } = *interval.value { + let df_op = match op { + BinaryOperator::Plus => Operator::Plus, + BinaryOperator::Minus => Operator::Minus, + _ => { + return not_impl_err!("Unsupported interval operator: {op:?}"); } - } - _ => { - return not_impl_err!( - "Unsupported interval argument. Expected string literal, got: {:?}", - interval.value - ); - } - }; + }; + let left_expr = self.sql_interval_to_expr( + negative, + Interval { + value: left, + leading_field: interval.leading_field.clone(), + leading_precision: None, + last_field: None, + fractional_seconds_precision: None, + }, + )?; + let right_expr = self.sql_interval_to_expr( + false, + Interval { + value: right, + leading_field: interval.leading_field, + leading_precision: None, + last_field: None, + fractional_seconds_precision: None, + }, + )?; + return Ok(Expr::BinaryExpr(BinaryExpr::new( + Box::new(left_expr), + df_op, + Box::new(right_expr), + ))); + } + + let value = interval_literal(*interval.value, negative)?; let value = if has_units(&value) { // If the interval already contains a unit @@ -343,6 +257,41 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } } +fn interval_literal(interval_value: SQLExpr, negative: bool) -> Result { + let s = match interval_value { + SQLExpr::Value(Value::SingleQuotedString(s) | Value::DoubleQuotedString(s)) => s, + SQLExpr::Value(Value::Number(ref v, long)) => { + if long { + return not_impl_err!( + "Unsupported interval argument. Long number not supported: {interval_value:?}" + ); + } else { + v.to_string() + } + } + SQLExpr::UnaryOp { op, expr } => { + let negative = match op { + UnaryOperator::Minus => !negative, + UnaryOperator::Plus => negative, + _ => { + return not_impl_err!( + "Unsupported SQL unary operator in interval {op:?}" + ); + } + }; + interval_literal(*expr, negative)? + } + _ => { + return not_impl_err!("Unsupported interval argument. Expected string literal or number, got: {interval_value:?}"); + } + }; + if negative { + Ok(format!("-{s}")) + } else { + Ok(s) + } +} + // TODO make interval parsing better in arrow-rs / expose `IntervalType` fn has_units(val: &str) -> bool { let val = val.to_lowercase(); diff --git a/datafusion/sql/tests/cases/plan_to_sql.rs b/datafusion/sql/tests/cases/plan_to_sql.rs index fa95d05c3275..bd338e440e36 100644 --- a/datafusion/sql/tests/cases/plan_to_sql.rs +++ b/datafusion/sql/tests/cases/plan_to_sql.rs @@ -495,11 +495,17 @@ fn test_table_references_in_plan_to_sql() { assert_eq!(format!("{}", sql), expected_sql) } - test("catalog.schema.table", "SELECT catalog.\"schema\".\"table\".id, catalog.\"schema\".\"table\".\"value\" FROM catalog.\"schema\".\"table\""); - test("schema.table", "SELECT \"schema\".\"table\".id, \"schema\".\"table\".\"value\" FROM \"schema\".\"table\""); + test( + "catalog.schema.table", + r#"SELECT "catalog"."schema"."table".id, "catalog"."schema"."table"."value" FROM "catalog"."schema"."table""#, + ); + test( + "schema.table", + r#"SELECT "schema"."table".id, "schema"."table"."value" FROM "schema"."table""#, + ); test( "table", - "SELECT \"table\".id, \"table\".\"value\" FROM \"table\"", + r#"SELECT "table".id, "table"."value" FROM "table""#, ); } @@ -521,10 +527,10 @@ fn test_table_scan_with_no_projection_in_plan_to_sql() { test( "catalog.schema.table", - "SELECT * FROM catalog.\"schema\".\"table\"", + r#"SELECT * FROM "catalog"."schema"."table""#, ); - test("schema.table", "SELECT * FROM \"schema\".\"table\""); - test("table", "SELECT * FROM \"table\""); + test("schema.table", r#"SELECT * FROM "schema"."table""#); + test("table", r#"SELECT * FROM "table""#); } #[test] diff --git a/datafusion/sqllogictest/test_files/expr.slt b/datafusion/sqllogictest/test_files/expr.slt index 002e8db2132d..a478e3617261 100644 --- a/datafusion/sqllogictest/test_files/expr.slt +++ b/datafusion/sqllogictest/test_files/expr.slt @@ -1355,6 +1355,16 @@ SELECT date_part('second', arrow_cast('23:32:50.123456789'::time, 'Time64(Nanose ---- 50.123456789 +query R +select extract(second from '2024-08-09T12:13:14') +---- +14 + +query R +select extract(seconds from '2024-08-09T12:13:14') +---- +14 + query R SELECT extract(second from arrow_cast('23:32:50.123456789'::time, 'Time64(Nanosecond)')) ---- @@ -1381,6 +1391,11 @@ SELECT extract(microsecond from arrow_cast('23:32:50.123456789'::time, 'Time64(N ---- 50123456.789000005 +query R +SELECT extract(us from arrow_cast('23:32:50.123456789'::time, 'Time64(Nanosecond)')) +---- +50123456.789000005 + query R SELECT date_part('nanosecond', arrow_cast('23:32:50.123456789'::time, 'Time64(Nanosecond)')) ---- diff --git a/datafusion/sqllogictest/test_files/interval.slt b/datafusion/sqllogictest/test_files/interval.slt index 077f38d5d5bb..c73e340f9115 100644 --- a/datafusion/sqllogictest/test_files/interval.slt +++ b/datafusion/sqllogictest/test_files/interval.slt @@ -15,7 +15,6 @@ # specific language governing permissions and limitations # under the License. - # Use `interval` SQL literal syntax # the types should be the same: https://github.com/apache/datafusion/issues/5801 query TT @@ -206,117 +205,60 @@ select interval '5 YEAR 5 MONTH 5 DAY 5 HOUR 5 MINUTE 5 SECOND 5 MILLISECOND 5 M ---- 65 mons 5 days 5 hours 5 mins 5.005005005 secs -# Interval with string literal addition -query ? -select interval '1 month' + '1 month' ----- -2 mons - -# Interval with string literal addition and leading field -query ? -select interval '1' + '1' month ----- -2 mons - -# Interval with nested string literal addition -query ? -select interval '1 month' + '1 month' + '1 month' ----- -3 mons - -# Interval with nested string literal addition and leading field -query ? -select interval '1' + '1' + '1' month ----- -3 mons - -# Interval mega nested string literal addition +# Interval mega nested literal addition query ? -select interval '1 year' + '1 month' + '1 day' + '1 hour' + '1 minute' + '1 second' + '1 millisecond' + '1 microsecond' + '1 nanosecond' +select interval '1 year' + interval '1 month' + interval '1 day' + interval '1 hour' + interval '1 minute' + interval '1 second' + interval '1 millisecond' + interval '1 microsecond' + interval '1 nanosecond' ---- 13 mons 1 days 1 hours 1 mins 1.001001001 secs # Interval with string literal subtraction query ? -select interval '1 month' - '1 day'; +select interval '1 month' - interval '1 day'; ---- 1 mons -1 days -# Interval with string literal subtraction and leading field -query ? -select interval '5' - '1' - '2' year; ----- -24 mons - # Interval with nested string literal subtraction query ? -select interval '1 month' - '1 day' - '1 hour'; +select interval '1 month' - interval '1 day' - interval '1 hour'; ---- 1 mons -1 days -1 hours -# Interval with nested string literal subtraction and leading field -query ? -select interval '10' - '1' - '1' month; ----- -8 mons - # Interval mega nested string literal subtraction query ? -select interval '1 year' - '1 month' - '1 day' - '1 hour' - '1 minute' - '1 second' - '1 millisecond' - '1 microsecond' - '1 nanosecond' +select interval '1 year' - interval '1 month' - interval '1 day' - interval '1 hour' - interval '1 minute' - interval '1 second' - interval '1 millisecond' - interval '1 microsecond' - interval '1 nanosecond' ---- 11 mons -1 days -1 hours -1 mins -1.001001001 secs -# Interval with string literal negation and leading field -query ? -select -interval '5' - '1' - '2' year; ----- --96 mons - -# Interval with nested string literal negation +# Interval with nested literal negation query ? -select -interval '1 month' + '1 day' + '1 hour'; +select -interval '1 month' + interval '1 day' + interval '1 hour'; ---- -1 mons 1 days 1 hours -# Interval with nested string literal negation and leading field -query ? -select -interval '10' - '1' - '1' month; ----- --12 mons - -# Interval mega nested string literal negation +# Interval mega nested literal negation query ? -select -interval '1 year' - '1 month' - '1 day' - '1 hour' - '1 minute' - '1 second' - '1 millisecond' - '1 microsecond' - '1 nanosecond' +select -interval '1 year' - interval '1 month' - interval '1 day' - interval '1 hour' - interval '1 minute' - interval '1 second' - interval '1 millisecond' - interval '1 microsecond' - interval '1 nanosecond' ---- -13 mons -1 days -1 hours -1 mins -1.001001001 secs # Interval string literal + date query D -select interval '1 month' + '1 day' + '2012-01-01'::date; ----- -2012-02-02 - -# Interval string literal parenthesized + date -query D -select ( interval '1 month' + '1 day' ) + '2012-01-01'::date; +select interval 1 month + interval 1 day + '2012-01-01'::date; ---- 2012-02-02 # Interval nested string literal + date query D -select interval '1 year' + '1 month' + '1 day' + '2012-01-01'::date +select interval 1 year + interval 1 month + interval 1 day + '2012-01-01'::date ---- 2013-02-02 # Interval nested string literal subtraction + date query D -select interval '1 year' - '1 month' + '1 day' + '2012-01-01'::date +select interval 1 year - interval 1 month + interval 1 day + '2012-01-01'::date ---- 2012-12-02 - - - # Use interval SQL type query TT select diff --git a/datafusion/sqllogictest/test_files/interval_mysql.slt b/datafusion/sqllogictest/test_files/interval_mysql.slt new file mode 100644 index 000000000000..c05bb007e5f1 --- /dev/null +++ b/datafusion/sqllogictest/test_files/interval_mysql.slt @@ -0,0 +1,71 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Use `interval` SQL literal syntax with MySQL dialect + +# this should fail with the generic dialect +query error DataFusion error: Error during planning: Cannot coerce arithmetic expression Interval\(MonthDayNano\) \+ Utf8 to valid types +select interval '1' + '1' month + +statement ok +set datafusion.sql_parser.dialect = 'Mysql'; + +# Interval with string literal addition and leading field +query ? +select interval '1' + '1' month +---- +2 mons + +# Interval with nested string literal addition +query ? +select interval 1 + 1 + 1 month +---- +3 mons + +# Interval with nested string literal addition and leading field +query ? +select interval '1' + '1' + '1' month +---- +3 mons + +# Interval with string literal subtraction and leading field +query ? +select interval '5' - '1' - '2' year; +---- +24 mons + +# Interval with nested string literal subtraction and leading field +query ? +select interval '10' - '1' - '1' month; +---- +8 mons + +# Interval with string literal negation and leading field +query ? +select -interval '5' - '1' - '2' year; +---- +-96 mons + +# Interval with nested string literal negation and leading field +query ? +select -interval '10' - '1' - '1' month; +---- +-12 mons + +# revert to standard dialect +statement ok +set datafusion.sql_parser.dialect = 'Generic';