Skip to content

Commit

Permalink
refactor(trainer): 每训练一轮进行随机打乱
Browse files Browse the repository at this point in the history
  • Loading branch information
zjykzj committed Oct 8, 2020
1 parent 93556b0 commit e000f57
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 3 deletions.
4 changes: 4 additions & 0 deletions tsn/data/samplers/iteration_based_batch_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,7 @@ def __iter__(self):

def __len__(self):
return self.num_iterations

def set_epoch(self, iteration):
if hasattr(self.batch_sampler.sampler, "set_epoch"):
self.batch_sampler.sampler.set_epoch(iteration)
5 changes: 2 additions & 3 deletions tsn/engine/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,8 @@ def do_train(args, cfg, arguments,
optimizer.step()
lr_scheduler.step()

if iteration % len(data_loader) == 0 and \
isinstance(data_loader.batch_sampler.batch_sampler.sampler, DistributedSampler):
data_loader.batch_sampler.batch_sampler.sampler.set_epoch(iteration)
if iteration % len(data_loader) == 0 and hasattr(data_loader.batch_sampler, "set_epoch"):
data_loader.batch_sampler.set_epoch(iteration)

batch_time = time.time() - end
end = time.time()
Expand Down

0 comments on commit e000f57

Please sign in to comment.