Skip to content

Commit

Permalink
Convert breaks classes to dataclasses
Browse files Browse the repository at this point in the history
  • Loading branch information
has2k1 committed Oct 22, 2024
1 parent cf3cb44 commit ff36853
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 70 deletions.
130 changes: 63 additions & 67 deletions mizani/breaks.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,11 @@
from __future__ import annotations

import sys
from dataclasses import KW_ONLY, dataclass, field
from datetime import datetime, timedelta
from itertools import product
from typing import TYPE_CHECKING
from warnings import warn

import numpy as np
import pandas as pd
Expand All @@ -29,7 +31,7 @@
from .utils import NANOSECONDS, SECONDS, log, min_max

if TYPE_CHECKING:
from typing import Callable, Literal, Sequence
from typing import Literal, Sequence

from mizani.typing import (
DatetimeBreaksUnits,
Expand All @@ -53,6 +55,7 @@
]


@dataclass
class breaks_log:
"""
Integer breaks on log transformed scales
Expand All @@ -76,9 +79,8 @@ class breaks_log:
array([0.1, 0.3, 1. , 3. ])
"""

def __init__(self, n: int = 5, base: float = 10):
self.n = n
self.base = base
n: int = 5
base: float = 10

def __call__(self, limits: tuple[float, float]) -> NDArrayFloat:
"""
Expand Down Expand Up @@ -124,6 +126,7 @@ def __call__(self, limits: tuple[float, float]) -> NDArrayFloat:
return _breaks_log_sub(n=n, base=base)(limits)


@dataclass
class _breaks_log_sub:
"""
Breaks for log transformed scales
Expand All @@ -144,9 +147,8 @@ class _breaks_log_sub:
algorithm in the r-scales package.
"""

def __init__(self, n: int = 5, base: float = 10):
self.n = n
self.base = base
n: int = 5
base: float = 10

def __call__(self, limits: tuple[float, float]) -> NDArrayFloat:
base = self.base
Expand Down Expand Up @@ -204,6 +206,7 @@ def delta(x):
return breaks_extended(n=n)(limits)


@dataclass
class minor_breaks:
"""
Compute minor breaks
Expand Down Expand Up @@ -234,8 +237,7 @@ class minor_breaks:
array([1.25, 1.5 , 1.75])
"""

def __init__(self, n: int = 1):
self.n = n
n: int = 1

def __call__(
self,
Expand Down Expand Up @@ -293,6 +295,7 @@ def __call__(
return minor


@dataclass
class minor_breaks_trans:
"""
Compute minor breaks for transformed scales
Expand Down Expand Up @@ -335,9 +338,8 @@ class minor_breaks_trans:
array([2.8, 4.6, 6.4, 8.2])
"""

def __init__(self, trans: Trans, n: int = 1):
self.trans = trans
self.n = n
trans: Trans
n: int = 1

def __call__(
self,
Expand Down Expand Up @@ -399,6 +401,7 @@ def _extend_breaks(self, major: FloatArrayLike) -> FloatArrayLike:
return major


@dataclass
class breaks_date:
"""
Regularly spaced dates
Expand Down Expand Up @@ -426,27 +429,35 @@ class breaks_date:
Breaks at 4 year intervals
>>> breaks = breaks_date('4 year')
>>> breaks = breaks_date(width='4 year')
>>> [d.year for d in breaks(limits)]
[2010, 2014, 2018, 2022, 2026]
"""

n: int
width: int | None = None
units: DatetimeBreaksUnits | None = None

def __init__(self, n: int = 5, width: str | None = None):
if isinstance(n, str):
width = n

self.n = n
n: int = 5
_: KW_ONLY
width: str | None = None

_width: int | None = field(init=False, default=None)
_units: DatetimeBreaksUnits | None = field(init=False, default=None)

def __post_init__(self):
# For backwards compatibility
if isinstance(self.n, str) and self.width is None:
warn(
"Passing the width as the parameter has been deprecated "
"and will not work in a future version. "
'Use breaks_date(width="4 years")',
FutureWarning,
)
self.width = self.n

if width:
if self.width:
# Parse the width specification
# e.g. '10 months' => (10, month)
_w, units = width.strip().lower().split()
self.width = int(_w)
self.units = units.rstrip("s") # type: ignore
_w, units = self.width.strip().lower().split()
self._width = int(_w)
self._units = units.rstrip("s") # type: ignore

def __call__(
self, limits: tuple[datetime, datetime]
Expand All @@ -472,14 +483,15 @@ def __call__(
):
limits = limits[0].astype(object), limits[1].astype(object)

if self.units and self.width:
if self._units and self._width:
return calculate_date_breaks_byunits(
limits, self.units, self.width
limits, self._units, self._width
)
else:
return calculate_date_breaks_auto(limits, self.n)


@dataclass
class breaks_timedelta:
"""
Timedelta breaks
Expand All @@ -502,10 +514,11 @@ class breaks_timedelta:
[0.0, 5.0, 10.0, 15.0, 20.0, 25.0]
"""

_calculate_breaks: Callable[[tuple[float, float]], NDArrayFloat]
n: int = 5
Q: Sequence[float] = (1, 2, 5, 10)

def __init__(self, n: int = 5, Q: Sequence[float] = (1, 2, 5, 10)):
self._calculate_breaks = breaks_extended(n=n, Q=Q)
def __post_init__(self):
self._calculate_breaks = breaks_extended(n=self.n, Q=self.Q)

def __call__(
self, limits: tuple[Timedelta, Timedelta]
Expand Down Expand Up @@ -534,6 +547,7 @@ def __call__(


# This could be cleaned up, state overload?
@dataclass
class timedelta_helper:
"""
Helper for computing timedelta breaks
Expand Down Expand Up @@ -561,22 +575,14 @@ class timedelta_helper:
"""

x: TimedeltaArrayLike
units: DurationUnit
limits: tuple[float, float]
package: Literal["pandas", "cpython"]
factor: float
units: DurationUnit | None = None

def __init__(
self,
x: TimedeltaArrayLike,
units: DurationUnit | None = None,
):
self.x = x
self.package = self.determine_package(x[0])
_limits = min(x), max(x)
self.limits = self.value(_limits[0]), self.value(_limits[1])
self.units = units or self.best_units(_limits)
self.factor = self.get_scaling_factor(self.units)
def __post_init__(self):
l, h = min(self.x), max(self.x)
self.package = self.determine_package(self.x[0])
self.limits = self.value(l), self.value(h)
self._units: DurationUnit = self.units or self.best_units((l, h))
self.factor = self.get_scaling_factor(self._units)

@classmethod
def determine_package(cls, td: Timedelta) -> Literal["pandas", "cpython"]:
Expand All @@ -594,7 +600,7 @@ def format_info(
cls, x: TimedeltaArrayLike, units: DurationUnit | None = None
) -> tuple[NDArrayFloat, DurationUnit]:
helper = cls(x, units)
return helper.timedelta_to_numeric(x), helper.units
return helper.timedelta_to_numeric(x), helper._units

def best_units(self, x: TimedeltaArrayLike) -> DurationUnit:
"""
Expand Down Expand Up @@ -691,11 +697,12 @@ def to_numeric(self, td: Timedelta) -> float:
determined with the object is initialised.
"""
if isinstance(td, pd.Timedelta):
return td.value / NANOSECONDS[self.units]
return td.value / NANOSECONDS[self._units]
else:
return td.total_seconds() / SECONDS[self.units]
return td.total_seconds() / SECONDS[self._units]


@dataclass
class breaks_extended:
"""
An extension of Wilkinson's tick position algorithm
Expand Down Expand Up @@ -732,19 +739,14 @@ class breaks_extended:
implementation is almost entirely based.
"""

def __init__(
self,
n: int = 5,
Q: Sequence[float] = (1, 5, 2, 2.5, 4, 3),
only_inside: bool = False,
w: Sequence[float] = (0.25, 0.2, 0.5, 0.05),
):
self.Q = Q
self.only_inside = only_inside
self.w = w
self.n = n
n: int = 5
Q: Sequence[float] = (1, 5, 2, 2.5, 4, 3)
only_inside: bool = False
w: Sequence[float] = (0.25, 0.2, 0.5, 0.05)

def __post_init__(self):
# Used for lookups during the computations
self.Q_index = {q: i for i, q in enumerate(Q)}
self.Q_index = {q: i for i, q in enumerate(self.Q)}

def coverage(
self, dmin: float, dmax: float, lmin: float, lmax: float
Expand Down Expand Up @@ -909,12 +911,6 @@ class breaks_symlog:
"""
Breaks for the Symmetric Logarithm Transform
Parameters
----------
n : int
Desired number of breaks
base : int
Base of logarithm
Examples
--------
Expand Down
10 changes: 7 additions & 3 deletions tests/test_breaks.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,20 +182,24 @@ class square_trans(trans):
def test_breaks_date():
# cpython
limits = (datetime(2010, 1, 1), datetime(2026, 1, 1))
breaks = breaks_date("5 Years")
breaks = breaks_date(width="5 Years")
assert [d.year for d in breaks(limits)] == [2010, 2015, 2020, 2025, 2030]

breaks = breaks_date("10 Years")(limits)
breaks = breaks_date(width="10 Years")(limits)
assert [d.year for d in breaks] == [2010, 2020, 2030]

with pytest.warns(FutureWarning):
breaks = breaks_date("10 Years")(limits)
assert [d.year for d in breaks] == [2010, 2020, 2030]

# numpy datetime64
limits = (np.datetime64("1973"), np.datetime64("1997"))
breaks = breaks_date(width="10 Years")(limits)
assert [d.year for d in breaks] == [1970, 1980, 1990, 2000]

# NaT
limits = np.datetime64("NaT"), datetime(2017, 1, 1)
breaks = breaks_date("10 Years")(limits)
breaks = breaks_date(width="10 Years")(limits)
assert len(breaks) == 0

# automatic monthly breaks
Expand Down

0 comments on commit ff36853

Please sign in to comment.