diff --git a/hnn_core/batch_simulate.py b/hnn_core/batch_simulate.py index 4f153b40e..20db668a0 100644 --- a/hnn_core/batch_simulate.py +++ b/hnn_core/batch_simulate.py @@ -5,6 +5,7 @@ # Ryan Thorpe # Mainak Jas +import json import numpy as np import os from joblib import Parallel, delayed, parallel_config @@ -13,6 +14,7 @@ from .externals.mne import _validate_type, _check_option from .dipole import simulate_dipole from .network_models import jones_2009_model +from .hnn_io import dict_to_network class BatchSimulate(object): @@ -24,7 +26,7 @@ def __init__(self, set_params, net=jones_2009_model(), tstop=170, save_dpl=True, save_spiking=False, save_lfp=False, save_voltages=False, save_currents=False, save_calcium=False, - clear_cache=False): + clear_cache=False, net_json=None): """Initialize the BatchSimulate class. Parameters @@ -100,6 +102,9 @@ def __init__(self, set_params, net=jones_2009_model(), tstop=170, clear_cache : bool, optional Whether to clear the results cache after saving each batch. Default is False. + net_json : str, optional + The path to a JSON file to create the network model. If provided, + this will override the `net` parameter. Default is None. Notes ----- When `save_output=True`, the saved files will appear as @@ -127,6 +132,8 @@ def __init__(self, set_params, net=jones_2009_model(), tstop=170, _validate_type(save_currents, types=(bool,), item_name='save_currents') _validate_type(save_calcium, types=(bool,), item_name='save_calcium') _validate_type(clear_cache, types=(bool,), item_name='clear_cache') + _validate_type(net_json, types=('path-like', None), + item_name='net_json') if set_params is not None and not callable(set_params): raise TypeError("set_params must be a callable function") @@ -154,6 +161,7 @@ def __init__(self, set_params, net=jones_2009_model(), tstop=170, self.save_currents = save_currents self.save_calcium = save_calcium self.clear_cache = clear_cache + self.net_json = net_json def run(self, param_grid, return_output=True, combinations=True, n_jobs=1, backend='loky', @@ -295,7 +303,14 @@ def _run_single_sim(self, param_values): - `param_values`: The parameter values used for the simulation. """ - net = self.net.copy() + if isinstance(self.net_json, str): + with open(self.net_json, 'r') as file: + net_data = json.load(file) + net = dict_to_network(net_data) + else: + net = self.net + net = net.copy() + self.set_params(param_values, net) results = {'net': net, 'param_values': param_values} diff --git a/hnn_core/tests/test_batch_simulate.py b/hnn_core/tests/test_batch_simulate.py index 3b61ca221..c0df84771 100644 --- a/hnn_core/tests/test_batch_simulate.py +++ b/hnn_core/tests/test_batch_simulate.py @@ -3,6 +3,7 @@ # Ryan Thorpe # Mainak Jas +from pathlib import Path import pytest import numpy as np import os @@ -10,6 +11,9 @@ from hnn_core.batch_simulate import BatchSimulate from hnn_core import jones_2009_model +hnn_core_root = Path(__file__).parents[1] +assets_path = Path(hnn_core_root, 'tests', 'assets') + @pytest.fixture def batch_simulate_instance(tmp_path): @@ -33,9 +37,9 @@ def set_params(param_values, net): weights_ampa=weights_ampa, synaptic_delays=synaptic_delays) - net = jones_2009_model() + net = jones_2009_model(mesh_shape=(3, 3)) return BatchSimulate(net=net, set_params=set_params, - tstop=1., + tstop=10, save_folder=tmp_path, batch_size=3) @@ -75,6 +79,9 @@ def test_parameter_validation(): with pytest.raises(TypeError, match="net must be"): BatchSimulate(net="invalid_network", set_params=lambda x: x) + with pytest.raises(TypeError, match="net_json must be"): + BatchSimulate(net_json=123, set_params=lambda x: x) + def test_generate_param_combinations(batch_simulate_instance, param_grid): """Test generating parameter combinations.""" @@ -104,6 +111,21 @@ def test_run_single_sim(batch_simulate_instance): assert isinstance(result['net'], type(batch_simulate_instance.net)) +def test_net_json_loading(param_grid): + """Test loading the network from a JSON file.""" + json_path = assets_path / 'jones2009_3x3_drives.json' + + batch_simulate = BatchSimulate(net_json=str(json_path), + set_params=lambda x, y: x, + tstop=70) + + result = batch_simulate._run_single_sim(param_grid) + assert isinstance(result, dict) + assert 'net' in result + assert 'param_values' in result + assert 'dpl' in result + + def test_simulate_batch(batch_simulate_instance, param_grid): """Test simulating a batch of parameter sets.""" param_combinations = batch_simulate_instance._generate_param_combinations(