Skip to content

Commit

Permalink
feat: improve type inference for WindowFrame (#13059)
Browse files Browse the repository at this point in the history
* feat: improve type inference for WindowFrame

Closes #11432

* Support Interval for groups and rows

* Remove case for SingleQuotedString
  • Loading branch information
notfilippo authored Oct 25, 2024
1 parent 2322933 commit 6a3c0b0
Show file tree
Hide file tree
Showing 5 changed files with 217 additions and 135 deletions.
191 changes: 154 additions & 37 deletions datafusion/expr/src/window_frame.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,11 @@
//! - An ending frame boundary,
//! - An EXCLUDE clause.

use crate::{expr::Sort, lit};
use arrow::datatypes::DataType;
use std::fmt::{self, Formatter};
use std::hash::Hash;

use crate::{expr::Sort, lit};

use datafusion_common::{plan_err, sql_err, DataFusionError, Result, ScalarValue};
use sqlparser::ast;
use sqlparser::parser::ParserError::ParserError;
Expand Down Expand Up @@ -119,9 +119,9 @@ impl TryFrom<ast::WindowFrame> for WindowFrame {
type Error = DataFusionError;

fn try_from(value: ast::WindowFrame) -> Result<Self> {
let start_bound = value.start_bound.try_into()?;
let start_bound = WindowFrameBound::try_parse(value.start_bound, &value.units)?;
let end_bound = match value.end_bound {
Some(value) => value.try_into()?,
Some(bound) => WindowFrameBound::try_parse(bound, &value.units)?,
None => WindowFrameBound::CurrentRow,
};

Expand All @@ -138,6 +138,7 @@ impl TryFrom<ast::WindowFrame> for WindowFrame {
)?
}
};

let units = value.units.into();
Ok(Self::new_bounds(units, start_bound, end_bound))
}
Expand Down Expand Up @@ -334,51 +335,84 @@ impl WindowFrameBound {
}
}

impl TryFrom<ast::WindowFrameBound> for WindowFrameBound {
type Error = DataFusionError;

fn try_from(value: ast::WindowFrameBound) -> Result<Self> {
impl WindowFrameBound {
fn try_parse(
value: ast::WindowFrameBound,
units: &ast::WindowFrameUnits,
) -> Result<Self> {
Ok(match value {
ast::WindowFrameBound::Preceding(Some(v)) => {
Self::Preceding(convert_frame_bound_to_scalar_value(*v)?)
Self::Preceding(convert_frame_bound_to_scalar_value(*v, units)?)
}
ast::WindowFrameBound::Preceding(None) => Self::Preceding(ScalarValue::Null),
ast::WindowFrameBound::Following(Some(v)) => {
Self::Following(convert_frame_bound_to_scalar_value(*v)?)
Self::Following(convert_frame_bound_to_scalar_value(*v, units)?)
}
ast::WindowFrameBound::Following(None) => Self::Following(ScalarValue::Null),
ast::WindowFrameBound::CurrentRow => Self::CurrentRow,
})
}
}

pub fn convert_frame_bound_to_scalar_value(v: ast::Expr) -> Result<ScalarValue> {
Ok(ScalarValue::Utf8(Some(match v {
ast::Expr::Value(ast::Value::Number(value, false))
| ast::Expr::Value(ast::Value::SingleQuotedString(value)) => value,
ast::Expr::Interval(ast::Interval {
value,
leading_field,
..
}) => {
let result = match *value {
ast::Expr::Value(ast::Value::SingleQuotedString(item)) => item,
e => {
return sql_err!(ParserError(format!(
"INTERVAL expression cannot be {e:?}"
)));
fn convert_frame_bound_to_scalar_value(
v: ast::Expr,
units: &ast::WindowFrameUnits,
) -> Result<ScalarValue> {
match units {
// For ROWS and GROUPS we are sure that the ScalarValue must be a non-negative integer ...
ast::WindowFrameUnits::Rows | ast::WindowFrameUnits::Groups => match v {
ast::Expr::Value(ast::Value::Number(value, false)) => {
Ok(ScalarValue::try_from_string(value, &DataType::UInt64)?)
},
ast::Expr::Interval(ast::Interval {
value,
leading_field: None,
leading_precision: None,
last_field: None,
fractional_seconds_precision: None,
}) => {
let value = match *value {
ast::Expr::Value(ast::Value::SingleQuotedString(item)) => item,
e => {
return sql_err!(ParserError(format!(
"INTERVAL expression cannot be {e:?}"
)));
}
};
Ok(ScalarValue::try_from_string(value, &DataType::UInt64)?)
}
_ => plan_err!(
"Invalid window frame: frame offsets for ROWS / GROUPS must be non negative integers"
),
},
// ... instead for RANGE it could be anything depending on the type of the ORDER BY clause,
// so we use a ScalarValue::Utf8.
ast::WindowFrameUnits::Range => Ok(ScalarValue::Utf8(Some(match v {
ast::Expr::Value(ast::Value::Number(value, false)) => value,
ast::Expr::Interval(ast::Interval {
value,
leading_field,
..
}) => {
let result = match *value {
ast::Expr::Value(ast::Value::SingleQuotedString(item)) => item,
e => {
return sql_err!(ParserError(format!(
"INTERVAL expression cannot be {e:?}"
)));
}
};
if let Some(leading_field) = leading_field {
format!("{result} {leading_field}")
} else {
result
}
};
if let Some(leading_field) = leading_field {
format!("{result} {leading_field}")
} else {
result
}
}
_ => plan_err!(
"Invalid window frame: frame offsets must be non negative integers"
)?,
})))
_ => plan_err!(
"Invalid window frame: frame offsets for RANGE must be either a numeric value, a string value or an interval"
)?,
}))),
}
}

impl fmt::Display for WindowFrameBound {
Expand Down Expand Up @@ -479,8 +513,91 @@ mod tests {
ast::Expr::Value(ast::Value::Number("1".to_string(), false)),
)))),
};
let result = WindowFrame::try_from(window_frame);
assert!(result.is_ok());

let window_frame = WindowFrame::try_from(window_frame)?;
assert_eq!(window_frame.units, WindowFrameUnits::Rows);
assert_eq!(
window_frame.start_bound,
WindowFrameBound::Preceding(ScalarValue::UInt64(Some(2)))
);
assert_eq!(
window_frame.end_bound,
WindowFrameBound::Preceding(ScalarValue::UInt64(Some(1)))
);

Ok(())
}

macro_rules! test_bound {
($unit:ident, $value:expr, $expected:expr) => {
let preceding = WindowFrameBound::try_parse(
ast::WindowFrameBound::Preceding($value),
&ast::WindowFrameUnits::$unit,
)?;
assert_eq!(preceding, WindowFrameBound::Preceding($expected));
let following = WindowFrameBound::try_parse(
ast::WindowFrameBound::Following($value),
&ast::WindowFrameUnits::$unit,
)?;
assert_eq!(following, WindowFrameBound::Following($expected));
};
}

macro_rules! test_bound_err {
($unit:ident, $value:expr, $expected:expr) => {
let err = WindowFrameBound::try_parse(
ast::WindowFrameBound::Preceding($value),
&ast::WindowFrameUnits::$unit,
)
.unwrap_err();
assert_eq!(err.strip_backtrace(), $expected);
let err = WindowFrameBound::try_parse(
ast::WindowFrameBound::Following($value),
&ast::WindowFrameUnits::$unit,
)
.unwrap_err();
assert_eq!(err.strip_backtrace(), $expected);
};
}

#[test]
fn test_window_frame_bound_creation() -> Result<()> {
// Unbounded
test_bound!(Rows, None, ScalarValue::Null);
test_bound!(Groups, None, ScalarValue::Null);
test_bound!(Range, None, ScalarValue::Null);

// Number
let number = Some(Box::new(ast::Expr::Value(ast::Value::Number(
"42".to_string(),
false,
))));
test_bound!(Rows, number.clone(), ScalarValue::UInt64(Some(42)));
test_bound!(Groups, number.clone(), ScalarValue::UInt64(Some(42)));
test_bound!(
Range,
number.clone(),
ScalarValue::Utf8(Some("42".to_string()))
);

// Interval
let number = Some(Box::new(ast::Expr::Interval(ast::Interval {
value: Box::new(ast::Expr::Value(ast::Value::SingleQuotedString(
"1".to_string(),
))),
leading_field: Some(ast::DateTimeField::Day),
fractional_seconds_precision: None,
last_field: None,
leading_precision: None,
})));
test_bound_err!(Rows, number.clone(), "Error during planning: Invalid window frame: frame offsets for ROWS / GROUPS must be non negative integers");
test_bound_err!(Groups, number.clone(), "Error during planning: Invalid window frame: frame offsets for ROWS / GROUPS must be non negative integers");
test_bound!(
Range,
number.clone(),
ScalarValue::Utf8(Some("1 DAY".to_string()))
);

Ok(())
}
}
23 changes: 12 additions & 11 deletions datafusion/optimizer/src/analyzer/type_coercion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -696,20 +696,20 @@ fn coerce_window_frame(
expressions: &[Sort],
) -> Result<WindowFrame> {
let mut window_frame = window_frame;
let current_types = expressions
.iter()
.map(|s| s.expr.get_type(schema))
.collect::<Result<Vec<_>>>()?;
let target_type = match window_frame.units {
WindowFrameUnits::Range => {
if let Some(col_type) = current_types.first() {
let current_types = expressions
.first()
.map(|s| s.expr.get_type(schema))
.transpose()?;
if let Some(col_type) = current_types {
if col_type.is_numeric()
|| is_utf8_or_large_utf8(col_type)
|| is_utf8_or_large_utf8(&col_type)
|| matches!(col_type, DataType::Null)
{
col_type
} else if is_datetime(col_type) {
&DataType::Interval(IntervalUnit::MonthDayNano)
} else if is_datetime(&col_type) {
DataType::Interval(IntervalUnit::MonthDayNano)
} else {
return internal_err!(
"Cannot run range queries on datatype: {col_type:?}"
Expand All @@ -719,10 +719,11 @@ fn coerce_window_frame(
return internal_err!("ORDER BY column cannot be empty");
}
}
WindowFrameUnits::Rows | WindowFrameUnits::Groups => &DataType::UInt64,
WindowFrameUnits::Rows | WindowFrameUnits::Groups => DataType::UInt64,
};
window_frame.start_bound = coerce_frame_bound(target_type, window_frame.start_bound)?;
window_frame.end_bound = coerce_frame_bound(target_type, window_frame.end_bound)?;
window_frame.start_bound =
coerce_frame_bound(&target_type, window_frame.start_bound)?;
window_frame.end_bound = coerce_frame_bound(&target_type, window_frame.end_bound)?;
Ok(window_frame)
}

Expand Down
2 changes: 1 addition & 1 deletion datafusion/sql/tests/cases/plan_to_sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1061,7 +1061,7 @@ fn test_aggregation_to_sql() {
FROM person
GROUP BY id, first_name;"#,
r#"SELECT person.id, person.first_name,
sum(person.id) AS total_sum, sum(person.id) OVER (PARTITION BY person.first_name ROWS BETWEEN '5' PRECEDING AND '2' FOLLOWING) AS moving_sum,
sum(person.id) AS total_sum, sum(person.id) OVER (PARTITION BY person.first_name ROWS BETWEEN 5 PRECEDING AND 2 FOLLOWING) AS moving_sum,
max(sum(person.id)) OVER (PARTITION BY person.first_name ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS max_total,
rank() OVER (PARTITION BY (grouping(person.id) + grouping(person.age)), CASE WHEN (grouping(person.age) = 0) THEN person.id END ORDER BY sum(person.id) DESC NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS rank_within_parent_1,
rank() OVER (PARTITION BY (grouping(person.age) + grouping(person.id)), CASE WHEN (CAST(grouping(person.age) AS BIGINT) = 0) THEN person.id END ORDER BY sum(person.id) DESC NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS rank_within_parent_2
Expand Down
32 changes: 28 additions & 4 deletions datafusion/sqllogictest/test_files/window.slt
Original file line number Diff line number Diff line change
Expand Up @@ -2208,7 +2208,7 @@ physical_plan
01)ProjectionExec: expr=[sum1@0 as sum1, sum2@1 as sum2]
02)--SortExec: TopK(fetch=5), expr=[c9@2 ASC NULLS LAST], preserve_partitioning=[false]
03)----ProjectionExec: expr=[sum(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c2 ASC NULLS LAST] GROUPS BETWEEN 1 PRECEDING AND 1 FOLLOWING@3 as sum1, sum(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST] GROUPS BETWEEN 5 PRECEDING AND 3 PRECEDING@4 as sum2, c9@1 as c9]
04)------BoundedWindowAggExec: wdw=[sum(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST] GROUPS BETWEEN 5 PRECEDING AND 3 PRECEDING: Ok(Field { name: "sum(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST] GROUPS BETWEEN 5 PRECEDING AND 3 PRECEDING", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Groups, start_bound: Preceding(UInt64(5)), end_bound: Preceding(UInt64(3)), is_causal: false }], mode=[Sorted]
04)------BoundedWindowAggExec: wdw=[sum(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST] GROUPS BETWEEN 5 PRECEDING AND 3 PRECEDING: Ok(Field { name: "sum(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST] GROUPS BETWEEN 5 PRECEDING AND 3 PRECEDING", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Groups, start_bound: Preceding(UInt64(5)), end_bound: Preceding(UInt64(3)), is_causal: true }], mode=[Sorted]
05)--------ProjectionExec: expr=[c1@0 as c1, c9@2 as c9, c12@3 as c12, sum(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c2 ASC NULLS LAST] GROUPS BETWEEN 1 PRECEDING AND 1 FOLLOWING@4 as sum(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c2 ASC NULLS LAST] GROUPS BETWEEN 1 PRECEDING AND 1 FOLLOWING]
06)----------BoundedWindowAggExec: wdw=[sum(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c2 ASC NULLS LAST] GROUPS BETWEEN 1 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "sum(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c2 ASC NULLS LAST] GROUPS BETWEEN 1 PRECEDING AND 1 FOLLOWING", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Groups, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)), is_causal: false }], mode=[Sorted]
07)------------SortExec: expr=[c1@0 ASC NULLS LAST,c2@1 ASC NULLS LAST], preserve_partitioning=[false]
Expand Down Expand Up @@ -2378,17 +2378,41 @@ SELECT c9, rn1 FROM (SELECT c9,


# invalid window frame. null as preceding
statement error DataFusion error: Error during planning: Invalid window frame: frame offsets must be non negative integers
statement error DataFusion error: Error during planning: Invalid window frame: frame offsets for ROWS / GROUPS must be non negative integers
select row_number() over (rows between null preceding and current row) from (select 1 a) x

# invalid window frame. null as preceding
statement error DataFusion error: Error during planning: Invalid window frame: frame offsets must be non negative integers
statement error DataFusion error: Error during planning: Invalid window frame: frame offsets for ROWS / GROUPS must be non negative integers
select row_number() over (rows between null preceding and current row) from (select 1 a) x

# invalid window frame. negative as following
statement error DataFusion error: Error during planning: Invalid window frame: frame offsets must be non negative integers
statement error DataFusion error: Error during planning: Invalid window frame: frame offsets for ROWS / GROUPS must be non negative integers
select row_number() over (rows between current row and -1 following) from (select 1 a) x

# invalid window frame. null as preceding
statement error DataFusion error: Error during planning: Invalid window frame: frame offsets for ROWS / GROUPS must be non negative integers
select row_number() over (order by a groups between null preceding and current row) from (select 1 a) x

# invalid window frame. null as preceding
statement error DataFusion error: Error during planning: Invalid window frame: frame offsets for ROWS / GROUPS must be non negative integers
select row_number() over (order by a groups between null preceding and current row) from (select 1 a) x

# invalid window frame. negative as following
statement error DataFusion error: Error during planning: Invalid window frame: frame offsets for ROWS / GROUPS must be non negative integers
select row_number() over (order by a groups between current row and -1 following) from (select 1 a) x

# interval for rows
query I
select row_number() over (rows between '1' preceding and current row) from (select 1 a) x
----
1

# interval for groups
query I
select row_number() over (order by a groups between '1' preceding and current row) from (select 1 a) x
----
1

# This test shows that ordering satisfy considers ordering equivalences,
# and can simplify (reduce expression size) multi expression requirements during normalization
# For the example below, requirement rn1 ASC, c9 DESC should be simplified to the rn1 ASC.
Expand Down
Loading

0 comments on commit 6a3c0b0

Please sign in to comment.