Skip to content

Commit

Permalink
Refactor: Removed joblib from simulate_dipole, and added parallel exe…
Browse files Browse the repository at this point in the history
…cution test.

Signed-off-by: samadpls <[email protected]>
  • Loading branch information
samadpls committed Nov 26, 2024
1 parent 2f79384 commit c6ab1ae
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 13 deletions.
5 changes: 3 additions & 2 deletions examples/howto/plot_batch_simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from hnn_core import jones_2009_model

# The number of cores may need modifying depending on your current machine.
n_jobs = 10
n_jobs = 4
###############################################################################
# The `add_evoked_drive` function simulates external input to the network,
# mimicking sensory stimulation or other external events.
Expand Down Expand Up @@ -116,7 +116,8 @@ def summary_func(results):
net = jones_2009_model(mesh_shape=(3, 3))
batch_simulation = BatchSimulate(net=net,
set_params=set_params,
summary_func=summary_func)
summary_func=summary_func,
n_trials=10)
simulation_results = batch_simulation.run(param_grid,
n_jobs=n_jobs,
combinations=False,
Expand Down
18 changes: 8 additions & 10 deletions hnn_core/batch_simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import os
from joblib import Parallel, delayed, parallel_config

from .parallel_backends import JoblibBackend
from .network import Network
from .externals.mne import _validate_type, _check_option
from .dipole import simulate_dipole
Expand Down Expand Up @@ -294,15 +293,14 @@ def _run_single_sim(self, param_values, n_jobs=1):
results = {'net': net, 'param_values': param_values}

if self.save_dpl:
with JoblibBackend(n_jobs=n_jobs):
dpl = simulate_dipole(net,
tstop=self.tstop,
dt=self.dt,
n_trials=self.n_trials,
record_vsec=self.record_vsec,
record_isec=self.record_isec,
postproc=self.postproc)
results['dpl'] = dpl
dpl = simulate_dipole(net,
tstop=self.tstop,
dt=self.dt,
n_trials=self.n_trials,
record_vsec=self.record_vsec,
record_isec=self.record_isec,
postproc=self.postproc)
results['dpl'] = dpl

if self.save_spiking:
results['spiking'] = {
Expand Down
28 changes: 27 additions & 1 deletion hnn_core/tests/test_batch_simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# Mainak Jas <[email protected]>

from pathlib import Path
import time
import pytest
import numpy as np
import os
Expand Down Expand Up @@ -41,7 +42,8 @@ def set_params(param_values, net):
return BatchSimulate(net=net, set_params=set_params,
tstop=10,
save_folder=tmp_path,
batch_size=3)
batch_size=3,
n_trials=3,)


@pytest.fixture
Expand Down Expand Up @@ -290,3 +292,27 @@ def test_load_results(batch_simulate_instance, param_grid, tmp_path):
# Validation Tests
with pytest.raises(TypeError, match='results must be'):
batch_simulate_instance._save("invalid_results", start_idx, end_idx)


def test_parallel_execution(batch_simulate_instance, param_grid):
"""Test parallel execution of simulations and ensure speedup."""

param_combinations = batch_simulate_instance._generate_param_combinations(
param_grid)

start_time = time.perf_counter()
_ = batch_simulate_instance.simulate_batch(
param_combinations, n_jobs=1, backend='loky')
end_time = time.perf_counter()
serial_time = end_time - start_time

start_time = time.perf_counter()
_ = batch_simulate_instance.simulate_batch(
param_combinations,
n_jobs=2,
backend='loky')
end_time = time.perf_counter()
parallel_time = end_time - start_time

assert (serial_time > parallel_time
), "Parallel execution is not faster than serial execution!"

0 comments on commit c6ab1ae

Please sign in to comment.