Skip to content

Commit

Permalink
Lint
Browse files Browse the repository at this point in the history
  • Loading branch information
djinnome committed Oct 23, 2024
1 parent 2187a8b commit a87e243
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 7 deletions.
7 changes: 3 additions & 4 deletions pyciemss/mira_integration/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import pyro
import torch


ParameterDict = Dict[str, torch.Tensor]


Expand All @@ -29,8 +28,8 @@ def sort_mira_dependencies(src: mira.metamodel.TemplateModel) -> list:
for param_info in src.parameters.values():
param_name = param_info.name
param_dist = getattr(param_info, "distribution", None)
# Check to see if the distribution parameters are sympy expressions
if param_dist is not None:
# Check to see if the distribution parameters are sympy expressions
if param_dist is not None:
for k, v in param_dist.parameters.items():
if isinstance(v, mira.metamodel.utils.SympyExprStr):
for free_symbol in v.free_symbols:
Expand Down Expand Up @@ -334,6 +333,6 @@ def mira_distribution_to_pyro(
f"Conversion from MIRA distribution type {mira_dist.type} to Pyro distribution has not been tested."
)

parameters ={k: torch.as_tensor(v) for k, v in mira_dist.parameters.items()}
parameters = {k: torch.as_tensor(v) for k, v in mira_dist.parameters.items()}

return _MIRA_TO_PYRO[mira_dist.type](parameters)
5 changes: 3 additions & 2 deletions tests/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,9 @@ def __init__(
os.path.join(MODELS_PATH, "beta_mean_gamma_cycle_sir_model.json"), "beta_mean"
),
ModelFixture(
os.path.join(MODELS_PATH, "gamma_mean_beta_mean_cycle_sir_model.json"), "beta_mean"
)
os.path.join(MODELS_PATH, "gamma_mean_beta_mean_cycle_sir_model.json"),
"beta_mean",
),
]
PETRI_MODELS = [
ModelFixture(
Expand Down
4 changes: 3 additions & 1 deletion tests/test_compiled_dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,9 @@ def test_compiled_dynamics_load_json(url, start_time, end_time):
@pytest.mark.parametrize("cyclic_model", CYCLIC_MODELS)
@pytest.mark.parametrize("start_time", START_TIMES)
@pytest.mark.parametrize("end_time", END_TIMES)
def test_hierarchical_compiled_dynamics(acyclic_model, cyclic_model, start_time, end_time):
def test_hierarchical_compiled_dynamics(
acyclic_model, cyclic_model, start_time, end_time
):
"""
Test the loading and dependency analysis of hierarchical compiled dynamics models.
Expand Down

0 comments on commit a87e243

Please sign in to comment.