diff --git a/crates/polars-core/src/datatypes/dtype.rs b/crates/polars-core/src/datatypes/dtype.rs index faf9f3088fbec..f26c07fc5d2ee 100644 --- a/crates/polars-core/src/datatypes/dtype.rs +++ b/crates/polars-core/src/datatypes/dtype.rs @@ -71,9 +71,12 @@ impl PartialEq for DataType { use DataType::*; { match (self, other) { - // Don't include rev maps in comparisons #[cfg(feature = "dtype-categorical")] - (Categorical(_, _), Categorical(_, _)) => true, + (Categorical(rev_l, _), Categorical(rev_r, _)) => { + let is_l_enum = rev_l.as_ref().map_or(false, |x| x.is_enum()); + let is_r_enum = rev_r.as_ref().map_or(false, |x| x.is_enum()); + is_l_enum == is_r_enum + }, (Datetime(tu_l, tz_l), Datetime(tu_r, tz_r)) => tu_l == tu_r && tz_l == tz_r, (List(left_inner), List(right_inner)) => left_inner == right_inner, #[cfg(feature = "dtype-duration")] diff --git a/py-polars/tests/unit/datatypes/test_enum.py b/py-polars/tests/unit/datatypes/test_enum.py index 0130cb2fdf894..f07c506810c61 100644 --- a/py-polars/tests/unit/datatypes/test_enum.py +++ b/py-polars/tests/unit/datatypes/test_enum.py @@ -380,7 +380,7 @@ def test_enum_creating_col_expr() -> None: }, schema={ "col1": pl.Enum(["a", "b", "c"]), - "col2": pl.String, + "col2": pl.Categorical(), "col3": pl.Enum(["g", "h", "i"]), }, )