Skip to content

Commit

Permalink
bugfix: remove out large samples from the multi pack batch sampler
Browse files Browse the repository at this point in the history
larger entries, longer than max_batch_len, may be present in the
output of the sampler which gets filtered out in the collator

in rare instances, this can lead to empty batches which can be
problematic for the training loop

Signed-off-by: ChanderG <[email protected]>
  • Loading branch information
ChanderG committed Jun 24, 2024
1 parent 1cb03c5 commit 72fde51
Showing 1 changed file with 5 additions and 0 deletions.
5 changes: 5 additions & 0 deletions src/instructlab/training/multipack_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,11 @@ def generate_batches(self, set_stats=False):
len(self.lengths)
)

# remove indices where the entries are longer than batch max length
indices = indices[self.lengths[indices] <= self.batch_max_length]
if len(indices) < len(self.lengths):
print(f"\033[33mDropping {len(self.lengths) - len(indices)} samples longer than batch_max_length. Ensure that the right max_batch_length is used during data processing.\033[0m")

lengths = self.lengths[indices]
lengths_cumsum = np.cumsum(lengths)

Expand Down

0 comments on commit 72fde51

Please sign in to comment.