Skip to content

Commit

Permalink
refactor: Add IR::Reduce (not yet implemented) (pola-rs#16216)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 authored May 16, 2024
1 parent 00963a8 commit 709143d
Show file tree
Hide file tree
Showing 33 changed files with 197 additions and 50 deletions.
4 changes: 2 additions & 2 deletions crates/polars-core/src/chunked_array/logical/date.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@ impl LogicalType for DateChunked {
}

fn get_any_value(&self, i: usize) -> PolarsResult<AnyValue<'_>> {
self.0.get_any_value(i).map(|av| av.into_date())
self.0.get_any_value(i).map(|av| av.as_date())
}

unsafe fn get_any_value_unchecked(&self, i: usize) -> AnyValue<'_> {
self.0.get_any_value_unchecked(i).into_date()
self.0.get_any_value_unchecked(i).as_date()
}

fn cast(&self, dtype: &DataType) -> PolarsResult<Series> {
Expand Down
4 changes: 2 additions & 2 deletions crates/polars-core/src/chunked_array/logical/datetime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,13 @@ impl LogicalType for DatetimeChunked {
fn get_any_value(&self, i: usize) -> PolarsResult<AnyValue<'_>> {
self.0
.get_any_value(i)
.map(|av| av.into_datetime(self.time_unit(), self.time_zone()))
.map(|av| av.as_datetime(self.time_unit(), self.time_zone()))
}

unsafe fn get_any_value_unchecked(&self, i: usize) -> AnyValue<'_> {
self.0
.get_any_value_unchecked(i)
.into_datetime(self.time_unit(), self.time_zone())
.as_datetime(self.time_unit(), self.time_zone())
}

fn cast(&self, dtype: &DataType) -> PolarsResult<Series> {
Expand Down
4 changes: 2 additions & 2 deletions crates/polars-core/src/chunked_array/logical/duration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@ impl LogicalType for DurationChunked {
fn get_any_value(&self, i: usize) -> PolarsResult<AnyValue<'_>> {
self.0
.get_any_value(i)
.map(|av| av.into_duration(self.time_unit()))
.map(|av| av.as_duration(self.time_unit()))
}
unsafe fn get_any_value_unchecked(&self, i: usize) -> AnyValue<'_> {
self.0
.get_any_value_unchecked(i)
.into_duration(self.time_unit())
.as_duration(self.time_unit())
}

fn cast(&self, dtype: &DataType) -> PolarsResult<Series> {
Expand Down
4 changes: 2 additions & 2 deletions crates/polars-core/src/chunked_array/logical/time.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@ impl LogicalType for TimeChunked {

#[cfg(feature = "dtype-time")]
fn get_any_value(&self, i: usize) -> PolarsResult<AnyValue<'_>> {
self.0.get_any_value(i).map(|av| av.into_time())
self.0.get_any_value(i).map(|av| av.as_time())
}
unsafe fn get_any_value_unchecked(&self, i: usize) -> AnyValue<'_> {
self.0.get_any_value_unchecked(i).into_time()
self.0.get_any_value_unchecked(i).as_time()
}

fn cast(&self, dtype: &DataType) -> PolarsResult<Series> {
Expand Down
20 changes: 9 additions & 11 deletions crates/polars-core/src/chunked_array/ops/aggregate/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ use crate::series::IsSorted;
/// Aggregations that return [`Series`] of unit length. Those can be used in broadcasting operations.
pub trait ChunkAggSeries {
/// Get the sum of the [`ChunkedArray`] as a new [`Series`] of length 1.
fn sum_as_series(&self) -> Series {
fn sum_as_series(&self) -> Scalar {
unimplemented!()
}
/// Get the max of the [`ChunkedArray`] as a new [`Series`] of length 1.
Expand Down Expand Up @@ -263,11 +263,9 @@ where
Add<Output = <T::Native as Simd>::Simd> + compute::aggregate::Sum<T::Native>,
ChunkedArray<T>: IntoSeries,
{
fn sum_as_series(&self) -> Series {
let v = self.sum();
let mut ca: ChunkedArray<T> = [v].iter().copied().collect();
ca.rename(self.name());
ca.into_series()
fn sum_as_series(&self) -> Scalar {
let v: Option<T::Native> = self.sum();
Scalar::new(T::get_dtype().clone(), v.into())
}

fn max_as_series(&self) -> Series {
Expand Down Expand Up @@ -395,9 +393,9 @@ impl QuantileAggSeries for Float64Chunked {
}

impl ChunkAggSeries for BooleanChunked {
fn sum_as_series(&self) -> Series {
fn sum_as_series(&self) -> Scalar {
let v = self.sum();
Series::new(self.name(), [v])
Scalar::new(IDX_DTYPE, v.into())
}
fn max_as_series(&self) -> Series {
let v = self.max();
Expand Down Expand Up @@ -459,8 +457,8 @@ impl StringChunked {
}

impl ChunkAggSeries for StringChunked {
fn sum_as_series(&self) -> Series {
StringChunked::full_null(self.name(), 1).into_series()
fn sum_as_series(&self) -> Scalar {
Scalar::new(DataType::String, AnyValue::Null)
}
fn max_as_series(&self) -> Series {
Series::new(self.name(), &[self.max_str()])
Expand Down Expand Up @@ -590,7 +588,7 @@ impl BinaryChunked {
}

impl ChunkAggSeries for BinaryChunked {
fn sum_as_series(&self) -> Series {
fn sum_as_series(&self) -> Scalar {
unimplemented!()
}
fn max_as_series(&self) -> Series {
Expand Down
54 changes: 45 additions & 9 deletions crates/polars-core/src/datatypes/any_value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,25 @@ use polars_utils::sync::SyncPtr;
use polars_utils::total_ord::ToTotalOrd;
use polars_utils::unwrap::UnwrapUncheckedRelease;

pub struct Scalar {
dtype: DataType,
value: AnyValue<'static>,
}

impl Scalar {
pub fn new(dtype: DataType, value: AnyValue<'static>) -> Self {
Self { dtype, value }
}

pub fn value(&self) -> &AnyValue {
&self.value
}

pub fn into_series(self, name: &str) -> Series {
Series::from_any_values_and_dtype(name, &[self.value], &self.dtype, true).unwrap()
}
}

use super::*;
#[cfg(feature = "dtype-struct")]
use crate::prelude::any_value::arr_to_any_value;
Expand Down Expand Up @@ -338,6 +357,23 @@ impl<'a> Deserialize<'a> for AnyValue<'static> {
}
}

impl AnyValue<'static> {
pub fn zero(dtype: &DataType) -> Self {
match dtype {
DataType::String => AnyValue::StringOwned("".into()),
DataType::Boolean => AnyValue::Boolean(false),
// SAFETY:
// Numeric values are static, inform the compiler of this.
d if d.is_numeric() => unsafe {
std::mem::transmute::<AnyValue<'_>, AnyValue<'static>>(
AnyValue::UInt8(0).cast(dtype),
)
},
_ => AnyValue::Null,
}
}
}

impl<'a> AnyValue<'a> {
/// Get the matching [`DataType`] for this [`AnyValue`]`.
///
Expand Down Expand Up @@ -735,43 +771,43 @@ where

impl<'a> AnyValue<'a> {
#[cfg(any(feature = "dtype-date", feature = "dtype-datetime"))]
pub(crate) fn into_date(self) -> Self {
pub(crate) fn as_date(&self) -> AnyValue<'static> {
match self {
#[cfg(feature = "dtype-date")]
AnyValue::Int32(v) => AnyValue::Date(v),
AnyValue::Int32(v) => AnyValue::Date(*v),
AnyValue::Null => AnyValue::Null,
dt => panic!("cannot create date from other type. dtype: {dt}"),
}
}
#[cfg(feature = "dtype-datetime")]
pub(crate) fn into_datetime(self, tu: TimeUnit, tz: &'a Option<TimeZone>) -> Self {
pub(crate) fn as_datetime(&self, tu: TimeUnit, tz: &'a Option<TimeZone>) -> AnyValue<'a> {
match self {
AnyValue::Int64(v) => AnyValue::Datetime(v, tu, tz),
AnyValue::Int64(v) => AnyValue::Datetime(*v, tu, tz),
AnyValue::Null => AnyValue::Null,
dt => panic!("cannot create date from other type. dtype: {dt}"),
}
}

#[cfg(feature = "dtype-duration")]
pub(crate) fn into_duration(self, tu: TimeUnit) -> Self {
pub(crate) fn as_duration(&self, tu: TimeUnit) -> AnyValue<'static> {
match self {
AnyValue::Int64(v) => AnyValue::Duration(v, tu),
AnyValue::Int64(v) => AnyValue::Duration(*v, tu),
AnyValue::Null => AnyValue::Null,
dt => panic!("cannot create date from other type. dtype: {dt}"),
}
}

#[cfg(feature = "dtype-time")]
pub(crate) fn into_time(self) -> Self {
pub(crate) fn as_time(&self) -> AnyValue<'static> {
match self {
AnyValue::Int64(v) => AnyValue::Time(v),
AnyValue::Int64(v) => AnyValue::Time(*v),
AnyValue::Null => AnyValue::Null,
dt => panic!("cannot create date from other type. dtype: {dt}"),
}
}

#[must_use]
pub fn add(&self, rhs: &AnyValue) -> Self {
pub fn add(&self, rhs: &AnyValue) -> AnyValue<'static> {
use AnyValue::*;
match (self, rhs) {
(Null, _) => Null,
Expand Down
2 changes: 1 addition & 1 deletion crates/polars-core/src/series/implementations/boolean.rs
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ impl SeriesTrait for SeriesWrap<BooleanChunked> {
ChunkShift::shift(&self.0, periods).into_series()
}

fn _sum_as_series(&self) -> PolarsResult<Series> {
fn _sum_as_series(&self) -> PolarsResult<Scalar> {
Ok(ChunkAggSeries::sum_as_series(&self.0))
}
fn max_as_series(&self) -> PolarsResult<Series> {
Expand Down
27 changes: 19 additions & 8 deletions crates/polars-core/src/series/implementations/decimal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,16 @@ unsafe impl IntoSeries for DecimalChunked {
impl private::PrivateSeriesNumeric for SeriesWrap<DecimalChunked> {}

impl SeriesWrap<DecimalChunked> {
fn apply_physical<F: Fn(&Int128Chunked) -> Int128Chunked>(&self, f: F) -> Series {
fn apply_physical_to_s<F: Fn(&Int128Chunked) -> Int128Chunked>(&self, f: F) -> Series {
f(&self.0)
.into_decimal_unchecked(self.0.precision(), self.0.scale())
.into_series()
}

fn apply_physical<T, F: Fn(&Int128Chunked) -> T>(&self, f: F) -> T {
f(&self.0)
}

fn agg_helper<F: Fn(&Int128Chunked) -> Series>(&self, f: F) -> Series {
let agg_s = f(&self.0);
match agg_s.dtype() {
Expand Down Expand Up @@ -187,7 +191,7 @@ impl SeriesTrait for SeriesWrap<DecimalChunked> {
}

fn slice(&self, offset: i64, length: usize) -> Series {
self.apply_physical(|ca| ca.slice(offset, length))
self.apply_physical_to_s(|ca| ca.slice(offset, length))
}

fn append(&mut self, other: &Series) -> PolarsResult<()> {
Expand Down Expand Up @@ -301,31 +305,38 @@ impl SeriesTrait for SeriesWrap<DecimalChunked> {
}

fn reverse(&self) -> Series {
self.apply_physical(|ca| ca.reverse())
self.apply_physical_to_s(|ca| ca.reverse())
}

fn shift(&self, periods: i64) -> Series {
self.apply_physical(|ca| ca.shift(periods))
self.apply_physical_to_s(|ca| ca.shift(periods))
}

fn clone_inner(&self) -> Arc<dyn SeriesTrait> {
Arc::new(SeriesWrap(Clone::clone(&self.0)))
}

fn _sum_as_series(&self) -> PolarsResult<Series> {
fn _sum_as_series(&self) -> PolarsResult<Scalar> {
Ok(self.apply_physical(|ca| {
let sum = ca.sum();
Int128Chunked::from_slice_options(self.name(), &[sum])
let DataType::Decimal(_, Some(scale)) = self.dtype() else {
unreachable!()
};
let av = match sum {
None => AnyValue::Null,
Some(v) => AnyValue::Decimal(v, *scale),
};
Scalar::new(self.dtype().clone(), av)
}))
}
fn min_as_series(&self) -> PolarsResult<Series> {
Ok(self.apply_physical(|ca| {
Ok(self.apply_physical_to_s(|ca| {
let min = ca.min();
Int128Chunked::from_slice_options(self.name(), &[min])
}))
}
fn max_as_series(&self) -> PolarsResult<Series> {
Ok(self.apply_physical(|ca| {
Ok(self.apply_physical_to_s(|ca| {
let max = ca.max();
Int128Chunked::from_slice_options(self.name(), &[max])
}))
Expand Down
6 changes: 4 additions & 2 deletions crates/polars-core/src/series/implementations/duration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -393,8 +393,10 @@ impl SeriesTrait for SeriesWrap<DurationChunked> {
.into_series()
}

fn _sum_as_series(&self) -> PolarsResult<Series> {
Ok(self.0.sum_as_series().into_duration(self.0.time_unit()))
fn _sum_as_series(&self) -> PolarsResult<Scalar> {
let sc = self.0.sum_as_series();
let v = sc.value().as_duration(self.0.time_unit());
Ok(Scalar::new(self.dtype().clone(), v))
}

fn max_as_series(&self) -> PolarsResult<Series> {
Expand Down
2 changes: 1 addition & 1 deletion crates/polars-core/src/series/implementations/floats.rs
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ macro_rules! impl_dyn_series {
ChunkShift::shift(&self.0, periods).into_series()
}

fn _sum_as_series(&self) -> PolarsResult<Series> {
fn _sum_as_series(&self) -> PolarsResult<Scalar> {
Ok(ChunkAggSeries::sum_as_series(&self.0))
}
fn max_as_series(&self) -> PolarsResult<Series> {
Expand Down
2 changes: 1 addition & 1 deletion crates/polars-core/src/series/implementations/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,7 @@ macro_rules! impl_dyn_series {
ChunkShift::shift(&self.0, periods).into_series()
}

fn _sum_as_series(&self) -> PolarsResult<Series> {
fn _sum_as_series(&self) -> PolarsResult<Scalar> {
Ok(ChunkAggSeries::sum_as_series(&self.0))
}
fn max_as_series(&self) -> PolarsResult<Series> {
Expand Down
2 changes: 1 addition & 1 deletion crates/polars-core/src/series/implementations/string.rs
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ impl SeriesTrait for SeriesWrap<StringChunked> {
ChunkShift::shift(&self.0, periods).into_series()
}

fn _sum_as_series(&self) -> PolarsResult<Series> {
fn _sum_as_series(&self) -> PolarsResult<Scalar> {
Ok(ChunkAggSeries::sum_as_series(&self.0))
}
fn max_as_series(&self) -> PolarsResult<Series> {
Expand Down
7 changes: 4 additions & 3 deletions crates/polars-core/src/series/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -391,8 +391,9 @@ impl Series {
where
T: NumCast,
{
let sum = self.sum_as_series()?.cast(&DataType::Float64)?;
Ok(T::from(sum.f64().unwrap().get(0).unwrap()).unwrap())
let sum = self.sum_as_series()?;
let sum = sum.value().cast(&DataType::Float64);
Ok(T::from(sum.extract::<f64>().unwrap()).unwrap())
}

/// Returns the minimum value in the array, according to the natural order.
Expand Down Expand Up @@ -628,7 +629,7 @@ impl Series {
///
/// If the [`DataType`] is one of `{Int8, UInt8, Int16, UInt16}` the `Series` is
/// first cast to `Int64` to prevent overflow issues.
pub fn sum_as_series(&self) -> PolarsResult<Series> {
pub fn sum_as_series(&self) -> PolarsResult<Scalar> {
use DataType::*;
match self.dtype() {
Int8 | UInt8 | Int16 | UInt16 => self.cast(&Int64).unwrap().sum_as_series(),
Expand Down
2 changes: 1 addition & 1 deletion crates/polars-core/src/series/series_trait.rs
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,7 @@ pub trait SeriesTrait:
///
/// If the [`DataType`] is one of `{Int8, UInt8, Int16, UInt16}` the `Series` is
/// first cast to `Int64` to prevent overflow issues.
fn _sum_as_series(&self) -> PolarsResult<Series> {
fn _sum_as_series(&self) -> PolarsResult<Scalar> {
polars_bail!(opq = sum, self._dtype());
}
/// Get the max of the Series as a new Series of length 1.
Expand Down
7 changes: 6 additions & 1 deletion crates/polars-lazy/src/physical_plan/planner/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -476,7 +476,12 @@ fn create_physical_expr_inner(
let state = *state;
SpecialEq::new(Arc::new(move |s: &mut [Series]| {
let s = std::mem::take(&mut s[0]);
parallel_op_series(|s| s.sum_as_series(), s, None, state)
parallel_op_series(
|s| s.sum_as_series().map(|sc| sc.into_series(s.name())),
s,
None,
state,
)
}) as Arc<dyn SeriesUdf>)
},
AAggExpr::Count(_, include_nulls) => {
Expand Down
Loading

0 comments on commit 709143d

Please sign in to comment.