Skip to content

Commit

Permalink
refactor(enums): move config to general (#1233)
Browse files Browse the repository at this point in the history
  • Loading branch information
bonjourmauko committed Oct 5, 2024
1 parent 6480a2c commit 738652a
Show file tree
Hide file tree
Showing 12 changed files with 133 additions and 103 deletions.
68 changes: 68 additions & 0 deletions openfisca_core/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from __future__ import annotations

from typing import Final, final

import dataclasses
import datetime

from openfisca_core import indexed_enums as enum

from . import types as t


@final
@dataclasses.dataclass(frozen=True)
class ValueType:
dtype: t.DTypeLike
default: object
json_type: str
formatted_value_type: str
is_period_size_independent: bool


value_types: Final = {
bool: ValueType(
dtype=t.DTypeBool,
default=False,
json_type="boolean",
formatted_value_type="Boolean",
is_period_size_independent=True,
),
int: ValueType(
dtype=t.DTypeInt,
default=0,
json_type="integer",
formatted_value_type="Int",
is_period_size_independent=False,
),
float: ValueType(
dtype=t.DTypeFloat,
default=0,
json_type="number",
formatted_value_type="Float",
is_period_size_independent=False,
),
str: ValueType(
dtype=t.DTypeStr,
default="",
json_type="string",
formatted_value_type="String",
is_period_size_independent=True,
),
enum.Enum: ValueType(
dtype=t.DTypeEnum,
default=None,
json_type="string",
formatted_value_type="String",
is_period_size_independent=True,
),
datetime.date: ValueType(
dtype="datetime64[D]",
default=None,
json_type="string",
formatted_value_type="Date",
is_period_size_independent=True,
),
}

__all__ = ["value_types"]
1 change: 0 additions & 1 deletion openfisca_core/entities/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ class RoleParams(TypedDict, total=False):
"GroupEntity",
"Role",
"RoleKey",
"RoleParams",
"RolePlural",
"SingleEntity",
"TaxBenefitSystem",
Expand Down
11 changes: 4 additions & 7 deletions openfisca_core/indexed_enums/enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,7 @@ def __init__(self, name: str) -> None:
def encode(
cls,
array: (
EnumArray
| t.Array[numpy.str_]
| t.Array[numpy.int16]
| t.Array[numpy.int32]
EnumArray | t.Array[t.DTypeEnum] | t.Array[t.DTypeInt] | t.Array[t.DTypeStr]
),
) -> EnumArray:
"""Encode a string numpy array, an enum item numpy array, or an int numpy
Expand Down Expand Up @@ -68,7 +65,7 @@ def encode(
array = numpy.select(
[array == item.name for item in cls],
[item.index for item in cls],
).astype(numpy.int16)
).astype(t.DTypeEnum)

# Enum items arrays
elif array.dtype.kind == "O":
Expand All @@ -92,9 +89,9 @@ def encode(
array = numpy.select(
[array == item for item in klass],
[item.index for item in klass],
).astype(numpy.int16)
).astype(t.DTypeEnum)

array = numpy.asarray(array, dtype=numpy.int16)
array = numpy.asarray(array, dtype=t.DTypeEnum)
return EnumArray(array, cls)


Expand Down
28 changes: 14 additions & 14 deletions openfisca_core/indexed_enums/enum_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,28 +32,28 @@ def __new__(

# See previous comment
def __array_finalize__(
self, obj: None | t.EnumArray | t.Array[numpy.generic]
self, obj: None | t.EnumArray | t.Array[t.DTypeObject]
) -> None:
if obj is None:
return None
return
if isinstance(obj, EnumArray):
self.possible_values = obj.possible_values
return None
return

@overload # type: ignore[override]
def __eq__(self, other: None | t.Enum | type[t.Enum]) -> t.Array[numpy.bool_]: ...
def __eq__(self, other: None | t.Enum | type[t.Enum]) -> t.Array[t.DTypeBool]: ...

@overload
def __eq__(self, other: object) -> t.Array[numpy.bool_] | bool: ...
def __eq__(self, other: object) -> t.Array[t.DTypeBool] | bool: ...

def __eq__(self, other: object) -> t.Array[numpy.bool_] | bool:
boolean_array: t.Array[numpy.bool_]
def __eq__(self, other: object) -> t.Array[t.DTypeBool] | bool:
boolean_array: t.Array[t.DTypeBool]
boolean: bool

if self.possible_values is None:
return NotImplemented

view: t.Array[numpy.int16] = self.view(numpy.ndarray)
view: t.Array[t.DTypeEnum] = self.view(numpy.ndarray)

if other is None or self._is_an_enum_type(other):
boolean_array = view == other
Expand All @@ -71,12 +71,12 @@ def __eq__(self, other: object) -> t.Array[numpy.bool_] | bool:
return boolean

@overload # type: ignore[override]
def __ne__(self, other: None | t.Enum | type[t.Enum]) -> t.Array[numpy.bool_]: ...
def __ne__(self, other: None | t.Enum | type[t.Enum]) -> t.Array[t.DTypeBool]: ...

@overload
def __ne__(self, other: object) -> t.Array[numpy.bool_] | bool: ...
def __ne__(self, other: object) -> t.Array[t.DTypeBool] | bool: ...

def __ne__(self, other: object) -> t.Array[numpy.bool_] | bool:
def __ne__(self, other: object) -> t.Array[t.DTypeBool] | bool:
return numpy.logical_not(self == other)

def _forbidden_operation(self, other: object) -> NoReturn:
Expand All @@ -97,7 +97,7 @@ def _forbidden_operation(self, other: object) -> NoReturn:
__and__ = _forbidden_operation # type: ignore[assignment]
__or__ = _forbidden_operation # type: ignore[assignment]

def decode(self) -> t.Array[numpy.int16]:
def decode(self) -> t.Array[t.DTypeEnum]:
"""Return the array of enum items corresponding to self.
For instance:
Expand All @@ -116,10 +116,10 @@ def decode(self) -> t.Array[numpy.int16]:

return numpy.select(
[self == item.index for item in self.possible_values],
[item for item in self.possible_values], # type: ignore[misc]
list(self.possible_values), # pyright: ignore[reportArgumentType]
)

def decode_to_str(self) -> t.Array[numpy.str_]:
def decode_to_str(self) -> t.Array[t.DTypeStr]:
"""Return the array of string identifiers corresponding to self.
For instance:
Expand Down
22 changes: 20 additions & 2 deletions openfisca_core/indexed_enums/types.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,21 @@
from openfisca_core.types import Array, DTypeEnum, Enum, EnumArray
from openfisca_core.types import (
Array,
DTypeBool,
DTypeEnum,
DTypeInt,
DTypeObject,
DTypeStr,
Enum,
EnumArray,
)

__all__ = ["Array", "DTypeEnum", "Enum", "EnumArray"]
__all__ = [
"Array",
"DTypeBool",
"DTypeEnum",
"DTypeInt",
"DTypeObject",
"DTypeStr",
"Enum",
"EnumArray",
]
5 changes: 0 additions & 5 deletions openfisca_core/periods/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,11 +173,6 @@ class PeriodStr(str, metaclass=_PeriodStrMeta): # type: ignore[misc]

__all__ = [
"DateUnit",
"ISOCalendarStr",
"ISOFormatStr",
"Instant",
"InstantStr",
"Period",
"PeriodStr",
"SeqInt",
]
23 changes: 13 additions & 10 deletions openfisca_core/types.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from collections.abc import Iterable, Sequence, Sized
from numpy.typing import NDArray
from numpy.typing import DTypeLike, NDArray
from typing import Any, NewType, TypeVar, Union
from typing_extensions import Protocol, Self, TypeAlias

Expand All @@ -14,7 +14,7 @@
#: Generic covariant type var.
_T_co = TypeVar("_T_co", covariant=True)

# Commons
# Arrays

#: Type var for numpy arrays.
_N_co = TypeVar("_N_co", covariant=True, bound="DTypeGeneric")
Expand Down Expand Up @@ -105,6 +105,14 @@ def key(self, /) -> RoleKey: ...
def plural(self, /) -> None | RolePlural: ...


# Holders


class Holder(Protocol):
def clone(self, population: Any, /) -> Holder: ...
def get_memory_usage(self, /) -> Any: ...


# Indexed enums


Expand All @@ -121,14 +129,6 @@ def __new__(
) -> Self: ...


# Holders


class Holder(Protocol):
def clone(self, population: Any, /) -> Holder: ...
def get_memory_usage(self, /) -> Any: ...


# Parameters


Expand Down Expand Up @@ -239,3 +239,6 @@ def __call__(

class Params(Protocol):
def __call__(self, instant: Instant, /) -> ParameterNodeAtInstant: ...


__all__ = ["DTypeLike"]
2 changes: 1 addition & 1 deletion openfisca_core/variables/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,6 @@
#
# See: https://www.python.org/dev/peps/pep-0008/#imports

from .config import FORMULA_NAME_PREFIX, VALUE_TYPES # noqa: F401
from .config import FORMULA_NAME_PREFIX # noqa: F401
from .helpers import get_annualized_variable, get_neutralized_variable # noqa: F401
from .variable import Variable # noqa: F401
52 changes: 0 additions & 52 deletions openfisca_core/variables/config.py
Original file line number Diff line number Diff line change
@@ -1,53 +1 @@
import datetime

import numpy

from openfisca_core import indexed_enums
from openfisca_core.indexed_enums import Enum

VALUE_TYPES = {
bool: {
"dtype": numpy.bool_,
"default": False,
"json_type": "boolean",
"formatted_value_type": "Boolean",
"is_period_size_independent": True,
},
int: {
"dtype": numpy.int32,
"default": 0,
"json_type": "integer",
"formatted_value_type": "Int",
"is_period_size_independent": False,
},
float: {
"dtype": numpy.float32,
"default": 0,
"json_type": "number",
"formatted_value_type": "Float",
"is_period_size_independent": False,
},
str: {
"dtype": object,
"default": "",
"json_type": "string",
"formatted_value_type": "String",
"is_period_size_independent": True,
},
Enum: {
"dtype": indexed_enums.ENUM_ARRAY_DTYPE,
"json_type": "string",
"formatted_value_type": "String",
"is_period_size_independent": True,
},
datetime.date: {
"dtype": "datetime64[D]",
"default": datetime.date.fromtimestamp(0), # 0 == 1970-01-01
"json_type": "string",
"formatted_value_type": "Date",
"is_period_size_independent": True,
},
}


FORMULA_NAME_PREFIX = "formula"
Loading

0 comments on commit 738652a

Please sign in to comment.