-
Notifications
You must be signed in to change notification settings - Fork 54
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Refactor: Removed joblib from simulate_dipole, and added parallel exe…
…cution test. Signed-off-by: samadpls <[email protected]>
- Loading branch information
Showing
3 changed files
with
39 additions
and
14 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,6 +4,7 @@ | |
# Mainak Jas <[email protected]> | ||
|
||
from pathlib import Path | ||
import time | ||
import pytest | ||
import numpy as np | ||
import os | ||
|
@@ -41,7 +42,8 @@ def set_params(param_values, net): | |
return BatchSimulate(net=net, set_params=set_params, | ||
tstop=10, | ||
save_folder=tmp_path, | ||
batch_size=3) | ||
batch_size=3, | ||
n_trials=3,) | ||
|
||
|
||
@pytest.fixture | ||
|
@@ -52,7 +54,7 @@ def param_grid(): | |
'weight_basket': np.logspace(-4, -1, 2), | ||
'weight_pyr': np.logspace(-4, -1, 2), | ||
'mu': [40], | ||
'sigma': [5] | ||
'sigma': [5], | ||
} | ||
|
||
|
||
|
@@ -290,3 +292,27 @@ def test_load_results(batch_simulate_instance, param_grid, tmp_path): | |
# Validation Tests | ||
with pytest.raises(TypeError, match='results must be'): | ||
batch_simulate_instance._save("invalid_results", start_idx, end_idx) | ||
|
||
|
||
def test_parallel_execution(batch_simulate_instance, param_grid): | ||
"""Test parallel execution of simulations and ensure speedup.""" | ||
|
||
param_combinations = batch_simulate_instance._generate_param_combinations( | ||
param_grid) | ||
|
||
start_time = time.perf_counter() | ||
_ = batch_simulate_instance.simulate_batch( | ||
param_combinations, n_jobs=1, backend='loky') | ||
end_time = time.perf_counter() | ||
serial_time = end_time - start_time | ||
|
||
start_time = time.perf_counter() | ||
_ = batch_simulate_instance.simulate_batch( | ||
param_combinations, | ||
n_jobs=2, | ||
backend='loky') | ||
end_time = time.perf_counter() | ||
parallel_time = end_time - start_time | ||
|
||
assert (serial_time > parallel_time | ||
), "Parallel execution is not faster than serial execution!" |