From 25df65809aaef7c3a8bbaeaa27d9be701dbf54c4 Mon Sep 17 00:00:00 2001 From: alexander-beedie Date: Sun, 2 Jun 2024 23:50:38 +0400 Subject: [PATCH] feat: Support use of ordinal values in SQL `ORDER BY` clause --- crates/polars-sql/src/context.rs | 95 ++++++++++++++++------- py-polars/tests/unit/sql/test_group_by.py | 4 +- py-polars/tests/unit/sql/test_order_by.py | 51 ++++++++++++ 3 files changed, 119 insertions(+), 31 deletions(-) diff --git a/crates/polars-sql/src/context.rs b/crates/polars-sql/src/context.rs index 86fcceb03048..375e29f66db9 100644 --- a/crates/polars-sql/src/context.rs +++ b/crates/polars-sql/src/context.rs @@ -197,6 +197,59 @@ impl SQLContext { }) } + fn expr_or_ordinal( + &mut self, + e: &SQLExpr, + schema: Option<&Schema>, + exprs: &[Expr], + clause: &str, + ) -> PolarsResult { + match e { + SQLExpr::UnaryOp { + op: UnaryOperator::Minus, + expr, + } if matches!(**expr, SQLExpr::Value(SQLValue::Number(_, _))) => { + if let SQLExpr::Value(SQLValue::Number(ref idx, _)) = **expr { + Err(polars_err!( + SQLSyntax: + "negative ordinals values are invalid for {}; found -{}", + clause, + idx + )) + } else { + unreachable!() + } + }, + SQLExpr::Value(SQLValue::Number(idx, _)) => { + // note: sql queries are 1-indexed + let idx = idx.parse::().map_err(|_| { + polars_err!( + SQLSyntax: + "negative ordinals values are invalid for {}; found {}", + clause, + idx + ) + })?; + Ok(exprs + .get(idx - 1) + .ok_or_else(|| { + polars_err!( + SQLInterface: + "{} ordinal value must refer to a valid column; found {}", + clause, + idx + ) + })? + .clone()) + }, + SQLExpr::Value(v) => Err(polars_err!( + SQLSyntax: + "{} requires a valid expression or positive ordinal; found {}", clause, v, + )), + _ => parse_sql_expr(e, self, schema), + } + } + pub(super) fn resolve_name(&self, tbl_name: &str, column_name: &str) -> String { if self.joined_aliases.borrow().contains_key(tbl_name) { self.joined_aliases @@ -473,36 +526,12 @@ impl SQLContext { // Check for "GROUP BY ..." (after projections, as there may be ordinal/position ints). let mut group_by_keys: Vec = Vec::new(); match &select_stmt.group_by { - // Standard "GROUP BY x, y, z" syntax + // Standard "GROUP BY x, y, z" syntax (also recognising ordinal values) GroupByExpr::Expressions(group_by_exprs) => { + // translate the group expressions, allowing ordinal values group_by_keys = group_by_exprs .iter() - .map(|e| match e { - SQLExpr::UnaryOp { - op: UnaryOperator::Minus, - expr, - } if matches!(**expr, SQLExpr::Value(SQLValue::Number(_, _))) => { - if let SQLExpr::Value(SQLValue::Number(ref idx, _)) = **expr { - Err(polars_err!( - SQLSyntax: - "GROUP BY error: expected a positive integer or valid expression; got -{}", - idx - )) - } else { - unreachable!() - } - }, - SQLExpr::Value(SQLValue::Number(idx, _)) => { - // note: sql queries are 1-indexed - let idx = idx.parse::().unwrap(); - Ok(projections[idx - 1].clone()) - }, - SQLExpr::Value(v) => Err(polars_err!( - SQLSyntax: - "GROUP BY error: expected a positive integer or valid expression; got {}", v, - )), - _ => parse_sql_expr(e, self, schema.as_deref()), - }) + .map(|e| self.expr_or_ordinal(e, schema.as_deref(), &projections, "GROUP BY")) .collect::>()? }, // "GROUP BY ALL" syntax; automatically adds expressions that do not contain @@ -838,7 +867,6 @@ impl SQLContext { .unwrap_or_else(|| tbl_name); self.table_map.insert(tbl_name.clone(), lf.clone()); - Ok((tbl_name, lf)) } @@ -852,13 +880,22 @@ impl SQLContext { let mut nulls_last = Vec::with_capacity(ob.len()); let schema = Some(lf.schema_with_arenas(&mut self.lp_arena, &mut self.expr_arena)?); + let column_names = schema + .clone() + .unwrap() + .iter_names() + .map(|e| col(e)) + .collect::>(); + for ob in ob { // note: if not specified 'NULLS FIRST' is default for DESC, 'NULLS LAST' otherwise // https://www.postgresql.org/docs/current/queries-order.html - by.push(parse_sql_expr(&ob.expr, self, schema.as_deref())?); let desc_order = !ob.asc.unwrap_or(true); nulls_last.push(!ob.nulls_first.unwrap_or(desc_order)); descending.push(desc_order); + + // translate order expression, allowing ordinal values + by.push(self.expr_or_ordinal(&ob.expr, schema.as_deref(), &column_names, "ORDER BY")?) } Ok(lf.sort_by_exprs( &by, diff --git a/py-polars/tests/unit/sql/test_group_by.py b/py-polars/tests/unit/sql/test_group_by.py index d07895e793b6..80955a647d40 100644 --- a/py-polars/tests/unit/sql/test_group_by.py +++ b/py-polars/tests/unit/sql/test_group_by.py @@ -223,13 +223,13 @@ def test_group_by_errors() -> None: with pytest.raises( SQLSyntaxError, - match=r"expected a positive integer or valid expression; got -99", + match=r"negative ordinals values are invalid for GROUP BY; found -99", ): df.sql("SELECT a, SUM(b) FROM self GROUP BY -99, a") with pytest.raises( SQLSyntaxError, - match=r"expected a positive integer or valid expression; got '!!!'", + match=r"GROUP BY requires a valid expression or positive ordinal; found '!!!'", ): df.sql("SELECT a, SUM(b) FROM self GROUP BY a, '!!!'") diff --git a/py-polars/tests/unit/sql/test_order_by.py b/py-polars/tests/unit/sql/test_order_by.py index 170c4bb90b6b..364beb5a7583 100644 --- a/py-polars/tests/unit/sql/test_order_by.py +++ b/py-polars/tests/unit/sql/test_order_by.py @@ -5,6 +5,7 @@ import pytest import polars as pl +from polars.exceptions import SQLInterfaceError, SQLSyntaxError @pytest.fixture() @@ -189,3 +190,53 @@ def test_order_by_multi_nulls_first_last() -> None: "x": [None, None, 1, 3], "y": [None, 3, 2, 1], } + + +def test_order_by_ordinal() -> None: + df = pl.DataFrame({"x": [None, 1, None, 3], "y": [3, 2, None, 1]}) + + res = df.sql("SELECT * FROM self ORDER BY 1, 2") + assert res.to_dict(as_series=False) == { + "x": [1, 3, None, None], + "y": [2, 1, 3, None], + } + + res = df.sql("SELECT * FROM self ORDER BY 1 DESC, 2") + assert res.to_dict(as_series=False) == { + "x": [None, None, 3, 1], + "y": [3, None, 1, 2], + } + + res = df.sql("SELECT * FROM self ORDER BY 1 DESC NULLS LAST, 2 ASC") + assert res.to_dict(as_series=False) == { + "x": [3, 1, None, None], + "y": [1, 2, 3, None], + } + + res = df.sql("SELECT * FROM self ORDER BY 1 DESC NULLS LAST, 2 ASC NULLS FIRST") + assert res.to_dict(as_series=False) == { + "x": [3, 1, None, None], + "y": [1, 2, None, 3], + } + + res = df.sql("SELECT * FROM self ORDER BY 1 DESC, 2 DESC NULLS FIRST") + assert res.to_dict(as_series=False) == { + "x": [None, None, 3, 1], + "y": [None, 3, 1, 2], + } + + +def test_order_by_errors() -> None: + df = pl.DataFrame({"a": ["w", "x", "y", "z"], "b": [1, 2, 3, 4]}) + + with pytest.raises( + SQLInterfaceError, + match="ORDER BY ordinal value must refer to a valid column; found 99", + ): + df.sql("SELECT * FROM self ORDER BY 99") + + with pytest.raises( + SQLSyntaxError, + match="negative ordinals values are invalid for ORDER BY; found -1", + ): + df.sql("SELECT * FROM self ORDER BY -1")