diff --git a/tsn/data/samplers/iteration_based_batch_sampler.py b/tsn/data/samplers/iteration_based_batch_sampler.py index dbdde3a..f34739b 100644 --- a/tsn/data/samplers/iteration_based_batch_sampler.py +++ b/tsn/data/samplers/iteration_based_batch_sampler.py @@ -28,3 +28,7 @@ def __iter__(self): def __len__(self): return self.num_iterations + + def set_epoch(self, iteration): + if hasattr(self.batch_sampler.sampler, "set_epoch"): + self.batch_sampler.sampler.set_epoch(iteration) diff --git a/tsn/engine/trainer.py b/tsn/engine/trainer.py index 5d88278..c77bd99 100644 --- a/tsn/engine/trainer.py +++ b/tsn/engine/trainer.py @@ -63,9 +63,8 @@ def do_train(args, cfg, arguments, optimizer.step() lr_scheduler.step() - if iteration % len(data_loader) == 0 and \ - isinstance(data_loader.batch_sampler.batch_sampler.sampler, DistributedSampler): - data_loader.batch_sampler.batch_sampler.sampler.set_epoch(iteration) + if iteration % len(data_loader) == 0 and hasattr(data_loader.batch_sampler, "set_epoch"): + data_loader.batch_sampler.set_epoch(iteration) batch_time = time.time() - end end = time.time()