Skip to content

Commit

Permalink
Use registry to keep track of transform functions
Browse files Browse the repository at this point in the history
  • Loading branch information
has2k1 committed Oct 23, 2024
1 parent ec5ee42 commit e74f102
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 21 deletions.
40 changes: 20 additions & 20 deletions mizani/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
)

if TYPE_CHECKING:
from typing import Any, Callable, Sequence, Type
from typing import Any, Sequence, Type

from mizani.typing import (
BreaksFunction,
Expand Down Expand Up @@ -101,6 +101,7 @@
]

UTC = ZoneInfo("UTC")
REGISTRY: dict[str, Type[trans]] = {}


@dataclass(kw_only=True)
Expand All @@ -122,6 +123,11 @@ class trans(ABC):
minor_breaks_func: MinorBreaksFunction | None = None
"Callable to calculate minor breaks"

def __init_subclass__(cls, *args, **kwargs):
# Register all subclasses
super().__init_subclass__(*args, **kwargs)
REGISTRY[cls.__name__] = cls

# Use type variables for trans.transform and trans.inverse
# to help upstream packages avoid type mismatches. e.g.
# transform(tuple[float, float]) -> tuple[float, float]
Expand Down Expand Up @@ -907,36 +913,30 @@ def inverse(self, x: FloatArrayLike) -> NDArrayFloat:
return np.sign(x) * (np.exp(np.abs(x)) - 1) # type: ignore


def gettrans(
t: str | Callable[[], Type[trans]] | Type[trans] | trans | None = None,
):
def gettrans(t: str | Type[trans] | trans | None = None):
"""
Return a trans object
Parameters
----------
t : str | callable | type | trans
t : str | type | trans
Name of transformation function. If None, returns an
identity transform.
Returns
-------
out : trans
"""
obj = t
# Make sure trans object is instantiated
if t is None:
if isinstance(t, str):
names = (f"{t}_trans", t)
for name in names:
if t := REGISTRY.get(name):
return t()
elif isinstance(t, trans):
return t
elif isinstance(t, type) and issubclass(t, trans):
return t()
elif t is None:
return identity_trans()

if isinstance(obj, str):
name = "{}_trans".format(obj)
obj = globals()[name]()
if callable(obj):
obj = obj()
if isinstance(obj, type):
obj = obj()

if not isinstance(obj, trans):
raise ValueError("Could not get transform object.")

return obj
raise ValueError(f"Could not get transform object. {t}")
10 changes: 9 additions & 1 deletion tests/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,11 @@
log2_trans,
log10_trans,
log_trans,
logit_trans,
modulus_trans,
pd_timedelta_trans,
probability_trans,
probit_trans,
pseudo_log_trans,
reciprocal_trans,
reverse_trans,
Expand All @@ -46,7 +48,9 @@ def test_gettrans():
t2 = gettrans(identity_trans)
t3 = gettrans("identity")
t4 = gettrans()
assert all(isinstance(x, identity_trans) for x in (t0, t1, t2, t3, t4))
assert all(
x.__class__.__name__ == "identity_trans" for x in (t0, t1, t2, t3, t4)
)

with pytest.raises(ValueError):
gettrans(object)
Expand Down Expand Up @@ -197,6 +201,10 @@ def test_probability_trans():
npt.assert_allclose(xt[:3], 1 - xt[-3:][::-1])
npt.assert_allclose(x, x2)

# Cover the paths these create as well
logit_trans()
probit_trans()


def test_datetime_trans():
UTC = ZoneInfo("UTC")
Expand Down

0 comments on commit e74f102

Please sign in to comment.