Skip to content

Commit

Permalink
feat(rust,python,cli): support negative indexing and expressions for …
Browse files Browse the repository at this point in the history
…`LEFT`,`RIGHT` and `SUBSTR` string funcs
  • Loading branch information
alexander-beedie committed Jan 21, 2024
1 parent 3259c29 commit 52bd9bc
Show file tree
Hide file tree
Showing 3 changed files with 151 additions and 64 deletions.
93 changes: 55 additions & 38 deletions crates/polars-sql/src/functions.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use polars_core::prelude::{polars_bail, polars_err, PolarsResult};
use polars_core::prelude::{polars_bail, polars_err, DataType, PolarsResult};
use polars_lazy::dsl::Expr;
#[cfg(feature = "list_eval")]
use polars_lazy::dsl::ListNameSpaceExtension;
Expand Down Expand Up @@ -860,13 +860,21 @@ impl SQLFunctionVisitor<'_> {
#[cfg(feature = "nightly")]
InitCap => self.visit_unary(|e| e.str().to_titlecase()),
Left => self.try_visit_binary(|e, length| {
Ok(e.str().slice(lit(0), match length {
Expr::Literal(LiteralValue::Int64(n)) => lit(n as u64),
Ok(match length {
Expr::Literal(Null) => lit(Null),
Expr::Literal(LiteralValue::Int64(0)) => lit(""),
Expr::Literal(LiteralValue::Int64(n)) => {
let len = if n > 0 { lit(n) } else { (e.clone().str().len_chars() + lit(n)).clip_min(lit(0)) };
e.str().slice(lit(0), len)
},
Expr::Literal(_) => polars_bail!(InvalidOperation: "Invalid 'n_chars' for Left: {}", function.args[1]),
_ => {
polars_bail!(InvalidOperation: "Invalid 'length' for Left: {}", function.args[1]);
when(length.clone().gt_eq(lit(0)))
.then(e.clone().str().slice(lit(0), length.clone().abs()))
.otherwise(e.clone().str().slice(lit(0), (e.clone().str().len_chars() + length.clone()).clip_min(lit(0))))
}
}))
}),
}
)}),
Length => self.visit_unary(|e| e.str().len_chars()),
Lower => self.visit_unary(|e| e.str().to_lowercase()),
LTrim => match function.args.len() {
Expand Down Expand Up @@ -902,51 +910,60 @@ impl SQLFunctionVisitor<'_> {
3 => self.try_visit_ternary(|e, old, new| {
Ok(e.str().replace_all(old, new, true))
}),
_ => polars_bail!(InvalidOperation:
"Invalid number of arguments for Replace: {}",
function.args.len()
),
_ => polars_bail!(InvalidOperation: "Invalid number of arguments for Replace: {}", function.args.len()),
},
Reverse => self.visit_unary(|e| e.str().reverse()),
Right => self.try_visit_binary(|e, length| {
Ok(e.str().slice( match length {
Expr::Literal(LiteralValue::Int64(n)) => lit(-n),
Ok(match length {
Expr::Literal(Null) => lit(Null),
Expr::Literal(LiteralValue::Int64(0)) => lit(""),
Expr::Literal(LiteralValue::Int64(n)) => {
let offset = if n < 0 { lit(n.abs()) } else { e.clone().str().len_chars().cast(DataType::Int32) - lit(n) };
e.str().slice(offset, lit(Null))
},
Expr::Literal(_) => polars_bail!(InvalidOperation: "Invalid 'n_chars' for Right: {}", function.args[1]),
_ => {
polars_bail!(InvalidOperation: "Invalid 'length' for Right: {}", function.args[1]);
when(length.clone().lt(lit(0)))
.then(e.clone().str().slice(length.clone().abs(), lit(Null)))
.otherwise(e.clone().str().slice(e.clone().str().len_chars().cast(DataType::Int32) - length.clone(), lit(Null)))
}
}, lit(Null)))
}),
}
)}),
RTrim => match function.args.len() {
1 => self.visit_unary(|e| e.str().strip_chars_end(lit(Null))),
2 => self.visit_binary(|e, s| e.str().strip_chars_end(s)),
_ => polars_bail!(InvalidOperation:
"Invalid number of arguments for RTrim: {}",
function.args.len()
),
_ => polars_bail!(InvalidOperation: "Invalid number of arguments for RTrim: {}", function.args.len()),
},
StartsWith => self.visit_binary(|e, s| e.str().starts_with(s)),
Substring => match function.args.len() {
// note that SQL is 1-indexed, not 0-indexed
// note that SQL is 1-indexed, not 0-indexed, hence the need for adjustments
2 => self.try_visit_binary(|e, start| {
Ok(e.str().slice(
match start {
Expr::Literal(LiteralValue::Int64(n)) => lit(n - 1) ,
_ => polars_bail!(InvalidOperation: "Invalid 'start' for Substring: {}", function.args[1]),
}, lit(Null)))
Ok(match start {
Expr::Literal(Null) => lit(Null),
Expr::Literal(LiteralValue::Int64(n)) if n <= 0 => e,
Expr::Literal(LiteralValue::Int64(n)) => e.str().slice(lit(n - 1), lit(Null)),
Expr::Literal(_) => polars_bail!(InvalidOperation: "Invalid 'start' for Substring: {}", function.args[1]),
_ => start.clone() + lit(1),
})
}),
3 => self.try_visit_ternary(|e, start, length| {
Ok(e.str().slice(
match start {
Expr::Literal(LiteralValue::Int64(n)) => lit(n - 1),
_ => {
polars_bail!(InvalidOperation: "Invalid 'start' for Substring: {}", function.args[1]);
}
}, match length {
Expr::Literal(LiteralValue::Int64(n)) => lit(n as u64),
_ => {
polars_bail!(InvalidOperation: "Invalid 'length' for Substring: {}", function.args[2]);
}
}))
3 => self.try_visit_ternary(|e: Expr, start: Expr, length: Expr| {
Ok(match (start.clone(), length.clone()) {
(Expr::Literal(Null), _) | (_, Expr::Literal(Null)) => lit(Null),
(Expr::Literal(LiteralValue::Int64(n)), _) if n > 0 => e.str().slice(lit(n - 1), length.clone()),
(Expr::Literal(LiteralValue::Int64(n)), _) => {
e.str().slice(lit(0), (length.clone() + lit(n - 1)).clip_min(lit(0)))
},
(Expr::Literal(_), _) => polars_bail!(InvalidOperation: "Invalid 'start' for Substring: {}", function.args[1]),
(_, Expr::Literal(LiteralValue::Float64(_))) => {
polars_bail!(InvalidOperation: "Invalid 'length' for Substring: {}", function.args[1])
},
_ => {
let adjusted_start = start.clone() - lit(1);
when(adjusted_start.clone().lt(lit(0)))
.then(e.clone().str().slice(lit(0), (length.clone() + adjusted_start.clone()).clip_min(lit(0))))
.otherwise(e.clone().str().slice(adjusted_start.clone(), length.clone()))
}
})
}),
_ => polars_bail!(InvalidOperation: "Invalid number of arguments for Substring: {}", function.args.len()),
}
Expand Down
21 changes: 10 additions & 11 deletions crates/polars-sql/src/sql_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -378,10 +378,16 @@ impl SQLExprVisitor<'_> {
/// e.g. +column or -column
fn visit_unary_op(&mut self, op: &UnaryOperator, expr: &SQLExpr) -> PolarsResult<Expr> {
let expr = self.visit_expr(expr)?;
Ok(match op {
UnaryOperator::Plus => lit(0) + expr,
UnaryOperator::Minus => lit(0) - expr,
UnaryOperator::Not => expr.not(),
Ok(match (op, expr.clone()) {
// simplify the parse tree by special-casing common unary +/- ops
(UnaryOperator::Plus, Expr::Literal(LiteralValue::Int64(n))) => lit(n),
(UnaryOperator::Plus, Expr::Literal(LiteralValue::Float64(n))) => lit(n),
(UnaryOperator::Minus, Expr::Literal(LiteralValue::Int64(n))) => lit(-n),
(UnaryOperator::Minus, Expr::Literal(LiteralValue::Float64(n))) => lit(-n),
// general case
(UnaryOperator::Plus, _) => lit(0) + expr,
(UnaryOperator::Minus, _) => lit(0) - expr,
(UnaryOperator::Not, _) => expr.not(),
other => polars_bail!(InvalidOperation: "Unary operator {:?} is not supported", other),
})
}
Expand Down Expand Up @@ -609,27 +615,20 @@ impl SQLExprVisitor<'_> {
/// Visit a SQL `ARRAY_AGG` expression.
fn visit_arr_agg(&mut self, expr: &ArrayAgg) -> PolarsResult<Expr> {
let mut base = self.visit_expr(&expr.expr)?;

if let Some(order_by) = expr.order_by.as_ref() {
let (order_by, descending) = self.visit_order_by(order_by)?;
base = base.sort_by(order_by, descending);
}

if let Some(limit) = &expr.limit {
let limit = match self.visit_expr(limit)? {
Expr::Literal(LiteralValue::UInt32(n)) => n as usize,
Expr::Literal(LiteralValue::UInt64(n)) => n as usize,
Expr::Literal(LiteralValue::Int32(n)) => n as usize,
Expr::Literal(LiteralValue::Int64(n)) => n as usize,
_ => polars_bail!(ComputeError: "limit in ARRAY_AGG must be a positive integer"),
};
base = base.head(Some(limit));
}

if expr.distinct {
base = base.unique_stable();
}

polars_ensure!(
!expr.within_group,
ComputeError: "ARRAY_AGG WITHIN GROUP is not yet supported"
Expand Down
101 changes: 86 additions & 15 deletions py-polars/tests/unit/sql/test_strings.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,14 +98,82 @@ def test_string_left_right_reverse() -> None:
"r": ["de", "bc", "a", None],
"rev": ["edcba", "cba", "a", None],
}
for func, invalid in (("LEFT", "'xyz'"), ("RIGHT", "-1")):
for func, invalid in (("LEFT", "'xyz'"), ("RIGHT", "6.66")):
with pytest.raises(
InvalidOperationError,
match=f"Invalid 'length' for {func.capitalize()}: {invalid}",
match=f"Invalid 'n_chars' for {func.capitalize()}: {invalid}",
):
ctx.execute(f"""SELECT {func}(txt,{invalid}) FROM df""").collect()


def test_string_left_negative_expr() -> None:
# negative values and expressions
df = pl.DataFrame({"s": ["alphabet", "alphabet"], "n": [-6, 6]})
with pl.SQLContext(df=df, eager_execution=True) as sql:
res = sql.execute(
"""
SELECT
LEFT("s",-50) AS l0, -- empty string
LEFT("s",-3) AS l1, -- all but last three chars
LEFT("s",SIGN(-1)) AS l2, -- all but last char (expr => -1)
LEFT("s",0) AS l3, -- empty string
LEFT("s",NULL) AS l4, -- null
LEFT("s",1) AS l5, -- first char
LEFT("s",SIGN(1)) AS l6, -- first char (expr => 1)
LEFT("s",3) AS l7, -- first three chars
LEFT("s",50) AS l8, -- entire string
LEFT("s","n") AS l9, -- from other col
FROM df
"""
)
assert res.to_dict(as_series=False) == {
"l0": ["", ""],
"l1": ["alpha", "alpha"],
"l2": ["alphabe", "alphabe"],
"l3": ["", ""],
"l4": [None, None],
"l5": ["a", "a"],
"l6": ["a", "a"],
"l7": ["alp", "alp"],
"l8": ["alphabet", "alphabet"],
"l9": ["al", "alphab"],
}


def test_string_right_negative_expr() -> None:
# negative values and expressions
df = pl.DataFrame({"s": ["alphabet", "alphabet"], "n": [-6, 6]})
with pl.SQLContext(df=df, eager_execution=True) as sql:
res = sql.execute(
"""
SELECT
RIGHT("s",-50) AS l0, -- empty string
RIGHT("s",-3) AS l1, -- all but first three chars
RIGHT("s",SIGN(-1)) AS l2, -- all but first char (expr => -1)
RIGHT("s",0) AS l3, -- empty string
RIGHT("s",NULL) AS l4, -- null
RIGHT("s",1) AS l5, -- last char
RIGHT("s",SIGN(1)) AS l6, -- last char (expr => 1)
RIGHT("s",3) AS l7, -- last three chars
RIGHT("s",50) AS l8, -- entire string
RIGHT("s","n") AS l9, -- from other col
FROM df
"""
)
assert res.to_dict(as_series=False) == {
"l0": ["", ""],
"l1": ["habet", "habet"],
"l2": ["lphabet", "lphabet"],
"l3": ["", ""],
"l4": [None, None],
"l5": ["t", "t"],
"l6": ["t", "t"],
"l7": ["bet", "bet"],
"l8": ["alphabet", "alphabet"],
"l9": ["et", "phabet"],
}


def test_string_lengths() -> None:
df = pl.DataFrame({"words": ["Café", None, "東京", ""]})

Expand Down Expand Up @@ -260,12 +328,17 @@ def test_string_substr() -> None:
"""
SELECT
-- note: sql is 1-indexed
SUBSTR(scol,1) AS s1,
SUBSTR(scol,2) AS s2,
SUBSTR(scol,3) AS s3,
SUBSTR(scol,1,5) AS s1_5,
SUBSTR(scol,2,2) AS s2_2,
SUBSTR(scol,3,1) AS s3_1,
SUBSTR(scol,1) AS s1,
SUBSTR(scol,2) AS s2,
SUBSTR(scol,3) AS s3,
SUBSTR(scol,1,5) AS s1_5,
SUBSTR(scol,2,2) AS s2_2,
SUBSTR(scol,3,1) AS s3_1,
SUBSTR(scol,-3) AS "s-3",
SUBSTR(scol,-3,3) AS "s-3_3",
SUBSTR(scol,-3,4) AS "s-3_4",
SUBSTR(scol,-3,5) AS "s-3_5",
SUBSTR(scol,-10,13) AS "s-10_13"
FROM df
"""
).collect()
Expand All @@ -277,15 +350,13 @@ def test_string_substr() -> None:
"s1_5": ["abcde", "abcde", "abc", None],
"s2_2": ["bc", "bc", "bc", None],
"s3_1": ["c", "c", "c", None],
"s-3": ["abcdefg", "abcde", "abc", None],
"s-3_3": ["", "", "", None],
"s-3_4": ["", "", "", None],
"s-3_5": ["a", "a", "a", None],
"s-10_13": ["ab", "ab", "ab", None],
}

# negative indexes are expected to be invalid
with pytest.raises(
InvalidOperationError,
match="Invalid 'start' for Substring: -1",
), pl.SQLContext(df=df) as ctx:
ctx.execute("SELECT SUBSTR(scol,-1) FROM df")


def test_string_trim(foods_ipc_path: Path) -> None:
lf = pl.scan_ipc(foods_ipc_path)
Expand Down

0 comments on commit 52bd9bc

Please sign in to comment.