Skip to content

Commit

Permalink
chore: Format
Browse files Browse the repository at this point in the history
  • Loading branch information
anth-volk committed Jan 2, 2025
1 parent 3adcdf8 commit 4a47959
Showing 1 changed file with 57 additions and 28 deletions.
85 changes: 57 additions & 28 deletions policyengine_core/reforms/structural_reform.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@
from datetime import datetime
from dataclasses import dataclass
from policyengine_core.variables import Variable
from policyengine_core.parameters import Parameter, ParameterNode, ParameterAtInstant
from policyengine_core.parameters import (
Parameter,
ParameterNode,
ParameterAtInstant,
)
from policyengine_core.periods import config
from policyengine_core.taxbenefitsystems import TaxBenefitSystem
from policyengine_core.errors import (
Expand All @@ -22,7 +26,7 @@ class TransformationLogItem:
transformation: Literal["neutralize", "add", "update"]


class StructuralReform:
class StructuralReform:

DEFAULT_START_INSTANT = "0000-01-01"
transformation_log: list[TransformationLogItem] = []
Expand Down Expand Up @@ -53,22 +57,26 @@ def activate(
tax_benefit_system: The tax benefit system to which the structural reform will be applied
"""
if tax_benefit_system is None:
raise ValueError("Tax benefit system must be provided.")
raise ValueError("Tax benefit system must be provided.")

if not isinstance(tax_benefit_system, TaxBenefitSystem):
raise TypeError(
"Tax benefit system must be an instance of the TaxBenefitSystem class."
)

self.tax_benefit_system = tax_benefit_system

# Fetch the trigger parameter
trigger_parameter: Parameter = self._fetch_parameter(self.trigger_parameter_path)
trigger_parameter: Parameter = self._fetch_parameter(
self.trigger_parameter_path
)

# Parse date out of trigger parameter and set
start_instant: Annotated[str, "YYYY-MM-DD"] | None
end_instant: Annotated[str, "YYYY-MM-DD"] | None
start_instant, end_instant = self._parse_activation_period(trigger_parameter)
start_instant, end_instant = self._parse_activation_period(
trigger_parameter
)

self.start_instant = start_instant
self.end_instant = end_instant
Expand Down Expand Up @@ -112,8 +120,8 @@ def add_variable(self, variable: Variable):

def update_variable(self, variable: Variable):
"""
When structural reform is activated, update a
variable in the tax benefit system; if the variable
When structural reform is activated, update a
variable in the tax benefit system; if the variable
does not yet exist, it will be added.
Args:
Expand Down Expand Up @@ -259,10 +267,10 @@ def _fetch_variable(self, name: str) -> Variable | None:
name: The name of the variable
"""
return self.tax_benefit_system.get_variable(name)

def _fetch_parameter(self, parameter_path: str) -> Parameter:
"""
Given a dot-notated string, fetch a parameter by
Given a dot-notated string, fetch a parameter by
reference from the tax benefit system.
Args:
Expand All @@ -280,11 +288,15 @@ def _fetch_parameter(self, parameter_path: str) -> Parameter:
try:
current = getattr(current, key)
except AttributeError:
raise AttributeError(f"Unable to find parameter at path '{full_path}'") from None

raise AttributeError(
f"Unable to find parameter at path '{full_path}'"
) from None

if not isinstance(current, Parameter):
raise AttributeError(f"Parameter at path '{full_path}' is not a Parameter, but a {type(current)}")

raise AttributeError(
f"Parameter at path '{full_path}' is not a Parameter, but a {type(current)}"
)

return current

# Method to modify metadata based on new items?
Expand Down Expand Up @@ -371,51 +383,68 @@ def _validate_instant(self, instant: Any) -> bool:

return True

def _parse_activation_period(self, trigger_parameter: Parameter) -> tuple[Annotated[str, "YYYY-MM-DD"] | None, Annotated[str, "YYYY-MM-DD"] | None]:
def _parse_activation_period(self, trigger_parameter: Parameter) -> tuple[
Annotated[str, "YYYY-MM-DD"] | None,
Annotated[str, "YYYY-MM-DD"] | None,
]:
"""
Given a trigger parameter, parse the reform start and end dates and return them.
Returns:
A tuple containing the start and end dates of the reform,
A tuple containing the start and end dates of the reform,
or None if the reform is not triggered
"""

# Crash if trigger param isn't Boolean; this shouldn't be used as a trigger
if (trigger_parameter.metadata is None) or (trigger_parameter.metadata["unit"] != "bool"):
if (trigger_parameter.metadata is None) or (
trigger_parameter.metadata["unit"] != "bool"
):
raise ValueError("Trigger parameter must be a Boolean.")

# Build custom representation of trigger parameter instants and values
values_dict: dict[Annotated[str, "YYYY-MM-DD"], int | float] = self._generate_param_values_dict(trigger_parameter.values_list)
values_dict: dict[Annotated[str, "YYYY-MM-DD"], int | float] = (
self._generate_param_values_dict(trigger_parameter.values_list)
)

if list(values_dict.values()).count(True) > 1:
raise ValueError("Trigger parameter must only be activated once.")

if list(values_dict.values()).count(True) == 0:
return (None, None)

# Now that True only occurs once, find it
start_instant_index: int = list(values_dict.values()).index(True)
start_instant: Annotated[str, "YYYY-MM-DD"] = list(values_dict.keys())[start_instant_index]
start_instant: Annotated[str, "YYYY-MM-DD"] = list(values_dict.keys())[
start_instant_index
]
self._validate_instant(start_instant)

# If it's the last item, the reform occurs into perpetuity, else
# If it's the last item, the reform occurs into perpetuity, else
# the reform ends at the next instant
if start_instant_index == len(values_dict) - 1:
return (start_instant, None)

end_instant: Annotated[str, "YYYY-MM-DD"] = list(values_dict.keys())[start_instant_index + 1]
end_instant: Annotated[str, "YYYY-MM-DD"] = list(values_dict.keys())[
start_instant_index + 1
]
self._validate_instant(end_instant)
return (start_instant, end_instant)

def _generate_param_values_dict(self, values_list: list[ParameterAtInstant]) -> dict[Annotated[str, "YYYY-MM-DD"], int | float]:
def _generate_param_values_dict(
self, values_list: list[ParameterAtInstant]
) -> dict[Annotated[str, "YYYY-MM-DD"], int | float]:
"""
Given a list of ParameterAtInstant objects, generate a dictionary of the form {instant: value}.
Args:
values_list: The list of ParameterAtInstant objects
"""
unsorted_dict = {value.instant_str: value.value for value in values_list}
sorted_dict = dict(sorted(unsorted_dict.items(), key=lambda item: item[0]))
unsorted_dict = {
value.instant_str: value.value for value in values_list
}
sorted_dict = dict(
sorted(unsorted_dict.items(), key=lambda item: item[0])
)
return sorted_dict

# Default outputs method of some sort?

0 comments on commit 4a47959

Please sign in to comment.