diff --git a/hnn_core/batch_simulate.py b/hnn_core/batch_simulate.py index 89a70091b..d53347005 100644 --- a/hnn_core/batch_simulate.py +++ b/hnn_core/batch_simulate.py @@ -197,16 +197,13 @@ def run(self, param_grid, return_output=True, param_combinations = self._generate_param_combinations( param_grid, combinations) total_sims = len(param_combinations) - num_sims_per_batch = max(total_sims // self.batch_size, 1) batch_size = min(self.batch_size, total_sims) results = [] simulated_data = [] - for i in range(0, total_sims, num_sims_per_batch): + for i in range(0, total_sims, batch_size): start_idx = i - end_idx = min(i + num_sims_per_batch, total_sims) - if i == batch_size - 1: - end_idx = len(param_combinations) + end_idx = min(i + batch_size, total_sims) batch_results = self.simulate_batch( param_combinations[start_idx:end_idx], n_jobs=n_jobs, @@ -297,15 +294,14 @@ def _run_single_sim(self, param_values, n_jobs=1): results = {'net': net, 'param_values': param_values} if self.save_dpl: - with JoblibBackend(n_jobs=n_jobs): - dpl = simulate_dipole(net, - tstop=self.tstop, - dt=self.dt, - n_trials=self.n_trials, - record_vsec=self.record_vsec, - record_isec=self.record_isec, - postproc=self.postproc) - results['dpl'] = dpl + dpl = simulate_dipole(net, + tstop=self.tstop, + dt=self.dt, + n_trials=self.n_trials, + record_vsec=self.record_vsec, + record_isec=self.record_isec, + postproc=self.postproc) + results['dpl'] = dpl if self.save_spiking: results['spiking'] = {