Skip to content

Commit

Permalink
Merge pull request #52 from ChanderG/bugfix-large-datapoints
Browse files Browse the repository at this point in the history
bugfix: remove out large samples from the multi pack batch sampler
  • Loading branch information
aldopareja authored Jun 24, 2024
2 parents 01cbcd7 + 72fde51 commit ee76c17
Showing 1 changed file with 5 additions and 0 deletions.
5 changes: 5 additions & 0 deletions src/instructlab/training/multipack_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,11 @@ def generate_batches(self, set_stats=False):
len(self.lengths)
)

# remove indices where the entries are longer than batch max length
indices = indices[self.lengths[indices] <= self.batch_max_length]
if len(indices) < len(self.lengths):
print(f"\033[33mDropping {len(self.lengths) - len(indices)} samples longer than batch_max_length. Ensure that the right max_batch_length is used during data processing.\033[0m")

lengths = self.lengths[indices]
lengths_cumsum = np.cumsum(lengths)

Expand Down

0 comments on commit ee76c17

Please sign in to comment.