Skip to content

Commit

Permalink
Refactor batch simulation parameters and backend
Browse files Browse the repository at this point in the history
Signed-off-by: samadpls <[email protected]>
  • Loading branch information
samadpls committed Nov 13, 2024
1 parent 68807c7 commit 9f6413a
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 16 deletions.
6 changes: 3 additions & 3 deletions examples/howto/plot_batch_simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,8 @@ def set_params(param_values, net=None):


param_grid = {
'weight_basket': np.logspace(-4, -1, 10),
'weight_pyr': np.logspace(-4, -1, 10)
'weight_basket': np.logspace(-4, -1, 20),
'weight_pyr': np.logspace(-4, -1, 20)
}

###############################################################################
Expand Down Expand Up @@ -120,7 +120,7 @@ def summary_func(results):
simulation_results = batch_simulation.run(param_grid,
n_jobs=n_jobs,
combinations=False,
backend='multiprocessing')
backend='loky')
# backend='dask' if installed
print("Simulation results:", simulation_results)
###############################################################################
Expand Down
28 changes: 15 additions & 13 deletions hnn_core/batch_simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import os
from joblib import Parallel, delayed, parallel_config

from .parallel_backends import JoblibBackend
from .network import Network
from .externals.mne import _validate_type, _check_option
from .dipole import simulate_dipole
Expand Down Expand Up @@ -201,9 +202,9 @@ def run(self, param_grid, return_output=True,

results = []
simulated_data = []
for i in range(batch_size):
start_idx = i * num_sims_per_batch
end_idx = start_idx + num_sims_per_batch
for i in range(0, total_sims, num_sims_per_batch):
start_idx = i
end_idx = min(i + num_sims_per_batch, total_sims)
if i == batch_size - 1:
end_idx = len(param_combinations)
batch_results = self.simulate_batch(
Expand Down Expand Up @@ -269,10 +270,10 @@ def simulate_batch(self, param_combinations, n_jobs=1,
with parallel_config(backend=backend):
res = Parallel(n_jobs=n_jobs, verbose=verbose)(
delayed(self._run_single_sim)(
params) for params in param_combinations)
params, n_jobs) for params in param_combinations)
return res

def _run_single_sim(self, param_values):
def _run_single_sim(self, param_values, n_jobs=1):
"""Run a single simulation.
Parameters
Expand All @@ -296,14 +297,15 @@ def _run_single_sim(self, param_values):
results = {'net': net, 'param_values': param_values}

if self.save_dpl:
dpl = simulate_dipole(net,
tstop=self.tstop,
dt=self.dt,
n_trials=self.n_trials,
record_vsec=self.record_vsec,
record_isec=self.record_isec,
postproc=self.postproc)
results['dpl'] = dpl
with JoblibBackend(n_jobs=n_jobs):
dpl = simulate_dipole(net,
tstop=self.tstop,
dt=self.dt,
n_trials=self.n_trials,
record_vsec=self.record_vsec,
record_isec=self.record_isec,
postproc=self.postproc)
results['dpl'] = dpl

if self.save_spiking:
results['spiking'] = {
Expand Down

0 comments on commit 9f6413a

Please sign in to comment.