From 5e716bc17016479738f38a2f83b7943279bbb50d Mon Sep 17 00:00:00 2001 From: Ritchie Vink Date: Fri, 7 Jun 2024 14:13:43 +0200 Subject: [PATCH] feat: Expose overflowing cast (#16805) --- .../src/compute/cast/binary_to.rs | 4 +- .../src/compute/cast/binview_to.rs | 4 +- .../src/compute/cast/dictionary_to.rs | 8 +- crates/polars-arrow/src/compute/cast/mod.rs | 20 +- .../src/compute/cast/primitive_to.rs | 4 +- crates/polars-core/src/chunked_array/cast.rs | 210 +++++++++++++----- .../chunked_array/comparison/categorical.rs | 5 +- .../chunked_array/logical/categorical/from.rs | 4 +- .../chunked_array/logical/categorical/mod.rs | 11 +- .../src/chunked_array/logical/date.rs | 10 +- .../src/chunked_array/logical/datetime.rs | 12 +- .../src/chunked_array/logical/decimal.rs | 12 +- .../src/chunked_array/logical/duration.rs | 8 +- .../src/chunked_array/logical/mod.rs | 7 +- .../src/chunked_array/logical/struct_/mod.rs | 21 +- .../src/chunked_array/logical/time.rs | 14 +- .../src/chunked_array/ops/apply.rs | 5 +- .../src/chunked_array/ops/decimal.rs | 6 +- .../polars-core/src/chunked_array/ops/mod.rs | 9 +- .../src/chunked_array/ops/rolling_window.rs | 3 +- .../frame/group_by/aggregations/boolean.rs | 5 +- .../src/frame/group_by/aggregations/mod.rs | 17 +- .../src/frame/group_by/into_groups.rs | 9 +- crates/polars-core/src/frame/mod.rs | 7 +- crates/polars-core/src/series/any_value.rs | 4 +- crates/polars-core/src/series/from.rs | 56 +++-- .../src/series/implementations/array.rs | 5 +- .../src/series/implementations/binary.rs | 5 +- .../series/implementations/binary_offset.rs | 4 +- .../src/series/implementations/boolean.rs | 23 +- .../src/series/implementations/categorical.rs | 4 +- .../src/series/implementations/date.rs | 20 +- .../src/series/implementations/datetime.rs | 10 +- .../src/series/implementations/decimal.rs | 4 +- .../src/series/implementations/duration.rs | 15 +- .../src/series/implementations/floats.rs | 8 +- .../src/series/implementations/list.rs | 4 +- .../src/series/implementations/mod.rs | 11 +- .../src/series/implementations/null.rs | 2 +- .../src/series/implementations/object.rs | 3 +- .../src/series/implementations/string.rs | 4 +- .../src/series/implementations/struct_.rs | 4 +- .../src/series/implementations/time.rs | 4 +- crates/polars-core/src/series/mod.rs | 39 +++- crates/polars-core/src/series/series_trait.rs | 3 +- crates/polars-expr/src/expressions/cast.rs | 9 +- crates/polars-expr/src/planner.rs | 4 +- crates/polars-plan/src/dsl/expr.rs | 5 +- .../src/dsl/functions/syntactic_sugar.rs | 4 +- crates/polars-plan/src/dsl/mod.rs | 14 +- .../src/logical_plan/aexpr/hash.rs | 4 +- .../polars-plan/src/logical_plan/aexpr/mod.rs | 3 +- .../src/logical_plan/alp/format.rs | 4 +- .../src/logical_plan/alp/tree_format.rs | 4 +- .../src/logical_plan/conversion/expr_to_ir.rs | 4 +- .../src/logical_plan/conversion/ir_to_dsl.rs | 4 +- .../conversion/type_coercion/binary.rs | 12 +- .../conversion/type_coercion/mod.rs | 24 +- crates/polars-plan/src/logical_plan/format.rs | 4 +- .../optimizer/simplify_functions.rs | 4 +- .../src/logical_plan/visitor/expr.rs | 6 +- .../polars-time/src/chunkedarray/datetime.rs | 4 +- .../polars/_utils/construction/dataframe.py | 12 +- .../polars/_utils/construction/series.py | 12 +- py-polars/polars/expr/expr.py | 12 +- py-polars/polars/series/series.py | 5 +- py-polars/src/dataframe/export.rs | 4 +- py-polars/src/expr/general.rs | 14 +- py-polars/src/lazyframe/visitor/expr_nodes.rs | 9 +- py-polars/src/series/mod.rs | 15 +- 70 files changed, 570 insertions(+), 273 deletions(-) diff --git a/crates/polars-arrow/src/compute/cast/binary_to.rs b/crates/polars-arrow/src/compute/cast/binary_to.rs index d5e8bfb30852..fa53a58f8bbb 100644 --- a/crates/polars-arrow/src/compute/cast/binary_to.rs +++ b/crates/polars-arrow/src/compute/cast/binary_to.rs @@ -1,6 +1,6 @@ use polars_error::PolarsResult; -use super::CastOptions; +use super::CastOptionsImpl; use crate::array::*; use crate::datatypes::ArrowDataType; use crate::offset::{Offset, Offsets}; @@ -118,7 +118,7 @@ where pub(super) fn binary_to_primitive_dyn( from: &dyn Array, to: &ArrowDataType, - options: CastOptions, + options: CastOptionsImpl, ) -> PolarsResult> where T: NativeType + Parse, diff --git a/crates/polars-arrow/src/compute/cast/binview_to.rs b/crates/polars-arrow/src/compute/cast/binview_to.rs index 1c157110ec49..5a12a14aaca7 100644 --- a/crates/polars-arrow/src/compute/cast/binview_to.rs +++ b/crates/polars-arrow/src/compute/cast/binview_to.rs @@ -3,7 +3,7 @@ use polars_error::PolarsResult; use crate::array::*; use crate::compute::cast::binary_to::Parse; -use crate::compute::cast::CastOptions; +use crate::compute::cast::CastOptionsImpl; #[cfg(feature = "dtype-decimal")] use crate::compute::decimal::deserialize_decimal; use crate::datatypes::{ArrowDataType, TimeUnit}; @@ -77,7 +77,7 @@ where pub(super) fn binview_to_primitive_dyn( from: &dyn Array, to: &ArrowDataType, - options: CastOptions, + options: CastOptionsImpl, ) -> PolarsResult> where T: NativeType + Parse, diff --git a/crates/polars-arrow/src/compute/cast/dictionary_to.rs b/crates/polars-arrow/src/compute/cast/dictionary_to.rs index 8ef67750dcdc..134c9af7991f 100644 --- a/crates/polars-arrow/src/compute/cast/dictionary_to.rs +++ b/crates/polars-arrow/src/compute/cast/dictionary_to.rs @@ -1,6 +1,6 @@ use polars_error::{polars_bail, PolarsResult}; -use super::{primitive_as_primitive, primitive_to_primitive, CastOptions}; +use super::{primitive_as_primitive, primitive_to_primitive, CastOptionsImpl}; use crate::array::{Array, DictionaryArray, DictionaryKey}; use crate::compute::cast::cast; use crate::datatypes::ArrowDataType; @@ -35,7 +35,7 @@ pub fn dictionary_to_dictionary_values( let values = from.values(); let length = values.len(); - let values = cast(values.as_ref(), values_type, CastOptions::default())?; + let values = cast(values.as_ref(), values_type, CastOptionsImpl::default())?; assert_eq!(values.len(), length); // this is guaranteed by `cast` unsafe { @@ -55,7 +55,7 @@ pub fn wrapping_dictionary_to_dictionary_values( let values = cast( values.as_ref(), values_type, - CastOptions { + CastOptionsImpl { wrapped: true, partial: false, }, @@ -127,7 +127,7 @@ where pub(super) fn dictionary_cast_dyn( array: &dyn Array, to_type: &ArrowDataType, - options: CastOptions, + options: CastOptionsImpl, ) -> PolarsResult> { let array = array.as_any().downcast_ref::>().unwrap(); let keys = array.keys(); diff --git a/crates/polars-arrow/src/compute/cast/mod.rs b/crates/polars-arrow/src/compute/cast/mod.rs index 787e925ec1b3..0afa67ec875a 100644 --- a/crates/polars-arrow/src/compute/cast/mod.rs +++ b/crates/polars-arrow/src/compute/cast/mod.rs @@ -33,7 +33,7 @@ use crate::temporal_conversions::utf8view_to_timestamp; /// options defining how Cast kernels behave #[derive(Clone, Copy, Debug, Default)] -pub struct CastOptions { +pub struct CastOptionsImpl { /// default to false /// whether an overflowing cast should be converted to `None` (default), or be wrapped (i.e. `256i16 as u8 = 0` vectorized). /// Settings this to `true` is 5-6x faster for numeric types. @@ -43,7 +43,7 @@ pub struct CastOptions { pub partial: bool, } -impl CastOptions { +impl CastOptionsImpl { pub fn unchecked() -> Self { Self { wrapped: true, @@ -52,7 +52,7 @@ impl CastOptions { } } -impl CastOptions { +impl CastOptionsImpl { fn with_wrapped(&self, v: bool) -> Self { let mut option = *self; option.wrapped = v; @@ -82,7 +82,7 @@ macro_rules! primitive_dyn { fn cast_struct( array: &StructArray, to_type: &ArrowDataType, - options: CastOptions, + options: CastOptionsImpl, ) -> PolarsResult { let values = array.values(); let fields = StructArray::get_fields(to_type); @@ -102,7 +102,7 @@ fn cast_struct( fn cast_list( array: &ListArray, to_type: &ArrowDataType, - options: CastOptions, + options: CastOptionsImpl, ) -> PolarsResult> { let values = array.values(); let new_values = cast( @@ -144,7 +144,7 @@ fn cast_large_to_list(array: &ListArray, to_type: &ArrowDataType) -> ListAr fn cast_fixed_size_list_to_list( fixed: &FixedSizeListArray, to_type: &ArrowDataType, - options: CastOptions, + options: CastOptionsImpl, ) -> PolarsResult> { let new_values = cast( fixed.values().as_ref(), @@ -170,7 +170,7 @@ fn cast_list_to_fixed_size_list( list: &ListArray, inner: &Field, size: usize, - options: CastOptions, + options: CastOptionsImpl, ) -> PolarsResult { let null_cnt = list.null_count(); let new_values = if null_cnt == 0 { @@ -245,7 +245,7 @@ pub fn cast_default(array: &dyn Array, to_type: &ArrowDataType) -> PolarsResult< } pub fn cast_unchecked(array: &dyn Array, to_type: &ArrowDataType) -> PolarsResult> { - cast(array, to_type, CastOptions::unchecked()) + cast(array, to_type, CastOptionsImpl::unchecked()) } /// Cast `array` to the provided data type and return a new [`Array`] with @@ -276,7 +276,7 @@ pub fn cast_unchecked(array: &dyn Array, to_type: &ArrowDataType) -> PolarsResul pub fn cast( array: &dyn Array, to_type: &ArrowDataType, - options: CastOptions, + options: CastOptionsImpl, ) -> PolarsResult> { use ArrowDataType::*; let from_type = array.data_type(); @@ -760,7 +760,7 @@ pub fn cast( fn cast_to_dictionary( array: &dyn Array, dict_value_type: &ArrowDataType, - options: CastOptions, + options: CastOptionsImpl, ) -> PolarsResult> { let array = cast(array, dict_value_type, options)?; let array = array.as_ref(); diff --git a/crates/polars-arrow/src/compute/cast/primitive_to.rs b/crates/polars-arrow/src/compute/cast/primitive_to.rs index 583b6ab19a96..98df4b7f6b83 100644 --- a/crates/polars-arrow/src/compute/cast/primitive_to.rs +++ b/crates/polars-arrow/src/compute/cast/primitive_to.rs @@ -3,7 +3,7 @@ use std::hash::Hash; use num_traits::{AsPrimitive, Float, ToPrimitive}; use polars_error::PolarsResult; -use super::CastOptions; +use super::CastOptionsImpl; use crate::array::*; use crate::bitmap::Bitmap; use crate::compute::arity::unary; @@ -142,7 +142,7 @@ where pub(super) fn primitive_to_primitive_dyn( from: &dyn Array, to_type: &ArrowDataType, - options: CastOptions, + options: CastOptionsImpl, ) -> PolarsResult> where I: NativeType + num_traits::NumCast + num_traits::AsPrimitive, diff --git a/crates/polars-core/src/chunked_array/cast.rs b/crates/polars-core/src/chunked_array/cast.rs index f028c5d2117e..ea74c39e64e7 100644 --- a/crates/polars-core/src/chunked_array/cast.rs +++ b/crates/polars-core/src/chunked_array/cast.rs @@ -1,6 +1,8 @@ //! Implementations of the ChunkCast Trait. -use arrow::compute::cast::CastOptions; +use arrow::compute::cast::CastOptionsImpl; +#[cfg(feature = "serde-lazy")] +use serde::{Deserialize, Serialize}; use crate::chunked_array::metadata::MetadataProperties; #[cfg(feature = "timezones")] @@ -9,24 +11,61 @@ use crate::chunked_array::temporal::validate_time_zone; use crate::prelude::DataType::Datetime; use crate::prelude::*; +#[derive(Copy, Clone, Debug, Default, PartialEq, Hash, Eq)] +#[cfg_attr(feature = "serde-lazy", derive(Serialize, Deserialize))] +#[repr(u8)] +pub enum CastOptions { + /// Raises on overflow + #[default] + Strict, + /// Overflow is replaced with null + NonStrict, + /// Allows wrapping overflow + Overflowing, +} + +impl CastOptions { + pub fn strict(&self) -> bool { + matches!(self, CastOptions::Strict) + } +} + +impl From for CastOptionsImpl { + fn from(value: CastOptions) -> Self { + let wrapped = match value { + CastOptions::Strict | CastOptions::NonStrict => false, + CastOptions::Overflowing => true, + }; + CastOptionsImpl { + wrapped, + partial: false, + } + } +} + pub(crate) fn cast_chunks( chunks: &[ArrayRef], dtype: &DataType, - checked: bool, + options: CastOptions, ) -> PolarsResult> { - let options = if checked { - Default::default() - } else { - CastOptions { - wrapped: true, - partial: false, - } - }; + let check_nulls = matches!(options, CastOptions::Strict); + let options = options.into(); let arrow_dtype = dtype.to_arrow(true); chunks .iter() - .map(|arr| arrow::compute::cast::cast(arr.as_ref(), &arrow_dtype, options)) + .map(|arr| { + let out = arrow::compute::cast::cast(arr.as_ref(), &arrow_dtype, options); + if check_nulls { + out.and_then(|new| { + polars_ensure!(arr.null_count() == new.null_count(), ComputeError: "strict cast failed"); + Ok(new) + }) + + } else { + out + } + }) .collect::>>() } @@ -34,9 +73,9 @@ fn cast_impl_inner( name: &str, chunks: &[ArrayRef], dtype: &DataType, - checked: bool, + options: CastOptions, ) -> PolarsResult { - let chunks = cast_chunks(chunks, &dtype.to_physical(), checked)?; + let chunks = cast_chunks(chunks, &dtype.to_physical(), options)?; let out = Series::try_from((name, chunks))?; use DataType::*; let out = match dtype { @@ -58,8 +97,13 @@ fn cast_impl_inner( Ok(out) } -fn cast_impl(name: &str, chunks: &[ArrayRef], dtype: &DataType) -> PolarsResult { - cast_impl_inner(name, chunks, dtype, true) +fn cast_impl( + name: &str, + chunks: &[ArrayRef], + dtype: &DataType, + options: CastOptions, +) -> PolarsResult { + cast_impl_inner(name, chunks, dtype, options) } #[cfg(feature = "dtype-struct")] @@ -67,12 +111,13 @@ fn cast_single_to_struct( name: &str, chunks: &[ArrayRef], fields: &[Field], + options: CastOptions, ) -> PolarsResult { let mut new_fields = Vec::with_capacity(fields.len()); // cast to first field dtype let mut fields = fields.iter(); let fld = fields.next().unwrap(); - let s = cast_impl_inner(&fld.name, chunks, &fld.dtype, true)?; + let s = cast_impl_inner(&fld.name, chunks, &fld.dtype, options)?; let length = s.len(); new_fields.push(s); @@ -87,7 +132,7 @@ impl ChunkedArray where T: PolarsNumericType, { - fn cast_impl(&self, data_type: &DataType, checked: bool) -> PolarsResult { + fn cast_impl(&self, data_type: &DataType, options: CastOptions) -> PolarsResult { if self.dtype() == data_type { // SAFETY: chunks are correct dtype let mut out = unsafe { @@ -119,7 +164,7 @@ where .clone() }, dt if dt.is_integer() => self - .cast(self.dtype())? + .cast_with_options(self.dtype(), options)? .strict_cast(&DataType::UInt32)? .u32()? .clone(), @@ -149,8 +194,10 @@ where } }, #[cfg(feature = "dtype-struct")] - DataType::Struct(fields) => cast_single_to_struct(self.name(), &self.chunks, fields), - _ => cast_impl_inner(self.name(), &self.chunks, data_type, checked).map(|mut s| { + DataType::Struct(fields) => { + cast_single_to_struct(self.name(), &self.chunks, fields, options) + }, + _ => cast_impl_inner(self.name(), &self.chunks, data_type, options).map(|mut s| { // maintain sorted if data types // - remain signed // - unsigned -> signed @@ -180,8 +227,12 @@ impl ChunkCast for ChunkedArray where T: PolarsNumericType, { - fn cast(&self, data_type: &DataType) -> PolarsResult { - self.cast_impl(data_type, true) + fn cast_with_options( + &self, + data_type: &DataType, + options: CastOptions, + ) -> PolarsResult { + self.cast_impl(data_type, options) } unsafe fn cast_unchecked(&self, data_type: &DataType) -> PolarsResult { @@ -206,13 +257,17 @@ where polars_bail!(ComputeError: "cannot cast numeric types to 'Categorical'"); } }, - _ => self.cast_impl(data_type, false), + _ => self.cast_impl(data_type, CastOptions::Overflowing), } } } impl ChunkCast for StringChunked { - fn cast(&self, data_type: &DataType) -> PolarsResult { + fn cast_with_options( + &self, + data_type: &DataType, + options: CastOptions, + ) -> PolarsResult { match data_type { #[cfg(feature = "dtype-categorical")] DataType::Categorical(rev_map, ordering) => match rev_map { @@ -242,7 +297,9 @@ impl ChunkCast for StringChunked { }) }, #[cfg(feature = "dtype-struct")] - DataType::Struct(fields) => cast_single_to_struct(self.name(), &self.chunks, fields), + DataType::Struct(fields) => { + cast_single_to_struct(self.name(), &self.chunks, fields, options) + }, #[cfg(feature = "dtype-decimal")] DataType::Decimal(precision, scale) => match (precision, scale) { (precision, Some(scale)) => { @@ -264,7 +321,7 @@ impl ChunkCast for StringChunked { }, #[cfg(feature = "dtype-date")] DataType::Date => { - let result = cast_chunks(&self.chunks, data_type, true)?; + let result = cast_chunks(&self.chunks, data_type, options)?; let out = Series::try_from((self.name(), result))?; Ok(out) }, @@ -277,24 +334,27 @@ impl ChunkCast for StringChunked { let result = cast_chunks( &self.chunks, &Datetime(time_unit.to_owned(), Some(time_zone.clone())), - true, + options, )?; Series::try_from((self.name(), result)) }, _ => { - let result = - cast_chunks(&self.chunks, &Datetime(time_unit.to_owned(), None), true)?; + let result = cast_chunks( + &self.chunks, + &Datetime(time_unit.to_owned(), None), + options, + )?; Series::try_from((self.name(), result)) }, }; out }, - _ => cast_impl(self.name(), &self.chunks, data_type), + _ => cast_impl(self.name(), &self.chunks, data_type, options), } } unsafe fn cast_unchecked(&self, data_type: &DataType) -> PolarsResult { - self.cast(data_type) + self.cast_with_options(data_type, CastOptions::Overflowing) } } @@ -335,54 +395,76 @@ impl StringChunked { } impl ChunkCast for BinaryChunked { - fn cast(&self, data_type: &DataType) -> PolarsResult { + fn cast_with_options( + &self, + data_type: &DataType, + options: CastOptions, + ) -> PolarsResult { match data_type { #[cfg(feature = "dtype-struct")] - DataType::Struct(fields) => cast_single_to_struct(self.name(), &self.chunks, fields), - _ => cast_impl(self.name(), &self.chunks, data_type), + DataType::Struct(fields) => { + cast_single_to_struct(self.name(), &self.chunks, fields, options) + }, + _ => cast_impl(self.name(), &self.chunks, data_type, options), } } unsafe fn cast_unchecked(&self, data_type: &DataType) -> PolarsResult { match data_type { DataType::String => unsafe { Ok(self.to_string_unchecked().into_series()) }, - _ => self.cast(data_type), + _ => self.cast_with_options(data_type, CastOptions::Overflowing), } } } impl ChunkCast for BinaryOffsetChunked { - fn cast(&self, data_type: &DataType) -> PolarsResult { + fn cast_with_options( + &self, + data_type: &DataType, + options: CastOptions, + ) -> PolarsResult { match data_type { #[cfg(feature = "dtype-struct")] - DataType::Struct(fields) => cast_single_to_struct(self.name(), &self.chunks, fields), - _ => cast_impl(self.name(), &self.chunks, data_type), + DataType::Struct(fields) => { + cast_single_to_struct(self.name(), &self.chunks, fields, options) + }, + _ => cast_impl(self.name(), &self.chunks, data_type, options), } } unsafe fn cast_unchecked(&self, data_type: &DataType) -> PolarsResult { - self.cast(data_type) + self.cast_with_options(data_type, CastOptions::Overflowing) } } impl ChunkCast for BooleanChunked { - fn cast(&self, data_type: &DataType) -> PolarsResult { + fn cast_with_options( + &self, + data_type: &DataType, + options: CastOptions, + ) -> PolarsResult { match data_type { #[cfg(feature = "dtype-struct")] - DataType::Struct(fields) => cast_single_to_struct(self.name(), &self.chunks, fields), - _ => cast_impl(self.name(), &self.chunks, data_type), + DataType::Struct(fields) => { + cast_single_to_struct(self.name(), &self.chunks, fields, options) + }, + _ => cast_impl(self.name(), &self.chunks, data_type, options), } } unsafe fn cast_unchecked(&self, data_type: &DataType) -> PolarsResult { - self.cast(data_type) + self.cast_with_options(data_type, CastOptions::Overflowing) } } /// We cannot cast anything to or from List/LargeList /// So this implementation casts the inner type impl ChunkCast for ListChunked { - fn cast(&self, data_type: &DataType) -> PolarsResult { + fn cast_with_options( + &self, + data_type: &DataType, + options: CastOptions, + ) -> PolarsResult { use DataType::*; match data_type { List(child_type) => { @@ -396,7 +478,7 @@ impl ChunkCast for ListChunked { }, _ => { // ensure the inner logical type bubbles up - let (arr, child_type) = cast_list(self, child_type)?; + let (arr, child_type) = cast_list(self, child_type, options)?; // SAFETY: we just casted so the dtype matches. // we must take this path to correct for physical types. unsafe { @@ -418,7 +500,7 @@ impl ChunkCast for ListChunked { polars_ensure!(!matches!(&**child_type, Categorical(_, _)), InvalidOperation: "array of categorical is not yet supported"); // cast to the physical type to avoid logical chunks. - let chunks = cast_chunks(self.chunks(), &physical_type, true)?; + let chunks = cast_chunks(self.chunks(), &physical_type, options)?; // SAFETY: we just casted so the dtype matches. // we must take this path to correct for physical types. unsafe { @@ -443,7 +525,7 @@ impl ChunkCast for ListChunked { use DataType::*; match data_type { List(child_type) => cast_list_unchecked(self, child_type), - _ => self.cast(data_type), + _ => self.cast_with_options(data_type, CastOptions::Overflowing), } } } @@ -452,7 +534,11 @@ impl ChunkCast for ListChunked { /// So this implementation casts the inner type #[cfg(feature = "dtype-array")] impl ChunkCast for ArrayChunked { - fn cast(&self, data_type: &DataType) -> PolarsResult { + fn cast_with_options( + &self, + data_type: &DataType, + options: CastOptions, + ) -> PolarsResult { use DataType::*; match data_type { Array(child_type, width) => { @@ -469,7 +555,7 @@ impl ChunkCast for ArrayChunked { }, _ => { // ensure the inner logical type bubbles up - let (arr, child_type) = cast_fixed_size_list(self, child_type)?; + let (arr, child_type) = cast_fixed_size_list(self, child_type, options)?; // SAFETY: we just casted so the dtype matches. // we must take this path to correct for physical types. unsafe { @@ -485,7 +571,7 @@ impl ChunkCast for ArrayChunked { List(child_type) => { let physical_type = data_type.to_physical(); // cast to the physical type to avoid logical chunks. - let chunks = cast_chunks(self.chunks(), &physical_type, true)?; + let chunks = cast_chunks(self.chunks(), &physical_type, options)?; // SAFETY: we just casted so the dtype matches. // we must take this path to correct for physical types. unsafe { @@ -507,13 +593,17 @@ impl ChunkCast for ArrayChunked { } unsafe fn cast_unchecked(&self, data_type: &DataType) -> PolarsResult { - self.cast(data_type) + self.cast_with_options(data_type, CastOptions::Overflowing) } } // Returns inner data type. This is needed because a cast can instantiate the dtype inner // values for instance with categoricals -fn cast_list(ca: &ListChunked, child_type: &DataType) -> PolarsResult<(ArrayRef, DataType)> { +fn cast_list( + ca: &ListChunked, + child_type: &DataType, + options: CastOptions, +) -> PolarsResult<(ArrayRef, DataType)> { // We still rechunk because we must bubble up a single data-type // TODO!: consider a version that works on chunks and merges the data-types and arrays. let ca = ca.rechunk(); @@ -522,7 +612,7 @@ fn cast_list(ca: &ListChunked, child_type: &DataType) -> PolarsResult<(ArrayRef, let s = unsafe { Series::from_chunks_and_dtype_unchecked("", vec![arr.values().clone()], ca.inner_dtype()) }; - let new_inner = s.cast(child_type)?; + let new_inner = s.cast_with_options(child_type, options)?; let inner_dtype = new_inner.dtype().clone(); debug_assert_eq!(&inner_dtype, child_type); @@ -571,6 +661,7 @@ unsafe fn cast_list_unchecked(ca: &ListChunked, child_type: &DataType) -> Polars fn cast_fixed_size_list( ca: &ArrayChunked, child_type: &DataType, + options: CastOptions, ) -> PolarsResult<(ArrayRef, DataType)> { let ca = ca.rechunk(); let arr = ca.downcast_iter().next().unwrap(); @@ -578,7 +669,7 @@ fn cast_fixed_size_list( let s = unsafe { Series::from_chunks_and_dtype_unchecked("", vec![arr.values().clone()], ca.inner_dtype()) }; - let new_inner = s.cast(child_type)?; + let new_inner = s.cast_with_options(child_type, options)?; let inner_dtype = new_inner.dtype().clone(); debug_assert_eq!(&inner_dtype, child_type); @@ -593,6 +684,7 @@ fn cast_fixed_size_list( #[cfg(test)] mod test { + use crate::chunked_array::cast::CastOptions; use crate::prelude::*; #[test] @@ -603,7 +695,10 @@ mod test { builder.append_opt_slice(Some(&[1i32, 2, 3])); let ca = builder.finish(); - let new = ca.cast(&DataType::List(DataType::Float64.into()))?; + let new = ca.cast_with_options( + &DataType::List(DataType::Float64.into()), + CastOptions::Strict, + )?; assert_eq!(new.dtype(), &DataType::List(DataType::Float64.into())); Ok(()) @@ -615,7 +710,10 @@ mod test { // check if we can cast categorical twice without panic let ca = StringChunked::new("foo", &["bar", "ham"]); let out = ca - .cast(&DataType::Categorical(None, Default::default())) + .cast_with_options( + &DataType::Categorical(None, Default::default()), + CastOptions::Strict, + ) .unwrap(); let out = out .cast(&DataType::Categorical(None, Default::default())) diff --git a/crates/polars-core/src/chunked_array/comparison/categorical.rs b/crates/polars-core/src/chunked_array/comparison/categorical.rs index c288929b07cd..77ddf45a5a69 100644 --- a/crates/polars-core/src/chunked_array/comparison/categorical.rs +++ b/crates/polars-core/src/chunked_array/comparison/categorical.rs @@ -2,6 +2,7 @@ use arrow::bitmap::Bitmap; use arrow::legacy::utils::FromTrustedLenIterator; use polars_compute::comparisons::TotalOrdKernel; +use crate::chunked_array::cast::CastOptions; use crate::prelude::nulls::replace_non_null; use crate::prelude::*; @@ -178,7 +179,7 @@ where }, } } else { - let lhs_string = lhs.cast(&DataType::String)?; + let lhs_string = lhs.cast_with_options(&DataType::String, CastOptions::NonStrict)?; Ok(str_compare_function(lhs_string.str().unwrap(), rhs)) } } @@ -211,7 +212,7 @@ where ), } } else { - let lhs_string = lhs.cast(&DataType::String)?; + let lhs_string = lhs.cast_with_options(&DataType::String, CastOptions::NonStrict)?; Ok(str_compare_function(lhs_string.str().unwrap(), rhs)) } } diff --git a/crates/polars-core/src/chunked_array/logical/categorical/from.rs b/crates/polars-core/src/chunked_array/logical/categorical/from.rs index 4be936d79555..568d5650ba8e 100644 --- a/crates/polars-core/src/chunked_array/logical/categorical/from.rs +++ b/crates/polars-core/src/chunked_array/logical/categorical/from.rs @@ -1,4 +1,4 @@ -use arrow::compute::cast::{cast, utf8view_to_utf8, CastOptions}; +use arrow::compute::cast::{cast, utf8view_to_utf8, CastOptionsImpl}; use arrow::datatypes::IntegerType; use super::*; @@ -72,7 +72,7 @@ impl CategoricalChunked { unsafe { DictionaryArray::try_new_unchecked( dtype, - cast(keys, &ArrowDataType::Int64, CastOptions::unchecked()) + cast(keys, &ArrowDataType::Int64, CastOptionsImpl::unchecked()) .unwrap() .as_any() .downcast_ref::>() diff --git a/crates/polars-core/src/chunked_array/logical/categorical/mod.rs b/crates/polars-core/src/chunked_array/logical/categorical/mod.rs index 9936c49d520e..30ab29b5c78a 100644 --- a/crates/polars-core/src/chunked_array/logical/categorical/mod.rs +++ b/crates/polars-core/src/chunked_array/logical/categorical/mod.rs @@ -13,6 +13,7 @@ use polars_utils::sync::SyncPtr; pub use revmap::*; use super::*; +use crate::chunked_array::cast::CastOptions; use crate::chunked_array::metadata::MetadataFlags; use crate::prelude::*; use crate::series::IsSorted; @@ -331,7 +332,7 @@ impl LogicalType for CategoricalChunked { } } - fn cast(&self, dtype: &DataType) -> PolarsResult { + fn cast_with_options(&self, dtype: &DataType, options: CastOptions) -> PolarsResult { match dtype { DataType::String => { let mapping = &**self.get_rev_map(); @@ -395,11 +396,13 @@ impl LogicalType for CategoricalChunked { self.physical.name(), self.get_rev_map().get_categories().clone(), ); - let casted_series = categories.cast(dtype)?; + let casted_series = categories.cast_with_options(dtype, options)?; #[cfg(feature = "bigidx")] { - let s = self.physical.cast(&DataType::UInt64)?; + let s = self + .physical + .cast_with_options(&DataType::UInt64, options)?; Ok(unsafe { casted_series.take_unchecked(s.u64()?) }) } #[cfg(not(feature = "bigidx"))] @@ -408,7 +411,7 @@ impl LogicalType for CategoricalChunked { Ok(unsafe { casted_series.take_unchecked(&self.physical) }) } }, - _ => self.physical.cast(dtype), + _ => self.physical.cast_with_options(dtype, options), } } } diff --git a/crates/polars-core/src/chunked_array/logical/date.rs b/crates/polars-core/src/chunked_array/logical/date.rs index 0b8a4b11b888..881044acac86 100644 --- a/crates/polars-core/src/chunked_array/logical/date.rs +++ b/crates/polars-core/src/chunked_array/logical/date.rs @@ -27,13 +27,17 @@ impl LogicalType for DateChunked { self.0.get_any_value_unchecked(i).as_date() } - fn cast(&self, dtype: &DataType) -> PolarsResult { + fn cast_with_options( + &self, + dtype: &DataType, + cast_options: CastOptions, + ) -> PolarsResult { use DataType::*; match dtype { Date => Ok(self.clone().into_series()), #[cfg(feature = "dtype-datetime")] Datetime(tu, tz) => { - let casted = self.0.cast(dtype)?; + let casted = self.0.cast_with_options(dtype, cast_options)?; let casted = casted.datetime().unwrap(); let conversion = match tu { TimeUnit::Nanoseconds => NS_IN_DAY, @@ -44,7 +48,7 @@ impl LogicalType for DateChunked { .into_datetime(*tu, tz.clone()) .into_series()) }, - dt if dt.is_numeric() => self.0.cast(dtype), + dt if dt.is_numeric() => self.0.cast_with_options(dtype, cast_options), dt => { polars_bail!( InvalidOperation: diff --git a/crates/polars-core/src/chunked_array/logical/datetime.rs b/crates/polars-core/src/chunked_array/logical/datetime.rs index c3ba7702106d..1be7a1cf9747 100644 --- a/crates/polars-core/src/chunked_array/logical/datetime.rs +++ b/crates/polars-core/src/chunked_array/logical/datetime.rs @@ -28,7 +28,11 @@ impl LogicalType for DatetimeChunked { .as_datetime(self.time_unit(), self.time_zone()) } - fn cast(&self, dtype: &DataType) -> PolarsResult { + fn cast_with_options( + &self, + dtype: &DataType, + cast_options: CastOptions, + ) -> PolarsResult { use DataType::*; use TimeUnit::*; let out = match dtype { @@ -43,7 +47,7 @@ impl LogicalType for DatetimeChunked { (Nanoseconds, Milliseconds) => (None, Some(1_000_000i64)), (Nanoseconds, Microseconds) => (None, Some(1_000i64)), (Microseconds, Milliseconds) => (None, Some(1_000i64)), - _ => return self.0.cast(dtype), + _ => return self.0.cast_with_options(dtype, cast_options), }; let result = match multiplier { // scale to higher precision (eg: ms → us, ms → ns, us → ns) @@ -68,7 +72,7 @@ impl LogicalType for DatetimeChunked { let mut dt = self .0 .apply_values(|v| v.div_euclid(tu_in_day)) - .cast(&Int32) + .cast_with_options(&Int32, cast_options) .unwrap() .into_date() .into_series(); @@ -97,7 +101,7 @@ impl LogicalType for DatetimeChunked { .into_time() .into_series()); }, - dt if dt.is_numeric() => return self.0.cast(dtype), + dt if dt.is_numeric() => return self.0.cast_with_options(dtype, cast_options), dt => { polars_bail!( InvalidOperation: diff --git a/crates/polars-core/src/chunked_array/logical/decimal.rs b/crates/polars-core/src/chunked_array/logical/decimal.rs index d21d3909b415..4526f0d63e99 100644 --- a/crates/polars-core/src/chunked_array/logical/decimal.rs +++ b/crates/polars-core/src/chunked_array/logical/decimal.rs @@ -80,7 +80,11 @@ impl LogicalType for DecimalChunked { } } - fn cast(&self, dtype: &DataType) -> PolarsResult { + fn cast_with_options( + &self, + dtype: &DataType, + cast_options: CastOptions, + ) -> PolarsResult { let (precision_src, scale_src) = (self.precision(), self.scale()); if let &DataType::Decimal(precision_dst, scale_dst) = dtype { let scale_dst = scale_dst.unwrap_or(scale_src); @@ -94,10 +98,10 @@ impl LogicalType for DecimalChunked { }; if scale_src == scale_dst && is_widen { let dtype = &DataType::Decimal(precision_dst, Some(scale_dst)); - return self.0.cast(dtype); // no conversion or checks needed + return self.0.cast_with_options(dtype, cast_options); // no conversion or checks needed } } - let chunks = cast_chunks(&self.chunks, dtype, true)?; + let chunks = cast_chunks(&self.chunks, dtype, cast_options)?; unsafe { Ok(Series::from_chunks_and_dtype_unchecked( self.name(), @@ -129,7 +133,7 @@ impl DecimalChunked { } let dtype = DataType::Decimal(None, Some(scale)); - let chunks = cast_chunks(&self.chunks, &dtype, true)?; + let chunks = cast_chunks(&self.chunks, &dtype, CastOptions::NonStrict)?; let mut dt = Self::new_logical(unsafe { Int128Chunked::from_chunks(self.name(), chunks) }); dt.2 = Some(dtype); Ok(Cow::Owned(dt)) diff --git a/crates/polars-core/src/chunked_array/logical/duration.rs b/crates/polars-core/src/chunked_array/logical/duration.rs index 873fb805d0fe..ca0347d87b5a 100644 --- a/crates/polars-core/src/chunked_array/logical/duration.rs +++ b/crates/polars-core/src/chunked_array/logical/duration.rs @@ -27,7 +27,11 @@ impl LogicalType for DurationChunked { .as_duration(self.time_unit()) } - fn cast(&self, dtype: &DataType) -> PolarsResult { + fn cast_with_options( + &self, + dtype: &DataType, + cast_options: CastOptions, + ) -> PolarsResult { use DataType::*; use TimeUnit::*; match dtype { @@ -50,7 +54,7 @@ impl LogicalType for DurationChunked { }; Ok(out.into_duration(to_unit).into_series()) }, - dt if dt.is_numeric() => self.0.cast(dtype), + dt if dt.is_numeric() => self.0.cast_with_options(dtype, cast_options), dt => { polars_bail!( InvalidOperation: diff --git a/crates/polars-core/src/chunked_array/logical/mod.rs b/crates/polars-core/src/chunked_array/logical/mod.rs index 0108742ef743..af77577fde01 100644 --- a/crates/polars-core/src/chunked_array/logical/mod.rs +++ b/crates/polars-core/src/chunked_array/logical/mod.rs @@ -31,6 +31,7 @@ pub use struct_::*; #[cfg(feature = "dtype-time")] pub use time::*; +use crate::chunked_array::cast::CastOptions; use crate::prelude::*; /// Maps a logical type to a chunked array implementation of the physical type. @@ -84,7 +85,11 @@ pub trait LogicalType { unimplemented!() } - fn cast(&self, dtype: &DataType) -> PolarsResult; + fn cast_with_options(&self, dtype: &DataType, options: CastOptions) -> PolarsResult; + + fn cast(&self, dtype: &DataType) -> PolarsResult { + self.cast_with_options(dtype, CastOptions::NonStrict) + } } impl Logical diff --git a/crates/polars-core/src/chunked_array/logical/struct_/mod.rs b/crates/polars-core/src/chunked_array/logical/struct_/mod.rs index 7156e63db56e..07bc343aee21 100644 --- a/crates/polars-core/src/chunked_array/logical/struct_/mod.rs +++ b/crates/polars-core/src/chunked_array/logical/struct_/mod.rs @@ -312,7 +312,12 @@ impl StructChunked { )) } - unsafe fn cast_impl(&self, dtype: &DataType, unchecked: bool) -> PolarsResult { + unsafe fn cast_impl( + &self, + dtype: &DataType, + cast_options: CastOptions, + unchecked: bool, + ) -> PolarsResult { match dtype { DataType::Struct(dtype_fields) => { let map = BTreeMap::from_iter(self.fields().iter().map(|s| (s.name(), s))); @@ -324,7 +329,7 @@ impl StructChunked { if unchecked { s.cast_unchecked(&new_field.dtype) } else { - s.cast(&new_field.dtype) + s.cast_with_options(&new_field.dtype, cast_options) } }, None => Ok(Series::full_null( @@ -398,7 +403,7 @@ impl StructChunked { if unchecked { s.cast_unchecked(dtype) } else { - s.cast(dtype) + s.cast_with_options(dtype, cast_options) } }) .collect::>>()?; @@ -411,7 +416,7 @@ impl StructChunked { if dtype == self.dtype() { return Ok(self.clone().into_series()); } - self.cast_impl(dtype, true) + self.cast_impl(dtype, CastOptions::Overflowing, true) } pub fn rows_encode(&self) -> PolarsResult { @@ -450,8 +455,12 @@ impl LogicalType for StructChunked { } // in case of a struct, a cast will coerce the inner types - fn cast(&self, dtype: &DataType) -> PolarsResult { - unsafe { self.cast_impl(dtype, false) } + fn cast_with_options( + &self, + dtype: &DataType, + cast_options: CastOptions, + ) -> PolarsResult { + unsafe { self.cast_impl(dtype, cast_options, false) } } } diff --git a/crates/polars-core/src/chunked_array/logical/time.rs b/crates/polars-core/src/chunked_array/logical/time.rs index c3e6e1f74df7..7e404ad98b61 100644 --- a/crates/polars-core/src/chunked_array/logical/time.rs +++ b/crates/polars-core/src/chunked_array/logical/time.rs @@ -28,15 +28,21 @@ impl LogicalType for TimeChunked { self.0.get_any_value_unchecked(i).as_time() } - fn cast(&self, dtype: &DataType) -> PolarsResult { + fn cast_with_options( + &self, + dtype: &DataType, + cast_options: CastOptions, + ) -> PolarsResult { use DataType::*; match dtype { Time => Ok(self.clone().into_series()), #[cfg(feature = "dtype-duration")] Duration(tu) => { - let out = self.0.cast(&DataType::Duration(TimeUnit::Nanoseconds)); + let out = self + .0 + .cast_with_options(&DataType::Duration(TimeUnit::Nanoseconds), cast_options); if !matches!(tu, TimeUnit::Nanoseconds) { - out?.cast(dtype) + out?.cast_with_options(dtype, cast_options) } else { out } @@ -49,7 +55,7 @@ impl LogicalType for TimeChunked { self.dtype(), dtype ) }, - dt if dt.is_numeric() => self.0.cast(dtype), + dt if dt.is_numeric() => self.0.cast_with_options(dtype, cast_options), _ => { polars_bail!( InvalidOperation: diff --git a/crates/polars-core/src/chunked_array/ops/apply.rs b/crates/polars-core/src/chunked_array/ops/apply.rs index a6fbdc7da0ea..f6d7db426cd1 100644 --- a/crates/polars-core/src/chunked_array/ops/apply.rs +++ b/crates/polars-core/src/chunked_array/ops/apply.rs @@ -1,6 +1,7 @@ //! Implementations of the ChunkApply Trait. use std::borrow::Cow; +use crate::chunked_array::cast::CastOptions; use crate::prelude::*; use crate::series::IsSorted; @@ -165,7 +166,9 @@ impl ChunkedArray { // this will ensure we have a single ref count // and we can mutate in place let chunks = { - let s = self.cast(&S::get_dtype()).unwrap(); + let s = self + .cast_with_options(&S::get_dtype(), CastOptions::Overflowing) + .unwrap(); s.chunks().clone() }; apply_in_place_impl(self.name(), chunks, f) diff --git a/crates/polars-core/src/chunked_array/ops/decimal.rs b/crates/polars-core/src/chunked_array/ops/decimal.rs index 09928baec612..e2f9c5845429 100644 --- a/crates/polars-core/src/chunked_array/ops/decimal.rs +++ b/crates/polars-core/src/chunked_array/ops/decimal.rs @@ -1,3 +1,4 @@ +use crate::chunked_array::cast::CastOptions; use crate::prelude::*; impl StringChunked { @@ -21,7 +22,10 @@ impl StringChunked { } } - self.cast(&DataType::Decimal(None, Some(scale as usize))) + self.cast_with_options( + &DataType::Decimal(None, Some(scale as usize)), + CastOptions::NonStrict, + ) } } diff --git a/crates/polars-core/src/chunked_array/ops/mod.rs b/crates/polars-core/src/chunked_array/ops/mod.rs index 95a064d3f200..fb4dafaf3037 100644 --- a/crates/polars-core/src/chunked_array/ops/mod.rs +++ b/crates/polars-core/src/chunked_array/ops/mod.rs @@ -42,6 +42,7 @@ pub mod zip; use serde::{Deserialize, Serialize}; pub use sort::options::*; +use crate::chunked_array::cast::CastOptions; use crate::series::IsSorted; #[cfg(feature = "reinterpret")] pub trait Reinterpret { @@ -182,7 +183,13 @@ pub trait ChunkSet<'a, A, B> { /// Cast `ChunkedArray` to `ChunkedArray` pub trait ChunkCast { /// Cast a [`ChunkedArray`] to [`DataType`] - fn cast(&self, data_type: &DataType) -> PolarsResult; + fn cast(&self, data_type: &DataType) -> PolarsResult { + self.cast_with_options(data_type, CastOptions::NonStrict) + } + + /// Cast a [`ChunkedArray`] to [`DataType`] + fn cast_with_options(&self, data_type: &DataType, options: CastOptions) + -> PolarsResult; /// Does not check if the cast is a valid one and may over/underflow /// diff --git a/crates/polars-core/src/chunked_array/ops/rolling_window.rs b/crates/polars-core/src/chunked_array/ops/rolling_window.rs index 43d4f0bd6cab..ef3c07529504 100644 --- a/crates/polars-core/src/chunked_array/ops/rolling_window.rs +++ b/crates/polars-core/src/chunked_array/ops/rolling_window.rs @@ -53,6 +53,7 @@ mod inner_mod { use num_traits::{Float, Zero}; use polars_utils::float::IsFloat; + use crate::chunked_array::cast::CastOptions; use crate::prelude::*; /// utility @@ -97,7 +98,7 @@ mod inner_mod { if options.weights.is_some() && !matches!(self.dtype(), DataType::Float64 | DataType::Float32) { - let s = self.cast(&DataType::Float64)?; + let s = self.cast_with_options(&DataType::Float64, CastOptions::NonStrict)?; return s.rolling_map(f, options); } diff --git a/crates/polars-core/src/frame/group_by/aggregations/boolean.rs b/crates/polars-core/src/frame/group_by/aggregations/boolean.rs index a7be2a845f6d..fd7e537dc0ab 100644 --- a/crates/polars-core/src/frame/group_by/aggregations/boolean.rs +++ b/crates/polars-core/src/frame/group_by/aggregations/boolean.rs @@ -1,4 +1,5 @@ use super::*; +use crate::chunked_array::cast::CastOptions; pub fn _agg_helper_idx_bool(groups: &GroupsIdx, f: F) -> Series where @@ -105,6 +106,8 @@ impl BooleanChunked { } } pub(crate) unsafe fn agg_sum(&self, groups: &GroupsProxy) -> Series { - self.cast(&IDX_DTYPE).unwrap().agg_sum(groups) + self.cast_with_options(&IDX_DTYPE, CastOptions::Overflowing) + .unwrap() + .agg_sum(groups) } } diff --git a/crates/polars-core/src/frame/group_by/aggregations/mod.rs b/crates/polars-core/src/frame/group_by/aggregations/mod.rs index cd012f20636b..0e624ceb8fe8 100644 --- a/crates/polars-core/src/frame/group_by/aggregations/mod.rs +++ b/crates/polars-core/src/frame/group_by/aggregations/mod.rs @@ -23,6 +23,7 @@ use polars_utils::idx_vec::IdxVec; use polars_utils::ord::{compare_fn_nan_max, compare_fn_nan_min}; use rayon::prelude::*; +use crate::chunked_array::cast::CastOptions; #[cfg(feature = "object")] use crate::chunked_array::object::extension::create_extension; use crate::frame::group_by::GroupsIdx; @@ -376,7 +377,9 @@ where GroupsProxy::Slice { groups, .. } => { if _use_rolling_kernels(groups, ca.chunks()) { // this cast is a no-op for floats - let s = ca.cast(&K::get_dtype()).unwrap(); + let s = ca + .cast_with_options(&K::get_dtype(), CastOptions::Overflowing) + .unwrap(); let ca: &ChunkedArray = s.as_ref().as_ref(); let arr = ca.downcast_iter().next().unwrap(); let values = arr.values().as_slice(); @@ -987,7 +990,9 @@ where .. } => { if _use_rolling_kernels(groups_slice, self.chunks()) { - let ca = self.cast(&DataType::Float64).unwrap(); + let ca = self + .cast_with_options(&DataType::Float64, CastOptions::Overflowing) + .unwrap(); ca.agg_mean(groups) } else { _agg_helper_slice::(groups_slice, |[first, len]| { @@ -1029,7 +1034,9 @@ where .. } => { if _use_rolling_kernels(groups_slice, self.chunks()) { - let ca = self.cast(&DataType::Float64).unwrap(); + let ca = self + .cast_with_options(&DataType::Float64, CastOptions::Overflowing) + .unwrap(); ca.agg_var(groups, ddof) } else { _agg_helper_slice::(groups_slice, |[first, len]| { @@ -1077,7 +1084,9 @@ where .. } => { if _use_rolling_kernels(groups_slice, self.chunks()) { - let ca = self.cast(&DataType::Float64).unwrap(); + let ca = self + .cast_with_options(&DataType::Float64, CastOptions::Overflowing) + .unwrap(); ca.agg_std(groups, ddof) } else { _agg_helper_slice::(groups_slice, |[first, len]| { diff --git a/crates/polars-core/src/frame/group_by/into_groups.rs b/crates/polars-core/src/frame/group_by/into_groups.rs index f7066c7b6bcd..365545b0bbb5 100644 --- a/crates/polars-core/src/frame/group_by/into_groups.rs +++ b/crates/polars-core/src/frame/group_by/into_groups.rs @@ -2,6 +2,7 @@ use arrow::legacy::kernels::sort_partition::{create_clean_partitions, partition_ use polars_utils::total_ord::{ToTotalOrd, TotalHash}; use super::*; +use crate::chunked_array::cast::CastOptions; use crate::config::verbose; use crate::prelude::sort::arg_sort_multiple::_get_rows_encoded_ca_unordered; use crate::utils::flatten::flatten_par; @@ -235,13 +236,17 @@ impl IntoGroupsProxy for BooleanChunked { fn group_tuples(&self, multithreaded: bool, sorted: bool) -> PolarsResult { #[cfg(feature = "performant")] { - let ca = self.cast(&DataType::UInt8).unwrap(); + let ca = self + .cast_with_options(&DataType::UInt8, CastOptions::Overflowing) + .unwrap(); let ca = ca.u8().unwrap(); ca.group_tuples(multithreaded, sorted) } #[cfg(not(feature = "performant"))] { - let ca = self.cast(&DataType::UInt32).unwrap(); + let ca = self + .cast_with_options(&DataType::UInt32, CastOptions::Overflowing) + .unwrap(); let ca = ca.u32().unwrap(); ca.group_tuples(multithreaded, sorted) } diff --git a/crates/polars-core/src/frame/mod.rs b/crates/polars-core/src/frame/mod.rs index b3530be45c11..93f965867110 100644 --- a/crates/polars-core/src/frame/mod.rs +++ b/crates/polars-core/src/frame/mod.rs @@ -28,6 +28,7 @@ use arrow::record_batch::RecordBatch; use serde::{Deserialize, Serialize}; use smartstring::alias::String as SmartString; +use crate::chunked_array::cast::CastOptions; #[cfg(feature = "row_hash")] use crate::hashing::_df_rows_to_hashes_threaded_vertical; #[cfg(feature = "zip_with")] @@ -2578,7 +2579,11 @@ impl DataFrame { numeric_df .columns .par_iter() - .map(|s| s.is_null().cast(&DataType::UInt32).unwrap()) + .map(|s| { + s.is_null() + .cast_with_options(&DataType::UInt32, CastOptions::NonStrict) + .unwrap() + }) .reduce_with(|l, r| &l + &r) // we can unwrap the option, because we are certain there is a column // we started this operation on 2 columns diff --git a/crates/polars-core/src/series/any_value.rs b/crates/polars-core/src/series/any_value.rs index 87de80903f65..f0f6f5c46df0 100644 --- a/crates/polars-core/src/series/any_value.rs +++ b/crates/polars-core/src/series/any_value.rs @@ -1,5 +1,7 @@ use std::fmt::Write; +#[cfg(feature = "dtype-categorical")] +use crate::chunked_array::cast::CastOptions; #[cfg(feature = "object")] use crate::chunked_array::object::registry::ObjectRegistry; use crate::prelude::*; @@ -432,7 +434,7 @@ fn any_values_to_categorical( if strict { ca.into_series().strict_cast(dtype) } else { - ca.cast(dtype) + ca.cast_with_options(dtype, CastOptions::NonStrict) } } diff --git a/crates/polars-core/src/series/from.rs b/crates/polars-core/src/series/from.rs index 7d0ce20f6fa9..9a82ed0e2506 100644 --- a/crates/polars-core/src/series/from.rs +++ b/crates/polars-core/src/series/from.rs @@ -11,7 +11,7 @@ use arrow::legacy::kernels::concatenate::concatenate_owned_unchecked; use arrow::temporal_conversions::*; use polars_error::feature_gated; -use crate::chunked_array::cast::cast_chunks; +use crate::chunked_array::cast::{cast_chunks, CastOptions}; #[cfg(feature = "object")] use crate::chunked_array::object::extension::polars_extension::PolarsExtension; #[cfg(feature = "object")] @@ -155,7 +155,8 @@ impl Series { match dtype { ArrowDataType::Utf8View => Ok(StringChunked::from_chunks(name, chunks).into_series()), ArrowDataType::Utf8 | ArrowDataType::LargeUtf8 => { - let chunks = cast_chunks(&chunks, &DataType::String, false).unwrap(); + let chunks = + cast_chunks(&chunks, &DataType::String, CastOptions::NonStrict).unwrap(); Ok(StringChunked::from_chunks(name, chunks).into_series()) }, ArrowDataType::BinaryView => Ok(BinaryChunked::from_chunks(name, chunks).into_series()), @@ -165,11 +166,13 @@ impl Series { return Ok(BinaryOffsetChunked::from_chunks(name, chunks).into_series()); } } - let chunks = cast_chunks(&chunks, &DataType::Binary, false).unwrap(); + let chunks = + cast_chunks(&chunks, &DataType::Binary, CastOptions::NonStrict).unwrap(); Ok(BinaryChunked::from_chunks(name, chunks).into_series()) }, ArrowDataType::Binary => { - let chunks = cast_chunks(&chunks, &DataType::Binary, false).unwrap(); + let chunks = + cast_chunks(&chunks, &DataType::Binary, CastOptions::NonStrict).unwrap(); Ok(BinaryChunked::from_chunks(name, chunks).into_series()) }, ArrowDataType::List(_) | ArrowDataType::LargeList(_) => { @@ -205,21 +208,24 @@ impl Series { ArrowDataType::Int32 => Ok(Int32Chunked::from_chunks(name, chunks).into_series()), ArrowDataType::Int64 => Ok(Int64Chunked::from_chunks(name, chunks).into_series()), ArrowDataType::Float16 => { - let chunks = cast_chunks(&chunks, &DataType::Float32, false).unwrap(); + let chunks = + cast_chunks(&chunks, &DataType::Float32, CastOptions::NonStrict).unwrap(); Ok(Float32Chunked::from_chunks(name, chunks).into_series()) }, ArrowDataType::Float32 => Ok(Float32Chunked::from_chunks(name, chunks).into_series()), ArrowDataType::Float64 => Ok(Float64Chunked::from_chunks(name, chunks).into_series()), #[cfg(feature = "dtype-date")] ArrowDataType::Date32 => { - let chunks = cast_chunks(&chunks, &DataType::Int32, false).unwrap(); + let chunks = + cast_chunks(&chunks, &DataType::Int32, CastOptions::Overflowing).unwrap(); Ok(Int32Chunked::from_chunks(name, chunks) .into_date() .into_series()) }, #[cfg(feature = "dtype-datetime")] ArrowDataType::Date64 => { - let chunks = cast_chunks(&chunks, &DataType::Int64, false).unwrap(); + let chunks = + cast_chunks(&chunks, &DataType::Int64, CastOptions::Overflowing).unwrap(); let ca = Int64Chunked::from_chunks(name, chunks); Ok(ca.into_datetime(TimeUnit::Milliseconds, None).into_series()) }, @@ -234,7 +240,8 @@ impl Series { }, _ => canonical_tz, }; - let chunks = cast_chunks(&chunks, &DataType::Int64, false).unwrap(); + let chunks = + cast_chunks(&chunks, &DataType::Int64, CastOptions::NonStrict).unwrap(); let s = Int64Chunked::from_chunks(name, chunks) .into_datetime(tu.into(), tz) .into_series(); @@ -247,7 +254,8 @@ impl Series { }, #[cfg(feature = "dtype-duration")] ArrowDataType::Duration(tu) => { - let chunks = cast_chunks(&chunks, &DataType::Int64, false).unwrap(); + let chunks = + cast_chunks(&chunks, &DataType::Int64, CastOptions::NonStrict).unwrap(); let s = Int64Chunked::from_chunks(name, chunks) .into_duration(tu.into()) .into_series(); @@ -262,9 +270,11 @@ impl Series { ArrowDataType::Time64(tu) | ArrowDataType::Time32(tu) => { let mut chunks = chunks; if matches!(dtype, ArrowDataType::Time32(_)) { - chunks = cast_chunks(&chunks, &DataType::Int32, false).unwrap(); + chunks = + cast_chunks(&chunks, &DataType::Int32, CastOptions::NonStrict).unwrap(); } - let chunks = cast_chunks(&chunks, &DataType::Int64, false).unwrap(); + let chunks = + cast_chunks(&chunks, &DataType::Int64, CastOptions::NonStrict).unwrap(); let s = Int64Chunked::from_chunks(name, chunks) .into_time() .into_series(); @@ -444,7 +454,7 @@ impl Series { Ok(StructChunked::new_unchecked(name, &fields).into_series()) }, ArrowDataType::FixedSizeBinary(_) => { - let chunks = cast_chunks(&chunks, &DataType::Binary, true)?; + let chunks = cast_chunks(&chunks, &DataType::Binary, CastOptions::NonStrict)?; Ok(BinaryChunked::from_chunks(name, chunks).into_series()) }, ArrowDataType::Decimal(precision, scale) @@ -459,9 +469,12 @@ impl Series { #[cfg(feature = "python")] { let (precision, scale) = (Some(*precision), *scale); - let chunks = - cast_chunks(&chunks, &DataType::Decimal(precision, Some(scale)), false) - .unwrap(); + let chunks = cast_chunks( + &chunks, + &DataType::Decimal(precision, Some(scale)), + CastOptions::NonStrict, + ) + .unwrap(); Ok(Int128Chunked::from_chunks(name, chunks) .into_decimal_unchecked(precision, scale) .into_series()) @@ -470,9 +483,12 @@ impl Series { #[cfg(not(feature = "python"))] { let (precision, scale) = (Some(*precision), *scale); - let chunks = - cast_chunks(&chunks, &DataType::Decimal(precision, Some(scale)), false) - .unwrap(); + let chunks = cast_chunks( + &chunks, + &DataType::Decimal(precision, Some(scale)), + CastOptions::NonStrict, + ) + .unwrap(); // or DecimalChunked? Ok(Int128Chunked::from_chunks(name, chunks) .into_decimal_unchecked(precision, scale) @@ -519,11 +535,11 @@ unsafe fn to_physical_and_dtype( ) -> (Vec, DataType) { match arrays[0].data_type() { ArrowDataType::Utf8 | ArrowDataType::LargeUtf8 => { - let chunks = cast_chunks(&arrays, &DataType::String, false).unwrap(); + let chunks = cast_chunks(&arrays, &DataType::String, CastOptions::NonStrict).unwrap(); (chunks, DataType::String) }, ArrowDataType::Binary | ArrowDataType::LargeBinary | ArrowDataType::FixedSizeBinary(_) => { - let chunks = cast_chunks(&arrays, &DataType::Binary, false).unwrap(); + let chunks = cast_chunks(&arrays, &DataType::Binary, CastOptions::NonStrict).unwrap(); (chunks, DataType::Binary) }, #[allow(unused_variables)] diff --git a/crates/polars-core/src/series/implementations/array.rs b/crates/polars-core/src/series/implementations/array.rs index a660c68db7a2..7a862fae2a61 100644 --- a/crates/polars-core/src/series/implementations/array.rs +++ b/crates/polars-core/src/series/implementations/array.rs @@ -2,6 +2,7 @@ use std::any::Any; use std::borrow::Cow; use super::{private, MetadataFlags}; +use crate::chunked_array::cast::CastOptions; use crate::chunked_array::comparison::*; use crate::chunked_array::ops::explode::ExplodeByOffsets; use crate::chunked_array::AsSinglePtr; @@ -140,8 +141,8 @@ impl SeriesTrait for SeriesWrap { ChunkExpandAtIndex::new_from_index(&self.0, index, length).into_series() } - fn cast(&self, data_type: &DataType) -> PolarsResult { - self.0.cast(data_type) + fn cast(&self, data_type: &DataType, options: CastOptions) -> PolarsResult { + self.0.cast_with_options(data_type, options) } fn get(&self, index: usize) -> PolarsResult { diff --git a/crates/polars-core/src/series/implementations/binary.rs b/crates/polars-core/src/series/implementations/binary.rs index 1aab1ba9bc00..f1362c1a0af3 100644 --- a/crates/polars-core/src/series/implementations/binary.rs +++ b/crates/polars-core/src/series/implementations/binary.rs @@ -1,4 +1,5 @@ use super::*; +use crate::chunked_array::cast::CastOptions; use crate::chunked_array::comparison::*; #[cfg(feature = "algorithm_group_by")] use crate::frame::group_by::*; @@ -164,8 +165,8 @@ impl SeriesTrait for SeriesWrap { ChunkExpandAtIndex::new_from_index(&self.0, index, length).into_series() } - fn cast(&self, data_type: &DataType) -> PolarsResult { - self.0.cast(data_type) + fn cast(&self, data_type: &DataType, options: CastOptions) -> PolarsResult { + self.0.cast_with_options(data_type, options) } fn get(&self, index: usize) -> PolarsResult { diff --git a/crates/polars-core/src/series/implementations/binary_offset.rs b/crates/polars-core/src/series/implementations/binary_offset.rs index 32ffc1a5cc11..f0d1ccf0b593 100644 --- a/crates/polars-core/src/series/implementations/binary_offset.rs +++ b/crates/polars-core/src/series/implementations/binary_offset.rs @@ -134,8 +134,8 @@ impl SeriesTrait for SeriesWrap { ChunkExpandAtIndex::new_from_index(&self.0, index, length).into_series() } - fn cast(&self, data_type: &DataType) -> PolarsResult { - self.0.cast(data_type) + fn cast(&self, data_type: &DataType, options: CastOptions) -> PolarsResult { + self.0.cast_with_options(data_type, options) } fn get(&self, index: usize) -> PolarsResult { diff --git a/crates/polars-core/src/series/implementations/boolean.rs b/crates/polars-core/src/series/implementations/boolean.rs index 1c55fe69898e..78aafff05eee 100644 --- a/crates/polars-core/src/series/implementations/boolean.rs +++ b/crates/polars-core/src/series/implementations/boolean.rs @@ -71,14 +71,14 @@ impl private::PrivateSeries for SeriesWrap { #[cfg(feature = "algorithm_group_by")] unsafe fn agg_std(&self, groups: &GroupsProxy, _ddof: u8) -> Series { self.0 - .cast(&DataType::Float64) + .cast_with_options(&DataType::Float64, CastOptions::Overflowing) .unwrap() .agg_std(groups, _ddof) } #[cfg(feature = "algorithm_group_by")] unsafe fn agg_var(&self, groups: &GroupsProxy, _ddof: u8) -> Series { self.0 - .cast(&DataType::Float64) + .cast_with_options(&DataType::Float64, CastOptions::Overflowing) .unwrap() .agg_var(groups, _ddof) } @@ -189,8 +189,8 @@ impl SeriesTrait for SeriesWrap { ChunkExpandAtIndex::new_from_index(&self.0, index, length).into_series() } - fn cast(&self, data_type: &DataType) -> PolarsResult { - self.0.cast(data_type) + fn cast(&self, data_type: &DataType, options: CastOptions) -> PolarsResult { + self.0.cast_with_options(data_type, options) } fn get(&self, index: usize) -> PolarsResult { @@ -263,21 +263,30 @@ impl SeriesTrait for SeriesWrap { Ok(ChunkAggSeries::min_reduce(&self.0)) } fn median_reduce(&self) -> PolarsResult { - let ca = self.0.cast(&DataType::Int8).unwrap(); + let ca = self + .0 + .cast_with_options(&DataType::Int8, CastOptions::Overflowing) + .unwrap(); let sc = ca.median_reduce()?; let v = sc.value().cast(&DataType::Float64); Ok(Scalar::new(DataType::Float64, v)) } /// Get the variance of the Series as a new Series of length 1. fn var_reduce(&self, _ddof: u8) -> PolarsResult { - let ca = self.0.cast(&DataType::Int8).unwrap(); + let ca = self + .0 + .cast_with_options(&DataType::Int8, CastOptions::Overflowing) + .unwrap(); let sc = ca.var_reduce(_ddof)?; let v = sc.value().cast(&DataType::Float64); Ok(Scalar::new(DataType::Float64, v)) } /// Get the standard deviation of the Series as a new Series of length 1. fn std_reduce(&self, _ddof: u8) -> PolarsResult { - let ca = self.0.cast(&DataType::Int8).unwrap(); + let ca = self + .0 + .cast_with_options(&DataType::Int8, CastOptions::Overflowing) + .unwrap(); let sc = ca.std_reduce(_ddof)?; let v = sc.value().cast(&DataType::Float64); Ok(Scalar::new(DataType::Float64, v)) diff --git a/crates/polars-core/src/series/implementations/categorical.rs b/crates/polars-core/src/series/implementations/categorical.rs index a0a71f40feaf..f47d4837c3ee 100644 --- a/crates/polars-core/src/series/implementations/categorical.rs +++ b/crates/polars-core/src/series/implementations/categorical.rs @@ -217,8 +217,8 @@ impl SeriesTrait for SeriesWrap { .into_series() } - fn cast(&self, data_type: &DataType) -> PolarsResult { - self.0.cast(data_type) + fn cast(&self, data_type: &DataType, options: CastOptions) -> PolarsResult { + self.0.cast_with_options(data_type, options) } fn get(&self, index: usize) -> PolarsResult { diff --git a/crates/polars-core/src/series/implementations/date.rs b/crates/polars-core/src/series/implementations/date.rs index 2b175ade5021..735a4b307056 100644 --- a/crates/polars-core/src/series/implementations/date.rs +++ b/crates/polars-core/src/series/implementations/date.rs @@ -84,12 +84,15 @@ impl private::PrivateSeries for SeriesWrap { match rhs.dtype() { DataType::Date => { let dt = DataType::Datetime(TimeUnit::Milliseconds, None); - let lhs = self.cast(&dt)?; + let lhs = self.cast(&dt, CastOptions::NonStrict)?; let rhs = rhs.cast(&dt)?; lhs.subtract(&rhs) }, DataType::Duration(_) => ((&self - .cast(&DataType::Datetime(TimeUnit::Milliseconds, None)) + .cast( + &DataType::Datetime(TimeUnit::Milliseconds, None), + CastOptions::NonStrict, + ) .unwrap()) - rhs) .cast(&DataType::Date), @@ -100,7 +103,10 @@ impl private::PrivateSeries for SeriesWrap { fn add_to(&self, rhs: &Series) -> PolarsResult { match rhs.dtype() { DataType::Duration(_) => ((&self - .cast(&DataType::Datetime(TimeUnit::Milliseconds, None)) + .cast( + &DataType::Datetime(TimeUnit::Milliseconds, None), + CastOptions::NonStrict, + ) .unwrap()) + rhs) .cast(&DataType::Date), @@ -224,7 +230,7 @@ impl SeriesTrait for SeriesWrap { .into_series() } - fn cast(&self, data_type: &DataType) -> PolarsResult { + fn cast(&self, data_type: &DataType, cast_options: CastOptions) -> PolarsResult { match data_type { DataType::String => Ok(self .0 @@ -236,11 +242,13 @@ impl SeriesTrait for SeriesWrap { .into_series()), #[cfg(feature = "dtype-datetime")] DataType::Datetime(_, _) => { - let mut out = self.0.cast(data_type)?; + let mut out = self + .0 + .cast_with_options(data_type, CastOptions::NonStrict)?; out.set_sorted_flag(self.0.is_sorted_flag()); Ok(out) }, - _ => self.0.cast(data_type), + _ => self.0.cast_with_options(data_type, cast_options), } } diff --git a/crates/polars-core/src/series/implementations/datetime.rs b/crates/polars-core/src/series/implementations/datetime.rs index b2575a432c00..04040dda8527 100644 --- a/crates/polars-core/src/series/implementations/datetime.rs +++ b/crates/polars-core/src/series/implementations/datetime.rs @@ -90,13 +90,13 @@ impl private::PrivateSeries for SeriesWrap { (DataType::Datetime(tu, tz), DataType::Datetime(tur, tzr)) => { assert_eq!(tu, tur); assert_eq!(tz, tzr); - let lhs = self.cast(&DataType::Int64).unwrap(); + let lhs = self.cast(&DataType::Int64, CastOptions::NonStrict).unwrap(); let rhs = rhs.cast(&DataType::Int64).unwrap(); Ok(lhs.subtract(&rhs)?.into_duration(*tu).into_series()) }, (DataType::Datetime(tu, tz), DataType::Duration(tur)) => { assert_eq!(tu, tur); - let lhs = self.cast(&DataType::Int64).unwrap(); + let lhs = self.cast(&DataType::Int64, CastOptions::NonStrict).unwrap(); let rhs = rhs.cast(&DataType::Int64).unwrap(); Ok(lhs .subtract(&rhs)? @@ -110,7 +110,7 @@ impl private::PrivateSeries for SeriesWrap { match (self.dtype(), rhs.dtype()) { (DataType::Datetime(tu, tz), DataType::Duration(tur)) => { assert_eq!(tu, tur); - let lhs = self.cast(&DataType::Int64).unwrap(); + let lhs = self.cast(&DataType::Int64, CastOptions::NonStrict).unwrap(); let rhs = rhs.cast(&DataType::Int64).unwrap(); Ok(lhs .add_to(&rhs)? @@ -246,7 +246,7 @@ impl SeriesTrait for SeriesWrap { .into_series() } - fn cast(&self, data_type: &DataType) -> PolarsResult { + fn cast(&self, data_type: &DataType, cast_options: CastOptions) -> PolarsResult { match (data_type, self.0.time_unit()) { (DataType::String, TimeUnit::Milliseconds) => { Ok(self.0.to_string("%F %T%.3f")?.into_series()) @@ -257,7 +257,7 @@ impl SeriesTrait for SeriesWrap { (DataType::String, TimeUnit::Nanoseconds) => { Ok(self.0.to_string("%F %T%.9f")?.into_series()) }, - _ => self.0.cast(data_type), + _ => self.0.cast_with_options(data_type, cast_options), } } diff --git a/crates/polars-core/src/series/implementations/decimal.rs b/crates/polars-core/src/series/implementations/decimal.rs index 12b3e6925958..b30a0a9a4256 100644 --- a/crates/polars-core/src/series/implementations/decimal.rs +++ b/crates/polars-core/src/series/implementations/decimal.rs @@ -263,8 +263,8 @@ impl SeriesTrait for SeriesWrap { .into_series() } - fn cast(&self, data_type: &DataType) -> PolarsResult { - self.0.cast(data_type) + fn cast(&self, data_type: &DataType, cast_options: CastOptions) -> PolarsResult { + self.0.cast_with_options(data_type, cast_options) } fn get(&self, index: usize) -> PolarsResult { diff --git a/crates/polars-core/src/series/implementations/duration.rs b/crates/polars-core/src/series/implementations/duration.rs index d85f30addf3a..8caafe33c041 100644 --- a/crates/polars-core/src/series/implementations/duration.rs +++ b/crates/polars-core/src/series/implementations/duration.rs @@ -127,7 +127,7 @@ impl private::PrivateSeries for SeriesWrap { match (self.dtype(), rhs.dtype()) { (DataType::Duration(tu), DataType::Duration(tur)) => { polars_ensure!(tu == tur, InvalidOperation: "units are different"); - let lhs = self.cast(&DataType::Int64).unwrap(); + let lhs = self.cast(&DataType::Int64, CastOptions::NonStrict).unwrap(); let rhs = rhs.cast(&DataType::Int64).unwrap(); Ok(lhs.subtract(&rhs)?.into_duration(*tu).into_series()) }, @@ -138,7 +138,7 @@ impl private::PrivateSeries for SeriesWrap { match (self.dtype(), rhs.dtype()) { (DataType::Duration(tu), DataType::Duration(tur)) => { polars_ensure!(tu == tur, InvalidOperation: "units are different"); - let lhs = self.cast(&DataType::Int64).unwrap(); + let lhs = self.cast(&DataType::Int64, CastOptions::NonStrict).unwrap(); let rhs = rhs.cast(&DataType::Int64).unwrap(); Ok(lhs.add_to(&rhs)?.into_duration(*tu).into_series()) }, @@ -148,7 +148,8 @@ impl private::PrivateSeries for SeriesWrap { TimeUnit::Microseconds => 86_400_000_000, TimeUnit::Nanoseconds => 86_400_000_000_000, }; - let lhs = self.cast(&DataType::Int64).unwrap() / one_day_in_tu; + let lhs = + self.cast(&DataType::Int64, CastOptions::NonStrict).unwrap() / one_day_in_tu; let rhs = rhs .cast(&DataType::Int32) .unwrap() @@ -162,7 +163,7 @@ impl private::PrivateSeries for SeriesWrap { }, (DataType::Duration(tu), DataType::Datetime(tur, tz)) => { polars_ensure!(tu == tur, InvalidOperation: "units are different"); - let lhs = self.cast(&DataType::Int64).unwrap(); + let lhs = self.cast(&DataType::Int64, CastOptions::NonStrict).unwrap(); let rhs = rhs.cast(&DataType::Int64).unwrap(); Ok(lhs .add_to(&rhs)? @@ -182,7 +183,7 @@ impl private::PrivateSeries for SeriesWrap { } fn remainder(&self, rhs: &Series) -> PolarsResult { polars_ensure!(self.dtype() == rhs.dtype(), InvalidOperation: "dtypes and units must be equal in duration arithmetic"); - let lhs = self.cast(&DataType::Int64).unwrap(); + let lhs = self.cast(&DataType::Int64, CastOptions::NonStrict).unwrap(); let rhs = rhs.cast(&DataType::Int64).unwrap(); Ok(lhs .remainder(&rhs)? @@ -317,8 +318,8 @@ impl SeriesTrait for SeriesWrap { .into_series() } - fn cast(&self, data_type: &DataType) -> PolarsResult { - self.0.cast(data_type) + fn cast(&self, data_type: &DataType, cast_options: CastOptions) -> PolarsResult { + self.0.cast_with_options(data_type, cast_options) } fn get(&self, index: usize) -> PolarsResult { diff --git a/crates/polars-core/src/series/implementations/floats.rs b/crates/polars-core/src/series/implementations/floats.rs index 0c4ad48ed650..64d35f4b741d 100644 --- a/crates/polars-core/src/series/implementations/floats.rs +++ b/crates/polars-core/src/series/implementations/floats.rs @@ -220,8 +220,12 @@ macro_rules! impl_dyn_series { ChunkExpandAtIndex::new_from_index(&self.0, index, length).into_series() } - fn cast(&self, data_type: &DataType) -> PolarsResult { - self.0.cast(data_type) + fn cast( + &self, + data_type: &DataType, + cast_options: CastOptions, + ) -> PolarsResult { + self.0.cast_with_options(data_type, cast_options) } fn get(&self, index: usize) -> PolarsResult { diff --git a/crates/polars-core/src/series/implementations/list.rs b/crates/polars-core/src/series/implementations/list.rs index 86d63e29334a..e02c670e14df 100644 --- a/crates/polars-core/src/series/implementations/list.rs +++ b/crates/polars-core/src/series/implementations/list.rs @@ -125,8 +125,8 @@ impl SeriesTrait for SeriesWrap { ChunkExpandAtIndex::new_from_index(&self.0, index, length).into_series() } - fn cast(&self, data_type: &DataType) -> PolarsResult { - self.0.cast(data_type) + fn cast(&self, data_type: &DataType, cast_options: CastOptions) -> PolarsResult { + self.0.cast_with_options(data_type, cast_options) } fn get(&self, index: usize) -> PolarsResult { diff --git a/crates/polars-core/src/series/implementations/mod.rs b/crates/polars-core/src/series/implementations/mod.rs index 117e91f93768..cbc62481235d 100644 --- a/crates/polars-core/src/series/implementations/mod.rs +++ b/crates/polars-core/src/series/implementations/mod.rs @@ -147,7 +147,10 @@ macro_rules! impl_dyn_series { unsafe fn agg_sum(&self, groups: &GroupsProxy) -> Series { use DataType::*; match self.dtype() { - Int8 | UInt8 | Int16 | UInt16 => self.cast(&Int64).unwrap().agg_sum(groups), + Int8 | UInt8 | Int16 | UInt16 => self + .cast(&Int64, CastOptions::Overflowing) + .unwrap() + .agg_sum(groups), _ => self.0.agg_sum(groups), } } @@ -321,8 +324,8 @@ macro_rules! impl_dyn_series { ChunkExpandAtIndex::new_from_index(&self.0, index, length).into_series() } - fn cast(&self, data_type: &DataType) -> PolarsResult { - self.0.cast(data_type) + fn cast(&self, data_type: &DataType, options: CastOptions) -> PolarsResult { + self.0.cast_with_options(data_type, options) } fn get(&self, index: usize) -> PolarsResult { @@ -464,7 +467,7 @@ impl private::PrivateSeriesNumeric for SeriesWrap { } fn bit_repr_small(&self) -> UInt32Chunked { self.0 - .cast(&DataType::UInt32) + .cast_with_options(&DataType::UInt32, CastOptions::NonStrict) .unwrap() .u32() .unwrap() diff --git a/crates/polars-core/src/series/implementations/null.rs b/crates/polars-core/src/series/implementations/null.rs index 02e06abc3c42..2fed6a1923c9 100644 --- a/crates/polars-core/src/series/implementations/null.rs +++ b/crates/polars-core/src/series/implementations/null.rs @@ -196,7 +196,7 @@ impl SeriesTrait for NullChunked { NullChunked::new(self.name.clone(), 0).into_series() } - fn cast(&self, data_type: &DataType) -> PolarsResult { + fn cast(&self, data_type: &DataType, _cast_options: CastOptions) -> PolarsResult { Ok(Series::full_null(self.name.as_ref(), self.len(), data_type)) } diff --git a/crates/polars-core/src/series/implementations/object.rs b/crates/polars-core/src/series/implementations/object.rs index b7f59999ee21..d65d0100cd2f 100644 --- a/crates/polars-core/src/series/implementations/object.rs +++ b/crates/polars-core/src/series/implementations/object.rs @@ -4,6 +4,7 @@ use std::borrow::Cow; use ahash::RandomState; use super::MetadataFlags; +use crate::chunked_array::cast::CastOptions; use crate::chunked_array::object::PolarsObjectSafe; use crate::chunked_array::ops::compare_inner::{IntoTotalEqInner, TotalEqInner}; use crate::prelude::*; @@ -148,7 +149,7 @@ where ChunkExpandAtIndex::new_from_index(&self.0, index, length).into_series() } - fn cast(&self, data_type: &DataType) -> PolarsResult { + fn cast(&self, data_type: &DataType, _cast_options: CastOptions) -> PolarsResult { if matches!(data_type, DataType::Object(_, None)) { Ok(self.0.clone().into_series()) } else { diff --git a/crates/polars-core/src/series/implementations/string.rs b/crates/polars-core/src/series/implementations/string.rs index 9b2fec66a71c..87d1401a8988 100644 --- a/crates/polars-core/src/series/implementations/string.rs +++ b/crates/polars-core/src/series/implementations/string.rs @@ -171,8 +171,8 @@ impl SeriesTrait for SeriesWrap { ChunkExpandAtIndex::new_from_index(&self.0, index, length).into_series() } - fn cast(&self, data_type: &DataType) -> PolarsResult { - self.0.cast(data_type) + fn cast(&self, data_type: &DataType, cast_options: CastOptions) -> PolarsResult { + self.0.cast_with_options(data_type, cast_options) } fn get(&self, index: usize) -> PolarsResult { diff --git a/crates/polars-core/src/series/implementations/struct_.rs b/crates/polars-core/src/series/implementations/struct_.rs index 971ef90904d4..e3f0620dbbe7 100644 --- a/crates/polars-core/src/series/implementations/struct_.rs +++ b/crates/polars-core/src/series/implementations/struct_.rs @@ -218,8 +218,8 @@ impl SeriesTrait for SeriesWrap { .into_series() } - fn cast(&self, dtype: &DataType) -> PolarsResult { - self.0.cast(dtype) + fn cast(&self, dtype: &DataType, cast_options: CastOptions) -> PolarsResult { + self.0.cast_with_options(dtype, cast_options) } fn get(&self, index: usize) -> PolarsResult { diff --git a/crates/polars-core/src/series/implementations/time.rs b/crates/polars-core/src/series/implementations/time.rs index 463d635fdf30..d6a1bcbd7932 100644 --- a/crates/polars-core/src/series/implementations/time.rs +++ b/crates/polars-core/src/series/implementations/time.rs @@ -206,7 +206,7 @@ impl SeriesTrait for SeriesWrap { .into_series() } - fn cast(&self, data_type: &DataType) -> PolarsResult { + fn cast(&self, data_type: &DataType, cast_options: CastOptions) -> PolarsResult { match data_type { DataType::String => Ok(self .0 @@ -216,7 +216,7 @@ impl SeriesTrait for SeriesWrap { .unwrap() .to_string("%T") .into_series()), - _ => self.0.cast(data_type), + _ => self.0.cast_with_options(data_type, cast_options), } } diff --git a/crates/polars-core/src/series/mod.rs b/crates/polars-core/src/series/mod.rs index 13d2f5f8ec62..4de7924fa334 100644 --- a/crates/polars-core/src/series/mod.rs +++ b/crates/polars-core/src/series/mod.rs @@ -26,6 +26,7 @@ use num_traits::NumCast; use rayon::prelude::*; pub use series_trait::{IsSorted, *}; +use crate::chunked_array::cast::CastOptions; use crate::chunked_array::metadata::{Metadata, MetadataFlags}; #[cfg(feature = "zip_with")] use crate::series::arithmetic::coerce_lhs_rhs; @@ -361,8 +362,12 @@ impl Series { self._get_inner_mut().as_single_ptr() } - /// Cast `[Series]` to another `[DataType]`. pub fn cast(&self, dtype: &DataType) -> PolarsResult { + self.cast_with_options(dtype, CastOptions::NonStrict) + } + + /// Cast `[Series]` to another `[DataType]`. + pub fn cast_with_options(&self, dtype: &DataType, options: CastOptions) -> PolarsResult { use DataType as D; let do_clone = match dtype { @@ -428,12 +433,30 @@ impl Series { Some(ref dtype) => dtype, }; - let ret = self.0.cast(dtype); + // Always allow casting all nulls to other all nulls. let len = self.len(); if self.null_count() == len { return Ok(Series::full_null(self.name(), len, dtype)); } - ret + + let new_options = match options { + // Strictness is handled on this level to improve error messages. + CastOptions::Strict => CastOptions::NonStrict, + opt => opt, + }; + + let ret = self.0.cast(dtype, new_options); + + match options { + CastOptions::NonStrict | CastOptions::Overflowing => ret, + CastOptions::Strict => { + let ret = ret?; + if self.null_count() != ret.null_count() { + handle_casting_failures(self, &ret)?; + } + Ok(ret) + }, + } } /// Cast from physical to logical types without any checks on the validity of the cast. @@ -452,7 +475,7 @@ impl Series { }) }, DataType::Binary => self.binary().unwrap().cast_unchecked(dtype), - _ => self.cast(dtype), + _ => self.cast_with_options(dtype, CastOptions::Overflowing), } } @@ -460,7 +483,7 @@ impl Series { pub fn to_float(&self) -> PolarsResult { match self.dtype() { DataType::Float32 | DataType::Float64 => Ok(self.clone()), - _ => self.cast(&DataType::Float64), + _ => self.cast_with_options(&DataType::Float64, CastOptions::Overflowing), } } @@ -753,11 +776,7 @@ impl Series { /// Cast throws an error if conversion had overflows pub fn strict_cast(&self, dtype: &DataType) -> PolarsResult { - let s = self.cast(dtype)?; - if self.null_count() != s.null_count() { - handle_casting_failures(self, &s)?; - } - Ok(s) + self.cast_with_options(dtype, CastOptions::Strict) } #[cfg(feature = "dtype-time")] diff --git a/crates/polars-core/src/series/series_trait.rs b/crates/polars-core/src/series/series_trait.rs index cc897c775537..7aa12d23c350 100644 --- a/crates/polars-core/src/series/series_trait.rs +++ b/crates/polars-core/src/series/series_trait.rs @@ -4,6 +4,7 @@ use std::borrow::Cow; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; +use crate::chunked_array::cast::CastOptions; #[cfg(feature = "object")] use crate::chunked_array::object::PolarsObjectSafe; use crate::prelude::*; @@ -321,7 +322,7 @@ pub trait SeriesTrait: /// ``` fn new_from_index(&self, _index: usize, _length: usize) -> Series; - fn cast(&self, _data_type: &DataType) -> PolarsResult; + fn cast(&self, _data_type: &DataType, options: CastOptions) -> PolarsResult; /// Get a single value by index. Don't use this operation for loops as a runtime cast is /// needed for every iteration. diff --git a/crates/polars-expr/src/expressions/cast.rs b/crates/polars-expr/src/expressions/cast.rs index d2e463a5cad2..f78f1f7c11f7 100644 --- a/crates/polars-expr/src/expressions/cast.rs +++ b/crates/polars-expr/src/expressions/cast.rs @@ -1,3 +1,4 @@ +use polars_core::chunked_array::cast::CastOptions; use polars_core::prelude::*; use super::*; @@ -7,16 +8,12 @@ pub struct CastExpr { pub(crate) input: Arc, pub(crate) data_type: DataType, pub(crate) expr: Expr, - pub(crate) strict: bool, + pub(crate) options: CastOptions, } impl CastExpr { fn finish(&self, input: &Series) -> PolarsResult { - if self.strict { - input.strict_cast(&self.data_type) - } else { - input.cast(&self.data_type) - } + input.cast_with_options(&self.data_type, self.options) } } diff --git a/crates/polars-expr/src/planner.rs b/crates/polars-expr/src/planner.rs index 9a993a9ef740..01756efd9f09 100644 --- a/crates/polars-expr/src/planner.rs +++ b/crates/polars-expr/src/planner.rs @@ -428,14 +428,14 @@ fn create_physical_expr_inner( Cast { expr, data_type, - strict, + options, } => { let phys_expr = create_physical_expr_inner(*expr, ctxt, expr_arena, schema, state)?; Ok(Arc::new(CastExpr { input: phys_expr, data_type: data_type.clone(), expr: node_to_expr(expression, expr_arena), - strict: *strict, + options: *options, })) }, Ternary { diff --git a/crates/polars-plan/src/dsl/expr.rs b/crates/polars-plan/src/dsl/expr.rs index b9407a19bdbb..2f569330c24c 100644 --- a/crates/polars-plan/src/dsl/expr.rs +++ b/crates/polars-plan/src/dsl/expr.rs @@ -1,6 +1,7 @@ use std::fmt::{Debug, Display, Formatter}; use std::hash::{Hash, Hasher}; +use polars_core::chunked_array::cast::CastOptions; use polars_core::prelude::*; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; @@ -81,7 +82,7 @@ pub enum Expr { Cast { expr: Arc, data_type: DataType, - strict: bool, + options: CastOptions, }, Sort { expr: Arc, @@ -192,7 +193,7 @@ impl Hash for Expr { Expr::Cast { expr, data_type, - strict, + options: strict, } => { expr.hash(state); data_type.hash(state); diff --git a/crates/polars-plan/src/dsl/functions/syntactic_sugar.rs b/crates/polars-plan/src/dsl/functions/syntactic_sugar.rs index 5315709da4cf..5e1e45a0124c 100644 --- a/crates/polars-plan/src/dsl/functions/syntactic_sugar.rs +++ b/crates/polars-plan/src/dsl/functions/syntactic_sugar.rs @@ -1,3 +1,5 @@ +use polars_core::chunked_array::cast::CastOptions; + use super::*; /// Sum all the values in the column named `name`. Shorthand for `col(name).sum()`. @@ -59,6 +61,6 @@ pub fn cast(expr: Expr, data_type: DataType) -> Expr { Expr::Cast { expr: Arc::new(expr), data_type, - strict: false, + options: CastOptions::NonStrict, } } diff --git a/crates/polars-plan/src/dsl/mod.rs b/crates/polars-plan/src/dsl/mod.rs index 876c4e72fda8..46b8973209f7 100644 --- a/crates/polars-plan/src/dsl/mod.rs +++ b/crates/polars-plan/src/dsl/mod.rs @@ -56,6 +56,7 @@ pub use list::*; pub use meta::*; pub use name::*; pub use options::*; +use polars_core::chunked_array::cast::CastOptions; use polars_core::error::feature_gated; use polars_core::prelude::*; #[cfg(feature = "diff")] @@ -379,7 +380,7 @@ impl Expr { Expr::Cast { expr: Arc::new(self), data_type, - strict: true, + options: CastOptions::Strict, } } @@ -388,7 +389,16 @@ impl Expr { Expr::Cast { expr: Arc::new(self), data_type, - strict: false, + options: CastOptions::NonStrict, + } + } + + /// Cast expression to another data type. + pub fn cast_with_options(self, data_type: DataType, cast_options: CastOptions) -> Self { + Expr::Cast { + expr: Arc::new(self), + data_type, + options: cast_options, } } diff --git a/crates/polars-plan/src/logical_plan/aexpr/hash.rs b/crates/polars-plan/src/logical_plan/aexpr/hash.rs index ec9e297faec6..3860419dac47 100644 --- a/crates/polars-plan/src/logical_plan/aexpr/hash.rs +++ b/crates/polars-plan/src/logical_plan/aexpr/hash.rs @@ -26,7 +26,9 @@ impl Hash for AExpr { }, AExpr::Agg(agg) => agg.hash(state), AExpr::SortBy { sort_options, .. } => sort_options.hash(state), - AExpr::Cast { strict, .. } => strict.hash(state), + AExpr::Cast { + options: strict, .. + } => strict.hash(state), AExpr::Window { options, .. } => options.hash(state), AExpr::BinaryExpr { op, .. } => op.hash(state), _ => {}, diff --git a/crates/polars-plan/src/logical_plan/aexpr/mod.rs b/crates/polars-plan/src/logical_plan/aexpr/mod.rs index 21458ac6bb8d..a0efdbf73c44 100644 --- a/crates/polars-plan/src/logical_plan/aexpr/mod.rs +++ b/crates/polars-plan/src/logical_plan/aexpr/mod.rs @@ -7,6 +7,7 @@ use std::hash::{Hash, Hasher}; #[cfg(feature = "cse")] pub(super) use hash::traverse_and_hash_aexpr; +use polars_core::chunked_array::cast::CastOptions; use polars_core::prelude::*; use polars_core::utils::{get_time_units, try_get_supertype}; use polars_utils::arena::{Arena, Node}; @@ -136,7 +137,7 @@ pub enum AExpr { Cast { expr: Node, data_type: DataType, - strict: bool, + options: CastOptions, }, Sort { expr: Node, diff --git a/crates/polars-plan/src/logical_plan/alp/format.rs b/crates/polars-plan/src/logical_plan/alp/format.rs index 672a42dfd2e1..e78daa1cd3dc 100644 --- a/crates/polars-plan/src/logical_plan/alp/format.rs +++ b/crates/polars-plan/src/logical_plan/alp/format.rs @@ -582,10 +582,10 @@ impl<'a> Display for ExprIRDisplay<'a> { Cast { expr, data_type, - strict, + options, } => { self.with_root(expr).fmt(f)?; - if *strict { + if options.strict() { write!(f, ".strict_cast({data_type:?})") } else { write!(f, ".cast({data_type:?})") diff --git a/crates/polars-plan/src/logical_plan/alp/tree_format.rs b/crates/polars-plan/src/logical_plan/alp/tree_format.rs index 7e69dca163a3..a260d09bf466 100644 --- a/crates/polars-plan/src/logical_plan/alp/tree_format.rs +++ b/crates/polars-plan/src/logical_plan/alp/tree_format.rs @@ -31,9 +31,9 @@ impl fmt::Display for TreeFmtAExpr<'_> { AExpr::Literal(lv) => return write!(f, "lit({lv:?})"), AExpr::BinaryExpr { op, .. } => return write!(f, "binary: {}", op), AExpr::Cast { - data_type, strict, .. + data_type, options, .. } => { - return if *strict { + return if options.strict() { write!(f, "strict cast({})", data_type) } else { write!(f, "cast({})", data_type) diff --git a/crates/polars-plan/src/logical_plan/conversion/expr_to_ir.rs b/crates/polars-plan/src/logical_plan/conversion/expr_to_ir.rs index 391165573ec8..f61357d203b7 100644 --- a/crates/polars-plan/src/logical_plan/conversion/expr_to_ir.rs +++ b/crates/polars-plan/src/logical_plan/conversion/expr_to_ir.rs @@ -145,11 +145,11 @@ fn to_aexpr_impl(expr: Expr, arena: &mut Arena, state: &mut ConversionSta Expr::Cast { expr, data_type, - strict, + options, } => AExpr::Cast { expr: to_aexpr_impl(owned(expr), arena, state), data_type, - strict, + options, }, Expr::Gather { expr, diff --git a/crates/polars-plan/src/logical_plan/conversion/ir_to_dsl.rs b/crates/polars-plan/src/logical_plan/conversion/ir_to_dsl.rs index 8204264d02b1..a7c2fac17edf 100644 --- a/crates/polars-plan/src/logical_plan/conversion/ir_to_dsl.rs +++ b/crates/polars-plan/src/logical_plan/conversion/ir_to_dsl.rs @@ -25,13 +25,13 @@ pub fn node_to_expr(node: Node, expr_arena: &Arena) -> Expr { AExpr::Cast { expr, data_type, - strict, + options: strict, } => { let exp = node_to_expr(expr, expr_arena); Expr::Cast { expr: Arc::new(exp), data_type, - strict, + options: strict, } }, AExpr::Sort { expr, options } => { diff --git a/crates/polars-plan/src/logical_plan/conversion/type_coercion/binary.rs b/crates/polars-plan/src/logical_plan/conversion/type_coercion/binary.rs index 6c02440ff3d8..dbe603ca62dd 100644 --- a/crates/polars-plan/src/logical_plan/conversion/type_coercion/binary.rs +++ b/crates/polars-plan/src/logical_plan/conversion/type_coercion/binary.rs @@ -97,7 +97,7 @@ fn process_list_arithmetic( let new_node_right = expr_arena.add(AExpr::Cast { expr: node_right, data_type: *inner.clone(), - strict: false, + options: CastOptions::NonStrict, }); Ok(Some(AExpr::BinaryExpr { @@ -114,7 +114,7 @@ fn process_list_arithmetic( let new_node_left = expr_arena.add(AExpr::Cast { expr: node_left, data_type: *inner.clone(), - strict: false, + options: CastOptions::NonStrict, }); Ok(Some(AExpr::BinaryExpr { @@ -147,7 +147,7 @@ fn process_struct_numeric_arithmetic( let new_node_right = expr_arena.add(AExpr::Cast { expr: node_right, data_type: DataType::Struct(vec![first.clone()]), - strict: false, + options: CastOptions::NonStrict, }); Ok(Some(AExpr::BinaryExpr { left: node_left, @@ -163,7 +163,7 @@ fn process_struct_numeric_arithmetic( let new_node_left = expr_arena.add(AExpr::Cast { expr: node_left, data_type: DataType::Struct(vec![first.clone()]), - strict: false, + options: CastOptions::NonStrict, }); Ok(Some(AExpr::BinaryExpr { @@ -337,7 +337,7 @@ pub(super) fn process_binary( expr_arena.add(AExpr::Cast { expr: node_left, data_type: st.clone(), - strict: false, + options: CastOptions::NonStrict, }) } else { node_left @@ -346,7 +346,7 @@ pub(super) fn process_binary( expr_arena.add(AExpr::Cast { expr: node_right, data_type: st, - strict: false, + options: CastOptions::NonStrict, }) } else { node_right diff --git a/crates/polars-plan/src/logical_plan/conversion/type_coercion/mod.rs b/crates/polars-plan/src/logical_plan/conversion/type_coercion/mod.rs index b86fb13f2254..ccaa3ba64e3a 100644 --- a/crates/polars-plan/src/logical_plan/conversion/type_coercion/mod.rs +++ b/crates/polars-plan/src/logical_plan/conversion/type_coercion/mod.rs @@ -4,6 +4,7 @@ use std::borrow::Cow; use arrow::legacy::utils::CustomIterTools; use binary::process_binary; +use polars_core::chunked_array::cast::CastOptions; use polars_core::prelude::*; use polars_core::utils::{get_supertype, materialize_dyn_int}; use polars_utils::idx_vec::UnitVec; @@ -120,11 +121,18 @@ impl OptimizationRule for TypeCoercionRule { AExpr::Cast { expr, ref data_type, - ref strict, + options, } => { let input = expr_arena.get(expr); - inline_or_prune_cast(input, data_type, *strict, lp_node, lp_arena, expr_arena)? + inline_or_prune_cast( + input, + data_type, + options.strict(), + lp_node, + lp_arena, + expr_arena, + )? }, AExpr::Ternary { truthy: truthy_node, @@ -151,7 +159,7 @@ impl OptimizationRule for TypeCoercionRule { expr_arena.add(AExpr::Cast { expr: truthy_node, data_type: st.clone(), - strict: true, + options: CastOptions::Strict, }) } else { truthy_node @@ -161,7 +169,7 @@ impl OptimizationRule for TypeCoercionRule { expr_arena.add(AExpr::Cast { expr: falsy_node, data_type: st, - strict: true, + options: CastOptions::Strict, }) } else { falsy_node @@ -206,7 +214,7 @@ impl OptimizationRule for TypeCoercionRule { (_, DataType::Null) => AExpr::Cast { expr: other_e.node(), data_type: type_left, - strict: false, + options: CastOptions::NonStrict, }, #[cfg(feature = "dtype-categorical")] (DataType::Categorical(_, _) | DataType::Enum(_, _), DataType::String) => { @@ -308,7 +316,7 @@ impl OptimizationRule for TypeCoercionRule { expr_arena.add(AExpr::Cast { expr: left_node, data_type: super_type.clone(), - strict: false, + options: CastOptions::NonStrict, }) } else { left_node @@ -318,7 +326,7 @@ impl OptimizationRule for TypeCoercionRule { expr_arena.add(AExpr::Cast { expr: fill_value_node, data_type: super_type.clone(), - strict: false, + options: CastOptions::NonStrict, }) } else { fill_value_node @@ -414,7 +422,7 @@ impl OptimizationRule for TypeCoercionRule { let n = expr_arena.add(AExpr::Cast { expr: e.node(), data_type: super_type.clone(), - strict: false, + options: CastOptions::NonStrict, }); e.set_node(n); } diff --git a/crates/polars-plan/src/logical_plan/format.rs b/crates/polars-plan/src/logical_plan/format.rs index 1e62cf112b07..980fbc883756 100644 --- a/crates/polars-plan/src/logical_plan/format.rs +++ b/crates/polars-plan/src/logical_plan/format.rs @@ -125,9 +125,9 @@ impl fmt::Debug for Expr { Cast { expr, data_type, - strict, + options, } => { - if *strict { + if options.strict() { write!(f, "{expr:?}.strict_cast({data_type:?})") } else { write!(f, "{expr:?}.cast({data_type:?})") diff --git a/crates/polars-plan/src/logical_plan/optimizer/simplify_functions.rs b/crates/polars-plan/src/logical_plan/optimizer/simplify_functions.rs index 47a32b5fb79a..396af32afef3 100644 --- a/crates/polars-plan/src/logical_plan/optimizer/simplify_functions.rs +++ b/crates/polars-plan/src/logical_plan/optimizer/simplify_functions.rs @@ -1,3 +1,5 @@ +use polars_core::chunked_array::cast::CastOptions; + use super::*; pub(super) fn optimize_functions( @@ -70,7 +72,7 @@ pub(super) fn optimize_functions( Some(AExpr::Cast { expr: input[0].node(), data_type: DataType::Boolean, - strict: false, + options: CastOptions::NonStrict, }) } else { None diff --git a/crates/polars-plan/src/logical_plan/visitor/expr.rs b/crates/polars-plan/src/logical_plan/visitor/expr.rs index 72b41efc51ef..db6752dead9a 100644 --- a/crates/polars-plan/src/logical_plan/visitor/expr.rs +++ b/crates/polars-plan/src/logical_plan/visitor/expr.rs @@ -51,7 +51,7 @@ impl TreeWalker for Expr { BinaryExpr { left, op, right } => { BinaryExpr { left: am(left, &mut f)? , op, right: am(right, f)?} }, - Cast { expr, data_type, strict } => Cast { expr: am(expr, f)?, data_type, strict }, + Cast { expr, data_type, options: strict } => Cast { expr: am(expr, f)?, data_type, options: strict }, Sort { expr, options } => Sort { expr: am(expr, f)?, options }, Gather { expr, idx, returns_scalar } => Gather { expr: am(expr, &mut f)?, idx: am(idx, f)?, returns_scalar }, SortBy { expr, by, sort_options } => SortBy { expr: am(expr, &mut f)?, by: by.into_iter().map(f).collect::>()?, sort_options }, @@ -165,12 +165,12 @@ impl AExpr { (Window { options: l, .. }, Window { options: r, .. }) => l == r, ( Cast { - strict: strict_l, + options: strict_l, data_type: dtl, .. }, Cast { - strict: strict_r, + options: strict_r, data_type: dtr, .. }, diff --git a/crates/polars-time/src/chunkedarray/datetime.rs b/crates/polars-time/src/chunkedarray/datetime.rs index f0111be9047f..7a2bf97d88e7 100644 --- a/crates/polars-time/src/chunkedarray/datetime.rs +++ b/crates/polars-time/src/chunkedarray/datetime.rs @@ -1,5 +1,5 @@ use arrow::array::{Array, PrimitiveArray}; -use arrow::compute::cast::{cast, CastOptions}; +use arrow::compute::cast::{cast, CastOptionsImpl}; use arrow::compute::temporal; use polars_core::prelude::*; @@ -17,7 +17,7 @@ fn cast_and_apply< let arr = cast( arr, &dtype, - CastOptions { + CastOptionsImpl { wrapped: true, partial: false, }, diff --git a/py-polars/polars/_utils/construction/dataframe.py b/py-polars/polars/_utils/construction/dataframe.py index b105c371c8fa..f05cbb8b9412 100644 --- a/py-polars/polars/_utils/construction/dataframe.py +++ b/py-polars/polars/_utils/construction/dataframe.py @@ -629,7 +629,7 @@ def _sequence_of_series_to_pydf( s = s.alias(column_names[i]) new_dtype = schema_overrides.get(column_names[i]) if new_dtype and new_dtype != s.dtype: - s = s.cast(new_dtype, strict=strict) + s = s.cast(new_dtype, strict=strict, allow_overflow=False) data_series.append(s._s) data_series = _handle_columns_arg(data_series, columns=column_names) @@ -766,7 +766,7 @@ def _sequence_of_pandas_to_pydf( pyseries = plc.pandas_to_pyseries(name=name, values=s) dtype = schema_overrides.get(name) if dtype is not None and dtype != pyseries.dtype(): - pyseries = pyseries.cast(dtype, strict=strict) + pyseries = pyseries.cast(dtype, strict=strict, allow_overflow=False) data_series.append(pyseries) return PyDataFrame(data_series) @@ -1353,7 +1353,9 @@ def series_to_pydf( if schema_overrides: new_dtype = next(iter(schema_overrides.values())) if new_dtype != data.dtype: - data_series[0] = data_series[0].cast(new_dtype, strict=strict) + data_series[0] = data_series[0].cast( + new_dtype, strict=strict, allow_overflow=False + ) data_series = _handle_columns_arg(data_series, columns=column_names) return PyDataFrame(data_series) @@ -1378,7 +1380,9 @@ def dataframe_to_pydf( existing_schema = data.schema for name, new_dtype in schema_overrides.items(): if new_dtype != existing_schema[name]: - data_series[name] = data_series[name].cast(new_dtype, strict=strict) + data_series[name] = data_series[name].cast( + new_dtype, strict=strict, allow_overflow=False + ) series_cols = _handle_columns_arg(list(data_series.values()), columns=column_names) return PyDataFrame(series_cols) diff --git a/py-polars/polars/_utils/construction/series.py b/py-polars/polars/_utils/construction/series.py index 1a670846c3df..aee1da92abd3 100644 --- a/py-polars/polars/_utils/construction/series.py +++ b/py-polars/polars/_utils/construction/series.py @@ -150,7 +150,7 @@ def sequence_to_pyseries( Decimal, ): if pyseries.dtype() != dtype: - pyseries = pyseries.cast(dtype, strict=strict) + pyseries = pyseries.cast(dtype, strict=strict, allow_overflow=False) return pyseries elif dtype == Struct: @@ -289,7 +289,7 @@ def sequence_to_pyseries( name, values, dtype, strict=strict ) if dtype != pyseries.dtype(): - pyseries = pyseries.cast(dtype, strict=False) + pyseries = pyseries.cast(dtype, strict=False, allow_overflow=False) return pyseries elif python_dtype == pl.Series: @@ -308,7 +308,7 @@ def sequence_to_pyseries( np.bool_(True), np.generic ): dtype = numpy_char_code_to_dtype(np.dtype(python_dtype).char) - return srs.cast(dtype, strict=strict) + return srs.cast(dtype, strict=strict, allow_overflow=False) else: return srs @@ -480,7 +480,11 @@ def arrow_to_pyseries( if rechunk: pys.rechunk(in_place=True) - return pys.cast(dtype, strict=strict) if dtype is not None else pys + return ( + pys.cast(dtype, strict=strict, allow_overflow=False) + if dtype is not None + else pys + ) def numpy_to_pyseries( diff --git a/py-polars/polars/expr/expr.py b/py-polars/polars/expr/expr.py index 349deb89daca..908844ffa7a1 100644 --- a/py-polars/polars/expr/expr.py +++ b/py-polars/polars/expr/expr.py @@ -1703,7 +1703,13 @@ def mode(self) -> Self: """ return self._from_pyexpr(self._pyexpr.mode()) - def cast(self, dtype: PolarsDataType | type[Any], *, strict: bool = True) -> Self: + def cast( + self, + dtype: PolarsDataType | type[Any], + *, + strict: bool = True, + allow_overflow: bool = False, + ) -> Self: """ Cast between data types. @@ -1714,6 +1720,8 @@ def cast(self, dtype: PolarsDataType | type[Any], *, strict: bool = True) -> Sel strict Throw an error if a cast could not be done (for instance, due to an overflow). + allow_overflow + Don't check for numeric overflow and instead do a `wrapping` numeric cast. Examples -------- @@ -1739,7 +1747,7 @@ def cast(self, dtype: PolarsDataType | type[Any], *, strict: bool = True) -> Sel └─────┴─────┘ """ dtype = py_type_to_dtype(dtype) - return self._from_pyexpr(self._pyexpr.cast(dtype, strict)) + return self._from_pyexpr(self._pyexpr.cast(dtype, strict, allow_overflow)) def sort(self, *, descending: bool = False, nulls_last: bool = False) -> Self: """ diff --git a/py-polars/polars/series/series.py b/py-polars/polars/series/series.py index 9b481d947aec..8d9fc4bacbef 100644 --- a/py-polars/polars/series/series.py +++ b/py-polars/polars/series/series.py @@ -3878,6 +3878,7 @@ def cast( dtype: PolarsDataType | type[int] | type[float] | type[str] | type[bool], *, strict: bool = True, + allow_overflow: bool = False, ) -> Self: """ Cast between data types. @@ -3889,6 +3890,8 @@ def cast( strict Throw an error if a cast could not be done (for instance, due to an overflow). + allow_overflow + Don't check for numeric overflow and instead do a `wrapping` numeric cast. Examples -------- @@ -3913,7 +3916,7 @@ def cast( """ # Do not dispatch cast as it is expensive and used in other functions. dtype = py_type_to_dtype(dtype) - return self._from_pyseries(self._s.cast(dtype, strict)) + return self._from_pyseries(self._s.cast(dtype, strict, allow_overflow)) def to_physical(self) -> Series: """ diff --git a/py-polars/src/dataframe/export.rs b/py-polars/src/dataframe/export.rs index ac7dcc58d72f..1c6d9bfc7698 100644 --- a/py-polars/src/dataframe/export.rs +++ b/py-polars/src/dataframe/export.rs @@ -1,6 +1,6 @@ use polars::export::arrow::record_batch::RecordBatch; use polars_core::export::arrow::datatypes::IntegerType; -use polars_core::utils::arrow::compute::cast::CastOptions; +use polars_core::utils::arrow::compute::cast::CastOptionsImpl; use pyo3::prelude::*; use pyo3::types::{PyList, PyTuple}; @@ -116,7 +116,7 @@ impl PyDataFrame { Box::new(ArrowDataType::LargeUtf8), false, ), - CastOptions::default(), + CastOptionsImpl::default(), ) .unwrap(); *arr = out; diff --git a/py-polars/src/expr/general.rs b/py-polars/src/expr/general.rs index 57a43f1ff672..059df272829b 100644 --- a/py-polars/src/expr/general.rs +++ b/py-polars/src/expr/general.rs @@ -4,6 +4,7 @@ use std::ops::Neg; use polars::lazy::dsl; use polars::prelude::*; use polars::series::ops::NullBehavior; +use polars_core::chunked_array::cast::CastOptions; use polars_core::series::IsSorted; use pyo3::class::basic::CompareOp; use pyo3::prelude::*; @@ -259,13 +260,18 @@ impl PyExpr { fn null_count(&self) -> Self { self.inner.clone().null_count().into() } - fn cast(&self, data_type: Wrap, strict: bool) -> Self { + fn cast(&self, data_type: Wrap, strict: bool, allow_overflow: bool) -> Self { let dt = data_type.0; - let expr = if strict { - self.inner.clone().strict_cast(dt) + + let options = if allow_overflow { + CastOptions::Overflowing + } else if strict { + CastOptions::Strict } else { - self.inner.clone().cast(dt) + CastOptions::NonStrict }; + + let expr = self.inner.clone().cast_with_options(dt, options); expr.into() } fn sort_with(&self, descending: bool, nulls_last: bool) -> Self { diff --git a/py-polars/src/lazyframe/visitor/expr_nodes.rs b/py-polars/src/lazyframe/visitor/expr_nodes.rs index 116274d8138f..3ce8c70eccbc 100644 --- a/py-polars/src/lazyframe/visitor/expr_nodes.rs +++ b/py-polars/src/lazyframe/visitor/expr_nodes.rs @@ -270,8 +270,11 @@ pub struct Cast { expr: usize, #[pyo3(get)] dtype: PyObject, + // 0: strict + // 1: non-strict + // 2: overflow #[pyo3(get)] - strict: bool, + options: u8, } #[pyclass] @@ -580,11 +583,11 @@ pub(crate) fn into_py(py: Python<'_>, expr: &AExpr) -> PyResult { AExpr::Cast { expr, data_type, - strict, + options, } => Cast { expr: expr.0, dtype: Wrap(data_type.clone()).to_object(py), - strict: *strict, + options: *options as u8, } .into_py(py), AExpr::Sort { expr, options } => Sort { diff --git a/py-polars/src/series/mod.rs b/py-polars/src/series/mod.rs index 9205e4483ad6..d25c40d5d2d7 100644 --- a/py-polars/src/series/mod.rs +++ b/py-polars/src/series/mod.rs @@ -10,6 +10,7 @@ mod scatter; use std::io::Cursor; +use polars_core::chunked_array::cast::CastOptions; use polars_core::series::IsSorted; use polars_core::utils::flatten::flatten_series; use polars_core::{with_match_physical_numeric_polars_type, with_match_physical_numeric_type}; @@ -709,13 +710,17 @@ impl PySeries { Ok(out) } - fn cast(&self, dtype: Wrap, strict: bool) -> PyResult { - let dtype = dtype.0; - let out = if strict { - self.series.strict_cast(&dtype) + fn cast(&self, dtype: Wrap, strict: bool, allow_overflow: bool) -> PyResult { + let options = if allow_overflow { + CastOptions::Overflowing + } else if strict { + CastOptions::Strict } else { - self.series.cast(&dtype) + CastOptions::NonStrict }; + + let dtype = dtype.0; + let out = self.series.cast_with_options(&dtype, options); let out = out.map_err(PyPolarsErr::from)?; Ok(out.into()) }