Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: Refactor BatchSimulate Example and Improve Documentation #857

Merged
merged 12 commits into from
Dec 12, 2024
Merged
62 changes: 48 additions & 14 deletions examples/howto/plot_batch_simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,25 @@
from hnn_core import jones_2009_model

# The number of cores may need modifying depending on your current machine.
n_jobs = 10
n_jobs = 4
###############################################################################
# The `add_evoked_drive` function simulates external input to the network,
# mimicking sensory stimulation or other external events.
#
# - `evprox` indicates a proximal drive, targeting dendrites near the cell
# bodies.
# - `mu=40` and `sigma=5` define the timing (mean and spread) of the input.
# - `weights_ampa` and `synaptic_delays` control the strength and
# timing of the input.
#
# This evoked drive causes the initial positive deflection in the dipole
# signal, triggering a cascade of activity through the network and
# resulting in the complex waveforms observed.


def set_params(param_values, net=None):
"""
Set parameters in the network drives.
Set parameters for the network drives.

Parameters
----------
Expand All @@ -57,16 +69,16 @@ def set_params(param_values, net=None):
synaptic_delays=synaptic_delays)

###############################################################################
# Define a parameter grid for the batch simulation.
# Next, we define a parameter grid for the batch simulation.


param_grid = {
'weight_basket': np.logspace(-4 - 1, 10),
'weight_pyr': np.logspace(-4, -1, 10)
'weight_basket': np.logspace(-4, -1, 20),
'weight_pyr': np.logspace(-4, -1, 20)
}

###############################################################################
# Define a function to calculate summary statistics
# We then define a function to calculate summary statistics.


def summary_func(results):
Expand Down Expand Up @@ -95,36 +107,58 @@ def summary_func(results):
###############################################################################
# Run the batch simulation and collect the results.

# Comment off this code, if dask and distributed Python packages are installed
# Uncomment this code if dask and distributed Python packages are installed.
# from dask.distributed import Client
# client = Client(threads_per_worker=1, n_workers=5, processes=False)
samadpls marked this conversation as resolved.
Show resolved Hide resolved


# Run the batch simulation and collect the results.
# Initialize the network model and run the batch simulation.
net = jones_2009_model(mesh_shape=(3, 3))
batch_simulation = BatchSimulate(net=net,
set_params=set_params,
summary_func=summary_func)
simulation_results = batch_simulation.run(param_grid,
n_jobs=n_jobs,
combinations=False,
backend='multiprocessing')
backend='loky')
# backend='dask' if installed
print("Simulation results:", simulation_results)
###############################################################################
# This plot shows an overlay of all smoothed dipole waveforms from the
# batch simulation. Each line represents a different set of parameters,
# allowing us to visualize the range of responses across the parameter space.
# batch simulation. Each line represents a different set of synaptic strength
# parameters (`weight_basket`), allowing us to visualize the range of responses
# across the parameter space.
# The colormap represents synaptic strengths, from weaker (purple)
# to stronger (yellow).
#
# As drive strength increases, dipole responses show progressively larger
# amplitudes and more distinct features, reflecting heightened network
# activity. Weak drives (purple lines) produce smaller amplitude signals with
# simpler waveforms, while stronger drives (yellow lines) generate
# larger responses with more pronounced oscillatory features, indicating
# more robust network activity.
#
# The y-axis represents dipole amplitude in nAm (nanoAmpere-meters), which is
# the product of current flow and distance in the neural tissue.
#
# Stronger synaptic connections (yellow lines) generally show larger
# amplitude responses and more pronounced features throughout the simulation.

dpl_waveforms = []
dpl_waveforms, param_values = [], []
for data_list in simulation_results['simulated_data']:
for data in data_list:
dpl_smooth = data['dpl'][0].copy().smooth(window_len=30)
dpl_waveforms.append(dpl_smooth.data['agg'])
param_values.append(data['param_values']['weight_basket'])

plt.figure(figsize=(10, 6))
for waveform in dpl_waveforms:
plt.plot(waveform, alpha=0.5, linewidth=3)
cmap = plt.get_cmap('viridis')
log_param_values = np.log10(param_values)
norm = plt.Normalize(log_param_values.min(), log_param_values.max())

for waveform, log_param in zip(dpl_waveforms, log_param_values):
color = cmap(norm(log_param))
plt.plot(waveform, color=color, alpha=0.7, linewidth=2)
plt.title('Overlay of Dipole Waveforms')
plt.xlabel('Time (ms)')
plt.ylabel('Dipole Amplitude (nAm)')
Expand Down
82 changes: 41 additions & 41 deletions hnn_core/batch_simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,14 @@


class BatchSimulate(object):
def __init__(self, set_params, net=jones_2009_model(), tstop=170,
dt=0.025, n_trials=1, record_vsec=False,
record_isec=False, postproc=False, save_outputs=False,
def __init__(self, set_params, net=None,
samadpls marked this conversation as resolved.
Show resolved Hide resolved
tstop=170, dt=0.025, n_trials=1,
save_folder='./sim_results', batch_size=100,
overwrite=True, summary_func=None,
save_dpl=True, save_spiking=False,
save_lfp=False, save_voltages=False,
save_currents=False, save_calcium=False,
clear_cache=False):
overwrite=True, save_outputs=False, save_dpl=True,
save_spiking=False, save_lfp=False, save_voltages=False,
save_currents=False, save_calcium=False, record_vsec=False,
record_isec=False, postproc=False, clear_cache=False,
summary_func=None):
"""Initialize the BatchSimulate class.

Parameters
Expand All @@ -37,34 +36,17 @@ def __init__(self, set_params, net=jones_2009_model(), tstop=170,
where ``net`` is a Network object and ``params`` is a dictionary
of the parameters that will be set inside the function.
net : Network object, optional
The network model to use for simulations. Must be an instance of
jones_2009_model, law_2021_model, or calcium_model.
Default is jones_2009_model().
The network model to use for simulations. Examples include:
- `jones_2009_model`: A network model based on Jones et al. (2009).
- `law_2021_model`: A network model based on Law et al. (2021).
- `calcium_model`: A network model incorporating calcium dynamics.
Default is `jones_2009_model()`
tstop : float, optional
The stop time for the simulation. Default is 170 ms.
dt : float, optional
The time step for the simulation. Default is 0.025 ms.
n_trials : int, optional
The number of trials for the simulation. Default is 1.
record_vsec : 'all' | 'soma' | False
Option to record voltages from all sections ('all'), or just
the soma ('soma'). Default: False.
record_isec : 'all' | 'soma' | False
Option to record voltages from all sections ('all'), or just
the soma ('soma'). Default: False.
postproc : bool
If True, smoothing (``dipole_smooth_win``) and scaling
(``dipole_scalefctr``) values are read from the parameter file, and
applied to the dipole objects before returning.
Note that this setting
only affects the dipole waveforms, and not somatic voltages,
possible extracellular recordings etc.
The preferred way is to use the
:meth:`~hnn_core.dipole.Dipole.smooth` and
:meth:`~hnn_core.dipole.Dipole.scale` methods instead.
Default: False.
save_outputs : bool, optional
Whether to save the simulation outputs to files. Default is False.
save_folder : str, optional
The path to save the simulation outputs.
Default is './sim_results'.
Expand All @@ -74,9 +56,8 @@ def __init__(self, set_params, net=jones_2009_model(), tstop=170,
overwrite : bool, optional
Whether to overwrite existing files and create file paths
if they do not exist. Default is True.
summary_func : callable, optional
A function to calculate summary statistics from the simulation
results. Default is None.
save_outputs : bool, optional
Whether to save the simulation outputs to files. Default is False.
save_dpl : bool
If True, save dipole results. Note, `save_outputs` must be True.
Default: True.
Expand All @@ -97,9 +78,23 @@ def __init__(self, set_params, net=jones_2009_model(), tstop=170,
If True, save calcium concentrations.
Note, `save_outputs` must be True.
Default: False.
record_vsec : 'all' | 'soma' | False
samadpls marked this conversation as resolved.
Show resolved Hide resolved
Option to record voltages from all sections ('all'), or just
the soma ('soma'). Default: False.
record_isec : 'all' | 'soma' | False
Option to record voltages from all sections ('all'), or just
the soma ('soma'). Default: False.
postproc : bool
If True, smoothing (``dipole_smooth_win``) and scaling
(``dipole_scalefctr``) values are read from the parameter file, and
applied to the dipole objects before returning.
Default: False.
clear_cache : bool, optional
Whether to clear the results cache after saving each batch.
Default is False.
summary_func : callable, optional
A function to calculate summary statistics from the simulation
results. Default is None.
Notes
-----
When `save_output=True`, the saved files will appear as
Expand All @@ -111,7 +106,7 @@ def __init__(self, set_params, net=jones_2009_model(), tstop=170,
will be overwritten.
"""

_validate_type(net, Network, 'net', 'Network')
_validate_type(net, (Network, None), 'net', 'Network')
_validate_type(tstop, types='numeric', item_name='tstop')
_validate_type(dt, types='numeric', item_name='dt')
_validate_type(n_trials, types='int', item_name='n_trials')
Expand All @@ -134,7 +129,7 @@ def __init__(self, set_params, net=jones_2009_model(), tstop=170,
if summary_func is not None and not callable(summary_func):
raise TypeError("summary_func must be a callable function")

self.net = net
self.net = net if net is not None else jones_2009_model()
self.set_params = set_params
self.tstop = tstop
self.dt = dt
Expand Down Expand Up @@ -201,16 +196,13 @@ def run(self, param_grid, return_output=True,
param_combinations = self._generate_param_combinations(
param_grid, combinations)
total_sims = len(param_combinations)
num_sims_per_batch = max(total_sims // self.batch_size, 1)
batch_size = min(self.batch_size, total_sims)

results = []
simulated_data = []
for i in range(batch_size):
start_idx = i * num_sims_per_batch
end_idx = start_idx + num_sims_per_batch
if i == batch_size - 1:
end_idx = len(param_combinations)
for i in range(0, total_sims, batch_size):
asoplata marked this conversation as resolved.
Show resolved Hide resolved
start_idx = i
end_idx = min(i + batch_size, total_sims)
batch_results = self.simulate_batch(
param_combinations[start_idx:end_idx],
n_jobs=n_jobs,
Expand Down Expand Up @@ -388,6 +380,14 @@ def _save(self, results, start_idx, end_idx):
if getattr(self, f'save_{attr}') and attr in results[0]:
save_data[attr] = [result[attr] for result in results]

metadata = {
'batch_size': self.batch_size,
'n_trials': self.n_trials,
'tstop': self.tstop,
'dt': self.dt
}
save_data['metadata'] = metadata

file_name = os.path.join(self.save_folder,
f'sim_run_{start_idx}-{end_idx}.npz')
if os.path.exists(file_name) and not self.overwrite:
Expand Down
12 changes: 6 additions & 6 deletions hnn_core/hnn_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,12 +397,12 @@ def dict_to_network(net_data,

Parameters
----------
fname : str or Path
Path to configuration file
read_drives : bool
Read-in drives to Network object
read_external_biases
Read-in external biases to Network object
net_data : dict
samadpls marked this conversation as resolved.
Show resolved Hide resolved
Dictionary containing network configurations.
read_drives : bool, optional
Read-in drives to Network object. Default is True.
read_external_biases : bool, optional
Read-in external biases to Network object. Default is True.

Returns : Network
-------
Expand Down
42 changes: 39 additions & 3 deletions hnn_core/tests/test_batch_simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,18 @@
# Ryan Thorpe <[email protected]>
# Mainak Jas <[email protected]>

from pathlib import Path
import time
import pytest
import numpy as np
import os

from hnn_core.batch_simulate import BatchSimulate
from hnn_core import jones_2009_model

hnn_core_root = Path(__file__).parents[1]
assets_path = Path(hnn_core_root, 'tests', 'assets')


@pytest.fixture
def batch_simulate_instance(tmp_path):
Expand All @@ -33,11 +38,12 @@ def set_params(param_values, net):
weights_ampa=weights_ampa,
synaptic_delays=synaptic_delays)

net = jones_2009_model()
net = jones_2009_model(mesh_shape=(3, 3))
return BatchSimulate(net=net, set_params=set_params,
tstop=1.,
tstop=10,
save_folder=tmp_path,
batch_size=3)
batch_size=3,
n_trials=3,)


@pytest.fixture
Expand Down Expand Up @@ -75,6 +81,12 @@ def test_parameter_validation():
with pytest.raises(TypeError, match="net must be"):
BatchSimulate(net="invalid_network", set_params=lambda x: x)

with pytest.raises(ValueError, match="'record_vsec' parameter"):
BatchSimulate(set_params=lambda x: x, record_vsec="invalid")

with pytest.raises(ValueError, match="'record_isec' parameter"):
BatchSimulate(set_params=lambda x: x, record_isec="invalid")


def test_generate_param_combinations(batch_simulate_instance, param_grid):
"""Test generating parameter combinations."""
Expand Down Expand Up @@ -280,3 +292,27 @@ def test_load_results(batch_simulate_instance, param_grid, tmp_path):
# Validation Tests
with pytest.raises(TypeError, match='results must be'):
batch_simulate_instance._save("invalid_results", start_idx, end_idx)


def test_parallel_execution(batch_simulate_instance, param_grid):
"""Test parallel execution of simulations and ensure speedup."""

param_combinations = batch_simulate_instance._generate_param_combinations(
param_grid)

start_time = time.perf_counter()
_ = batch_simulate_instance.simulate_batch(
param_combinations, n_jobs=1, backend='loky')
end_time = time.perf_counter()
serial_time = end_time - start_time

start_time = time.perf_counter()
_ = batch_simulate_instance.simulate_batch(
param_combinations,
n_jobs=2,
backend='loky')
end_time = time.perf_counter()
parallel_time = end_time - start_time

assert (serial_time > parallel_time
), "Parallel execution is not faster than serial execution!"
Loading