diff --git a/paddle3d/apis/trainer.py b/paddle3d/apis/trainer.py index cc3cfe85..fae006bb 100644 --- a/paddle3d/apis/trainer.py +++ b/paddle3d/apis/trainer.py @@ -13,6 +13,7 @@ # limitations under the License. import os +import time from typing import Callable, Optional, Union import paddle @@ -23,7 +24,7 @@ from paddle3d.apis.pipeline import training_step, validation_step from paddle3d.apis.scheduler import Scheduler, SchedulerABC from paddle3d.utils.logger import logger -from paddle3d.utils.timer import Timer +from paddle3d.utils.timer import TimeAverager, Timer def default_dataloader_build_fn(**kwargs) -> paddle.io.DataLoader: @@ -113,14 +114,16 @@ def __init__( **dataloader_fn) if isinstance(dataloader_fn, dict) else dataloader_fn - self.train_dataloader = _dataloader_build_fn(train_dataset, self.model) + self.train_dataloader = _dataloader_build_fn( + train_dataset, self.model) if train_dataset else None self.eval_dataloader = _dataloader_build_fn( val_dataset, self.model) if val_dataset else None self.val_dataset = val_dataset self.resume = resume vdl_file_name = None - self.iters_per_epoch = len(self.train_dataloader) + self.iters_per_epoch = len( + self.train_dataloader) if train_dataset else 1 if iters is None: self.epochs = epochs @@ -187,6 +190,9 @@ def __init__( def train(self): """ """ + reader_cost_averager = TimeAverager() + batch_cost_averager = TimeAverager() + iter_train_dataloader = iter(self.train_dataloader) sync_bn = (getattr(self.model, 'sync_bn', False) and env.nranks > 1) if sync_bn: @@ -205,65 +211,76 @@ def train(self): while self.cur_iter < self.iters: - for sample in self.train_dataloader: - self.cur_iter += 1 + start_time = time.time() + try: + sample = next(iter_train_dataloader) + except StopIteration: + iter_train_dataloader = iter(self.train_dataloader) + sample = next(iter_train_dataloader) + reader_cost_averager.record(time.time() - start_time) + + self.cur_iter += 1 + + if self.cur_iter % self.iters_per_epoch == 1: + self.cur_epoch += 1 + + if self.cur_iter > self.iters: + break + + lr = self.optimizer.get_lr() + loss = training_step(model, self.optimizer, sample, self.cur_iter) + loss_sum += loss.numpy()[0] + + # ensure model or dataset's collate_fn collate `batch_size` filed + batch_cost_averager.record( + time.time() - start_time, + num_samples=sample.get('batch_size', 1)) + timer.step() + status = self.scheduler.step() + + if status.do_log and env.local_rank == 0: + loss_sum = float(loss_sum / self.scheduler.log_interval) + logger.info( + '[TRAIN] epoch={}/{}, iter={}/{}, loss={:.6f}, lr={:.6f}, reader_cost={:.5f} s, batch_cost={:.5f} s, ips={:.5f} sample/s | ETA {}' + .format(self.cur_epoch, self.epochs, self.cur_iter, + self.iters, loss_sum, lr, + reader_cost_averager.get_average(), + batch_cost_averager.get_average(), + batch_cost_averager.get_ips_average(), timer.eta)) + + self.log_writer.add_scalar( + tag='Training/learning_rate', value=lr, step=self.cur_iter) + self.log_writer.add_scalar( + tag='Training/loss', value=loss_sum, step=self.cur_iter) + + loss_sum = 0 + + if status.do_eval and env.local_rank == 0: + # TODO: whether to save a checkpoint based on the metric + metrics = self.evaluate() + for k, v in metrics.items(): + if not isinstance(v, paddle.Tensor) or v.numel() != 1: + continue - if self.cur_iter % self.iters_per_epoch == 1: - self.cur_epoch += 1 - - if self.cur_iter > self.iters: - break - - lr = self.optimizer.get_lr() - loss = training_step(model, self.optimizer, sample, - self.cur_iter) - loss_sum += loss.numpy()[0] + self.log_writer.add_scalar( + tag='Evaluation/{}'.format(k), + value=float(v), + step=self.cur_iter) - timer.step() - status = self.scheduler.step() + if status.save_checkpoint and env.local_rank == 0: + if self.train_by_epoch: + tag = 'epoch_{}'.format(self.cur_epoch) + else: + tag = 'iter_{}'.format(self.cur_iter) - if status.do_log and env.local_rank == 0: - loss_sum = float(loss_sum / self.scheduler.log_interval) - logger.info( - '[TRAIN] epoch={}/{}, iter={}/{}, loss={:.6f}, lr={:.6f} | ETA {}' - .format(self.cur_epoch, self.epochs, self.cur_iter, - self.iters, loss_sum, lr, timer.eta)) + self.checkpoint.push( + tag=tag, + params_dict=self.model.state_dict(), + opt_dict=self.optimizer.state_dict(), + verbose=True) - self.log_writer.add_scalar( - tag='Training/learning_rate', - value=lr, - step=self.cur_iter) - self.log_writer.add_scalar( - tag='Training/loss', value=loss_sum, step=self.cur_iter) - - loss_sum = 0 - - if status.do_eval and env.local_rank == 0: - # TODO: whether to save a checkpoint based on the metric - metrics = self.evaluate() - for k, v in metrics.items(): - if not isinstance(v, paddle.Tensor) or v.numel() != 1: - continue - - self.log_writer.add_scalar( - tag='Evaluation/{}'.format(k), - value=float(v), - step=self.cur_iter) - - if status.save_checkpoint and env.local_rank == 0: - if self.train_by_epoch: - tag = 'epoch_{}'.format(self.cur_epoch) - else: - tag = 'iter_{}'.format(self.cur_iter) - - self.checkpoint.push( - tag=tag, - params_dict=self.model.state_dict(), - opt_dict=self.optimizer.state_dict(), - verbose=True) - - self.checkpoint.record('iters', self.cur_iter) - self.checkpoint.record('epochs', self.cur_epoch) + self.checkpoint.record('iters', self.cur_iter) + self.checkpoint.record('epochs', self.cur_epoch) logger.info('Training is complete.') diff --git a/paddle3d/utils/timer.py b/paddle3d/utils/timer.py index 04d01e38..780e54b2 100644 --- a/paddle3d/utils/timer.py +++ b/paddle3d/utils/timer.py @@ -72,3 +72,29 @@ def eta(self): remaining_time %= 60**i return result.format(*arr) + + +class TimeAverager(object): + """ Time averager """ + + def __init__(self): + self.reset() + + def reset(self): + self._cnt = 0 + self._total_time = 0 + self._total_samples = 0 + + def record(self, usetime, num_samples=None): + self._cnt += 1 + self._total_time += usetime + if num_samples: + self._total_samples += num_samples + + def get_average(self): + if self._cnt == 0: + return 0 + return self._total_time / float(self._cnt) + + def get_ips_average(self): + return self._total_samples / self._total_time diff --git a/tools/evaluate.py b/tools/evaluate.py index 692dfa1b..05b7a5ae 100644 --- a/tools/evaluate.py +++ b/tools/evaluate.py @@ -64,6 +64,8 @@ def main(args): raise RuntimeError("Config file `{}` does not exist!".format(args.cfg)) cfg = Config(path=args.cfg, batch_size=args.batch_size) + # set None to avoid train dataset instantiation + cfg.dic.pop('train_dataset') if cfg.val_dataset is None: raise RuntimeError(