Skip to content

Commit

Permalink
ENH BatchSimulate for JSON path handling
Browse files Browse the repository at this point in the history
Signed-off-by: samadpls <[email protected]>
  • Loading branch information
samadpls committed Aug 10, 2024
1 parent 100baad commit 86a68e8
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 4 deletions.
19 changes: 17 additions & 2 deletions hnn_core/batch_simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# Ryan Thorpe <[email protected]>
# Mainak Jas <[email protected]>

import json
import numpy as np
import os
from joblib import Parallel, delayed, parallel_config
Expand All @@ -13,6 +14,7 @@
from .externals.mne import _validate_type, _check_option
from .dipole import simulate_dipole
from .network_models import jones_2009_model
from .hnn_io import dict_to_network


class BatchSimulate(object):
Expand All @@ -24,7 +26,7 @@ def __init__(self, set_params, net=jones_2009_model(), tstop=170,
save_dpl=True, save_spiking=False,
save_lfp=False, save_voltages=False,
save_currents=False, save_calcium=False,
clear_cache=False):
clear_cache=False, net_json=None):
"""Initialize the BatchSimulate class.
Parameters
Expand Down Expand Up @@ -100,6 +102,9 @@ def __init__(self, set_params, net=jones_2009_model(), tstop=170,
clear_cache : bool, optional
Whether to clear the results cache after saving each batch.
Default is False.
net_json : str, optional
The path to a JSON file to create the network model. If provided,
this will override the `net` parameter. Default is None.
Notes
-----
When `save_output=True`, the saved files will appear as
Expand Down Expand Up @@ -127,6 +132,8 @@ def __init__(self, set_params, net=jones_2009_model(), tstop=170,
_validate_type(save_currents, types=(bool,), item_name='save_currents')
_validate_type(save_calcium, types=(bool,), item_name='save_calcium')
_validate_type(clear_cache, types=(bool,), item_name='clear_cache')
_validate_type(net_json, types=('path-like', None),
item_name='net_json')

if set_params is not None and not callable(set_params):
raise TypeError("set_params must be a callable function")
Expand Down Expand Up @@ -154,6 +161,7 @@ def __init__(self, set_params, net=jones_2009_model(), tstop=170,
self.save_currents = save_currents
self.save_calcium = save_calcium
self.clear_cache = clear_cache
self.net_json = net_json

def run(self, param_grid, return_output=True,
combinations=True, n_jobs=1, backend='loky',
Expand Down Expand Up @@ -295,7 +303,14 @@ def _run_single_sim(self, param_values):
- `param_values`: The parameter values used for the simulation.
"""

net = self.net.copy()
if isinstance(self.net_json, str):
with open(self.net_json, 'r') as file:
net_data = json.load(file)
net = dict_to_network(net_data)
else:
net = self.net
net = net.copy()

self.set_params(param_values, net)

results = {'net': net, 'param_values': param_values}
Expand Down
26 changes: 24 additions & 2 deletions hnn_core/tests/test_batch_simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,17 @@
# Ryan Thorpe <[email protected]>
# Mainak Jas <[email protected]>

from pathlib import Path
import pytest
import numpy as np
import os

from hnn_core.batch_simulate import BatchSimulate
from hnn_core import jones_2009_model

hnn_core_root = Path(__file__).parents[1]
assets_path = Path(hnn_core_root, 'tests', 'assets')


@pytest.fixture
def batch_simulate_instance(tmp_path):
Expand All @@ -33,9 +37,9 @@ def set_params(param_values, net):
weights_ampa=weights_ampa,
synaptic_delays=synaptic_delays)

net = jones_2009_model()
net = jones_2009_model(mesh_shape=(3, 3))
return BatchSimulate(net=net, set_params=set_params,
tstop=1.,
tstop=10,
save_folder=tmp_path,
batch_size=3)

Expand Down Expand Up @@ -75,6 +79,9 @@ def test_parameter_validation():
with pytest.raises(TypeError, match="net must be"):
BatchSimulate(net="invalid_network", set_params=lambda x: x)

with pytest.raises(TypeError, match="net_json must be"):
BatchSimulate(net_json=123, set_params=lambda x: x)


def test_generate_param_combinations(batch_simulate_instance, param_grid):
"""Test generating parameter combinations."""
Expand Down Expand Up @@ -104,6 +111,21 @@ def test_run_single_sim(batch_simulate_instance):
assert isinstance(result['net'], type(batch_simulate_instance.net))


def test_net_json_loading(param_grid):
"""Test loading the network from a JSON file."""
json_path = assets_path / 'jones2009_3x3_drives.json'

batch_simulate = BatchSimulate(net_json=str(json_path),
set_params=lambda x, y: x,
tstop=70)

result = batch_simulate._run_single_sim(param_grid)
assert isinstance(result, dict)
assert 'net' in result
assert 'param_values' in result
assert 'dpl' in result


def test_simulate_batch(batch_simulate_instance, param_grid):
"""Test simulating a batch of parameter sets."""
param_combinations = batch_simulate_instance._generate_param_combinations(
Expand Down

0 comments on commit 86a68e8

Please sign in to comment.