Skip to content

Commit

Permalink
feat: Add ignore_nulls for arr.join (pola-rs#13919)
Browse files Browse the repository at this point in the history
  • Loading branch information
reswqa authored and r-brink committed Jan 24, 2024
1 parent 686b7d6 commit b798a74
Show file tree
Hide file tree
Showing 10 changed files with 92 additions and 34 deletions.
43 changes: 32 additions & 11 deletions crates/polars-ops/src/chunked_array/array/join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@ use polars_core::prelude::ArrayChunked;

use super::*;

fn join_literal(ca: &ArrayChunked, separator: &str) -> PolarsResult<StringChunked> {
fn join_literal(
ca: &ArrayChunked,
separator: &str,
ignore_nulls: bool,
) -> PolarsResult<StringChunked> {
let DataType::Array(_, _) = ca.dtype() else {
unreachable!()
};
Expand All @@ -13,47 +17,60 @@ fn join_literal(ca: &ArrayChunked, separator: &str) -> PolarsResult<StringChunke
let mut builder = StringChunkedBuilder::new(ca.name(), ca.len());

ca.for_each_amortized(|opt_s| {
let opt_val = opt_s.map(|s| {
let opt_val = opt_s.and_then(|s| {
// make sure that we don't write values of previous iteration
buf.clear();
let ca = s.as_ref().str().unwrap();

let iter = ca.into_iter().map(|opt_v| opt_v.unwrap_or("null"));
if ca.null_count() != 0 && !ignore_nulls {
return None;
}

let iter = ca.into_iter().flatten();

for val in iter {
buf.write_str(val).unwrap();
buf.write_str(separator).unwrap();
}
// last value should not have a separator, so slice that off
// saturating sub because there might have been nothing written.
&buf[..buf.len().saturating_sub(separator.len())]
Some(&buf[..buf.len().saturating_sub(separator.len())])
});
builder.append_option(opt_val)
});
Ok(builder.finish())
}

fn join_many(ca: &ArrayChunked, separator: &StringChunked) -> PolarsResult<StringChunked> {
fn join_many(
ca: &ArrayChunked,
separator: &StringChunked,
ignore_nulls: bool,
) -> PolarsResult<StringChunked> {
let mut buf = String::new();
let mut builder = StringChunkedBuilder::new(ca.name(), ca.len());

ca.amortized_iter()
.zip(separator)
.for_each(|(opt_s, opt_sep)| match opt_sep {
Some(separator) => {
let opt_val = opt_s.map(|s| {
let opt_val = opt_s.and_then(|s| {
// make sure that we don't write values of previous iteration
buf.clear();
let ca = s.as_ref().str().unwrap();
let iter = ca.into_iter().map(|opt_v| opt_v.unwrap_or("null"));

if ca.null_count() != 0 && !ignore_nulls {
return None;
}

let iter = ca.into_iter().flatten();

for val in iter {
buf.write_str(val).unwrap();
buf.write_str(separator).unwrap();
}
// last value should not have a separator, so slice that off
// saturating sub because there might have been nothing written.
&buf[..buf.len().saturating_sub(separator.len())]
Some(&buf[..buf.len().saturating_sub(separator.len())])
});
builder.append_option(opt_val)
},
Expand All @@ -64,14 +81,18 @@ fn join_many(ca: &ArrayChunked, separator: &StringChunked) -> PolarsResult<Strin

/// In case the inner dtype [`DataType::String`], the individual items will be joined into a
/// single string separated by `separator`.
pub fn array_join(ca: &ArrayChunked, separator: &StringChunked) -> PolarsResult<StringChunked> {
pub fn array_join(
ca: &ArrayChunked,
separator: &StringChunked,
ignore_nulls: bool,
) -> PolarsResult<StringChunked> {
match ca.inner_dtype() {
DataType::String => match separator.len() {
1 => match separator.get(0) {
Some(separator) => join_literal(ca, separator),
Some(separator) => join_literal(ca, separator, ignore_nulls),
_ => Ok(StringChunked::full_null(ca.name(), ca.len())),
},
_ => join_many(ca, separator),
_ => join_many(ca, separator, ignore_nulls),
},
dt => polars_bail!(op = "`array.join`", got = dt, expected = "String"),
}
Expand Down
4 changes: 2 additions & 2 deletions crates/polars-ops/src/chunked_array/array/namespace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,9 +104,9 @@ pub trait ArrayNameSpace: AsArray {
array_get(ca, index)
}

fn array_join(&self, separator: &StringChunked) -> PolarsResult<Series> {
fn array_join(&self, separator: &StringChunked, ignore_nulls: bool) -> PolarsResult<Series> {
let ca = self.as_array();
array_join(ca, separator).map(|ok| ok.into_series())
array_join(ca, separator, ignore_nulls).map(|ok| ok.into_series())
}

#[cfg(feature = "array_count")]
Expand Down
4 changes: 2 additions & 2 deletions crates/polars-plan/src/dsl/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,9 @@ impl ArrayNameSpace {
/// Join all string items in a sub-array and place a separator between them.
/// # Error
/// Raise if inner type of array is not `DataType::String`.
pub fn join(self, separator: Expr) -> Expr {
pub fn join(self, separator: Expr, ignore_nulls: bool) -> Expr {
self.0.map_many_private(
FunctionExpr::ArrayExpr(ArrayFunction::Join),
FunctionExpr::ArrayExpr(ArrayFunction::Join(ignore_nulls)),
&[separator],
false,
false,
Expand Down
12 changes: 6 additions & 6 deletions crates/polars-plan/src/dsl/function_expr/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ pub enum ArrayFunction {
ArgMin,
ArgMax,
Get,
Join,
Join(bool),
#[cfg(feature = "is_in")]
Contains,
#[cfg(feature = "array_count")]
Expand All @@ -41,7 +41,7 @@ impl ArrayFunction {
Reverse => mapper.with_same_dtype(),
ArgMin | ArgMax => mapper.with_dtype(IDX_DTYPE),
Get => mapper.map_to_list_and_array_inner_dtype(),
Join => mapper.with_dtype(DataType::String),
Join(_) => mapper.with_dtype(DataType::String),
#[cfg(feature = "is_in")]
Contains => mapper.with_dtype(DataType::Boolean),
#[cfg(feature = "array_count")]
Expand Down Expand Up @@ -76,7 +76,7 @@ impl Display for ArrayFunction {
ArgMin => "arg_min",
ArgMax => "arg_max",
Get => "get",
Join => "join",
Join(_) => "join",
#[cfg(feature = "is_in")]
Contains => "contains",
#[cfg(feature = "array_count")]
Expand Down Expand Up @@ -104,7 +104,7 @@ impl From<ArrayFunction> for SpecialEq<Arc<dyn SeriesUdf>> {
ArgMin => map!(arg_min),
ArgMax => map!(arg_max),
Get => map_as_slice!(get),
Join => map_as_slice!(join),
Join(ignore_nulls) => map_as_slice!(join, ignore_nulls),
#[cfg(feature = "is_in")]
Contains => map_as_slice!(contains),
#[cfg(feature = "array_count")]
Expand Down Expand Up @@ -173,10 +173,10 @@ pub(super) fn get(s: &[Series]) -> PolarsResult<Series> {
ca.array_get(index)
}

pub(super) fn join(s: &[Series]) -> PolarsResult<Series> {
pub(super) fn join(s: &[Series], ignore_nulls: bool) -> PolarsResult<Series> {
let ca = s[0].array()?;
let separator = s[1].str()?;
ca.array_join(separator)
ca.array_join(separator, ignore_nulls)
}

#[cfg(feature = "is_in")]
Expand Down
9 changes: 7 additions & 2 deletions py-polars/polars/expr/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,7 +432,7 @@ def last(self) -> Expr:
"""
return self.get(-1)

def join(self, separator: IntoExprColumn) -> Expr:
def join(self, separator: IntoExprColumn, *, ignore_nulls: bool = True) -> Expr:
"""
Join all string items in a sub-array and place a separator between them.
Expand All @@ -442,6 +442,11 @@ def join(self, separator: IntoExprColumn) -> Expr:
----------
separator
string to separate the items with
ignore_nulls
Ignore null values (default).
If set to ``False``, null values will be propagated.
If the sub-list contains any null values, the output is ``None``.
Returns
-------
Expand Down Expand Up @@ -470,7 +475,7 @@ def join(self, separator: IntoExprColumn) -> Expr:
"""
separator = parse_as_expression(separator, str_as_lit=True)
return wrap_expr(self._pyexpr.arr_join(separator))
return wrap_expr(self._pyexpr.arr_join(separator, ignore_nulls))

def contains(
self, item: float | str | bool | int | date | datetime | time | IntoExprColumn
Expand Down
2 changes: 1 addition & 1 deletion py-polars/polars/expr/list.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,7 +570,7 @@ def join(self, separator: IntoExprColumn, *, ignore_nulls: bool = True) -> Expr:
Ignore null values (default).
If set to ``False``, null values will be propagated.
if the sub-list contains any null values, the output is ``None``.
If the sub-list contains any null values, the output is ``None``.
Returns
-------
Expand Down
7 changes: 6 additions & 1 deletion py-polars/polars/series/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ def last(self) -> Series:
"""

def join(self, separator: IntoExprColumn) -> Series:
def join(self, separator: IntoExprColumn, *, ignore_nulls: bool = True) -> Series:
"""
Join all string items in a sub-array and place a separator between them.
Expand All @@ -356,6 +356,11 @@ def join(self, separator: IntoExprColumn) -> Series:
----------
separator
string to separate the items with
ignore_nulls
Ignore null values (default).
If set to ``False``, null values will be propagated.
If the sub-list contains any null values, the output is ``None``.
Returns
-------
Expand Down
2 changes: 1 addition & 1 deletion py-polars/polars/series/list.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,7 +400,7 @@ def join(self, separator: IntoExprColumn, *, ignore_nulls: bool = True) -> Serie
Ignore null values (default).
If set to ``False``, null values will be propagated.
if the sub-list contains any null values, the output is ``None``.
If the sub-list contains any null values, the output is ``None``.
Returns
-------
Expand Down
8 changes: 6 additions & 2 deletions py-polars/src/expr/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,12 @@ impl PyExpr {
self.inner.clone().arr().get(index.inner).into()
}

fn arr_join(&self, separator: PyExpr) -> Self {
self.inner.clone().arr().join(separator.inner).into()
fn arr_join(&self, separator: PyExpr, ignore_nulls: bool) -> Self {
self.inner
.clone()
.arr()
.join(separator.inner, ignore_nulls)
.into()
}

#[cfg(feature = "is_in")]
Expand Down
35 changes: 29 additions & 6 deletions py-polars/tests/unit/namespaces/array/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,13 +228,36 @@ def test_array_join() -> None:
},
)
out = df.select(pl.col("a").arr.join("-"))
assert out.to_dict(as_series=False) == {
"a": ["ab-c-d", "e-f-g", "null-null-null", None]
}
assert out.to_dict(as_series=False) == {"a": ["ab-c-d", "e-f-g", "", None]}
out = df.select(pl.col("a").arr.join(pl.col("separator")))
assert out.to_dict(as_series=False) == {
"a": ["ab&c&d", None, "null*null*null", None]
}
assert out.to_dict(as_series=False) == {"a": ["ab&c&d", None, "", None]}

# test ignore_nulls argument
df = pl.DataFrame(
{
"a": [
["a", None, "b", None],
None,
[None, None, None, None],
["c", "d", "e", "f"],
],
"separator": ["-", "&", " ", "@"],
},
schema={
"a": pl.Array(pl.String, 4),
"separator": pl.String,
},
)
# ignore nulls
out = df.select(pl.col("a").arr.join("-", ignore_nulls=True))
assert out.to_dict(as_series=False) == {"a": ["a-b", None, "", "c-d-e-f"]}
out = df.select(pl.col("a").arr.join(pl.col("separator"), ignore_nulls=True))
assert out.to_dict(as_series=False) == {"a": ["a-b", None, "", "c@d@e@f"]}
# propagate nulls
out = df.select(pl.col("a").arr.join("-", ignore_nulls=False))
assert out.to_dict(as_series=False) == {"a": [None, None, None, "c-d-e-f"]}
out = df.select(pl.col("a").arr.join(pl.col("separator"), ignore_nulls=False))
assert out.to_dict(as_series=False) == {"a": [None, None, None, "c@d@e@f"]}


@pytest.mark.parametrize(
Expand Down

0 comments on commit b798a74

Please sign in to comment.