Skip to content

Commit

Permalink
feat: add method for propagating units to descendants for `ParameterS…
Browse files Browse the repository at this point in the history
…cale` objects
  • Loading branch information
abhcs committed Feb 26, 2024
1 parent 94b78ef commit 357e2a7
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 0 deletions.
11 changes: 11 additions & 0 deletions policyengine_core/parameters/parameter_scale.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,17 @@ def __repr__(self) -> str:
]
)

def propagate_units(self) -> None:
for unit_key in parameters.ParameterScaleBracket.allowed_unit_keys():
if unit_key in self.metadata:
child_key = unit_key[:-5]
for bracket in self.brackets:
if child_key in bracket.children:
if "unit" not in bracket.children[child_key].metadata:
bracket.children[child_key].metadata["unit"] = (
self.metadata[unit_key]
)

def get_descendants(self) -> Iterable:
for bracket in self.brackets:
yield bracket
Expand Down
4 changes: 4 additions & 0 deletions policyengine_core/parameters/parameter_scale_bracket.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@ class ParameterScaleBracket(ParameterNode):
["amount", "threshold", "rate", "average_rate", "base"]
)

@staticmethod
def allowed_unit_keys():
return [key + "_unit" for key in ParameterScaleBracket._allowed_keys]

def get_descendants(self) -> Iterable[Parameter]:
for key in self._allowed_keys:
if key in self.children:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
description: Propagate units in metadata of a scaled parameter to its descendants
metadata:
type: single_amount
threshold_unit: child
amount_unit: currency-USD
label: Test unit propagation

brackets:
- threshold:
values:
1995-01-01: 0
amount:
values:
2017-01-01: 8_340
2018-01-01: 8_510
2019-01-01: 8_650
2020-01-01: 8_790
2021-01-01: 11_610
2022-01-01: 9_160
2023-01-01: 9_800
2024-01-01: 10_330
- threshold:
values:
1995-01-01: 1
amount:
values:
2017-01-01: 18_340
2018-01-01: 18_700
2019-01-01: 19_030
2020-01-01: 19_330
2021-01-01: 19_520
2022-01-01: 20_130
2023-01-01: 21_560
2024-01-01: 22_720
metadata:
a: b
- threshold:
values:
1995-01-01: 2
metadata:
a: b
amount:
values:
2017-01-01: 18_340
2018-01-01: 18_700
2019-01-01: 19_030
2020-01-01: 19_330
2021-01-01: 19_520
2022-01-01: 20_130
2023-01-01: 21_560
2024-01-01: 22_720
- threshold:
values:
1995-01-01: 3
metadata:
a: b
amount:
values:
2017-01-01: 18_340
2018-01-01: 18_700
2019-01-01: 19_030
2020-01-01: 19_330
2021-01-01: 19_520
2022-01-01: 20_130
2023-01-01: 21_560
2024-01-01: 22_720
metadata:
a: b
14 changes: 14 additions & 0 deletions tests/core/parameter_validation/test_propagate_units.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import os

from policyengine_core.parameters import load_parameter_file

BASE_DIR = os.path.dirname(os.path.abspath(__file__))


def test_propagate_units():
path = os.path.join(BASE_DIR, "parameter_for_unit_propagation.yaml")
parameter = load_parameter_file(path)
parameter.propagate_units()
for i in range(4):
assert parameter.brackets[i].threshold.metadata["unit"] == "child"
assert parameter.brackets[i].amount.metadata["unit"] == "currency-USD"

0 comments on commit 357e2a7

Please sign in to comment.