Skip to content

Commit

Permalink
fix: Correct wildcard expansion for functions (#19449)
Browse files Browse the repository at this point in the history
  • Loading branch information
siddharth-vi authored Oct 26, 2024
1 parent e058bd3 commit d616866
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 23 deletions.
2 changes: 1 addition & 1 deletion crates/polars-plan/src/dsl/functions/temporal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -427,7 +427,7 @@ pub fn duration(args: DurationArgs) -> Expr {
function: FunctionExpr::TemporalExpr(TemporalFunction::Duration(args.time_unit)),
options: FunctionOptions {
collect_groups: ApplyOptions::ElementWise,
flags: FunctionFlags::default() | FunctionFlags::INPUT_WILDCARD_EXPANSION,
flags: FunctionFlags::default(),
..Default::default()
},
}
Expand Down
34 changes: 12 additions & 22 deletions crates/polars-plan/src/dsl/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -332,34 +332,24 @@ impl ListNameSpace {
pub fn contains<E: Into<Expr>>(self, other: E) -> Expr {
let other = other.into();

self.0
.map_many_private(
FunctionExpr::ListExpr(ListFunction::Contains),
&[other],
false,
None,
)
.with_function_options(|mut options| {
options.flags |= FunctionFlags::INPUT_WILDCARD_EXPANSION;
options
})
self.0.map_many_private(
FunctionExpr::ListExpr(ListFunction::Contains),
&[other],
false,
None,
)
}
#[cfg(feature = "list_count")]
/// Count how often the value produced by ``element`` occurs.
pub fn count_matches<E: Into<Expr>>(self, element: E) -> Expr {
let other = element.into();

self.0
.map_many_private(
FunctionExpr::ListExpr(ListFunction::CountMatches),
&[other],
false,
None,
)
.with_function_options(|mut options| {
options.flags |= FunctionFlags::INPUT_WILDCARD_EXPANSION;
options
})
self.0.map_many_private(
FunctionExpr::ListExpr(ListFunction::CountMatches),
&[other],
false,
None,
)
}

#[cfg(feature = "list_sets")]
Expand Down
9 changes: 9 additions & 0 deletions py-polars/tests/unit/functions/as_datatype/test_duration.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,3 +181,12 @@ def test_duration_time_unit_ms() -> None:
result = pl.duration(milliseconds=4)
expected = pl.duration(milliseconds=4, time_unit="us")
assert_frame_equal(pl.select(result), pl.select(expected))


def test_duration_wildcard_expansion() -> None:
# Test that wildcard expansions occurs correctly in pl.duration
# https://github.com/pola-rs/polars/issues/19007
df = df = pl.DataFrame({"a": [1], "b": [2]})
assert df.select(pl.duration(hours=pl.all()).name.keep()).to_dict(
as_series=False
) == {"a": [timedelta(seconds=3600)], "b": [timedelta(seconds=7200)]}
20 changes: 20 additions & 0 deletions py-polars/tests/unit/operations/namespaces/list/test_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,16 @@ def test_list_contains_invalid_datatype() -> None:
df.select(pl.col("a").list.contains(2))


def test_list_contains_wildcard_expansion() -> None:
# Test that wildcard expansions occurs correctly in list.contains
# https://github.com/pola-rs/polars/issues/18968
df = pl.DataFrame({"a": [[1, 2]], "b": [[3, 4]]})
assert df.select(pl.all().list.contains(3)).to_dict(as_series=False) == {
"a": [False],
"b": [True],
}


def test_list_concat() -> None:
df = pl.DataFrame({"a": [[1, 2], [1], [1, 2, 3]]})

Expand Down Expand Up @@ -686,6 +696,16 @@ def test_list_count_matches_boolean_nulls_9141() -> None:
assert a.select(pl.col("a").list.count_matches(True))["a"].to_list() == [1]


def test_list_count_matches_wildcard_expansion() -> None:
# Test that wildcard expansions occurs correctly in list.count_match
# https://github.com/pola-rs/polars/issues/18968
df = pl.DataFrame({"a": [[1, 2]], "b": [[3, 4]]})
assert df.select(pl.all().list.count_matches(3)).to_dict(as_series=False) == {
"a": [0],
"b": [1],
}


def test_list_gather_oob_10079() -> None:
df = pl.DataFrame(
{
Expand Down

0 comments on commit d616866

Please sign in to comment.