diff --git a/examples/howto/plot_batch_simulate.py b/examples/howto/plot_batch_simulate.py index 470a7debb..c157f0751 100644 --- a/examples/howto/plot_batch_simulate.py +++ b/examples/howto/plot_batch_simulate.py @@ -24,7 +24,7 @@ from hnn_core import jones_2009_model # The number of cores may need modifying depending on your current machine. -n_jobs = 10 +n_jobs = 4 ############################################################################### # The `add_evoked_drive` function simulates external input to the network, # mimicking sensory stimulation or other external events. @@ -116,7 +116,8 @@ def summary_func(results): net = jones_2009_model(mesh_shape=(3, 3)) batch_simulation = BatchSimulate(net=net, set_params=set_params, - summary_func=summary_func) + summary_func=summary_func, + n_trials=10) simulation_results = batch_simulation.run(param_grid, n_jobs=n_jobs, combinations=False, diff --git a/hnn_core/batch_simulate.py b/hnn_core/batch_simulate.py index afd47ba13..424976187 100644 --- a/hnn_core/batch_simulate.py +++ b/hnn_core/batch_simulate.py @@ -9,7 +9,6 @@ import os from joblib import Parallel, delayed, parallel_config -from .parallel_backends import JoblibBackend from .network import Network from .externals.mne import _validate_type, _check_option from .dipole import simulate_dipole @@ -294,15 +293,14 @@ def _run_single_sim(self, param_values, n_jobs=1): results = {'net': net, 'param_values': param_values} if self.save_dpl: - with JoblibBackend(n_jobs=n_jobs): - dpl = simulate_dipole(net, - tstop=self.tstop, - dt=self.dt, - n_trials=self.n_trials, - record_vsec=self.record_vsec, - record_isec=self.record_isec, - postproc=self.postproc) - results['dpl'] = dpl + dpl = simulate_dipole(net, + tstop=self.tstop, + dt=self.dt, + n_trials=self.n_trials, + record_vsec=self.record_vsec, + record_isec=self.record_isec, + postproc=self.postproc) + results['dpl'] = dpl if self.save_spiking: results['spiking'] = { diff --git a/hnn_core/tests/test_batch_simulate.py b/hnn_core/tests/test_batch_simulate.py index aa0266dcb..80a1312f7 100644 --- a/hnn_core/tests/test_batch_simulate.py +++ b/hnn_core/tests/test_batch_simulate.py @@ -4,6 +4,7 @@ # Mainak Jas from pathlib import Path +import time import pytest import numpy as np import os @@ -290,3 +291,27 @@ def test_load_results(batch_simulate_instance, param_grid, tmp_path): # Validation Tests with pytest.raises(TypeError, match='results must be'): batch_simulate_instance._save("invalid_results", start_idx, end_idx) + + +def test_parallel_execution(batch_simulate_instance, param_grid): + """Test parallel execution of simulations and ensure speedup.""" + + param_combinations = batch_simulate_instance._generate_param_combinations( + param_grid) + + start_time = time.perf_counter() + results_serial = batch_simulate_instance.simulate_batch( + param_combinations, n_jobs=1, backend='loky') + end_time = time.perf_counter() + serial_time = end_time - start_time + + start_time = time.perf_counter() + results_parallel = batch_simulate_instance.simulate_batch( + param_combinations, + n_jobs=4, + backend='loky') + end_time = time.perf_counter() + parallel_time = end_time - start_time + + assert (serial_time > parallel_time + ), "Parallel execution is not faster than serial execution!"