diff --git a/policyengine_core/parameters/parameter_scale.py b/policyengine_core/parameters/parameter_scale.py index 87251bf3..2e750ac3 100644 --- a/policyengine_core/parameters/parameter_scale.py +++ b/policyengine_core/parameters/parameter_scale.py @@ -71,6 +71,17 @@ def __repr__(self) -> str: ] ) + def propagate_units(self) -> None: + for unit_key in parameters.ParameterScaleBracket.allowed_unit_keys(): + if unit_key in self.metadata: + child_key = unit_key[:-5] + for bracket in self.brackets: + if child_key in bracket.children: + if "unit" not in bracket.children[child_key].metadata: + bracket.children[child_key].metadata["unit"] = ( + self.metadata[unit_key] + ) + def get_descendants(self) -> Iterable: for bracket in self.brackets: yield bracket diff --git a/policyengine_core/parameters/parameter_scale_bracket.py b/policyengine_core/parameters/parameter_scale_bracket.py index e8b8210a..0546962d 100644 --- a/policyengine_core/parameters/parameter_scale_bracket.py +++ b/policyengine_core/parameters/parameter_scale_bracket.py @@ -11,6 +11,10 @@ class ParameterScaleBracket(ParameterNode): ["amount", "threshold", "rate", "average_rate", "base"] ) + @staticmethod + def allowed_unit_keys(): + return [key + "_unit" for key in ParameterScaleBracket._allowed_keys] + def get_descendants(self) -> Iterable[Parameter]: for key in self._allowed_keys: if key in self.children: diff --git a/tests/core/parameter_validation/parameter_for_unit_propagation.yaml b/tests/core/parameter_validation/parameter_for_unit_propagation.yaml new file mode 100644 index 00000000..c689db03 --- /dev/null +++ b/tests/core/parameter_validation/parameter_for_unit_propagation.yaml @@ -0,0 +1,68 @@ +description: Propagate units in metadata of a scaled parameter to its descendants +metadata: + type: single_amount + threshold_unit: child + amount_unit: currency-USD + label: Test unit propagation + +brackets: + - threshold: + values: + 1995-01-01: 0 + amount: + values: + 2017-01-01: 8_340 + 2018-01-01: 8_510 + 2019-01-01: 8_650 + 2020-01-01: 8_790 + 2021-01-01: 11_610 + 2022-01-01: 9_160 + 2023-01-01: 9_800 + 2024-01-01: 10_330 + - threshold: + values: + 1995-01-01: 1 + amount: + values: + 2017-01-01: 18_340 + 2018-01-01: 18_700 + 2019-01-01: 19_030 + 2020-01-01: 19_330 + 2021-01-01: 19_520 + 2022-01-01: 20_130 + 2023-01-01: 21_560 + 2024-01-01: 22_720 + metadata: + a: b + - threshold: + values: + 1995-01-01: 2 + metadata: + a: b + amount: + values: + 2017-01-01: 18_340 + 2018-01-01: 18_700 + 2019-01-01: 19_030 + 2020-01-01: 19_330 + 2021-01-01: 19_520 + 2022-01-01: 20_130 + 2023-01-01: 21_560 + 2024-01-01: 22_720 + - threshold: + values: + 1995-01-01: 3 + metadata: + a: b + amount: + values: + 2017-01-01: 18_340 + 2018-01-01: 18_700 + 2019-01-01: 19_030 + 2020-01-01: 19_330 + 2021-01-01: 19_520 + 2022-01-01: 20_130 + 2023-01-01: 21_560 + 2024-01-01: 22_720 + metadata: + a: b diff --git a/tests/core/parameter_validation/test_propagate_units.py b/tests/core/parameter_validation/test_propagate_units.py new file mode 100644 index 00000000..59325469 --- /dev/null +++ b/tests/core/parameter_validation/test_propagate_units.py @@ -0,0 +1,14 @@ +import os + +from policyengine_core.parameters import load_parameter_file + +BASE_DIR = os.path.dirname(os.path.abspath(__file__)) + + +def test_propagate_units(): + path = os.path.join(BASE_DIR, "parameter_for_unit_propagation.yaml") + parameter = load_parameter_file(path) + parameter.propagate_units() + for i in range(4): + assert parameter.brackets[i].threshold.metadata["unit"] == "child" + assert parameter.brackets[i].amount.metadata["unit"] == "currency-USD"