Skip to content

Commit

Permalink
ensure the sampler do not goes past the file in the last rank.
Browse files Browse the repository at this point in the history
  • Loading branch information
hariharan-devarajan committed Aug 31, 2024
1 parent 682a300 commit 2321952
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions dlio_benchmark/data_loader/torch_data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,10 @@ def __init__(self, rank, size, num_samples, epochs):
self.epochs = epochs
samples_per_proc = int(math.ceil(num_samples/size))
start_sample = self.rank * samples_per_proc
end_sample = (self.rank + 1) * samples_per_proc
self.indices = list(range(start_sample, end_sample))
end_sample = (self.rank + 1) * samples_per_proc - 1
if end_sample > num_samples - 1:
end_sample = num_samples - 1
self.indices = list(range(start_sample, end_sample + 1))


def __len__(self):
Expand Down

0 comments on commit 2321952

Please sign in to comment.