diff --git a/examples/howto/plot_batch_simulate.py b/examples/howto/plot_batch_simulate.py index f8018e731..35ac2a75c 100644 --- a/examples/howto/plot_batch_simulate.py +++ b/examples/howto/plot_batch_simulate.py @@ -24,13 +24,25 @@ from hnn_core import jones_2009_model # The number of cores may need modifying depending on your current machine. -n_jobs = 10 +n_jobs = 4 ############################################################################### +# The `add_evoked_drive` function simulates external input to the network, +# mimicking sensory stimulation or other external events. +# +# - `evprox` indicates a proximal drive, targeting dendrites near the cell +# bodies. +# - `mu=40` and `sigma=5` define the timing (mean and spread) of the input. +# - `weights_ampa` and `synaptic_delays` control the strength and +# timing of the input. +# +# This evoked drive causes the initial positive deflection in the dipole +# signal, triggering a cascade of activity through the network and +# resulting in the complex waveforms observed. def set_params(param_values, net=None): """ - Set parameters in the network drives. + Set parameters for the network drives. Parameters ---------- @@ -57,16 +69,16 @@ def set_params(param_values, net=None): synaptic_delays=synaptic_delays) ############################################################################### -# Define a parameter grid for the batch simulation. +# Next, we define a parameter grid for the batch simulation. param_grid = { - 'weight_basket': np.logspace(-4 - 1, 10), - 'weight_pyr': np.logspace(-4, -1, 10) + 'weight_basket': np.logspace(-4, -1, 20), + 'weight_pyr': np.logspace(-4, -1, 20) } ############################################################################### -# Define a function to calculate summary statistics +# We then define a function to calculate summary statistics. def summary_func(results): @@ -95,12 +107,8 @@ def summary_func(results): ############################################################################### # Run the batch simulation and collect the results. -# Comment off this code, if dask and distributed Python packages are installed -# from dask.distributed import Client -# client = Client(threads_per_worker=1, n_workers=5, processes=False) - -# Run the batch simulation and collect the results. +# Initialize the network model and run the batch simulation. net = jones_2009_model(mesh_shape=(3, 3)) batch_simulation = BatchSimulate(net=net, set_params=set_params, @@ -108,23 +116,45 @@ def summary_func(results): simulation_results = batch_simulation.run(param_grid, n_jobs=n_jobs, combinations=False, - backend='multiprocessing') -# backend='dask' if installed + backend='loky') + print("Simulation results:", simulation_results) ############################################################################### # This plot shows an overlay of all smoothed dipole waveforms from the -# batch simulation. Each line represents a different set of parameters, -# allowing us to visualize the range of responses across the parameter space. +# batch simulation. Each line represents a different set of synaptic strength +# parameters (`weight_basket`), allowing us to visualize the range of responses +# across the parameter space. +# The colormap represents synaptic strengths, from weaker (purple) +# to stronger (yellow). +# +# As drive strength increases, dipole responses show progressively larger +# amplitudes and more distinct features, reflecting heightened network +# activity. Weak drives (purple lines) produce smaller amplitude signals with +# simpler waveforms, while stronger drives (yellow lines) generate +# larger responses with more pronounced oscillatory features, indicating +# more robust network activity. +# +# The y-axis represents dipole amplitude in nAm (nanoAmpere-meters), which is +# the product of current flow and distance in the neural tissue. +# +# Stronger synaptic connections (yellow lines) generally show larger +# amplitude responses and more pronounced features throughout the simulation. -dpl_waveforms = [] +dpl_waveforms, param_values = [], [] for data_list in simulation_results['simulated_data']: for data in data_list: dpl_smooth = data['dpl'][0].copy().smooth(window_len=30) dpl_waveforms.append(dpl_smooth.data['agg']) + param_values.append(data['param_values']['weight_basket']) plt.figure(figsize=(10, 6)) -for waveform in dpl_waveforms: - plt.plot(waveform, alpha=0.5, linewidth=3) +cmap = plt.get_cmap('viridis') +log_param_values = np.log10(param_values) +norm = plt.Normalize(log_param_values.min(), log_param_values.max()) + +for waveform, log_param in zip(dpl_waveforms, log_param_values): + color = cmap(norm(log_param)) + plt.plot(waveform, color=color, alpha=0.7, linewidth=2) plt.title('Overlay of Dipole Waveforms') plt.xlabel('Time (ms)') plt.ylabel('Dipole Amplitude (nAm)') diff --git a/hnn_core/batch_simulate.py b/hnn_core/batch_simulate.py index 4f153b40e..168c3ea95 100644 --- a/hnn_core/batch_simulate.py +++ b/hnn_core/batch_simulate.py @@ -16,15 +16,14 @@ class BatchSimulate(object): - def __init__(self, set_params, net=jones_2009_model(), tstop=170, - dt=0.025, n_trials=1, record_vsec=False, - record_isec=False, postproc=False, save_outputs=False, + def __init__(self, set_params, net=jones_2009_model(), + tstop=170, dt=0.025, n_trials=1, save_folder='./sim_results', batch_size=100, - overwrite=True, summary_func=None, - save_dpl=True, save_spiking=False, - save_lfp=False, save_voltages=False, - save_currents=False, save_calcium=False, - clear_cache=False): + overwrite=True, save_outputs=False, save_dpl=True, + save_spiking=False, save_lfp=False, save_voltages=False, + save_currents=False, save_calcium=False, record_vsec=False, + record_isec=False, postproc=False, clear_cache=False, + summary_func=None): """Initialize the BatchSimulate class. Parameters @@ -37,34 +36,17 @@ def __init__(self, set_params, net=jones_2009_model(), tstop=170, where ``net`` is a Network object and ``params`` is a dictionary of the parameters that will be set inside the function. net : Network object, optional - The network model to use for simulations. Must be an instance of - jones_2009_model, law_2021_model, or calcium_model. - Default is jones_2009_model(). + The network model to use for simulations. Examples include: + - `jones_2009_model`: A network model based on Jones et al. (2009). + - `law_2021_model`: A network model based on Law et al. (2021). + - `calcium_model`: A network model incorporating calcium dynamics. + Default is `jones_2009_model()` tstop : float, optional The stop time for the simulation. Default is 170 ms. dt : float, optional The time step for the simulation. Default is 0.025 ms. n_trials : int, optional The number of trials for the simulation. Default is 1. - record_vsec : 'all' | 'soma' | False - Option to record voltages from all sections ('all'), or just - the soma ('soma'). Default: False. - record_isec : 'all' | 'soma' | False - Option to record voltages from all sections ('all'), or just - the soma ('soma'). Default: False. - postproc : bool - If True, smoothing (``dipole_smooth_win``) and scaling - (``dipole_scalefctr``) values are read from the parameter file, and - applied to the dipole objects before returning. - Note that this setting - only affects the dipole waveforms, and not somatic voltages, - possible extracellular recordings etc. - The preferred way is to use the - :meth:`~hnn_core.dipole.Dipole.smooth` and - :meth:`~hnn_core.dipole.Dipole.scale` methods instead. - Default: False. - save_outputs : bool, optional - Whether to save the simulation outputs to files. Default is False. save_folder : str, optional The path to save the simulation outputs. Default is './sim_results'. @@ -74,9 +56,8 @@ def __init__(self, set_params, net=jones_2009_model(), tstop=170, overwrite : bool, optional Whether to overwrite existing files and create file paths if they do not exist. Default is True. - summary_func : callable, optional - A function to calculate summary statistics from the simulation - results. Default is None. + save_outputs : bool, optional + Whether to save the simulation outputs to files. Default is False. save_dpl : bool If True, save dipole results. Note, `save_outputs` must be True. Default: True. @@ -97,9 +78,23 @@ def __init__(self, set_params, net=jones_2009_model(), tstop=170, If True, save calcium concentrations. Note, `save_outputs` must be True. Default: False. + record_vsec : 'all' | 'soma' | False + Option to record voltages from all sections ('all'), or just + the soma ('soma'). Default: False. + record_isec : 'all' | 'soma' | False + Option to record voltages from all sections ('all'), or just + the soma ('soma'). Default: False. + postproc : bool + If True, smoothing (``dipole_smooth_win``) and scaling + (``dipole_scalefctr``) values are read from the parameter file, and + applied to the dipole objects before returning. + Default: False. clear_cache : bool, optional Whether to clear the results cache after saving each batch. Default is False. + summary_func : callable, optional + A function to calculate summary statistics from the simulation + results. Default is None. Notes ----- When `save_output=True`, the saved files will appear as @@ -201,16 +196,13 @@ def run(self, param_grid, return_output=True, param_combinations = self._generate_param_combinations( param_grid, combinations) total_sims = len(param_combinations) - num_sims_per_batch = max(total_sims // self.batch_size, 1) batch_size = min(self.batch_size, total_sims) results = [] simulated_data = [] - for i in range(batch_size): - start_idx = i * num_sims_per_batch - end_idx = start_idx + num_sims_per_batch - if i == batch_size - 1: - end_idx = len(param_combinations) + for i in range(0, total_sims, batch_size): + start_idx = i + end_idx = min(i + batch_size, total_sims) batch_results = self.simulate_batch( param_combinations[start_idx:end_idx], n_jobs=n_jobs, @@ -388,6 +380,14 @@ def _save(self, results, start_idx, end_idx): if getattr(self, f'save_{attr}') and attr in results[0]: save_data[attr] = [result[attr] for result in results] + metadata = { + 'batch_size': self.batch_size, + 'n_trials': self.n_trials, + 'tstop': self.tstop, + 'dt': self.dt + } + save_data['metadata'] = metadata + file_name = os.path.join(self.save_folder, f'sim_run_{start_idx}-{end_idx}.npz') if os.path.exists(file_name) and not self.overwrite: diff --git a/hnn_core/hnn_io.py b/hnn_core/hnn_io.py index 667cf82ab..c933ac5dd 100644 --- a/hnn_core/hnn_io.py +++ b/hnn_core/hnn_io.py @@ -397,12 +397,12 @@ def dict_to_network(net_data, Parameters ---------- - fname : str or Path - Path to configuration file - read_drives : bool - Read-in drives to Network object - read_external_biases - Read-in external biases to Network object + net_data : dict + Dictionary containing network configurations. + read_drives : bool, optional + Read-in drives to Network object. Default is True. + read_external_biases : bool, optional + Read-in external biases to Network object. Default is True. Returns : Network ------- diff --git a/hnn_core/tests/test_batch_simulate.py b/hnn_core/tests/test_batch_simulate.py index 3b61ca221..71dac4c43 100644 --- a/hnn_core/tests/test_batch_simulate.py +++ b/hnn_core/tests/test_batch_simulate.py @@ -3,6 +3,8 @@ # Ryan Thorpe # Mainak Jas +from pathlib import Path +import time import pytest import numpy as np import os @@ -10,6 +12,9 @@ from hnn_core.batch_simulate import BatchSimulate from hnn_core import jones_2009_model +hnn_core_root = Path(__file__).parents[1] +assets_path = Path(hnn_core_root, 'tests', 'assets') + @pytest.fixture def batch_simulate_instance(tmp_path): @@ -33,11 +38,12 @@ def set_params(param_values, net): weights_ampa=weights_ampa, synaptic_delays=synaptic_delays) - net = jones_2009_model() + net = jones_2009_model(mesh_shape=(3, 3)) return BatchSimulate(net=net, set_params=set_params, - tstop=1., + tstop=10, save_folder=tmp_path, - batch_size=3) + batch_size=3, + n_trials=3,) @pytest.fixture @@ -75,6 +81,12 @@ def test_parameter_validation(): with pytest.raises(TypeError, match="net must be"): BatchSimulate(net="invalid_network", set_params=lambda x: x) + with pytest.raises(ValueError, match="'record_vsec' parameter"): + BatchSimulate(set_params=lambda x: x, record_vsec="invalid") + + with pytest.raises(ValueError, match="'record_isec' parameter"): + BatchSimulate(set_params=lambda x: x, record_isec="invalid") + def test_generate_param_combinations(batch_simulate_instance, param_grid): """Test generating parameter combinations.""" @@ -280,3 +292,27 @@ def test_load_results(batch_simulate_instance, param_grid, tmp_path): # Validation Tests with pytest.raises(TypeError, match='results must be'): batch_simulate_instance._save("invalid_results", start_idx, end_idx) + + +def test_parallel_execution(batch_simulate_instance, param_grid): + """Test parallel execution of simulations and ensure speedup.""" + + param_combinations = batch_simulate_instance._generate_param_combinations( + param_grid) + + start_time = time.perf_counter() + _ = batch_simulate_instance.simulate_batch( + param_combinations, n_jobs=1, backend='loky') + end_time = time.perf_counter() + serial_time = end_time - start_time + + start_time = time.perf_counter() + _ = batch_simulate_instance.simulate_batch( + param_combinations, + n_jobs=2, + backend='loky') + end_time = time.perf_counter() + parallel_time = end_time - start_time + + assert (serial_time > parallel_time + ), "Parallel execution is not faster than serial execution!"