diff --git a/crates/polars-core/src/chunked_array/logical/date.rs b/crates/polars-core/src/chunked_array/logical/date.rs index 4585d89b4af9..55df4f3d428c 100644 --- a/crates/polars-core/src/chunked_array/logical/date.rs +++ b/crates/polars-core/src/chunked_array/logical/date.rs @@ -20,11 +20,11 @@ impl LogicalType for DateChunked { } fn get_any_value(&self, i: usize) -> PolarsResult> { - self.0.get_any_value(i).map(|av| av.into_date()) + self.0.get_any_value(i).map(|av| av.as_date()) } unsafe fn get_any_value_unchecked(&self, i: usize) -> AnyValue<'_> { - self.0.get_any_value_unchecked(i).into_date() + self.0.get_any_value_unchecked(i).as_date() } fn cast(&self, dtype: &DataType) -> PolarsResult { diff --git a/crates/polars-core/src/chunked_array/logical/datetime.rs b/crates/polars-core/src/chunked_array/logical/datetime.rs index 337d18357f58..eef9e7e859cd 100644 --- a/crates/polars-core/src/chunked_array/logical/datetime.rs +++ b/crates/polars-core/src/chunked_array/logical/datetime.rs @@ -19,13 +19,13 @@ impl LogicalType for DatetimeChunked { fn get_any_value(&self, i: usize) -> PolarsResult> { self.0 .get_any_value(i) - .map(|av| av.into_datetime(self.time_unit(), self.time_zone())) + .map(|av| av.as_datetime(self.time_unit(), self.time_zone())) } unsafe fn get_any_value_unchecked(&self, i: usize) -> AnyValue<'_> { self.0 .get_any_value_unchecked(i) - .into_datetime(self.time_unit(), self.time_zone()) + .as_datetime(self.time_unit(), self.time_zone()) } fn cast(&self, dtype: &DataType) -> PolarsResult { diff --git a/crates/polars-core/src/chunked_array/logical/duration.rs b/crates/polars-core/src/chunked_array/logical/duration.rs index 64ef1620c3c0..63546969df79 100644 --- a/crates/polars-core/src/chunked_array/logical/duration.rs +++ b/crates/polars-core/src/chunked_array/logical/duration.rs @@ -19,12 +19,12 @@ impl LogicalType for DurationChunked { fn get_any_value(&self, i: usize) -> PolarsResult> { self.0 .get_any_value(i) - .map(|av| av.into_duration(self.time_unit())) + .map(|av| av.as_duration(self.time_unit())) } unsafe fn get_any_value_unchecked(&self, i: usize) -> AnyValue<'_> { self.0 .get_any_value_unchecked(i) - .into_duration(self.time_unit()) + .as_duration(self.time_unit()) } fn cast(&self, dtype: &DataType) -> PolarsResult { diff --git a/crates/polars-core/src/chunked_array/logical/time.rs b/crates/polars-core/src/chunked_array/logical/time.rs index 8dd4c6239ae9..3c546ef64ab5 100644 --- a/crates/polars-core/src/chunked_array/logical/time.rs +++ b/crates/polars-core/src/chunked_array/logical/time.rs @@ -22,10 +22,10 @@ impl LogicalType for TimeChunked { #[cfg(feature = "dtype-time")] fn get_any_value(&self, i: usize) -> PolarsResult> { - self.0.get_any_value(i).map(|av| av.into_time()) + self.0.get_any_value(i).map(|av| av.as_time()) } unsafe fn get_any_value_unchecked(&self, i: usize) -> AnyValue<'_> { - self.0.get_any_value_unchecked(i).into_time() + self.0.get_any_value_unchecked(i).as_time() } fn cast(&self, dtype: &DataType) -> PolarsResult { diff --git a/crates/polars-core/src/chunked_array/ops/aggregate/mod.rs b/crates/polars-core/src/chunked_array/ops/aggregate/mod.rs index aa671546cbe3..36c65a14d1cb 100644 --- a/crates/polars-core/src/chunked_array/ops/aggregate/mod.rs +++ b/crates/polars-core/src/chunked_array/ops/aggregate/mod.rs @@ -26,7 +26,7 @@ use crate::series::IsSorted; /// Aggregations that return [`Series`] of unit length. Those can be used in broadcasting operations. pub trait ChunkAggSeries { /// Get the sum of the [`ChunkedArray`] as a new [`Series`] of length 1. - fn sum_as_series(&self) -> Series { + fn sum_as_series(&self) -> Scalar { unimplemented!() } /// Get the max of the [`ChunkedArray`] as a new [`Series`] of length 1. @@ -263,11 +263,9 @@ where Add::Simd> + compute::aggregate::Sum, ChunkedArray: IntoSeries, { - fn sum_as_series(&self) -> Series { - let v = self.sum(); - let mut ca: ChunkedArray = [v].iter().copied().collect(); - ca.rename(self.name()); - ca.into_series() + fn sum_as_series(&self) -> Scalar { + let v: Option = self.sum(); + Scalar::new(T::get_dtype().clone(), v.into()) } fn max_as_series(&self) -> Series { @@ -395,9 +393,9 @@ impl QuantileAggSeries for Float64Chunked { } impl ChunkAggSeries for BooleanChunked { - fn sum_as_series(&self) -> Series { + fn sum_as_series(&self) -> Scalar { let v = self.sum(); - Series::new(self.name(), [v]) + Scalar::new(IDX_DTYPE, v.into()) } fn max_as_series(&self) -> Series { let v = self.max(); @@ -459,8 +457,8 @@ impl StringChunked { } impl ChunkAggSeries for StringChunked { - fn sum_as_series(&self) -> Series { - StringChunked::full_null(self.name(), 1).into_series() + fn sum_as_series(&self) -> Scalar { + Scalar::new(DataType::String, AnyValue::Null) } fn max_as_series(&self) -> Series { Series::new(self.name(), &[self.max_str()]) @@ -590,7 +588,7 @@ impl BinaryChunked { } impl ChunkAggSeries for BinaryChunked { - fn sum_as_series(&self) -> Series { + fn sum_as_series(&self) -> Scalar { unimplemented!() } fn max_as_series(&self) -> Series { diff --git a/crates/polars-core/src/datatypes/any_value.rs b/crates/polars-core/src/datatypes/any_value.rs index c9f06add1d48..f9059c74f6c4 100644 --- a/crates/polars-core/src/datatypes/any_value.rs +++ b/crates/polars-core/src/datatypes/any_value.rs @@ -9,6 +9,25 @@ use polars_utils::sync::SyncPtr; use polars_utils::total_ord::ToTotalOrd; use polars_utils::unwrap::UnwrapUncheckedRelease; +pub struct Scalar { + dtype: DataType, + value: AnyValue<'static>, +} + +impl Scalar { + pub fn new(dtype: DataType, value: AnyValue<'static>) -> Self { + Self { dtype, value } + } + + pub fn value(&self) -> &AnyValue { + &self.value + } + + pub fn into_series(self, name: &str) -> Series { + Series::from_any_values_and_dtype(name, &[self.value], &self.dtype, true).unwrap() + } +} + use super::*; #[cfg(feature = "dtype-struct")] use crate::prelude::any_value::arr_to_any_value; @@ -338,6 +357,23 @@ impl<'a> Deserialize<'a> for AnyValue<'static> { } } +impl AnyValue<'static> { + pub fn zero(dtype: &DataType) -> Self { + match dtype { + DataType::String => AnyValue::StringOwned("".into()), + DataType::Boolean => AnyValue::Boolean(false), + // SAFETY: + // Numeric values are static, inform the compiler of this. + d if d.is_numeric() => unsafe { + std::mem::transmute::, AnyValue<'static>>( + AnyValue::UInt8(0).cast(dtype), + ) + }, + _ => AnyValue::Null, + } + } +} + impl<'a> AnyValue<'a> { /// Get the matching [`DataType`] for this [`AnyValue`]`. /// @@ -735,43 +771,43 @@ where impl<'a> AnyValue<'a> { #[cfg(any(feature = "dtype-date", feature = "dtype-datetime"))] - pub(crate) fn into_date(self) -> Self { + pub(crate) fn as_date(&self) -> AnyValue<'static> { match self { #[cfg(feature = "dtype-date")] - AnyValue::Int32(v) => AnyValue::Date(v), + AnyValue::Int32(v) => AnyValue::Date(*v), AnyValue::Null => AnyValue::Null, dt => panic!("cannot create date from other type. dtype: {dt}"), } } #[cfg(feature = "dtype-datetime")] - pub(crate) fn into_datetime(self, tu: TimeUnit, tz: &'a Option) -> Self { + pub(crate) fn as_datetime(&self, tu: TimeUnit, tz: &'a Option) -> AnyValue<'a> { match self { - AnyValue::Int64(v) => AnyValue::Datetime(v, tu, tz), + AnyValue::Int64(v) => AnyValue::Datetime(*v, tu, tz), AnyValue::Null => AnyValue::Null, dt => panic!("cannot create date from other type. dtype: {dt}"), } } #[cfg(feature = "dtype-duration")] - pub(crate) fn into_duration(self, tu: TimeUnit) -> Self { + pub(crate) fn as_duration(&self, tu: TimeUnit) -> AnyValue<'static> { match self { - AnyValue::Int64(v) => AnyValue::Duration(v, tu), + AnyValue::Int64(v) => AnyValue::Duration(*v, tu), AnyValue::Null => AnyValue::Null, dt => panic!("cannot create date from other type. dtype: {dt}"), } } #[cfg(feature = "dtype-time")] - pub(crate) fn into_time(self) -> Self { + pub(crate) fn as_time(&self) -> AnyValue<'static> { match self { - AnyValue::Int64(v) => AnyValue::Time(v), + AnyValue::Int64(v) => AnyValue::Time(*v), AnyValue::Null => AnyValue::Null, dt => panic!("cannot create date from other type. dtype: {dt}"), } } #[must_use] - pub fn add(&self, rhs: &AnyValue) -> Self { + pub fn add(&self, rhs: &AnyValue) -> AnyValue<'static> { use AnyValue::*; match (self, rhs) { (Null, _) => Null, diff --git a/crates/polars-core/src/series/implementations/boolean.rs b/crates/polars-core/src/series/implementations/boolean.rs index 08401055ff57..63f61710fa85 100644 --- a/crates/polars-core/src/series/implementations/boolean.rs +++ b/crates/polars-core/src/series/implementations/boolean.rs @@ -253,7 +253,7 @@ impl SeriesTrait for SeriesWrap { ChunkShift::shift(&self.0, periods).into_series() } - fn _sum_as_series(&self) -> PolarsResult { + fn _sum_as_series(&self) -> PolarsResult { Ok(ChunkAggSeries::sum_as_series(&self.0)) } fn max_as_series(&self) -> PolarsResult { diff --git a/crates/polars-core/src/series/implementations/decimal.rs b/crates/polars-core/src/series/implementations/decimal.rs index 59163ab5cbd1..f1d00c60db24 100644 --- a/crates/polars-core/src/series/implementations/decimal.rs +++ b/crates/polars-core/src/series/implementations/decimal.rs @@ -10,12 +10,16 @@ unsafe impl IntoSeries for DecimalChunked { impl private::PrivateSeriesNumeric for SeriesWrap {} impl SeriesWrap { - fn apply_physical Int128Chunked>(&self, f: F) -> Series { + fn apply_physical_to_s Int128Chunked>(&self, f: F) -> Series { f(&self.0) .into_decimal_unchecked(self.0.precision(), self.0.scale()) .into_series() } + fn apply_physical T>(&self, f: F) -> T { + f(&self.0) + } + fn agg_helper Series>(&self, f: F) -> Series { let agg_s = f(&self.0); match agg_s.dtype() { @@ -187,7 +191,7 @@ impl SeriesTrait for SeriesWrap { } fn slice(&self, offset: i64, length: usize) -> Series { - self.apply_physical(|ca| ca.slice(offset, length)) + self.apply_physical_to_s(|ca| ca.slice(offset, length)) } fn append(&mut self, other: &Series) -> PolarsResult<()> { @@ -301,31 +305,38 @@ impl SeriesTrait for SeriesWrap { } fn reverse(&self) -> Series { - self.apply_physical(|ca| ca.reverse()) + self.apply_physical_to_s(|ca| ca.reverse()) } fn shift(&self, periods: i64) -> Series { - self.apply_physical(|ca| ca.shift(periods)) + self.apply_physical_to_s(|ca| ca.shift(periods)) } fn clone_inner(&self) -> Arc { Arc::new(SeriesWrap(Clone::clone(&self.0))) } - fn _sum_as_series(&self) -> PolarsResult { + fn _sum_as_series(&self) -> PolarsResult { Ok(self.apply_physical(|ca| { let sum = ca.sum(); - Int128Chunked::from_slice_options(self.name(), &[sum]) + let DataType::Decimal(_, Some(scale)) = self.dtype() else { + unreachable!() + }; + let av = match sum { + None => AnyValue::Null, + Some(v) => AnyValue::Decimal(v, *scale), + }; + Scalar::new(self.dtype().clone(), av) })) } fn min_as_series(&self) -> PolarsResult { - Ok(self.apply_physical(|ca| { + Ok(self.apply_physical_to_s(|ca| { let min = ca.min(); Int128Chunked::from_slice_options(self.name(), &[min]) })) } fn max_as_series(&self) -> PolarsResult { - Ok(self.apply_physical(|ca| { + Ok(self.apply_physical_to_s(|ca| { let max = ca.max(); Int128Chunked::from_slice_options(self.name(), &[max]) })) diff --git a/crates/polars-core/src/series/implementations/duration.rs b/crates/polars-core/src/series/implementations/duration.rs index 7249a3f92950..f3c6237a01b7 100644 --- a/crates/polars-core/src/series/implementations/duration.rs +++ b/crates/polars-core/src/series/implementations/duration.rs @@ -393,8 +393,10 @@ impl SeriesTrait for SeriesWrap { .into_series() } - fn _sum_as_series(&self) -> PolarsResult { - Ok(self.0.sum_as_series().into_duration(self.0.time_unit())) + fn _sum_as_series(&self) -> PolarsResult { + let sc = self.0.sum_as_series(); + let v = sc.value().as_duration(self.0.time_unit()); + Ok(Scalar::new(self.dtype().clone(), v)) } fn max_as_series(&self) -> PolarsResult { diff --git a/crates/polars-core/src/series/implementations/floats.rs b/crates/polars-core/src/series/implementations/floats.rs index 2c21aec09a63..72bc27a94101 100644 --- a/crates/polars-core/src/series/implementations/floats.rs +++ b/crates/polars-core/src/series/implementations/floats.rs @@ -284,7 +284,7 @@ macro_rules! impl_dyn_series { ChunkShift::shift(&self.0, periods).into_series() } - fn _sum_as_series(&self) -> PolarsResult { + fn _sum_as_series(&self) -> PolarsResult { Ok(ChunkAggSeries::sum_as_series(&self.0)) } fn max_as_series(&self) -> PolarsResult { diff --git a/crates/polars-core/src/series/implementations/mod.rs b/crates/polars-core/src/series/implementations/mod.rs index d05797465617..ef36dd763676 100644 --- a/crates/polars-core/src/series/implementations/mod.rs +++ b/crates/polars-core/src/series/implementations/mod.rs @@ -387,7 +387,7 @@ macro_rules! impl_dyn_series { ChunkShift::shift(&self.0, periods).into_series() } - fn _sum_as_series(&self) -> PolarsResult { + fn _sum_as_series(&self) -> PolarsResult { Ok(ChunkAggSeries::sum_as_series(&self.0)) } fn max_as_series(&self) -> PolarsResult { diff --git a/crates/polars-core/src/series/implementations/string.rs b/crates/polars-core/src/series/implementations/string.rs index b6137caca3b0..77ced4ef3c18 100644 --- a/crates/polars-core/src/series/implementations/string.rs +++ b/crates/polars-core/src/series/implementations/string.rs @@ -235,7 +235,7 @@ impl SeriesTrait for SeriesWrap { ChunkShift::shift(&self.0, periods).into_series() } - fn _sum_as_series(&self) -> PolarsResult { + fn _sum_as_series(&self) -> PolarsResult { Ok(ChunkAggSeries::sum_as_series(&self.0)) } fn max_as_series(&self) -> PolarsResult { diff --git a/crates/polars-core/src/series/mod.rs b/crates/polars-core/src/series/mod.rs index bc73d230f9de..43af08e02110 100644 --- a/crates/polars-core/src/series/mod.rs +++ b/crates/polars-core/src/series/mod.rs @@ -391,8 +391,9 @@ impl Series { where T: NumCast, { - let sum = self.sum_as_series()?.cast(&DataType::Float64)?; - Ok(T::from(sum.f64().unwrap().get(0).unwrap()).unwrap()) + let sum = self.sum_as_series()?; + let sum = sum.value().cast(&DataType::Float64); + Ok(T::from(sum.extract::().unwrap()).unwrap()) } /// Returns the minimum value in the array, according to the natural order. @@ -628,7 +629,7 @@ impl Series { /// /// If the [`DataType`] is one of `{Int8, UInt8, Int16, UInt16}` the `Series` is /// first cast to `Int64` to prevent overflow issues. - pub fn sum_as_series(&self) -> PolarsResult { + pub fn sum_as_series(&self) -> PolarsResult { use DataType::*; match self.dtype() { Int8 | UInt8 | Int16 | UInt16 => self.cast(&Int64).unwrap().sum_as_series(), diff --git a/crates/polars-core/src/series/series_trait.rs b/crates/polars-core/src/series/series_trait.rs index 583eeac2db11..3ea809265deb 100644 --- a/crates/polars-core/src/series/series_trait.rs +++ b/crates/polars-core/src/series/series_trait.rs @@ -417,7 +417,7 @@ pub trait SeriesTrait: /// /// If the [`DataType`] is one of `{Int8, UInt8, Int16, UInt16}` the `Series` is /// first cast to `Int64` to prevent overflow issues. - fn _sum_as_series(&self) -> PolarsResult { + fn _sum_as_series(&self) -> PolarsResult { polars_bail!(opq = sum, self._dtype()); } /// Get the max of the Series as a new Series of length 1. diff --git a/crates/polars-lazy/src/physical_plan/planner/expr.rs b/crates/polars-lazy/src/physical_plan/planner/expr.rs index 1e39c5898f9a..1815f0b25355 100644 --- a/crates/polars-lazy/src/physical_plan/planner/expr.rs +++ b/crates/polars-lazy/src/physical_plan/planner/expr.rs @@ -476,7 +476,12 @@ fn create_physical_expr_inner( let state = *state; SpecialEq::new(Arc::new(move |s: &mut [Series]| { let s = std::mem::take(&mut s[0]); - parallel_op_series(|s| s.sum_as_series(), s, None, state) + parallel_op_series( + |s| s.sum_as_series().map(|sc| sc.into_series(s.name())), + s, + None, + state, + ) }) as Arc) }, AAggExpr::Count(_, include_nulls) => { diff --git a/crates/polars-lazy/src/physical_plan/planner/lp.rs b/crates/polars-lazy/src/physical_plan/planner/lp.rs index 45acfaa1bb4e..66da47300825 100644 --- a/crates/polars-lazy/src/physical_plan/planner/lp.rs +++ b/crates/polars-lazy/src/physical_plan/planner/lp.rs @@ -348,6 +348,20 @@ fn create_physical_plan_impl( streamable, })) }, + Reduce { + exprs, + input, + schema, + } => { + let select = Select { + input, + expr: exprs.into(), + schema, + options: Default::default(), + }; + let node = lp_arena.add(select); + create_physical_plan(node, lp_arena, expr_arena) + }, DataFrameScan { df, projection, diff --git a/crates/polars-ops/src/chunked_array/list/sum_mean.rs b/crates/polars-ops/src/chunked_array/list/sum_mean.rs index 972d454364b0..9387f3e6702f 100644 --- a/crates/polars-ops/src/chunked_array/list/sum_mean.rs +++ b/crates/polars-ops/src/chunked_array/list/sum_mean.rs @@ -105,7 +105,7 @@ pub(super) fn sum_with_nulls(ca: &ListChunked, inner_dtype: &DataType) -> Polars }, // slowest sum_as_series path _ => ca - .try_apply_amortized(|s| s.as_ref().sum_as_series())? + .try_apply_amortized(|s| s.as_ref().sum_as_series().map(|sc| sc.into_series("")))? .explode() .unwrap() .into_series(), diff --git a/crates/polars-ops/src/series/ops/log.rs b/crates/polars-ops/src/series/ops/log.rs index d231066008ae..a63017eec009 100644 --- a/crates/polars-ops/src/series/ops/log.rs +++ b/crates/polars-ops/src/series/ops/log.rs @@ -92,7 +92,7 @@ pub trait LogSeries: SeriesSealed { let pk = s.as_ref(); let pk = if normalize { - let sum = pk.sum_as_series().unwrap(); + let sum = pk.sum_as_series().unwrap().into_series(""); if sum.get(0).unwrap().extract::().unwrap() != 1.0 { pk / &sum diff --git a/crates/polars-plan/src/lib.rs b/crates/polars-plan/src/lib.rs index 5cad9759d823..bdb073b70428 100644 --- a/crates/polars-plan/src/lib.rs +++ b/crates/polars-plan/src/lib.rs @@ -10,4 +10,5 @@ pub mod frame; pub mod global; pub mod logical_plan; pub mod prelude; +mod reduce; pub mod utils; diff --git a/crates/polars-plan/src/logical_plan/alp/dot.rs b/crates/polars-plan/src/logical_plan/alp/dot.rs index 62b162349c1a..a0692b7ef9d6 100644 --- a/crates/polars-plan/src/logical_plan/alp/dot.rs +++ b/crates/polars-plan/src/logical_plan/alp/dot.rs @@ -155,6 +155,11 @@ impl<'a> IRDotDisplay<'a> { self.with_root(*input)._format(f, Some(id), last)?; write_label(f, id, |f| write!(f, "WITH COLUMNS {exprs}"))?; }, + Reduce { input, exprs, .. } => { + let exprs = self.display_exprs(exprs); + self.with_root(*input)._format(f, Some(id), last)?; + write_label(f, id, |f| write!(f, "REDUCE {exprs}"))?; + }, Slice { input, offset, len } => { self.with_root(*input)._format(f, Some(id), last)?; write_label(f, id, |f| write!(f, "SLICE offset: {offset}; len: {len}"))?; diff --git a/crates/polars-plan/src/logical_plan/alp/format.rs b/crates/polars-plan/src/logical_plan/alp/format.rs index 95bfbc92f3ef..01c9d3eec49b 100644 --- a/crates/polars-plan/src/logical_plan/alp/format.rs +++ b/crates/polars-plan/src/logical_plan/alp/format.rs @@ -220,6 +220,13 @@ impl<'a> IRDisplay<'a> { selection, ) }, + Reduce { input, exprs, .. } => { + // @NOTE: Maybe there should be a clear delimiter here? + let default_exprs = self.display_expr_slice(exprs); + + write!(f, "{:indent$} REDUCE {default_exprs} FROM", "")?; + self.with_root(*input)._format(f, sub_indent) + }, Select { expr, input, .. } => { // @NOTE: Maybe there should be a clear delimiter here? let default_exprs = self.display_expr_slice(expr.default_exprs()); diff --git a/crates/polars-plan/src/logical_plan/alp/inputs.rs b/crates/polars-plan/src/logical_plan/alp/inputs.rs index c6e0c1c725a0..c389097e2ba9 100644 --- a/crates/polars-plan/src/logical_plan/alp/inputs.rs +++ b/crates/polars-plan/src/logical_plan/alp/inputs.rs @@ -31,6 +31,11 @@ impl IR { input: inputs[0], predicate: exprs.pop().unwrap(), }, + Reduce { schema, .. } => Reduce { + input: inputs[0], + exprs, + schema: schema.clone(), + }, Select { schema, options, .. } => Select { @@ -165,6 +170,7 @@ impl IR { Slice { .. } | Cache { .. } | Distinct { .. } | Union { .. } | MapFunction { .. } => {}, Sort { by_column, .. } => container.extend_from_slice(by_column), Filter { predicate, .. } => container.push(predicate.clone()), + Reduce { exprs, .. } => container.extend_from_slice(exprs), Select { expr, .. } => container.extend_from_slice(expr), GroupBy { keys, aggs, .. } => { let iter = keys.iter().cloned().chain(aggs.iter().cloned()); @@ -226,6 +232,7 @@ impl IR { Slice { input, .. } => *input, Filter { input, .. } => *input, Select { input, .. } => *input, + Reduce { input, .. } => *input, SimpleProjection { input, .. } => *input, Sort { input, .. } => *input, Cache { input, .. } => *input, diff --git a/crates/polars-plan/src/logical_plan/alp/mod.rs b/crates/polars-plan/src/logical_plan/alp/mod.rs index dd55f598a944..f29b991adaf4 100644 --- a/crates/polars-plan/src/logical_plan/alp/mod.rs +++ b/crates/polars-plan/src/logical_plan/alp/mod.rs @@ -72,6 +72,12 @@ pub enum IR { input: Node, columns: SchemaRef, }, + // Special case of `select` where all operations reduce to a single row. + Reduce { + input: Node, + exprs: Vec, + schema: SchemaRef, + }, // Polars' `select` operation. This may access full materialized data. Select { input: Node, diff --git a/crates/polars-plan/src/logical_plan/alp/schema.rs b/crates/polars-plan/src/logical_plan/alp/schema.rs index db4d77b61b03..6047fe6d5943 100644 --- a/crates/polars-plan/src/logical_plan/alp/schema.rs +++ b/crates/polars-plan/src/logical_plan/alp/schema.rs @@ -23,6 +23,7 @@ impl IR { Filter { .. } => "selection", DataFrameScan { .. } => "df", Select { .. } => "projection", + Reduce { .. } => "reduce", Sort { .. } => "sort", Cache { .. } => "cache", GroupBy { .. } => "aggregate", @@ -81,6 +82,7 @@ impl IR { } => output_schema.as_ref().unwrap_or(schema), Filter { input, .. } => return arena.get(*input).schema(arena), Select { schema, .. } => schema, + Reduce { schema, .. } => schema, SimpleProjection { columns, .. } => columns, GroupBy { schema, .. } => schema, Join { schema, .. } => schema, diff --git a/crates/polars-plan/src/logical_plan/alp/tree_format.rs b/crates/polars-plan/src/logical_plan/alp/tree_format.rs index 244280b24eb2..7337a6c33201 100644 --- a/crates/polars-plan/src/logical_plan/alp/tree_format.rs +++ b/crates/polars-plan/src/logical_plan/alp/tree_format.rs @@ -304,6 +304,14 @@ impl<'a> TreeFmtNode<'a> { .chain([self.lp_node(None, *input)]) .collect(), ), + Reduce { input, exprs, .. } => ND( + wh(h, "REDUCE"), + exprs + .iter() + .map(|expr| self.expr_node(Some("expression:".to_string()), expr)) + .chain([self.lp_node(None, *input)]) + .collect(), + ), Distinct { input, options } => ND( wh( h, diff --git a/crates/polars-plan/src/logical_plan/conversion/mod.rs b/crates/polars-plan/src/logical_plan/conversion/mod.rs index 230c2fd2b4e3..36a2021b5f37 100644 --- a/crates/polars-plan/src/logical_plan/conversion/mod.rs +++ b/crates/polars-plan/src/logical_plan/conversion/mod.rs @@ -119,6 +119,15 @@ impl IR { options, } }, + IR::Reduce { exprs, input, .. } => { + let i = convert_to_lp(input, lp_arena); + let expr = expr_irs_to_exprs(exprs, expr_arena); + DslPlan::Select { + expr, + input: Arc::new(i), + options: Default::default(), + } + }, IR::SimpleProjection { input, columns } => { let input = convert_to_lp(input, lp_arena); let expr = columns diff --git a/crates/polars-plan/src/logical_plan/optimizer/predicate_pushdown/mod.rs b/crates/polars-plan/src/logical_plan/optimizer/predicate_pushdown/mod.rs index d6ec2e01ca64..37328471b59a 100644 --- a/crates/polars-plan/src/logical_plan/optimizer/predicate_pushdown/mod.rs +++ b/crates/polars-plan/src/logical_plan/optimizer/predicate_pushdown/mod.rs @@ -622,6 +622,7 @@ impl<'a> PredicatePushDown<'a> { }, lp @ HStack { .. } | lp @ Select { .. } + | lp @ Reduce { .. } | lp @ SimpleProjection { .. } | lp @ ExtContext { .. } => { self.pushdown_and_continue(lp, acc_predicates, lp_arena, expr_arena, true) diff --git a/crates/polars-plan/src/logical_plan/optimizer/projection_pushdown/mod.rs b/crates/polars-plan/src/logical_plan/optimizer/projection_pushdown/mod.rs index 56a3f74a1b37..6744dd38986a 100644 --- a/crates/polars-plan/src/logical_plan/optimizer/projection_pushdown/mod.rs +++ b/crates/polars-plan/src/logical_plan/optimizer/projection_pushdown/mod.rs @@ -325,6 +325,8 @@ impl ProjectionPushDown { use IR::*; match logical_plan { + // Should not yet be here + Reduce { .. } => unreachable!(), Select { expr, input, .. } => process_projection( self, input, diff --git a/crates/polars-plan/src/logical_plan/visitor/hash.rs b/crates/polars-plan/src/logical_plan/visitor/hash.rs index ad84b3cc2229..ed59499f3c4f 100644 --- a/crates/polars-plan/src/logical_plan/visitor/hash.rs +++ b/crates/polars-plan/src/logical_plan/visitor/hash.rs @@ -110,6 +110,13 @@ impl Hash for HashableEqLP<'_> { hash_exprs(expr.default_exprs(), self.expr_arena, state); options.hash(state); }, + IR::Reduce { + input: _, + exprs, + schema: _, + } => { + hash_exprs(exprs, self.expr_arena, state); + }, IR::Sort { input: _, by_column, diff --git a/crates/polars-plan/src/reduce/mod.rs b/crates/polars-plan/src/reduce/mod.rs new file mode 100644 index 000000000000..3bdc7be55966 --- /dev/null +++ b/crates/polars-plan/src/reduce/mod.rs @@ -0,0 +1,13 @@ +use polars_core::datatypes::Scalar; +use polars_core::prelude::Series; + +#[allow(dead_code)] +trait Reduction { + fn init(&mut self); + + fn update(&mut self, batch: &Series); + + fn combine(&mut self, other: &dyn Reduction); + + fn finalize(&mut self) -> Scalar; +} diff --git a/crates/polars/tests/it/core/rolling_window.rs b/crates/polars/tests/it/core/rolling_window.rs index a374523e1454..7d1b069ccbbc 100644 --- a/crates/polars/tests/it/core/rolling_window.rs +++ b/crates/polars/tests/it/core/rolling_window.rs @@ -177,7 +177,7 @@ fn test_rolling_map() { let out = ca .rolling_map( - &|s| s.sum_as_series().unwrap(), + &|s| s.sum_as_series().unwrap().into_series(s.name()), RollingOptionsFixedWindow { window_size: 3, min_periods: 3, diff --git a/py-polars/src/lazyframe/visitor/nodes.rs b/py-polars/src/lazyframe/visitor/nodes.rs index ba37cc3f4fd8..74e15cbc2fdd 100644 --- a/py-polars/src/lazyframe/visitor/nodes.rs +++ b/py-polars/src/lazyframe/visitor/nodes.rs @@ -432,6 +432,17 @@ pub(crate) fn into_py(py: Python<'_>, plan: &IR) -> PyResult { options: (), } .into_py(py), + IR::Reduce { + input, + exprs, + schema: _, + } => Select { + input: input.0, + expr: exprs.iter().map(|e| e.into()).collect(), + cse_expr: vec![], + options: (), + } + .into_py(py), IR::Distinct { input, options } => Distinct { input: input.0, // TODO, rest of options diff --git a/py-polars/src/series/aggregation.rs b/py-polars/src/series/aggregation.rs index ed6c973b960a..626d3eb7cef5 100644 --- a/py-polars/src/series/aggregation.rs +++ b/py-polars/src/series/aggregation.rs @@ -151,6 +151,7 @@ impl PySeries { Ok(Wrap( self.series .sum_as_series() + .map(|sc| sc.into_series("")) .map_err(PyPolarsErr::from)? .get(0) .map_err(PyPolarsErr::from)?,