Skip to content

Commit

Permalink
Mypy errors test_crn.py (#571)
Browse files Browse the repository at this point in the history
* add types

* Update CHANGELOG.rst

* lint
  • Loading branch information
patricktnast authored Jan 15, 2025
1 parent 16121fc commit 1ffa22c
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 18 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
**3.2.15 - 01/10/25**

- Type-hinting: Fix mypy errors in tests/framework/randomness/test_crn.py

**3.2.14 - 01/10/25**

- Type-hinting: Fix mypy errors in vivarium/interface/interactive.py
Expand Down
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ exclude = [
'tests/framework/lookup/test_lookup.py',
'tests/framework/population/test_manager.py',
'tests/framework/population/test_population_view.py',
'tests/framework/randomness/test_crn.py',
'tests/framework/randomness/test_index_map.py',
'tests/framework/randomness/test_manager.py',
'tests/framework/randomness/test_reproducibility.py',
Expand Down
48 changes: 31 additions & 17 deletions tests/framework/randomness/test_crn.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from collections.abc import Iterator
from itertools import cycle
from typing import Type
from typing import Callable, Type, TypedDict

import numpy as np
import pandas as pd
Expand All @@ -22,11 +22,23 @@


@pytest.mark.parametrize("initializes_crn_attributes", [True, False])
def test_basic_repeatability(initializes_crn_attributes):
def test_basic_repeatability(initializes_crn_attributes: bool) -> None:
test_idx = pd.Index(range(100))
index_map = IndexMap()

stream_args = {
StreamArgs = TypedDict(
"StreamArgs",
{
"key": str,
"clock": Callable[[], pd.Timestamp],
"seed": str,
"index_map": IndexMap,
"initializes_crn_attributes": bool,
},
total=False,
)

stream_args: StreamArgs = {
"key": "test",
"clock": lambda: pd.Timestamp("2020-01-01"),
"seed": "abc",
Expand All @@ -36,14 +48,14 @@ def test_basic_repeatability(initializes_crn_attributes):

stream_base = RandomnessStream(**stream_args)
draw_base = stream_base.get_draw(test_idx)

for arg_permutation in [
arg_permutations: list[StreamArgs] = [
{},
{"key": "test2"},
{"clock": lambda: pd.Timestamp("2020-01-02")},
{"seed": "123"},
]:
new_stream_args = {**stream_args, **arg_permutation}
]
for arg_permutation in arg_permutations:
new_stream_args: StreamArgs = stream_args | arg_permutation
stream_permutation = RandomnessStream(**new_stream_args)
draw_1 = stream_permutation.get_draw(test_idx)
draw_2 = stream_permutation.get_draw(test_idx)
Expand All @@ -63,14 +75,14 @@ class BasePopulation(Component):
"""

@property
def name(self):
def name(self) -> str:
return "population"

@property
def columns_created(self) -> list[str]:
return ["crn_attr1", "crn_attr2", "other_attr1"]

def __init__(self, with_crn: bool, sims_to_add: Iterator = cycle([0])):
def __init__(self, with_crn: bool, sims_to_add: Iterator[int] = cycle([0])) -> None:
"""
Parameters
----------
Expand Down Expand Up @@ -137,7 +149,7 @@ def setup(self, builder: Builder) -> None:
super().setup(builder)
self.count = 0

def on_initialize_simulants(self, pop_data: SimulantData):
def on_initialize_simulants(self, pop_data: SimulantData) -> None:
new_people = len(pop_data.index)

population = pd.DataFrame(
Expand Down Expand Up @@ -173,10 +185,10 @@ def on_initialize_simulants(self, pop_data: SimulantData):
],
)
def test_multi_sim_basic_reproducibility_with_same_pop_growth(
pop_class: Type,
pop_class: Type[BasePopulation],
with_crn: bool,
sims_to_add: cycle,
):
sims_to_add: Iterator[int],
) -> None:
if with_crn:
configuration = {"randomness": {"key_columns": ["crn_attr1", "crn_attr2"]}}
else:
Expand Down Expand Up @@ -213,7 +225,9 @@ def test_multi_sim_basic_reproducibility_with_same_pop_growth(
pytest.param(SequentialPopulation, False),
],
)
def test_multi_sim_reproducibility_with_different_pop_growth(pop_class: Type, with_crn: bool):
def test_multi_sim_reproducibility_with_different_pop_growth(
pop_class: Type[BasePopulation], with_crn: bool
) -> None:
if with_crn:
configuration = {"randomness": {"key_columns": ["crn_attr1", "crn_attr2"]}}
else:
Expand Down Expand Up @@ -258,7 +272,7 @@ class UnBrokenPopulation(BasePopulation):
This is now a regression testing class.
"""

def on_initialize_simulants(self, pop_data: SimulantData):
def on_initialize_simulants(self, pop_data: SimulantData) -> None:
crn_attr = (1_000_000 * self.randomness_init.get_draw(index=pop_data.index)).astype(
int
)
Expand Down Expand Up @@ -287,8 +301,8 @@ def on_initialize_simulants(self, pop_data: SimulantData):
],
)
def test_prior_failure_path_when_first_crn_attribute_not_datelike(
with_crn: bool, sims_to_add: cycle
):
with_crn: bool, sims_to_add: Iterator[int]
) -> None:
if with_crn:
configuration = {"randomness": {"key_columns": ["crn_attr1", "crn_attr2"]}}
else:
Expand Down

0 comments on commit 1ffa22c

Please sign in to comment.