diff --git a/hnn_core/batch_simulate.py b/hnn_core/batch_simulate.py index 89a70091b..afd47ba13 100644 --- a/hnn_core/batch_simulate.py +++ b/hnn_core/batch_simulate.py @@ -197,16 +197,13 @@ def run(self, param_grid, return_output=True, param_combinations = self._generate_param_combinations( param_grid, combinations) total_sims = len(param_combinations) - num_sims_per_batch = max(total_sims // self.batch_size, 1) batch_size = min(self.batch_size, total_sims) results = [] simulated_data = [] - for i in range(0, total_sims, num_sims_per_batch): + for i in range(0, total_sims, batch_size): start_idx = i - end_idx = min(i + num_sims_per_batch, total_sims) - if i == batch_size - 1: - end_idx = len(param_combinations) + end_idx = min(i + batch_size, total_sims) batch_results = self.simulate_batch( param_combinations[start_idx:end_idx], n_jobs=n_jobs,