diff --git a/src/instructlab/training/multipack_sampler.py b/src/instructlab/training/multipack_sampler.py index 42fd7d17..56245bd6 100644 --- a/src/instructlab/training/multipack_sampler.py +++ b/src/instructlab/training/multipack_sampler.py @@ -437,6 +437,11 @@ def generate_batches(self, set_stats=False): len(self.lengths) ) + # remove indices where the entries are longer than batch max length + indices = indices[self.lengths[indices] <= self.batch_max_length] + if len(indices) < len(self.lengths): + print(f"\033[33mDropping {len(self.lengths) - len(indices)} samples longer than batch_max_length. Ensure that the right max_batch_length is used during data processing.\033[0m") + lengths = self.lengths[indices] lengths_cumsum = np.cumsum(lengths)