From 26fe8ace8b1499614c39f3f94805d718dbd02d62 Mon Sep 17 00:00:00 2001 From: Weijie Guo Date: Fri, 19 Jan 2024 18:15:33 +0800 Subject: [PATCH] feat: Expressify `str.zfill` (#13790) --- .../src/chunked_array/strings/namespace.rs | 2 +- .../src/chunked_array/strings/pad.rs | 80 +++++++++++-------- .../src/dsl/function_expr/strings.rs | 16 ++-- crates/polars-plan/src/dsl/string.rs | 5 +- py-polars/polars/expr/string.py | 3 +- py-polars/polars/series/string.py | 2 +- py-polars/src/expr/string.rs | 4 +- .../tests/unit/namespaces/string/test_pad.py | 22 +++++ 8 files changed, 87 insertions(+), 47 deletions(-) diff --git a/crates/polars-ops/src/chunked_array/strings/namespace.rs b/crates/polars-ops/src/chunked_array/strings/namespace.rs index f5b0e9bd6757..8c59e1fdbeea 100644 --- a/crates/polars-ops/src/chunked_array/strings/namespace.rs +++ b/crates/polars-ops/src/chunked_array/strings/namespace.rs @@ -232,7 +232,7 @@ pub trait StringNameSpaceImpl: AsString { /// Strings with length equal to or greater than the given length are /// returned as-is. #[cfg(feature = "string_pad")] - fn zfill(&self, length: usize) -> StringChunked { + fn zfill(&self, length: &UInt64Chunked) -> StringChunked { let ca = self.as_string(); pad::zfill(ca, length) } diff --git a/crates/polars-ops/src/chunked_array/strings/pad.rs b/crates/polars-ops/src/chunked_array/strings/pad.rs index d776c435c137..8e1bbe4a1dba 100644 --- a/crates/polars-ops/src/chunked_array/strings/pad.rs +++ b/crates/polars-ops/src/chunked_array/strings/pad.rs @@ -1,6 +1,7 @@ use std::fmt::Write; -use polars_core::prelude::StringChunked; +use polars_core::prelude::arity::broadcast_binary_elementwise; +use polars_core::prelude::{StringChunked, UInt64Chunked}; pub(super) fn pad_end<'a>(ca: &'a StringChunked, length: usize, fill_char: char) -> StringChunked { // amortize allocation @@ -50,38 +51,51 @@ pub(super) fn pad_start<'a>( ca.apply_mut(f) } -pub(super) fn zfill<'a>(ca: &'a StringChunked, length: usize) -> StringChunked { +fn zfill_fn<'a>(s: Option<&'a str>, len: Option, buf: &mut String) -> Option<&'a str> { + match (s, len) { + (Some(s), Some(length)) => { + let length = length.saturating_sub(s.len() as u64); + if length == 0 { + return Some(s); + } + buf.clear(); + if let Some(stripped) = s.strip_prefix('-') { + write!( + buf, + "-{:0length$}{value}", + 0, + length = length as usize, + value = stripped + ) + .unwrap(); + } else { + write!( + buf, + "{:0length$}{value}", + 0, + length = length as usize, + value = s + ) + .unwrap(); + }; + // extend lifetime + // lifetime is bound to 'a + let slice = buf.as_str(); + Some(unsafe { std::mem::transmute::<&str, &'a str>(slice) }) + }, + _ => None, + } +} + +pub(super) fn zfill<'a>(ca: &'a StringChunked, length: &'a UInt64Chunked) -> StringChunked { // amortize allocation let mut buf = String::new(); - let f = |s: &'a str| { - let length = length.saturating_sub(s.len()); - if length == 0 { - return s; - } - buf.clear(); - if let Some(stripped) = s.strip_prefix('-') { - write!( - &mut buf, - "-{:0length$}{value}", - 0, - length = length, - value = stripped - ) - .unwrap(); - } else { - write!( - &mut buf, - "{:0length$}{value}", - 0, - length = length, - value = s - ) - .unwrap(); - }; - // extend lifetime - // lifetime is bound to 'a - let slice = buf.as_str(); - unsafe { std::mem::transmute::<&str, &'a str>(slice) } - }; - ca.apply_mut(f) + fn infer FnMut(Option<&'a str>, Option) -> Option<&'a str>>(f: F) -> F where { + f + } + broadcast_binary_elementwise( + ca, + length, + infer(|opt_s, opt_len| zfill_fn(opt_s, opt_len, &mut buf)), + ) } diff --git a/crates/polars-plan/src/dsl/function_expr/strings.rs b/crates/polars-plan/src/dsl/function_expr/strings.rs index 76e55e83272e..b673da76b182 100644 --- a/crates/polars-plan/src/dsl/function_expr/strings.rs +++ b/crates/polars-plan/src/dsl/function_expr/strings.rs @@ -108,7 +108,7 @@ pub enum StringFunction { Titlecase, Uppercase, #[cfg(feature = "string_pad")] - ZFill(usize), + ZFill, #[cfg(feature = "find_many")] ContainsMany { ascii_case_insensitive: bool, @@ -163,7 +163,7 @@ impl StringFunction { Uppercase | Lowercase | StripChars | StripCharsStart | StripCharsEnd | StripPrefix | StripSuffix | Slice => mapper.with_same_dtype(), #[cfg(feature = "string_pad")] - PadStart { .. } | PadEnd { .. } | ZFill { .. } => mapper.with_same_dtype(), + PadStart { .. } | PadEnd { .. } | ZFill => mapper.with_same_dtype(), #[cfg(feature = "dtype-struct")] SplitExact { n, .. } => mapper.with_dtype(DataType::Struct( (0..n + 1) @@ -257,7 +257,7 @@ impl Display for StringFunction { ToDecimal(_) => "to_decimal", Uppercase => "uppercase", #[cfg(feature = "string_pad")] - ZFill(_) => "zfill", + ZFill => "zfill", #[cfg(feature = "find_many")] ContainsMany { .. } => "contains_many", #[cfg(feature = "find_many")] @@ -298,8 +298,8 @@ impl From for SpecialEq> { map!(strings::pad_start, length, fill_char) }, #[cfg(feature = "string_pad")] - ZFill(alignment) => { - map!(strings::zfill, alignment) + ZFill => { + map_as_slice!(strings::zfill) }, #[cfg(feature = "temporal")] Strptime(dtype, options) => { @@ -472,8 +472,10 @@ pub(super) fn pad_end(s: &Series, length: usize, fill_char: char) -> PolarsResul } #[cfg(feature = "string_pad")] -pub(super) fn zfill(s: &Series, length: usize) -> PolarsResult { - let ca = s.str()?; +pub(super) fn zfill(s: &[Series]) -> PolarsResult { + let ca = s[0].str()?; + let length_s = s[1].strict_cast(&DataType::UInt64)?; + let length = length_s.u64()?; Ok(ca.zfill(length).into_series()) } diff --git a/crates/polars-plan/src/dsl/string.rs b/crates/polars-plan/src/dsl/string.rs index 5b4ae8a1f05e..42a7cb2471fd 100644 --- a/crates/polars-plan/src/dsl/string.rs +++ b/crates/polars-plan/src/dsl/string.rs @@ -191,8 +191,9 @@ impl StringNameSpace { /// Strings with length equal to or greater than the given length are /// returned as-is. #[cfg(feature = "string_pad")] - pub fn zfill(self, length: usize) -> Expr { - self.0.map_private(StringFunction::ZFill(length).into()) + pub fn zfill(self, length: Expr) -> Expr { + self.0 + .map_many_private(StringFunction::ZFill.into(), &[length], false, false) } /// Find the index of a literal substring within another string value. diff --git a/py-polars/polars/expr/string.py b/py-polars/polars/expr/string.py index 2fb77a061758..2b2f7b16baed 100644 --- a/py-polars/polars/expr/string.py +++ b/py-polars/polars/expr/string.py @@ -919,7 +919,7 @@ def pad_end(self, length: int, fill_char: str = " ") -> Expr: return wrap_expr(self._pyexpr.str_pad_end(length, fill_char)) @deprecate_renamed_parameter("alignment", "length", version="0.19.12") - def zfill(self, length: int) -> Expr: + def zfill(self, length: int | IntoExprColumn) -> Expr: """ Pad the start of the string with zeros until it reaches the given length. @@ -957,6 +957,7 @@ def zfill(self, length: int) -> Expr: │ null ┆ null │ └────────┴────────┘ """ + length = parse_as_expression(length) return wrap_expr(self._pyexpr.str_zfill(length)) def contains( diff --git a/py-polars/polars/series/string.py b/py-polars/polars/series/string.py index 3f801c35ec33..af4907c9a7d8 100644 --- a/py-polars/polars/series/string.py +++ b/py-polars/polars/series/string.py @@ -1416,7 +1416,7 @@ def pad_end(self, length: int, fill_char: str = " ") -> Series: """ @deprecate_renamed_parameter("alignment", "length", version="0.19.12") - def zfill(self, length: int) -> Series: + def zfill(self, length: int | IntoExprColumn) -> Series: """ Pad the start of the string with zeros until it reaches the given length. diff --git a/py-polars/src/expr/string.rs b/py-polars/src/expr/string.rs index aba3cc748244..e4e8b7bcceb7 100644 --- a/py-polars/src/expr/string.rs +++ b/py-polars/src/expr/string.rs @@ -157,8 +157,8 @@ impl PyExpr { self.inner.clone().str().pad_end(length, fill_char).into() } - fn str_zfill(&self, length: usize) -> Self { - self.inner.clone().str().zfill(length).into() + fn str_zfill(&self, length: Self) -> Self { + self.inner.clone().str().zfill(length.inner).into() } #[pyo3(signature = (pat, literal, strict))] diff --git a/py-polars/tests/unit/namespaces/string/test_pad.py b/py-polars/tests/unit/namespaces/string/test_pad.py index 2b8e5c032817..7364cf5fb9ba 100644 --- a/py-polars/tests/unit/namespaces/string/test_pad.py +++ b/py-polars/tests/unit/namespaces/string/test_pad.py @@ -68,6 +68,28 @@ def test_str_zfill() -> None: assert df["num"].cast(str).str.zfill(5).to_list() == out +def test_str_zfill_expr() -> None: + df = pl.DataFrame( + { + "num": ["-10", "-1", "0", "1", "10", None, "1"], + "len": [3, 4, 3, 2, 5, 3, None], + } + ) + out = df.select( + all_expr=pl.col("num").str.zfill(pl.col("len")), + str_lit=pl.lit("10").str.zfill(pl.col("len")), + len_lit=pl.col("num").str.zfill(5), + ) + expected = pl.DataFrame( + { + "all_expr": ["-10", "-001", "000", "01", "00010", None, None], + "str_lit": ["010", "0010", "010", "10", "00010", "010", None], + "len_lit": ["-0010", "-0001", "00000", "00001", "00010", None, "00001"], + } + ) + assert_frame_equal(out, expected) + + def test_str_ljust_deprecated() -> None: s = pl.Series(["a", "bc", "def"])