Skip to content

Commit

Permalink
Improve code for datetime breaks
Browse files Browse the repository at this point in the history
  • Loading branch information
has2k1 committed Oct 23, 2024
1 parent e74f102 commit dc9a34c
Show file tree
Hide file tree
Showing 9 changed files with 63 additions and 31 deletions.
6 changes: 6 additions & 0 deletions doc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,12 @@ API Changes

- `mizani.transforms.trans_new` function has been deprecated.

Enhancements
************

- `~mizani.breaks.breaks_date` has been slightly improved for the case when it
generates monthly breaks.

New
***

Expand Down
25 changes: 18 additions & 7 deletions mizani/_core/date_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,6 @@ class Interval:
end: datetime

def __post_init__(self):
if isinstance(self.start, date):
self.start = datetime.fromisoformat(self.start.isoformat())

if isinstance(self.end, date):
self.end = datetime.fromisoformat(self.end.isoformat())

self._delta = relativedelta(self.end, self.start)
self._tdelta = self.end - self.start

Expand Down Expand Up @@ -149,7 +143,7 @@ def limits_year(self) -> tuple[datetime, datetime]:
return floor_year(self.start), ceil_year(self.end)

def limits_month(self) -> tuple[datetime, datetime]:
return round_month(self.start), round_month(self.end)
return floor_month(self.start), ceil_month(self.end)

def limits_week(self) -> tuple[datetime, datetime]:
return floor_week(self.start), ceil_week(self.end)
Expand Down Expand Up @@ -481,3 +475,20 @@ def expand_datetime_limits(
end = end.replace(y2)

return start, end


def as_datetime(
tup: tuple[datetime, datetime] | tuple[date, date],
) -> tuple[datetime, datetime]:
"""
Ensure that a tuple of datetime values
"""
l, h = tup

if not isinstance(l, datetime):
l = datetime.fromisoformat(l.isoformat())

if not isinstance(h, datetime):
h = datetime.fromisoformat(h.isoformat())

return l, h
14 changes: 11 additions & 3 deletions mizani/_core/dates.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,12 @@
from dateutil.rrule import rrule

from ..utils import get_timezone, isclose_abs
from .date_utils import Interval, align_limits, expand_datetime_limits
from .date_utils import (
Interval,
align_limits,
as_datetime,
expand_datetime_limits,
)
from .types import DateFrequency, date_breaks_info

if TYPE_CHECKING:
Expand Down Expand Up @@ -316,10 +321,13 @@ def calculate_date_breaks_info(
return res


def calculate_date_breaks_auto(limits, n: int = 5) -> Sequence[datetime]:
def calculate_date_breaks_auto(
limits: tuple[datetime, datetime], n: int = 5
) -> Sequence[datetime]:
"""
Calcuate date breaks using appropriate units
"""
limits = as_datetime(limits)
info = calculate_date_breaks_info(limits, n=n)
lookup = {
DF.YEARLY: yearly_breaks,
Expand All @@ -334,7 +342,7 @@ def calculate_date_breaks_auto(limits, n: int = 5) -> Sequence[datetime]:


def calculate_date_breaks_byunits(
limits,
limits: tuple[datetime, datetime],
units: DatetimeBreaksUnits,
width: int,
max_breaks: int | None = None,
Expand Down
6 changes: 4 additions & 2 deletions mizani/breaks.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,15 @@

import sys
from dataclasses import KW_ONLY, dataclass, field
from datetime import datetime, timedelta
from datetime import date, datetime, timedelta
from itertools import product
from typing import TYPE_CHECKING
from warnings import warn

import numpy as np
import pandas as pd

from mizani._core.date_utils import as_datetime
from mizani._core.dates import (
calculate_date_breaks_auto,
calculate_date_breaks_byunits,
Expand Down Expand Up @@ -460,7 +461,7 @@ def __post_init__(self):
self._units = units.rstrip("s") # type: ignore

def __call__(
self, limits: tuple[datetime, datetime]
self, limits: tuple[datetime, datetime] | tuple[date, date]
) -> Sequence[datetime]:
"""
Compute breaks
Expand All @@ -483,6 +484,7 @@ def __call__(
):
limits = limits[0].astype(object), limits[1].astype(object)

limits = as_datetime(limits)
if self._units and self._width:
return calculate_date_breaks_byunits(
limits, self._units, self._width
Expand Down
20 changes: 6 additions & 14 deletions mizani/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

import math
import sys
from datetime import datetime, timezone
from typing import TYPE_CHECKING, cast, overload
from datetime import datetime
from typing import TYPE_CHECKING, overload
from warnings import warn

import numpy as np
Expand All @@ -22,7 +22,6 @@
NDArrayFloat,
NullType,
NumericUFunction,
SeqDatetime,
)

T = TypeVar("T")
Expand Down Expand Up @@ -327,28 +326,21 @@ def log(x, base):
return res


def get_timezone(x: SeqDatetime) -> tzinfo | None:
def get_timezone(x: Sequence[datetime]) -> tzinfo | None:
"""
Return a single timezone for the sequence of datetimes
Returns the timezone of first item and warns if any other items
have a different timezone
"""

# Ref: https://en.wikipedia.org/wiki/List_of_tz_database_time_zones
x0 = next(iter(x))
if not isinstance(x0, datetime):
if not len(x) or x[0].tzinfo is None:
return None

x = cast(list[datetime], x)
info = x0.tzinfo
if info is None:
return timezone.utc

# Consistency check
tzname0 = info.tzname(x0)
info = x[0].tzinfo
tzname0 = info.tzname(x[0])
tznames = (dt.tzinfo.tzname(dt) if dt.tzinfo else None for dt in x)

if any(tzname0 != name for name in tznames):
msg = (
"Dates in column have different time zones. "
Expand Down
4 changes: 2 additions & 2 deletions tests/test_bounds.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ def test_squish_infinite():
squish_infinite(a, (-100, 100)), [-100, 100, -100, 100]
)

b = np.array([5, -np.inf, 2, 3, 6])
b = pd.Series([5, -np.inf, 2, 3, 6])
npt.assert_allclose(squish_infinite(b, (1, 10)), [5, 1, 2, 3, 6])


Expand All @@ -270,7 +270,7 @@ def test_squish():
b = np.array([5, 0, -2, 3, 10])
npt.assert_allclose(squish(b, (0, 5)), [5, 0, 0, 3, 5])

c = np.array([5, -np.inf, 2, 3, 6])
c = pd.Series([5, -np.inf, 2, 3, 6])
npt.assert_allclose(squish(c, (1, 10), only_finite=False), [5, 1, 2, 3, 6])
npt.assert_allclose(squish(c, (1, 10)), c)

Expand Down
6 changes: 5 additions & 1 deletion tests/test_breaks.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def test_breaks_date():
# automatic monthly breaks with rounding
limits = (datetime(2019, 12, 27), datetime(2020, 6, 3))
breaks = breaks_date()(limits)
assert [dt.month for dt in breaks] == [1, 3, 5]
assert [dt.month for dt in breaks] == [12, 2, 4, 6]

# automatic day breaks
limits = (datetime(2020, 1, 1), datetime(2020, 1, 15))
Expand Down Expand Up @@ -246,6 +246,10 @@ def test_breaks_date():
breaks = breaks_date()(limits)
assert breaks[0].tzinfo == UG

# date
limits = (date(2000, 4, 23), date(2000, 6, 15))
breaks = breaks_date()(limits)

# Special cases
limits = (datetime(2039, 12, 17), datetime(2045, 12, 16))
breaks = breaks_date()(limits)
Expand Down
9 changes: 9 additions & 0 deletions tests/test_date_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from mizani._core.date_utils import (
ceil_month,
expand_datetime_limits,
round_month,
shift_limits_down,
)

Expand Down Expand Up @@ -31,3 +32,11 @@ def test_ceil_month():

d = datetime(2020, 1, 1)
assert ceil_month(d) == d


def test_round_month():
d = datetime(2000, 4, 23)
assert round_month(d) == datetime(2000, 5, 1)

d = datetime(2000, 4, 14)
assert round_month(d) == datetime(2000, 4, 1)
4 changes: 2 additions & 2 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from datetime import date, datetime
from datetime import datetime
from zoneinfo import ZoneInfo

import pandas as pd
Expand Down Expand Up @@ -133,7 +133,7 @@ def test_get_timezone():
UTC = ZoneInfo("UTC")
UG = ZoneInfo("Africa/Kampala")

x = [date(2022, 1, 1), date(2022, 12, 1)]
x = [datetime(2022, 1, 1), datetime(2022, 12, 1)]
assert get_timezone(x) is None

x = [datetime(2022, 1, 1, tzinfo=UTC), datetime(2022, 12, 1, tzinfo=UG)]
Expand Down

0 comments on commit dc9a34c

Please sign in to comment.