Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Added input validation for explode operation in the array namespace #19163

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions crates/polars-ops/src/chunked_array/array/namespace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,10 @@ pub trait ArrayNameSpace: AsArray {
};
Ok(out.into_series())
}
fn array_explode(&self) -> PolarsResult<Series> {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need it on the namespace.

let ca = self.as_array();
ca.explode()
}
}

impl ArrayNameSpace for ArrayChunked {}
5 changes: 5 additions & 0 deletions crates/polars-plan/src/dsl/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -193,4 +193,9 @@ impl ArrayNameSpace {
None,
)
}
/// Returns a column with a separate row for every array element.
pub fn explode(self) -> Expr {
self.0
.map_private(FunctionExpr::ArrayExpr(ArrayFunction::Explode))
}
}
8 changes: 8 additions & 0 deletions crates/polars-plan/src/dsl/function_expr/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ pub enum ArrayFunction {
#[cfg(feature = "array_count")]
CountMatches,
Shift,
Explode,
}

impl ArrayFunction {
Expand All @@ -56,6 +57,7 @@ impl ArrayFunction {
#[cfg(feature = "array_count")]
CountMatches => mapper.with_dtype(IDX_DTYPE),
Shift => mapper.with_same_dtype(),
Explode => mapper.map_to_list_and_array_inner_dtype(),
Copy link
Collaborator

@nameexhaustion nameexhaustion Oct 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can do the dtype validation here - e.g.

Explode => mapper.try_map_to_array_inner_dtype()?,

Where try_map_to_array_inner_dtype returns an error if the dtype isn't Array

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In fact, we could maybe even add a check before this match block to ensure the dtype is Array

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think that's needed. As that check would be required on every get_dtype. We already have this

fn check_namespace(function: &FunctionExpr, first_dtype: &DataType) -> PolarsResult<()> {

So I think it will resolve itself. During simplification of expressions we can then rewrite Array(Explode) to Explode.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see - then adding Explode under ArrayFunction should already add input dtype validation

Copy link
Contributor Author

@barak1412 barak1412 Oct 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@nameexhaustion
Thanks, I was thinking about this solution, but I prefered one validation to all namespaces operations.

@ritchie46
It's not going to wok because we don't hit check_namespace due to an extra condition in the match - if options.cast_to_supertypes.is_some():
https://github.com/pola-rs/polars/blob/main/crates/polars-plan/src/plans/conversion/type_coercion/mod.rs#L275

In addition, if we convert to Expr:Explode, we may not hit that match at all, I need to verify it.

So I propose -

  1. Make check_namespace hit all FunctionExpr.
  2. If the Array(Explode) -> Explode conversion won't hit, we will implement @nameexhaustion 's idea.

What do you think?

Edit:
@ritchie46 I saw you added to Expr the possibility to know its output type, so I will use it during conversion.

On Separate PR I will make all AExpr::Function hit the namespace validation.

}
}
}
Expand Down Expand Up @@ -96,6 +98,7 @@ impl Display for ArrayFunction {
#[cfg(feature = "array_count")]
CountMatches => "count_matches",
Shift => "shift",
Explode => "explode",
};
write!(f, "arr.{name}")
}
Expand Down Expand Up @@ -129,6 +132,7 @@ impl From<ArrayFunction> for SpecialEq<Arc<dyn ColumnsUdf>> {
#[cfg(feature = "array_count")]
CountMatches => map_as_slice!(count_matches),
Shift => map_as_slice!(shift),
Explode => map!(explode),
}
}
}
Expand Down Expand Up @@ -249,3 +253,7 @@ pub(super) fn shift(s: &[Column]) -> PolarsResult<Column> {

ca.array_shift(n.as_materialized_series()).map(Column::from)
}

pub(super) fn explode(s: &Column) -> PolarsResult<Column> {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not needed.

s.array()?.array_explode().map(Column::from)
}
4 changes: 4 additions & 0 deletions crates/polars-python/src/expr/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -132,4 +132,8 @@ impl PyExpr {
fn arr_shift(&self, n: PyExpr) -> Self {
self.inner.clone().arr().shift(n.inner).into()
}

fn arr_explode(&self) -> Self {
self.inner.clone().arr().explode().into()
}
}
2 changes: 1 addition & 1 deletion py-polars/polars/expr/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -605,7 +605,7 @@ def explode(self) -> Expr:
│ 6 │
└─────┘
"""
return wrap_expr(self._pyexpr.explode())
return wrap_expr(self._pyexpr.arr_explode())

def contains(
self, item: float | str | bool | int | date | datetime | time | IntoExprColumn
Expand Down
2 changes: 1 addition & 1 deletion py-polars/polars/series/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

@expr_dispatch
class ArrayNameSpace:
"""Namespace for list related methods."""
"""Namespace for array related methods."""

_accessor = "arr"

Expand Down
Loading