Skip to content

Commit

Permalink
docs(tracers): add types to full tracer (#1245)
Browse files Browse the repository at this point in the history
  • Loading branch information
bonjourmauko committed Oct 18, 2024
1 parent 41c35c0 commit b0cec15
Show file tree
Hide file tree
Showing 4 changed files with 161 additions and 24 deletions.
45 changes: 24 additions & 21 deletions openfisca_core/tracers/flat_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,27 +37,6 @@ def get_serialized_trace(self) -> t.SerializedNodeMap:
for key, flat_trace in self.get_trace().items()
}

def serialize(
self,
value: None | t.VarArray | t.ArrayLike[object],
) -> None | t.ArrayLike[object]:
if value is None:
return None

if isinstance(value, EnumArray):
return value.decode_to_str().tolist()

if isinstance(value, numpy.ndarray) and numpy.issubdtype(
value.dtype,
numpy.dtype(bytes),
):
return value.astype(numpy.dtype(str)).tolist()

if isinstance(value, numpy.ndarray):
return value.tolist()

return value

def _get_flat_trace(
self,
node: t.TraceNode,
Expand All @@ -83,3 +62,27 @@ def key(node: t.TraceNode) -> t.NodeKey:
name = node.name
period = node.period
return t.NodeKey(f"{name}<{period}>")

@staticmethod
def serialize(
value: None | t.VarArray | t.ArrayLike[object],
) -> None | t.ArrayLike[object]:
if value is None:
return None

if isinstance(value, EnumArray):
return value.decode_to_str().tolist()

if isinstance(value, numpy.ndarray) and numpy.issubdtype(
value.dtype,
numpy.dtype(bytes),
):
return value.astype(numpy.dtype(str)).tolist()

if isinstance(value, numpy.ndarray):
return value.tolist()

return value


__all__ = ["FlatTrace"]
7 changes: 5 additions & 2 deletions openfisca_core/tracers/full_tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,10 +91,10 @@ def generate_performance_tables(self, dir_path: str) -> None:
def get_nb_requests(self, variable: str) -> int:
return sum(self._get_nb_requests(tree, variable) for tree in self.trees)

def get_flat_trace(self) -> dict:
def get_flat_trace(self) -> t.FlatNodeMap:
return self.flat_trace.get_trace()

def get_serialized_flat_trace(self) -> dict:
def get_serialized_flat_trace(self) -> t.SerializedNodeMap:
return self.flat_trace.get_serialized_trace()

def browse_trace(self) -> Iterator[t.TraceNode]:
Expand Down Expand Up @@ -161,3 +161,6 @@ def _get_nb_requests(self, tree: t.TraceNode, variable: str) -> int:
@staticmethod
def _get_time_in_sec() -> t.Time:
return time.time_ns() / (10**9)


__all__ = ["FullTracer"]
108 changes: 108 additions & 0 deletions openfisca_core/tracers/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
from __future__ import annotations

from collections.abc import Iterator
from typing import NewType, Protocol
from typing_extensions import TypeAlias, TypedDict

from openfisca_core.types import (
Array,
ArrayLike,
ParameterNode,
ParameterNodeChild,
Period,
PeriodInt,
VariableName,
)

from numpy import generic as VarDType

#: A type of a generic array.
VarArray: TypeAlias = Array[VarDType]

#: A type representing a unit time.
Time: TypeAlias = float

#: A type representing a mapping of flat traces.
FlatNodeMap: TypeAlias = dict["NodeKey", "FlatTraceMap"]

#: A type representing a mapping of serialized traces.
SerializedNodeMap: TypeAlias = dict["NodeKey", "SerializedTraceMap"]

#: A stack of simple traces.
SimpleStack: TypeAlias = list["SimpleTraceMap"]

#: Key of a trace.
NodeKey = NewType("NodeKey", str)


class FlatTraceMap(TypedDict, total=True):
dependencies: list[NodeKey]
parameters: dict[NodeKey, None | ArrayLike[object]]
value: None | VarArray
calculation_time: Time
formula_time: Time


class SerializedTraceMap(TypedDict, total=True):
dependencies: list[NodeKey]
parameters: dict[NodeKey, None | ArrayLike[object]]
value: None | ArrayLike[object]
calculation_time: Time
formula_time: Time


class SimpleTraceMap(TypedDict, total=True):
name: VariableName
period: int | Period


class ComputationLog(Protocol):
def print_log(self, aggregate: bool = ..., max_depth: int = ..., /) -> None: ...


class FlatTrace(Protocol):
def get_trace(self, /) -> FlatNodeMap: ...
def get_serialized_trace(self, /) -> SerializedNodeMap: ...


class FullTracer(Protocol):
@property
def trees(self, /) -> list[TraceNode]: ...
def browse_trace(self, /) -> Iterator[TraceNode]: ...


class PerformanceLog(Protocol):
def generate_graph(self, dir_path: str, /) -> None: ...
def generate_performance_tables(self, dir_path: str, /) -> None: ...


class SimpleTracer(Protocol):
@property
def stack(self, /) -> SimpleStack: ...
def record_calculation_start(
self, variable: VariableName, period: PeriodInt | Period, /
) -> None: ...
def record_calculation_end(self, /) -> None: ...


class TraceNode(Protocol):
children: list[TraceNode]
end: Time
name: str
parameters: list[TraceNode]
parent: None | TraceNode
period: PeriodInt | Period
start: Time
value: None | VarArray

def calculation_time(self, *, round_: bool = ...) -> Time: ...
def formula_time(self, /) -> Time: ...
def append_child(self, node: TraceNode, /) -> None: ...


__all__ = [
"ArrayLike",
"ParameterNode",
"ParameterNodeChild",
"PeriodInt",
]
25 changes: 24 additions & 1 deletion openfisca_core/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,31 @@ class MemoryUsage(TypedDict, total=False):

# Parameters

#: A type representing a node of parameters.
ParameterNode: TypeAlias = Union[
"ParameterNodeAtInstant", "VectorialParameterNodeAtInstant"
]

class ParameterNodeAtInstant(Protocol): ...
#: A type representing a ???
ParameterNodeChild: TypeAlias = Union[ParameterNode, ArrayLike[object]]


class ParameterNodeAtInstant(Protocol):
_instant_str: InstantStr

def __contains__(self, __item: object, /) -> bool: ...
def __getitem__(
self, __index: str | Array[DTypeGeneric], /
) -> ParameterNodeChild: ...


class VectorialParameterNodeAtInstant(Protocol):
_instant_str: InstantStr

def __contains__(self, item: object, /) -> bool: ...
def __getitem__(
self, __index: str | Array[DTypeGeneric], /
) -> ParameterNodeChild: ...


# Periods
Expand Down

0 comments on commit b0cec15

Please sign in to comment.