From a87e243367a1500ae23a2146882554e5f6b8d1c0 Mon Sep 17 00:00:00 2001 From: "Zucker, Jeremy D" Date: Wed, 23 Oct 2024 10:12:39 -0700 Subject: [PATCH] Lint --- pyciemss/mira_integration/distributions.py | 7 +++---- tests/fixtures.py | 5 +++-- tests/test_compiled_dynamics.py | 4 +++- 3 files changed, 9 insertions(+), 7 deletions(-) diff --git a/pyciemss/mira_integration/distributions.py b/pyciemss/mira_integration/distributions.py index cd775379..4600f6ad 100644 --- a/pyciemss/mira_integration/distributions.py +++ b/pyciemss/mira_integration/distributions.py @@ -7,7 +7,6 @@ import pyro import torch - ParameterDict = Dict[str, torch.Tensor] @@ -29,8 +28,8 @@ def sort_mira_dependencies(src: mira.metamodel.TemplateModel) -> list: for param_info in src.parameters.values(): param_name = param_info.name param_dist = getattr(param_info, "distribution", None) - # Check to see if the distribution parameters are sympy expressions - if param_dist is not None: + # Check to see if the distribution parameters are sympy expressions + if param_dist is not None: for k, v in param_dist.parameters.items(): if isinstance(v, mira.metamodel.utils.SympyExprStr): for free_symbol in v.free_symbols: @@ -334,6 +333,6 @@ def mira_distribution_to_pyro( f"Conversion from MIRA distribution type {mira_dist.type} to Pyro distribution has not been tested." ) - parameters ={k: torch.as_tensor(v) for k, v in mira_dist.parameters.items()} + parameters = {k: torch.as_tensor(v) for k, v in mira_dist.parameters.items()} return _MIRA_TO_PYRO[mira_dist.type](parameters) diff --git a/tests/fixtures.py b/tests/fixtures.py index 9a2400a5..c2abae0a 100644 --- a/tests/fixtures.py +++ b/tests/fixtures.py @@ -53,8 +53,9 @@ def __init__( os.path.join(MODELS_PATH, "beta_mean_gamma_cycle_sir_model.json"), "beta_mean" ), ModelFixture( - os.path.join(MODELS_PATH, "gamma_mean_beta_mean_cycle_sir_model.json"), "beta_mean" - ) + os.path.join(MODELS_PATH, "gamma_mean_beta_mean_cycle_sir_model.json"), + "beta_mean", + ), ] PETRI_MODELS = [ ModelFixture( diff --git a/tests/test_compiled_dynamics.py b/tests/test_compiled_dynamics.py index f679573f..e99ccef1 100644 --- a/tests/test_compiled_dynamics.py +++ b/tests/test_compiled_dynamics.py @@ -81,7 +81,9 @@ def test_compiled_dynamics_load_json(url, start_time, end_time): @pytest.mark.parametrize("cyclic_model", CYCLIC_MODELS) @pytest.mark.parametrize("start_time", START_TIMES) @pytest.mark.parametrize("end_time", END_TIMES) -def test_hierarchical_compiled_dynamics(acyclic_model, cyclic_model, start_time, end_time): +def test_hierarchical_compiled_dynamics( + acyclic_model, cyclic_model, start_time, end_time +): """ Test the loading and dependency analysis of hierarchical compiled dynamics models.