Skip to content

Commit

Permalink
feat: Support use of ordinal values in SQL ORDER BY clause
Browse files Browse the repository at this point in the history
  • Loading branch information
alexander-beedie committed Jun 5, 2024
1 parent 6f3fd8e commit 25df658
Show file tree
Hide file tree
Showing 3 changed files with 119 additions and 31 deletions.
95 changes: 66 additions & 29 deletions crates/polars-sql/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,59 @@ impl SQLContext {
})
}

fn expr_or_ordinal(
&mut self,
e: &SQLExpr,
schema: Option<&Schema>,
exprs: &[Expr],
clause: &str,
) -> PolarsResult<Expr> {
match e {
SQLExpr::UnaryOp {
op: UnaryOperator::Minus,
expr,
} if matches!(**expr, SQLExpr::Value(SQLValue::Number(_, _))) => {
if let SQLExpr::Value(SQLValue::Number(ref idx, _)) = **expr {
Err(polars_err!(
SQLSyntax:
"negative ordinals values are invalid for {}; found -{}",
clause,
idx
))
} else {
unreachable!()
}
},
SQLExpr::Value(SQLValue::Number(idx, _)) => {
// note: sql queries are 1-indexed
let idx = idx.parse::<usize>().map_err(|_| {
polars_err!(
SQLSyntax:
"negative ordinals values are invalid for {}; found {}",
clause,
idx
)
})?;
Ok(exprs
.get(idx - 1)
.ok_or_else(|| {
polars_err!(
SQLInterface:
"{} ordinal value must refer to a valid column; found {}",
clause,
idx
)
})?
.clone())
},
SQLExpr::Value(v) => Err(polars_err!(
SQLSyntax:
"{} requires a valid expression or positive ordinal; found {}", clause, v,
)),
_ => parse_sql_expr(e, self, schema),
}
}

pub(super) fn resolve_name(&self, tbl_name: &str, column_name: &str) -> String {
if self.joined_aliases.borrow().contains_key(tbl_name) {
self.joined_aliases
Expand Down Expand Up @@ -473,36 +526,12 @@ impl SQLContext {
// Check for "GROUP BY ..." (after projections, as there may be ordinal/position ints).
let mut group_by_keys: Vec<Expr> = Vec::new();
match &select_stmt.group_by {
// Standard "GROUP BY x, y, z" syntax
// Standard "GROUP BY x, y, z" syntax (also recognising ordinal values)
GroupByExpr::Expressions(group_by_exprs) => {
// translate the group expressions, allowing ordinal values
group_by_keys = group_by_exprs
.iter()
.map(|e| match e {
SQLExpr::UnaryOp {
op: UnaryOperator::Minus,
expr,
} if matches!(**expr, SQLExpr::Value(SQLValue::Number(_, _))) => {
if let SQLExpr::Value(SQLValue::Number(ref idx, _)) = **expr {
Err(polars_err!(
SQLSyntax:
"GROUP BY error: expected a positive integer or valid expression; got -{}",
idx
))
} else {
unreachable!()
}
},
SQLExpr::Value(SQLValue::Number(idx, _)) => {
// note: sql queries are 1-indexed
let idx = idx.parse::<usize>().unwrap();
Ok(projections[idx - 1].clone())
},
SQLExpr::Value(v) => Err(polars_err!(
SQLSyntax:
"GROUP BY error: expected a positive integer or valid expression; got {}", v,
)),
_ => parse_sql_expr(e, self, schema.as_deref()),
})
.map(|e| self.expr_or_ordinal(e, schema.as_deref(), &projections, "GROUP BY"))
.collect::<PolarsResult<_>>()?
},
// "GROUP BY ALL" syntax; automatically adds expressions that do not contain
Expand Down Expand Up @@ -838,7 +867,6 @@ impl SQLContext {
.unwrap_or_else(|| tbl_name);

self.table_map.insert(tbl_name.clone(), lf.clone());

Ok((tbl_name, lf))
}

Expand All @@ -852,13 +880,22 @@ impl SQLContext {
let mut nulls_last = Vec::with_capacity(ob.len());

let schema = Some(lf.schema_with_arenas(&mut self.lp_arena, &mut self.expr_arena)?);
let column_names = schema
.clone()
.unwrap()
.iter_names()
.map(|e| col(e))
.collect::<Vec<_>>();

for ob in ob {
// note: if not specified 'NULLS FIRST' is default for DESC, 'NULLS LAST' otherwise
// https://www.postgresql.org/docs/current/queries-order.html
by.push(parse_sql_expr(&ob.expr, self, schema.as_deref())?);
let desc_order = !ob.asc.unwrap_or(true);
nulls_last.push(!ob.nulls_first.unwrap_or(desc_order));
descending.push(desc_order);

// translate order expression, allowing ordinal values
by.push(self.expr_or_ordinal(&ob.expr, schema.as_deref(), &column_names, "ORDER BY")?)
}
Ok(lf.sort_by_exprs(
&by,
Expand Down
4 changes: 2 additions & 2 deletions py-polars/tests/unit/sql/test_group_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,13 +223,13 @@ def test_group_by_errors() -> None:

with pytest.raises(
SQLSyntaxError,
match=r"expected a positive integer or valid expression; got -99",
match=r"negative ordinals values are invalid for GROUP BY; found -99",
):
df.sql("SELECT a, SUM(b) FROM self GROUP BY -99, a")

with pytest.raises(
SQLSyntaxError,
match=r"expected a positive integer or valid expression; got '!!!'",
match=r"GROUP BY requires a valid expression or positive ordinal; found '!!!'",
):
df.sql("SELECT a, SUM(b) FROM self GROUP BY a, '!!!'")

Expand Down
51 changes: 51 additions & 0 deletions py-polars/tests/unit/sql/test_order_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pytest

import polars as pl
from polars.exceptions import SQLInterfaceError, SQLSyntaxError


@pytest.fixture()
Expand Down Expand Up @@ -189,3 +190,53 @@ def test_order_by_multi_nulls_first_last() -> None:
"x": [None, None, 1, 3],
"y": [None, 3, 2, 1],
}


def test_order_by_ordinal() -> None:
df = pl.DataFrame({"x": [None, 1, None, 3], "y": [3, 2, None, 1]})

res = df.sql("SELECT * FROM self ORDER BY 1, 2")
assert res.to_dict(as_series=False) == {
"x": [1, 3, None, None],
"y": [2, 1, 3, None],
}

res = df.sql("SELECT * FROM self ORDER BY 1 DESC, 2")
assert res.to_dict(as_series=False) == {
"x": [None, None, 3, 1],
"y": [3, None, 1, 2],
}

res = df.sql("SELECT * FROM self ORDER BY 1 DESC NULLS LAST, 2 ASC")
assert res.to_dict(as_series=False) == {
"x": [3, 1, None, None],
"y": [1, 2, 3, None],
}

res = df.sql("SELECT * FROM self ORDER BY 1 DESC NULLS LAST, 2 ASC NULLS FIRST")
assert res.to_dict(as_series=False) == {
"x": [3, 1, None, None],
"y": [1, 2, None, 3],
}

res = df.sql("SELECT * FROM self ORDER BY 1 DESC, 2 DESC NULLS FIRST")
assert res.to_dict(as_series=False) == {
"x": [None, None, 3, 1],
"y": [None, 3, 1, 2],
}


def test_order_by_errors() -> None:
df = pl.DataFrame({"a": ["w", "x", "y", "z"], "b": [1, 2, 3, 4]})

with pytest.raises(
SQLInterfaceError,
match="ORDER BY ordinal value must refer to a valid column; found 99",
):
df.sql("SELECT * FROM self ORDER BY 99")

with pytest.raises(
SQLSyntaxError,
match="negative ordinals values are invalid for ORDER BY; found -1",
):
df.sql("SELECT * FROM self ORDER BY -1")

0 comments on commit 25df658

Please sign in to comment.