Skip to content

Commit

Permalink
fix: Return appropriate data type for time mean and median (#14471)
Browse files Browse the repository at this point in the history
  • Loading branch information
mcrumiller authored Apr 13, 2024
1 parent ce39151 commit 4b1b945
Show file tree
Hide file tree
Showing 7 changed files with 148 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -116,14 +116,14 @@ impl Series {
Float64 => SeriesWrap(self.f64().unwrap().clone()).agg_median(groups),
dt if dt.is_numeric() => apply_method_physical_integer!(self, agg_median, groups),
#[cfg(feature = "dtype-datetime")]
dt @ (Datetime(_, _) | Duration(_)) => self
dt @ (Datetime(_, _) | Duration(_) | Time) => self
.to_physical_repr()
.agg_median(groups)
.cast(&Int64)
.unwrap()
.cast(dt)
.unwrap(),
dt @ (Date | Time) => {
dt @ Date => {
let ca = self.to_physical_repr();
let physical_type = ca.dtype();
let s = apply_method_physical_integer!(ca, agg_median, groups);
Expand Down Expand Up @@ -174,14 +174,14 @@ impl Series {
Float64 => SeriesWrap(self.f64().unwrap().clone()).agg_mean(groups),
dt if dt.is_numeric() => apply_method_physical_integer!(self, agg_mean, groups),
#[cfg(feature = "dtype-datetime")]
dt @ (Datetime(_, _) | Duration(_)) => self
dt @ (Datetime(_, _) | Duration(_) | Time) => self
.to_physical_repr()
.agg_mean(groups)
.cast(&Int64)
.unwrap()
.cast(dt)
.unwrap(),
dt @ (Date | Time) => {
dt @ Date => {
let ca = self.to_physical_repr();
let physical_type = ca.dtype();
let s = apply_method_physical_integer!(ca, agg_mean, groups);
Expand Down
3 changes: 3 additions & 0 deletions crates/polars-core/src/series/implementations/dates_time.rs
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,9 @@ macro_rules! impl_dyn_series {
fn min_as_series(&self) -> PolarsResult<Series> {
Ok(self.0.min_as_series().$into_logical())
}
fn median_as_series(&self) -> PolarsResult<Series> {
Series::new(self.name(), &[self.median().map(|v| v as i64)]).cast(self.dtype())
}

fn clone_inner(&self) -> Arc<dyn SeriesTrait> {
Arc::new(SeriesWrap(Clone::clone(&self.0)))
Expand Down
4 changes: 4 additions & 0 deletions crates/polars-core/src/series/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -789,6 +789,10 @@ impl Series {
.cast(dt)
.unwrap()
},
#[cfg(feature = "dtype-time")]
dt @ DataType::Time => Series::new(self.name(), &[self.mean().map(|v| v as i64)])
.cast(dt)
.unwrap(),
_ => return Series::full_null(self.name(), 1, self.dtype()),
}
}
Expand Down
10 changes: 8 additions & 2 deletions crates/polars-lazy/src/frame/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1402,7 +1402,10 @@ impl LazyFrame {
dt.is_numeric()
|| matches!(
dt,
DataType::Boolean | DataType::Duration(_) | DataType::Datetime(_, _)
DataType::Boolean
| DataType::Duration(_)
| DataType::Datetime(_, _)
| DataType::Time
)
},
|name| col(name).mean(),
Expand All @@ -1420,7 +1423,10 @@ impl LazyFrame {
dt.is_numeric()
|| matches!(
dt,
DataType::Boolean | DataType::Duration(_) | DataType::Datetime(_, _)
DataType::Boolean
| DataType::Duration(_)
| DataType::Datetime(_, _)
| DataType::Time
)
},
|name| col(name).median(),
Expand Down
6 changes: 3 additions & 3 deletions py-polars/polars/series/datetime.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from polars._utils.deprecation import deprecate_function, deprecate_renamed_function
from polars._utils.unstable import unstable
from polars._utils.wrap import wrap_s
from polars.datatypes import Date, Datetime, Duration
from polars.datatypes import Date, Datetime, Duration, Time
from polars.series.utils import expr_dispatch

if TYPE_CHECKING:
Expand Down Expand Up @@ -88,7 +88,7 @@ def median(self) -> TemporalLiteral | float | None:
if out is not None:
if s.dtype == Date:
return to_py_date(int(out)) # type: ignore[arg-type]
elif s.dtype in (Datetime, Duration):
elif s.dtype in (Datetime, Duration, Time):
return out # type: ignore[return-value]
else:
return to_py_datetime(int(out), s.dtype.time_unit) # type: ignore[arg-type, attr-defined]
Expand All @@ -112,7 +112,7 @@ def mean(self) -> TemporalLiteral | float | None:
if out is not None:
if s.dtype == Date:
return to_py_date(int(out)) # type: ignore[arg-type]
elif s.dtype in (Datetime, Duration):
elif s.dtype in (Datetime, Duration, Time):
return out # type: ignore[return-value]
else:
return to_py_datetime(int(out), s.dtype.time_unit) # type: ignore[arg-type, attr-defined]
Expand Down
4 changes: 2 additions & 2 deletions py-polars/src/series/aggregation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ impl PySeries {
.map_err(PyPolarsErr::from)?,
)
.into_py(py)),
DataType::Datetime(_, _) | DataType::Duration(_) => Ok(Wrap(
DataType::Datetime(_, _) | DataType::Duration(_) | DataType::Time => Ok(Wrap(
self.series
.mean_as_series()
.get(0)
Expand All @@ -77,7 +77,7 @@ impl PySeries {
.map_err(PyPolarsErr::from)?,
)
.into_py(py)),
DataType::Datetime(_, _) | DataType::Duration(_) => Ok(Wrap(
DataType::Datetime(_, _) | DataType::Duration(_) | DataType::Time => Ok(Wrap(
self.series
.median_as_series()
.map_err(PyPolarsErr::from)?
Expand Down
153 changes: 124 additions & 29 deletions py-polars/tests/unit/namespaces/test_datetime.py
Original file line number Diff line number Diff line change
Expand Up @@ -982,6 +982,9 @@ def test_weekday(time_unit: TimeUnit) -> None:
([timedelta(days=1)], timedelta(days=1)),
([timedelta(days=1), timedelta(days=2), timedelta(days=3)], timedelta(days=2)),
([timedelta(days=1), timedelta(days=2), timedelta(days=15)], timedelta(days=2)),
([time(hour=1)], time(hour=1)),
([time(hour=1), time(hour=2), time(hour=3)], time(hour=2)),
([time(hour=1), time(hour=2), time(hour=15)], time(hour=2)),
],
ids=[
"empty",
Expand All @@ -995,6 +998,9 @@ def test_weekday(time_unit: TimeUnit) -> None:
"single_dur",
"spread_even_dur",
"spread_skewed_dur",
"single_time",
"spread_even_time",
"spread_skewed_time",
],
)
def test_median(
Expand All @@ -1003,7 +1009,7 @@ def test_median(
s = pl.Series(values)
assert s.dt.median() == expected_median

if s.dtype == pl.Datetime:
if s.dtype in (pl.Datetime, pl.Duration, pl.Time):
assert s.median() == expected_median


Expand All @@ -1027,6 +1033,9 @@ def test_median(
([timedelta(days=1)], timedelta(days=1)),
([timedelta(days=1), timedelta(days=2), timedelta(days=3)], timedelta(days=2)),
([timedelta(days=1), timedelta(days=2), timedelta(days=15)], timedelta(days=6)),
([time(hour=1)], time(hour=1)),
([time(hour=1), time(hour=2), time(hour=3)], time(hour=2)),
([time(hour=1), time(hour=2), time(hour=15)], time(hour=6)),
],
ids=[
"empty",
Expand All @@ -1040,6 +1049,9 @@ def test_median(
"single_duration",
"spread_even_duration",
"spread_skewed_duration",
"single_time",
"spread_even_time",
"spread_skewed_time",
],
)
def test_mean(
Expand All @@ -1048,62 +1060,91 @@ def test_mean(
s = pl.Series(values)
assert s.dt.mean() == expected_mean

if s.dtype == pl.Datetime:
if s.dtype in (pl.Datetime, pl.Duration, pl.Time):
assert s.mean() == expected_mean


@pytest.mark.parametrize(
("values", "expected_mean"),
[
([None], None),
(
[datetime(2022, 1, 1), datetime(2022, 1, 2), datetime(2024, 5, 15)],
datetime(2022, 10, 16, 16, 0, 0),
),
],
ids=["spread_skewed_dt"],
ids=["None_dt", "spread_skewed_dt"],
)
def test_datetime_mean_with_tu(values: list[datetime], expected_mean: datetime) -> None:
assert pl.Series(values, dtype=pl.Duration("ms")).mean() == expected_mean
assert pl.Series(values, dtype=pl.Duration("ms")).dt.mean() == expected_mean
assert pl.Series(values, dtype=pl.Duration("us")).mean() == expected_mean
assert pl.Series(values, dtype=pl.Duration("us")).dt.mean() == expected_mean
assert pl.Series(values, dtype=pl.Duration("ns")).mean() == expected_mean
assert pl.Series(values, dtype=pl.Duration("ns")).dt.mean() == expected_mean
@pytest.mark.parametrize("time_unit", ["ms", "us", "ns"])
def test_datetime_mean_with_tu(
values: list[datetime], expected_mean: datetime, time_unit: TimeUnit
) -> None:
assert pl.Series(values, dtype=pl.Duration(time_unit)).mean() == expected_mean
assert pl.Series(values, dtype=pl.Duration(time_unit)).dt.mean() == expected_mean


@pytest.mark.parametrize(
("values", "expected_median"),
[
([None], None),
(
[datetime(2022, 1, 1), datetime(2022, 1, 2), datetime(2024, 5, 15)],
datetime(2022, 1, 2),
),
],
ids=["None_dt", "spread_skewed_dt"],
)
@pytest.mark.parametrize("time_unit", ["ms", "us", "ns"])
def test_datetime_median_with_tu(
values: list[datetime], expected_median: datetime, time_unit: TimeUnit
) -> None:
assert pl.Series(values, dtype=pl.Duration(time_unit)).median() == expected_median
assert (
pl.Series(values, dtype=pl.Duration(time_unit)).dt.median() == expected_median
)


@pytest.mark.parametrize(
("values", "expected_mean"),
[([timedelta(days=1), timedelta(days=2), timedelta(days=15)], timedelta(days=6))],
ids=["spread_skewed_dur"],
[
([None], None),
(
[timedelta(days=1), timedelta(days=2), timedelta(days=15)],
timedelta(days=6),
),
],
ids=["None_dur", "spread_skewed_dur"],
)
@pytest.mark.parametrize("time_unit", ["ms", "us", "ns"])
def test_duration_mean_with_tu(
values: list[timedelta], expected_mean: timedelta
values: list[timedelta], expected_mean: timedelta, time_unit: TimeUnit
) -> None:
assert pl.Series(values, dtype=pl.Duration("ms")).mean() == expected_mean
assert pl.Series(values, dtype=pl.Duration("ms")).dt.mean() == expected_mean
assert pl.Series(values, dtype=pl.Duration("us")).mean() == expected_mean
assert pl.Series(values, dtype=pl.Duration("us")).dt.mean() == expected_mean
assert pl.Series(values, dtype=pl.Duration("ns")).mean() == expected_mean
assert pl.Series(values, dtype=pl.Duration("ns")).dt.mean() == expected_mean
assert pl.Series(values, dtype=pl.Duration(time_unit)).mean() == expected_mean
assert pl.Series(values, dtype=pl.Duration(time_unit)).dt.mean() == expected_mean


@pytest.mark.parametrize(
("values", "expected_median"),
[([timedelta(days=1), timedelta(days=2), timedelta(days=15)], timedelta(days=2))],
ids=["spread_skewed_dur"],
[
([None], None),
(
[timedelta(days=1), timedelta(days=2), timedelta(days=15)],
timedelta(days=2),
),
],
ids=["None_dur", "spread_skewed_dur"],
)
@pytest.mark.parametrize("time_unit", ["ms", "us", "ns"])
def test_duration_median_with_tu(
values: list[timedelta], expected_median: timedelta
values: list[timedelta], expected_median: timedelta, time_unit: TimeUnit
) -> None:
assert pl.Series(values, dtype=pl.Duration("ms")).median() == expected_median
assert pl.Series(values, dtype=pl.Duration("ms")).dt.median() == expected_median
assert pl.Series(values, dtype=pl.Duration("us")).median() == expected_median
assert pl.Series(values, dtype=pl.Duration("us")).dt.median() == expected_median
assert pl.Series(values, dtype=pl.Duration("ns")).median() == expected_median
assert pl.Series(values, dtype=pl.Duration("ns")).dt.median() == expected_median
assert pl.Series(values, dtype=pl.Duration(time_unit)).median() == expected_median
assert (
pl.Series(values, dtype=pl.Duration(time_unit)).dt.median() == expected_median
)


def test_agg_expr() -> None:
def test_agg_mean_expr() -> None:
df = pl.DataFrame(
{
"datetime_ms": pl.Series(
Expand All @@ -1130,6 +1171,10 @@ def test_agg_expr() -> None:
[timedelta(days=1), timedelta(days=2), timedelta(days=4)],
dtype=pl.Duration("ns"),
),
"time": pl.Series(
[time(hour=1), time(hour=2), time(hour=4)],
dtype=pl.Time,
),
}
)

Expand All @@ -1153,7 +1198,57 @@ def test_agg_expr() -> None:
"duration_ns": pl.Series(
[timedelta(days=2, hours=8)], dtype=pl.Duration("ns")
),
"time": pl.Series([time(hour=2, minute=20)], dtype=pl.Time),
}
)

assert_frame_equal(df.select(pl.all().mean()), expected)


def test_agg_median_expr() -> None:
df = pl.DataFrame(
{
"datetime_ms": pl.Series(
[datetime(2023, 1, 1), datetime(2023, 1, 2), datetime(2023, 1, 4)],
dtype=pl.Datetime("ms"),
),
"datetime_us": pl.Series(
[datetime(2023, 1, 1), datetime(2023, 1, 2), datetime(2023, 1, 4)],
dtype=pl.Datetime("us"),
),
"datetime_ns": pl.Series(
[datetime(2023, 1, 1), datetime(2023, 1, 2), datetime(2023, 1, 4)],
dtype=pl.Datetime("ns"),
),
"duration_ms": pl.Series(
[timedelta(days=1), timedelta(days=2), timedelta(days=4)],
dtype=pl.Duration("ms"),
),
"duration_us": pl.Series(
[timedelta(days=1), timedelta(days=2), timedelta(days=4)],
dtype=pl.Duration("us"),
),
"duration_ns": pl.Series(
[timedelta(days=1), timedelta(days=2), timedelta(days=4)],
dtype=pl.Duration("ns"),
),
"time": pl.Series(
[time(hour=1), time(hour=2), time(hour=4)],
dtype=pl.Time,
),
}
)

expected = pl.DataFrame(
{
"datetime_ms": pl.Series([datetime(2023, 1, 2)], dtype=pl.Datetime("ms")),
"datetime_us": pl.Series([datetime(2023, 1, 2)], dtype=pl.Datetime("us")),
"datetime_ns": pl.Series([datetime(2023, 1, 2)], dtype=pl.Datetime("ns")),
"duration_ms": pl.Series([timedelta(days=2)], dtype=pl.Duration("ms")),
"duration_us": pl.Series([timedelta(days=2)], dtype=pl.Duration("us")),
"duration_ns": pl.Series([timedelta(days=2)], dtype=pl.Duration("ns")),
"time": pl.Series([time(hour=2)], dtype=pl.Time),
}
)

assert_frame_equal(df.select(pl.all().median()), expected)

0 comments on commit 4b1b945

Please sign in to comment.