-
Notifications
You must be signed in to change notification settings - Fork 52
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Removed net_json param and update test
Signed-off-by: samadpls <[email protected]>
- Loading branch information
Showing
2 changed files
with
10 additions
and
35 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,7 +5,6 @@ | |
# Ryan Thorpe <[email protected]> | ||
# Mainak Jas <[email protected]> | ||
|
||
import json | ||
import numpy as np | ||
import os | ||
from joblib import Parallel, delayed, parallel_config | ||
|
@@ -14,11 +13,10 @@ | |
from .externals.mne import _validate_type, _check_option | ||
from .dipole import simulate_dipole | ||
from .network_models import jones_2009_model | ||
from .hnn_io import dict_to_network | ||
|
||
|
||
class BatchSimulate(object): | ||
def __init__(self, set_params, net=jones_2009_model(), net_json=None, | ||
def __init__(self, set_params, net=jones_2009_model(), | ||
tstop=170, dt=0.025, n_trials=1, | ||
save_folder='./sim_results', batch_size=100, | ||
overwrite=True, save_outputs=False, save_dpl=True, | ||
|
@@ -41,9 +39,6 @@ def __init__(self, set_params, net=jones_2009_model(), net_json=None, | |
The network model to use for simulations. Must be an instance of | ||
jones_2009_model, law_2021_model, or calcium_model. | ||
Default is jones_2009_model(). | ||
net_json : str, optional | ||
The path to a JSON file to create the network model. If provided, | ||
this will override the `net` parameter. Default is None. | ||
tstop : float, optional | ||
The stop time for the simulation. Default is 170 ms. | ||
dt : float, optional | ||
|
@@ -125,8 +120,6 @@ def __init__(self, set_params, net=jones_2009_model(), net_json=None, | |
_validate_type(save_currents, types=(bool,), item_name='save_currents') | ||
_validate_type(save_calcium, types=(bool,), item_name='save_calcium') | ||
_validate_type(clear_cache, types=(bool,), item_name='clear_cache') | ||
_validate_type(net_json, types=('path-like', None), | ||
item_name='net_json') | ||
|
||
if set_params is not None and not callable(set_params): | ||
raise TypeError("set_params must be a callable function") | ||
|
@@ -154,7 +147,6 @@ def __init__(self, set_params, net=jones_2009_model(), net_json=None, | |
self.save_currents = save_currents | ||
self.save_calcium = save_calcium | ||
self.clear_cache = clear_cache | ||
self.net_json = net_json | ||
|
||
def run(self, param_grid, return_output=True, | ||
combinations=True, n_jobs=1, backend='loky', | ||
|
@@ -296,14 +288,7 @@ def _run_single_sim(self, param_values): | |
- `param_values`: The parameter values used for the simulation. | ||
""" | ||
|
||
if isinstance(self.net_json, str): | ||
with open(self.net_json, 'r') as file: | ||
net_data = json.load(file) | ||
net = dict_to_network(net_data) | ||
else: | ||
net = self.net | ||
net = net.copy() | ||
|
||
net = self.net.copy() | ||
self.set_params(param_values, net) | ||
|
||
results = {'net': net, 'param_values': param_values} | ||
|
@@ -396,6 +381,14 @@ def _save(self, results, start_idx, end_idx): | |
if getattr(self, f'save_{attr}') and attr in results[0]: | ||
save_data[attr] = [result[attr] for result in results] | ||
|
||
metadata = { | ||
'batch_size': self.batch_size, | ||
'n_trials': self.n_trials, | ||
'tstop': self.tstop, | ||
'dt': self.dt | ||
} | ||
save_data['metadata'] = metadata | ||
|
||
file_name = os.path.join(self.save_folder, | ||
f'sim_run_{start_idx}-{end_idx}.npz') | ||
if os.path.exists(file_name) and not self.overwrite: | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters