Skip to content

Commit

Permalink
Post Generate (#21)
Browse files Browse the repository at this point in the history
* version bump

* post generate and tests for post generate and post load
  • Loading branch information
aloosley authored Dec 30, 2024
1 parent 6f629ef commit aa6b9ec
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 4 deletions.
2 changes: 1 addition & 1 deletion persistable/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .base import Persistable # noqa
from .data import PersistableParams # noqa

__version__ = "1.2.3"
__version__ = "1.3.0"
21 changes: 20 additions & 1 deletion persistable/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,8 @@ def generate(self, persist: bool = True, **untracked_payload_params: Any) -> Non
if persist:
self.persist()

self._post_generate()

def persist(self) -> None:
"""Persist both payload and parameters."""

Expand Down Expand Up @@ -254,7 +256,7 @@ def _validate_payload(self, payload: PayloadTypeT) -> None:
-------
"""

def _post_load(self) -> None:
def _post_load(self, **untracked_payload_params: Any) -> None:
"""
Define here extra algorithmic steps to run after loading the payload.
Expand All @@ -271,6 +273,23 @@ def _post_load(self) -> None:
"""

def _post_generate(self, **untracked_payload_params: Any) -> None:
"""
Define here extra algorithmic steps to run after generating the payload.
This is sometimes useful to augment the payload with data that is inefficient to persist or
should not be persisted.
Parameters
----------
untracked_payload_params : dict
Payload parameters that the user doesn't want to track (not persisted to file)
Returns
-------
"""

@property
def persist_filepath(self) -> Path:
return self.data_dir / f"{self.payload_name}({self.persist_hash})"
Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@

setup(
name="persistable",
version="1.2.3",
version="1.3.0",
packages=find_packages(),
url="https://github.com/aloosley/persistable",
download_url="https://github.com/aloosley/persistable/archive/1.2.3.tar.gz",
download_url="https://github.com/aloosley/persistable/archive/1.3.0.tar.gz",
license="",
author="Alex Loosley, Stephan Sahm",
author_email="[email protected], [email protected]",
Expand Down
55 changes: 55 additions & 0 deletions tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,29 @@ def _generate_payload(self, **untracked_payload_params: Any) -> Dict[str, Any]:
return dict(a=self.params.a, old=self.dummy_persistable.payload)


class DummyPersistableWithPostLoadGenerate(Persistable[Dict[str, Any], DummyPersistableParams]):
def __init__(
self,
data_dir: Path,
params: DummyPersistableParams,
*,
verbose: bool = False,
logger: Optional[Logger] = None,
) -> None:
super().__init__(data_dir, params, tracked_persistable_dependencies=None, verbose=verbose, logger=logger)
self.post_load_: Optional[int] = None
self.post_generate_: Optional[int] = None

def _generate_payload(self, **untracked_payload_params: Any) -> Dict[str, Any]:
return dict(a=self.params.a, b="b")

def _post_load(self) -> None:
self.post_load_ = 1

def _post_generate(self) -> None:
self.post_generate_ = 1


class TestPersistable:
def test_init(self, tmp_path: Path) -> None:
# GIVEN
Expand Down Expand Up @@ -234,3 +257,35 @@ def test_validate_payload_on_load(self, tmp_path: Path) -> None:
dummy_persistable.load(warn_if_validation_fails=True)
with pytest.raises(InvalidPayloadError):
dummy_persistable.load(warn_if_validation_fails=False)

def test_persistable_with_post_generate_and_post_load(self, tmp_path: Path) -> None:
# GIVEN persistable
data_dir = tmp_path
params = DummyPersistableParams()
dummy_persistable = DummyPersistableWithPostLoadGenerate(data_dir=data_dir, params=params)
dummy_persistable_2 = DummyPersistableWithPostLoadGenerate(data_dir=data_dir, params=params)

# Assert that post load and post generate variables initiated to None
assert dummy_persistable.post_generate_ is None
assert dummy_persistable.post_load_ is None

# WHEN generated
dummy_persistable.generate()

# THEN
assert dummy_persistable.post_generate_ == 1
assert dummy_persistable.post_load_ is None

# WHEN loaded after generated
dummy_persistable.load()

# THEN
assert dummy_persistable.post_generate_ == 1
assert dummy_persistable.post_load_ == 1

# WHEN loaded
dummy_persistable_2.load()

# THEN
assert dummy_persistable_2.post_generate_ is None
assert dummy_persistable_2.post_load_ == 1

0 comments on commit aa6b9ec

Please sign in to comment.