Skip to content

Commit

Permalink
Added get_name on the mira parameter keys for safety (#218)
Browse files Browse the repository at this point in the history
* Added get_name on the mira parameter keys for safety

* Removed test for SIDARTHE which fails due to gyorilab/mira#195

* checked out scenario1 from main so that it doesn't clobber Vignesh's changes

* Added test that confirms gyorilab/mira#196 solves gyorilab/mira#195

* Added pin to latest mira commit
  • Loading branch information
djinnome authored Jul 11, 2023
1 parent 99d8c28 commit 7dc3033
Show file tree
Hide file tree
Showing 8 changed files with 1,504 additions and 1,052 deletions.

Large diffs are not rendered by default.

28 changes: 14 additions & 14 deletions notebook/integration_demo/demo_ensemble.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -121,20 +121,20 @@
"name": "stdout",
"output_type": "stream",
"text": [
"iteration 0: loss = 66.16035544872284\n",
"iteration 25: loss = 42.31934851408005\n",
"iteration 50: loss = 33.17540234327316\n",
"iteration 75: loss = 27.15283751487732\n",
"iteration 100: loss = 21.151540756225586\n",
"iteration 125: loss = 17.88287889957428\n",
"iteration 150: loss = 19.490506768226624\n",
"iteration 175: loss = 16.146777868270874\n",
"iteration 200: loss = 16.498568773269653\n",
"iteration 225: loss = 15.53571480512619\n",
"iteration 250: loss = 15.07484495639801\n",
"iteration 275: loss = 14.595208406448364\n",
"iteration 300: loss = 14.831865012645721\n",
"iteration 325: loss = 15.439243197441101\n"
"iteration 0: loss = 63.51079273223877\n",
"iteration 25: loss = 40.8916078209877\n",
"iteration 50: loss = 37.46487545967102\n",
"iteration 75: loss = 25.852425813674927\n",
"iteration 100: loss = 21.470157742500305\n",
"iteration 125: loss = 17.58678960800171\n",
"iteration 150: loss = 15.267673671245575\n",
"iteration 175: loss = 16.10206639766693\n",
"iteration 200: loss = 15.993813276290894\n",
"iteration 225: loss = 15.860064268112183\n",
"iteration 250: loss = 13.152705788612366\n",
"iteration 275: loss = 16.014341235160828\n",
"iteration 300: loss = 15.707162857055664\n",
"iteration 325: loss = 12.456510186195374\n"
]
}
],
Expand Down
1,000 changes: 500 additions & 500 deletions notebook/integration_demo/results_petri_ensemble/calibrated_sample_results.csv

Large diffs are not rendered by default.

1,000 changes: 500 additions & 500 deletions notebook/integration_demo/results_petri_ensemble/sample_results.csv

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ install_requires =
torchdiffeq
networkx
pandas
mira @ git+https://github.com/indralab/mira.git@3a3a931ee52c9e5b976ea849424b1a5c65f4c8d2
mira @ git+https://github.com/indralab/mira.git@a4995799db6dfb8119ef8c67833a7228c82c81b9
xarray
netcdf4
h5netcdf
Expand Down
5 changes: 3 additions & 2 deletions src/pyciemss/PetriNetODE/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,9 @@ def deriv(self, t: Time, state: State) -> StateDeriv:
# Get the current state
states = {v: state[i] for i, v in enumerate(self.var_order.keys())}
# Get the parameters
parameters = {k: getattr (self, k) for k in self.G.parameters}
parameters = {get_name(param_info): getattr (self, get_name(param_info))
for param_info in self.G.parameters.values()
}

# Evaluate the rate laws for each transition
deriv_tensor = self.compiled_rate_law(**states, **parameters, **dict(t=t))
Expand Down Expand Up @@ -500,7 +502,6 @@ def mass_action_deriv(self, t: Time, state: State) -> StateDeriv:
def param_prior(self):
for param_info in self.G.parameters.values():
param_name = get_name(param_info)

param_value = param_info.value
if isinstance(param_value, torch.nn.Parameter):
setattr(self, param_name, pyro.param(param_name, param_value))
Expand Down
292 changes: 292 additions & 0 deletions test/models/AMR_examples/scenario1_a.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,292 @@
{
"name": "Scenario 1a",
"schema": "https://raw.githubusercontent.com/DARPA-ASKEM/Model-Representations/petrinet_v0.5/petrinet/petrinet_schema.json",
"schema_name": "petrinet",
"description": "Scenario 1a",
"model_version": "0.1",
"properties": {},
"model": {
"states": [
{
"id": "S",
"name": "S",
"grounding": {
"identifiers": {
"ido": "0000514"
},
"modifiers": {}
},
"units": {
"expression": "person",
"expression_mathml": "<ci>person</ci>"
}
},
{
"id": "I",
"name": "I",
"grounding": {
"identifiers": {
"ido": "0000511"
},
"modifiers": {}
},
"units": {
"expression": "person",
"expression_mathml": "<ci>person</ci>"
}
},
{
"id": "E",
"name": "E",
"grounding": {
"identifiers": {
"apollosv": "0000154"
},
"modifiers": {}
},
"units": {
"expression": "person",
"expression_mathml": "<ci>person</ci>"
}
},
{
"id": "R",
"name": "R",
"grounding": {
"identifiers": {
"ido": "0000592"
},
"modifiers": {}
},
"units": {
"expression": "person",
"expression_mathml": "<ci>person</ci>"
}
},
{
"id": "D",
"name": "D",
"grounding": {
"identifiers": {
"ncit": "C28554"
},
"modifiers": {}
},
"units": {
"expression": "person",
"expression_mathml": "<ci>person</ci>"
}
}
],
"transitions": [
{
"id": "t1",
"input": [
"I",
"S"
],
"output": [
"I",
"E"
],
"properties": {
"name": "t1"
}
},
{
"id": "t2",
"input": [
"E"
],
"output": [
"I"
],
"properties": {
"name": "t2"
}
},
{
"id": "t3",
"input": [
"I"
],
"output": [
"R"
],
"properties": {
"name": "t3"
}
},
{
"id": "t4",
"input": [
"I"
],
"output": [
"D"
],
"properties": {
"name": "t4"
}
}
]
},
"semantics": {
"ode": {
"rates": [
{
"target": "t1",
"expression": "I*S*kappa*(beta_c + (-beta_c + beta_s)/(1 + exp(-k*(-t + t_0))))/N",
"expression_mathml": "<apply><divide/><apply><times/><ci>I</ci><ci>S</ci><ci>kappa</ci><apply><plus/><ci>beta_c</ci><apply><divide/><apply><plus/><apply><minus/><ci>beta_c</ci></apply><ci>beta_s</ci></apply><apply><plus/><cn>1</cn><apply><exp/><apply><minus/><apply><times/><ci>k</ci><apply><minus/><ci>t_0</ci><ci>t</ci></apply></apply></apply></apply></apply></apply></apply></apply><ci>N</ci></apply>"
},
{
"target": "t2",
"expression": "E*delta",
"expression_mathml": "<apply><times/><ci>E</ci><ci>delta</ci></apply>"
},
{
"target": "t3",
"expression": "I*gamma*(1 - alpha)",
"expression_mathml": "<apply><times/><ci>I</ci><ci>gamma</ci><apply><minus/><cn>1</cn><ci>alpha</ci></apply></apply>"
},
{
"target": "t4",
"expression": "I*alpha*rho",
"expression_mathml": "<apply><times/><ci>I</ci><ci>alpha</ci><ci>rho</ci></apply>"
}
],
"initials": [
{
"target": "S",
"expression": "5599999.00000000",
"expression_mathml": "<cn>5599999.0</cn>"
},
{
"target": "I",
"expression": "0.0",
"expression_mathml": "<cn>0.0</cn>"
},
{
"target": "E",
"expression": "1.00000000000000",
"expression_mathml": "<cn>1.0</cn>"
},
{
"target": "R",
"expression": "0.0",
"expression_mathml": "<cn>0.0</cn>"
},
{
"target": "D",
"expression": "0.0",
"expression_mathml": "<cn>0.0</cn>"
}
],
"parameters": [
{
"id": "N",
"value": 5600000.0,
"units": {
"expression": "person",
"expression_mathml": "<ci>person</ci>"
}
},
{
"id": "beta_c",
"value": 0.4,
"units": {
"expression": "1/(day*person)",
"expression_mathml": "<apply><divide/><cn>1</cn><apply><times/><ci>day</ci><ci>person</ci></apply></apply>"
}
},
{
"id": "beta_s",
"value": 1.0,
"units": {
"expression": "1/(day*person)",
"expression_mathml": "<apply><divide/><cn>1</cn><apply><times/><ci>day</ci><ci>person</ci></apply></apply>"
}
},
{
"id": "k",
"value": 5.0,
"units": {
"expression": "1",
"expression_mathml": "<cn>1</cn>"
}
},
{
"id": "kappa",
"value": 0.45454545454545453,
"units": {
"expression": "1/day",
"expression_mathml": "<apply><power/><ci>day</ci><cn>-1</cn></apply>"
}
},
{
"id": "t_0",
"value": 89.0,
"units": {
"expression": "day",
"expression_mathml": "<ci>day</ci>"
}
},
{
"id": "delta",
"value": 0.2,
"units": {
"expression": "1/day",
"expression_mathml": "<apply><power/><ci>day</ci><cn>-1</cn></apply>"
}
},
{
"id": "alpha",
"value": 6.4e-05,
"units": {
"expression": "1",
"expression_mathml": "<cn>1</cn>"
}
},
{
"id": "gamma",
"value": 0.09090909090909091,
"units": {
"expression": "1/day",
"expression_mathml": "<apply><power/><ci>day</ci><cn>-1</cn></apply>"
}
},
{
"id": "rho",
"value": 0.1111111111111111,
"units": {
"expression": "1/day",
"expression_mathml": "<apply><power/><ci>day</ci><cn>-1</cn></apply>"
}
}
],
"observables": [],
"time": {
"id": "t",
"units": {
"expression": "day",
"expression_mathml": "<ci>day</ci>"
}
}
}
},
"metadata": {
"annotations": {
"license": null,
"authors": [],
"references": [],
"time_scale": null,
"time_start": null,
"time_end": null,
"locations": [],
"pathogens": [],
"diseases": [],
"hosts": [],
"model_types": []
}
}
}
9 changes: 6 additions & 3 deletions test/test_petrinet_ode/test_ode_interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,6 @@ def test_samples_dtype(self):
for col_name in s.columns[2:]:
self.assertEqual(s[col_name].dtype, np.float64)




class TestODEInterfaces(unittest.TestCase):
Expand Down Expand Up @@ -294,9 +293,13 @@ def test_load_and_calibrate_and_sample_petri_model(self):
actual_intervened_samples = load_and_calibrate_and_sample_petri_model(ASKENET_PATH, data_path, num_samples, timepoints, interventions = interventions, start_state=initial_state, num_iterations=2)
assert_frame_equal(expected_intervened_samples, actual_intervened_samples, check_exact=False, atol=1e-5)

SCENARIO_1a_H2 = 'test/models/AMR_examples/scenario1_a.json'
scenario1a_output = load_and_sample_petri_model(SCENARIO_1a_H2, num_samples, timepoints)
self.assertTrue(isinstance(scenario1a_output, pd.DataFrame))



SIDARTHE = 'test/models/AMR_examples/BIOMD0000000955_askenet.json'
sidarthe_output = load_and_sample_petri_model(SIDARTHE, num_samples, timepoints)
self.assertTrue(isinstance(sidarthe_output, pd.DataFrame))



Expand Down

0 comments on commit 7dc3033

Please sign in to comment.