From 8dbbcc21d844e3e4261d063f94c2d77703633bf3 Mon Sep 17 00:00:00 2001 From: chielP Date: Wed, 24 Jan 2024 16:07:17 +0100 Subject: [PATCH] cat casting --- .../chunked_array/logical/categorical/mod.rs | 21 +++++++++++++++++++ .../tests/unit/datatypes/test_categorical.py | 12 +++++++++++ 2 files changed, 33 insertions(+) diff --git a/crates/polars-core/src/chunked_array/logical/categorical/mod.rs b/crates/polars-core/src/chunked_array/logical/categorical/mod.rs index 9d79dc28de8a9..7eaec38f72016 100644 --- a/crates/polars-core/src/chunked_array/logical/categorical/mod.rs +++ b/crates/polars-core/src/chunked_array/logical/categorical/mod.rs @@ -355,6 +355,27 @@ impl LogicalType for CategoricalChunked { // Otherwise we do nothing Ok(self.clone().set_ordering(*ordering, true).into_series()) }, + dt if dt.is_numeric() => { + // Apply the cast to to the categories and then index into the casted series + let categories = + StringChunked::with_chunk("", self.get_rev_map().get_categories().clone()); + let casted_series = categories.cast(dtype)?; + + macro_rules! get_elements { + ($ca:expr) => {{ + Ok(self + .physical() + .into_iter() + .map(|opt_el| { + opt_el.map(|el: u32| unsafe { + $ca.get_unchecked(el as usize).unwrap() + }) + }) + .collect()) + }}; + } + downcast_as_macro_arg_physical!(casted_series, get_elements) + }, _ => self.physical.cast(dtype), } } diff --git a/py-polars/tests/unit/datatypes/test_categorical.py b/py-polars/tests/unit/datatypes/test_categorical.py index 07f7a20263059..4e02decb8fe9f 100644 --- a/py-polars/tests/unit/datatypes/test_categorical.py +++ b/py-polars/tests/unit/datatypes/test_categorical.py @@ -800,3 +800,15 @@ def test_sort_categorical_retain_none( "foo", "ham", ] + + +def test_cast_from_cat_to_numeric() -> None: + cat_series = pl.Series( + "cat_series", + ["0.69845702", "0.69317475", "2.43642724", "-0.95303469", "0.60684237"], + ).cast(pl.Categorical) + maximum = cat_series.cast(pl.Float32).max() + assert abs(maximum - 2.43642724) < 1e-6 # type: ignore[operator] + + s = pl.Series(["1", "2", "3"], dtype=pl.Categorical) + assert s.cast(pl.UInt8).sum() == 6