Skip to content

Commit

Permalink
Merge pull request #224 from SylviaDu99/cache197_1
Browse files Browse the repository at this point in the history
add class: SimulationMacroCache
  • Loading branch information
anth-volk authored Sep 2, 2024
2 parents 17609b5 + 31ebb51 commit 8648c96
Show file tree
Hide file tree
Showing 8 changed files with 183 additions and 71 deletions.
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,12 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [3.6.1] - 2024-08-31 16:13:07

### Added

- Added class SimulationMacroCache for macro simulation caching purposes.

## [3.6.0] - 2024-08-28 16:41:26

### Fixed
Expand Down Expand Up @@ -797,6 +803,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0



[3.6.1]: https://github.com/PolicyEngine/policyengine-core/compare/3.6.0...3.6.1
[3.6.0]: https://github.com/PolicyEngine/policyengine-core/compare/3.5.3...3.6.0
[3.5.3]: https://github.com/PolicyEngine/policyengine-core/compare/3.5.2...3.5.3
[3.5.2]: https://github.com/PolicyEngine/policyengine-core/compare/3.5.1...3.5.2
Expand Down
5 changes: 5 additions & 0 deletions changelog.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -639,3 +639,8 @@
fixed:
- Bugs in typing and Dataset saving.
date: 2024-08-28 16:41:26
- bump: patch
changes:
added:
- Added class SimulationMacroCache for macro simulation caching purposes.
date: 2024-08-31 16:13:07
4 changes: 4 additions & 0 deletions changelog_entry.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
- bump: patch
changes:
added:
- Added class SimulationMacroCache for macro simulation caching purposes.
84 changes: 84 additions & 0 deletions policyengine_core/simulations/sim_macro_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import shutil
from pathlib import Path
import h5py
from numpy.typing import ArrayLike
import importlib.metadata

from policyengine_core.taxbenefitsystems import TaxBenefitSystem


class Singleton(type):
_instances = {}

def __call__(cls, *args, **kwargs):
if cls not in cls._instances:
cls._instances[cls] = super(Singleton, cls).__call__(
*args, **kwargs
)
return cls._instances[cls]


class SimulationMacroCache(metaclass=Singleton):
def __init__(self, tax_benefit_system: TaxBenefitSystem):
self.core_version = importlib.metadata.version("policyengine-core")
self.country_package_metadata = (
tax_benefit_system.get_package_metadata()
)
self.country_version = self.country_package_metadata["version"]
self.cache_folder_path = None
self.cache_file_path = None

def set_cache_path(
self,
parent_path: [Path, str],
dataset_name: str,
variable_name: str,
period: str,
branch_name: str,
):
storage_folder = Path(parent_path) / f"{dataset_name}_variable_cache"
self.cache_folder_path = storage_folder
storage_folder.mkdir(exist_ok=True)
self.cache_file_path = (
storage_folder / f"{variable_name}_{period}_{branch_name}.h5"
)

def set_cache_value(self, cache_file_path: Path, value: ArrayLike):
with h5py.File(cache_file_path, "w") as f:
f.create_dataset(
"metadata:core_version",
data=self.core_version,
)
f.create_dataset(
"metadata:country_version",
data=self.country_version,
)
f.create_dataset("values", data=value)

def get_cache_path(self):
return self.cache_file_path

def get_cache_value(self, cache_file_path: Path):
with h5py.File(cache_file_path, "r") as f:
# Validate both core version and country package metadata are up-to-date, otherwise flush the cache
if (
"metadata:core_version" in f
and "metadata:country_version" in f
):
if (
f["metadata:core_version"][()].decode("utf-8")
!= self.core_version
or f["metadata:country_version"][()].decode("utf-8")
!= self.country_version
):
f.close()
self.clear_cache(self.cache_folder_path)
return None
else:
f.close()
self.clear_cache(self.cache_folder_path)
return None
return f["values"][()]

def clear_cache(self, cache_folder_path: Path):
shutil.rmtree(cache_folder_path)
112 changes: 43 additions & 69 deletions policyengine_core/simulations/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pandas as pd
from numpy.typing import ArrayLike
import logging
from pathlib import Path

from policyengine_core import commons, periods
from policyengine_core.data.dataset import Dataset
Expand All @@ -21,9 +22,6 @@
SimpleTracer,
TracingParameterNodeAtInstant,
)
import h5py
from pathlib import Path
import shutil

import json

Expand All @@ -36,6 +34,7 @@
from policyengine_core.variables import Variable, QuantityType
from policyengine_core.reforms.reform import Reform
from policyengine_core.parameters import get_parameter
from policyengine_core.simulations.sim_macro_cache import SimulationMacroCache


class Simulation:
Expand Down Expand Up @@ -606,11 +605,30 @@ def _calculate(
if cached_array is not None:
return cached_array

cache_path = self._get_macro_cache(variable_name, str(period))
if cache_path and cache_path.exists():
value = self._get_macro_cache_value(cache_path)
if value is not None:
return self._get_macro_cache_value(cache_path)
smc = SimulationMacroCache(self.tax_benefit_system)

# Check if cache could be used, if available, check if path exists
is_cache_available = self.check_macro_cache(variable_name, str(period))
if is_cache_available:
smc.set_cache_path(
self.dataset.file_path.parent,
self.dataset.name,
variable_name,
str(period),
self.branch_name,
)
cache_path = smc.get_cache_path()
if cache_path.exists():
if (
not self.macro_cache_read
or self.tax_benefit_system.data_modified
):
value = None
else:
value = smc.get_cache_value(cache_path)

if value is not None:
return value

if variable.requires_computation_after is not None:
if variable.requires_computation_after not in [
Expand Down Expand Up @@ -639,8 +657,8 @@ def _calculate(
values = self.calculate_divide(variable_name, period)

if alternate_period_handling:
if cache_path is not None:
self._set_macro_cache_value(cache_path, values)
if is_cache_available:
smc.set_cache_value(cache_path, values)
return values

self._check_period_consistency(period, variable)
Expand Down Expand Up @@ -738,8 +756,8 @@ def _calculate(
f"RecursionError while calculating {variable_name} for period {period}. The full computation stack is:\n{stack_formatted}"
)

if cache_path is not None:
self._set_macro_cache_value(cache_path, array)
if is_cache_available:
smc.set_cache_value(cache_path, array)

return array

Expand Down Expand Up @@ -1396,77 +1414,33 @@ def extract_person(

return json.loads(json.dumps(situation, cls=NpEncoder))

def _get_macro_cache(
self,
variable_name: str,
period: str,
):
def check_macro_cache(self, variable_name: str, period: str) -> bool:
"""
Get the cache location of a variable for a given period, if it exists.
Check if the variable is able to have cached value
"""
if not self.is_over_dataset:
return None
if (
hasattr(self, "dataset")
and self.dataset.data_format == Dataset.FLAT_FILE
):
return False

if self.is_over_dataset:
return True

variable = self.tax_benefit_system.get_variable(variable_name)
parameter_deps = variable.exhaustive_parameter_dependencies

if parameter_deps is None:
return None
return False

for parameter in parameter_deps:
param = get_parameter(
self.tax_benefit_system.parameters, parameter
)
if param.modified:
return None

storage_folder = (
self.dataset.file_path.parent
/ f"{self.dataset.name}_variable_cache"
)
storage_folder.mkdir(exist_ok=True)
return False

cache_file_path = (
storage_folder / f"{variable_name}_{period}_{self.branch_name}.h5"
)

return cache_file_path

def clear_macro_cache(self):
"""
Clear the cache of all variables.
"""
storage_folder = (
self.dataset.file_path.parent
/ f"{self.dataset.name}_variable_cache"
)
if storage_folder.exists():
shutil.rmtree(storage_folder)

def _get_macro_cache_value(
self,
cache_file_path: Path,
):
"""
Get the value of a variable from a cache file.
"""
if not self.macro_cache_read or self.tax_benefit_system.data_modified:
return None
with h5py.File(cache_file_path, "r") as f:
return f["values"][()]

def _set_macro_cache_value(
self,
cache_file_path: Path,
value: ArrayLike,
):
"""
Set the value of a variable in a cache file.
"""
if not self.macro_cache_write or self.tax_benefit_system.data_modified:
return None
with h5py.File(cache_file_path, "w") as f:
f.create_dataset("values", data=value)
return True

def to_input_dataframe(
self,
Expand Down
3 changes: 2 additions & 1 deletion policyengine_core/taxbenefitsystems/tax_benefit_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -598,7 +598,8 @@ def get_package_metadata(self) -> dict:

fallback_metadata = {
"name": self.__class__.__name__,
"version": "",
# For testing purposes
"version": "0.0.0",
"repository_url": "",
"location": "",
}
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@

setup(
name="policyengine-core",
version="3.6.0",
version="3.6.1",
author="PolicyEngine",
author_email="[email protected]",
classifiers=[
Expand Down
37 changes: 37 additions & 0 deletions tests/core/test_simulations.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
from policyengine_core.country_template.situation_examples import single
from policyengine_core.simulations import SimulationBuilder
from policyengine_core.simulations.sim_macro_cache import SimulationMacroCache
import importlib.metadata
import numpy as np
from pathlib import Path


def test_calculate_full_tracer(tax_benefit_system):
Expand Down Expand Up @@ -61,3 +65,36 @@ def test_get_memory_usage(tax_benefit_system):
memory_usage = simulation.get_memory_usage(variables=["salary"])
assert memory_usage["total_nb_bytes"] > 0
assert len(memory_usage["by_variable"]) == 1


def test_macro_cache(tax_benefit_system):
simulation = SimulationBuilder().build_from_entities(
tax_benefit_system, single
)

cache = SimulationMacroCache(tax_benefit_system)
assert cache.core_version == importlib.metadata.version(
"policyengine-core"
)
assert cache.country_version == "0.0.0"

cache.set_cache_path(
parent_path="tests/core",
dataset_name="test_dataset",
variable_name="test_variable",
period="2020",
branch_name="test_branch",
)
cache.set_cache_value(
cache_file_path=cache.cache_file_path,
value=np.array([1, 2, 3], dtype=np.float32),
)
assert cache.get_cache_path() == Path(
"tests/core/test_dataset_variable_cache/test_variable_2020_test_branch.h5"
)
assert np.array_equal(
cache.get_cache_value(cache.cache_file_path),
np.array([1, 2, 3], dtype=np.float32),
)
cache.clear_cache(cache.cache_folder_path)
assert not cache.cache_folder_path.exists()

0 comments on commit 8648c96

Please sign in to comment.