diff --git a/examples/howto/plot_batch_simulate.py b/examples/howto/plot_batch_simulate.py index f683a6087..470a7debb 100644 --- a/examples/howto/plot_batch_simulate.py +++ b/examples/howto/plot_batch_simulate.py @@ -73,8 +73,8 @@ def set_params(param_values, net=None): param_grid = { - 'weight_basket': np.logspace(-4, -1, 10), - 'weight_pyr': np.logspace(-4, -1, 10) + 'weight_basket': np.logspace(-4, -1, 20), + 'weight_pyr': np.logspace(-4, -1, 20) } ############################################################################### @@ -120,7 +120,7 @@ def summary_func(results): simulation_results = batch_simulation.run(param_grid, n_jobs=n_jobs, combinations=False, - backend='multiprocessing') + backend='loky') # backend='dask' if installed print("Simulation results:", simulation_results) ############################################################################### diff --git a/hnn_core/batch_simulate.py b/hnn_core/batch_simulate.py index b096819cc..89a70091b 100644 --- a/hnn_core/batch_simulate.py +++ b/hnn_core/batch_simulate.py @@ -9,6 +9,7 @@ import os from joblib import Parallel, delayed, parallel_config +from .parallel_backends import JoblibBackend from .network import Network from .externals.mne import _validate_type, _check_option from .dipole import simulate_dipole @@ -201,9 +202,9 @@ def run(self, param_grid, return_output=True, results = [] simulated_data = [] - for i in range(batch_size): - start_idx = i * num_sims_per_batch - end_idx = start_idx + num_sims_per_batch + for i in range(0, total_sims, num_sims_per_batch): + start_idx = i + end_idx = min(i + num_sims_per_batch, total_sims) if i == batch_size - 1: end_idx = len(param_combinations) batch_results = self.simulate_batch( @@ -269,10 +270,10 @@ def simulate_batch(self, param_combinations, n_jobs=1, with parallel_config(backend=backend): res = Parallel(n_jobs=n_jobs, verbose=verbose)( delayed(self._run_single_sim)( - params) for params in param_combinations) + params, n_jobs) for params in param_combinations) return res - def _run_single_sim(self, param_values): + def _run_single_sim(self, param_values, n_jobs=1): """Run a single simulation. Parameters @@ -296,14 +297,15 @@ def _run_single_sim(self, param_values): results = {'net': net, 'param_values': param_values} if self.save_dpl: - dpl = simulate_dipole(net, - tstop=self.tstop, - dt=self.dt, - n_trials=self.n_trials, - record_vsec=self.record_vsec, - record_isec=self.record_isec, - postproc=self.postproc) - results['dpl'] = dpl + with JoblibBackend(n_jobs=n_jobs): + dpl = simulate_dipole(net, + tstop=self.tstop, + dt=self.dt, + n_trials=self.n_trials, + record_vsec=self.record_vsec, + record_isec=self.record_isec, + postproc=self.postproc) + results['dpl'] = dpl if self.save_spiking: results['spiking'] = {