From 76350caab2369c61f199f34a4aa51cdf1ca9fee0 Mon Sep 17 00:00:00 2001 From: chielP Date: Wed, 24 Jan 2024 16:07:17 +0100 Subject: [PATCH 1/2] cat casting --- .../chunked_array/logical/categorical/mod.rs | 21 +++++++++++++++++++ .../tests/unit/datatypes/test_categorical.py | 12 +++++++++++ 2 files changed, 33 insertions(+) diff --git a/crates/polars-core/src/chunked_array/logical/categorical/mod.rs b/crates/polars-core/src/chunked_array/logical/categorical/mod.rs index 9d79dc28de8a..7eaec38f7201 100644 --- a/crates/polars-core/src/chunked_array/logical/categorical/mod.rs +++ b/crates/polars-core/src/chunked_array/logical/categorical/mod.rs @@ -355,6 +355,27 @@ impl LogicalType for CategoricalChunked { // Otherwise we do nothing Ok(self.clone().set_ordering(*ordering, true).into_series()) }, + dt if dt.is_numeric() => { + // Apply the cast to to the categories and then index into the casted series + let categories = + StringChunked::with_chunk("", self.get_rev_map().get_categories().clone()); + let casted_series = categories.cast(dtype)?; + + macro_rules! get_elements { + ($ca:expr) => {{ + Ok(self + .physical() + .into_iter() + .map(|opt_el| { + opt_el.map(|el: u32| unsafe { + $ca.get_unchecked(el as usize).unwrap() + }) + }) + .collect()) + }}; + } + downcast_as_macro_arg_physical!(casted_series, get_elements) + }, _ => self.physical.cast(dtype), } } diff --git a/py-polars/tests/unit/datatypes/test_categorical.py b/py-polars/tests/unit/datatypes/test_categorical.py index 07f7a2026305..4e02decb8fe9 100644 --- a/py-polars/tests/unit/datatypes/test_categorical.py +++ b/py-polars/tests/unit/datatypes/test_categorical.py @@ -800,3 +800,15 @@ def test_sort_categorical_retain_none( "foo", "ham", ] + + +def test_cast_from_cat_to_numeric() -> None: + cat_series = pl.Series( + "cat_series", + ["0.69845702", "0.69317475", "2.43642724", "-0.95303469", "0.60684237"], + ).cast(pl.Categorical) + maximum = cat_series.cast(pl.Float32).max() + assert abs(maximum - 2.43642724) < 1e-6 # type: ignore[operator] + + s = pl.Series(["1", "2", "3"], dtype=pl.Categorical) + assert s.cast(pl.UInt8).sum() == 6 From 38d2f701f84af9230ba22670cd08b0502e9679fc Mon Sep 17 00:00:00 2001 From: chielP Date: Fri, 26 Jan 2024 14:15:43 +0100 Subject: [PATCH 2/2] gather --- .../chunked_array/logical/categorical/mod.rs | 22 ++++++++----------- 1 file changed, 9 insertions(+), 13 deletions(-) diff --git a/crates/polars-core/src/chunked_array/logical/categorical/mod.rs b/crates/polars-core/src/chunked_array/logical/categorical/mod.rs index 7eaec38f7201..7b286e476e09 100644 --- a/crates/polars-core/src/chunked_array/logical/categorical/mod.rs +++ b/crates/polars-core/src/chunked_array/logical/categorical/mod.rs @@ -361,20 +361,16 @@ impl LogicalType for CategoricalChunked { StringChunked::with_chunk("", self.get_rev_map().get_categories().clone()); let casted_series = categories.cast(dtype)?; - macro_rules! get_elements { - ($ca:expr) => {{ - Ok(self - .physical() - .into_iter() - .map(|opt_el| { - opt_el.map(|el: u32| unsafe { - $ca.get_unchecked(el as usize).unwrap() - }) - }) - .collect()) - }}; + #[cfg(feature = "bigidx")] + { + let s = self.physical.cast(&DataType::UInt64)?; + Ok(unsafe { casted_series.take_unchecked(s.u64()?) }) + } + #[cfg(not(feature = "bigidx"))] + { + // Safety: Invariant of categorical means indices are in bound + Ok(unsafe { casted_series.take_unchecked(&self.physical) }) } - downcast_as_macro_arg_physical!(casted_series, get_elements) }, _ => self.physical.cast(dtype), }