Skip to content

Commit

Permalink
bugfix: remove out large samples from the multi pack batch sampler
Browse files Browse the repository at this point in the history
larger entries, longer than max_batch_len, may be present in the
output of the sampler which gets filtered out in the collator

in rare instances, this can lead to empty batches which can be
problematic for the training loop

Signed-off-by: ChanderG <[email protected]>
  • Loading branch information
ChanderG committed Jun 19, 2024
1 parent 1cb03c5 commit f8ae22b
Showing 1 changed file with 3 additions and 0 deletions.
3 changes: 3 additions & 0 deletions src/instructlab/training/multipack_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,9 @@ def generate_batches(self, set_stats=False):
len(self.lengths)
)

# remove indices where the entries are longer than batch max length
indices = indices[self.lengths[indices] <= self.batch_max_length]

lengths = self.lengths[indices]
lengths_cumsum = np.cumsum(lengths)

Expand Down

0 comments on commit f8ae22b

Please sign in to comment.