Skip to content

Commit

Permalink
cat casting
Browse files Browse the repository at this point in the history
  • Loading branch information
c-peters committed Jan 24, 2024
1 parent 25537c2 commit 8dbbcc2
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 0 deletions.
21 changes: 21 additions & 0 deletions crates/polars-core/src/chunked_array/logical/categorical/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,27 @@ impl LogicalType for CategoricalChunked {
// Otherwise we do nothing
Ok(self.clone().set_ordering(*ordering, true).into_series())
},
dt if dt.is_numeric() => {
// Apply the cast to to the categories and then index into the casted series
let categories =
StringChunked::with_chunk("", self.get_rev_map().get_categories().clone());
let casted_series = categories.cast(dtype)?;

macro_rules! get_elements {
($ca:expr) => {{
Ok(self
.physical()
.into_iter()
.map(|opt_el| {
opt_el.map(|el: u32| unsafe {
$ca.get_unchecked(el as usize).unwrap()
})
})
.collect())
}};
}
downcast_as_macro_arg_physical!(casted_series, get_elements)
},
_ => self.physical.cast(dtype),
}
}
Expand Down
12 changes: 12 additions & 0 deletions py-polars/tests/unit/datatypes/test_categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -800,3 +800,15 @@ def test_sort_categorical_retain_none(
"foo",
"ham",
]


def test_cast_from_cat_to_numeric() -> None:
cat_series = pl.Series(
"cat_series",
["0.69845702", "0.69317475", "2.43642724", "-0.95303469", "0.60684237"],
).cast(pl.Categorical)
maximum = cat_series.cast(pl.Float32).max()
assert abs(maximum - 2.43642724) < 1e-6 # type: ignore[operator]

s = pl.Series(["1", "2", "3"], dtype=pl.Categorical)
assert s.cast(pl.UInt8).sum() == 6

0 comments on commit 8dbbcc2

Please sign in to comment.