From 9dd78cf60e2a58f9fbe74fe5a3c48d9253a4d813 Mon Sep 17 00:00:00 2001 From: dlyakhov Date: Mon, 18 Sep 2023 18:42:24 +0200 Subject: [PATCH] Fix classification sample metric dumps not in main process --- examples/torch/classification/main.py | 29 ++++++++++++------------ tests/torch/test_compression_training.py | 4 ++-- 2 files changed, 17 insertions(+), 16 deletions(-) diff --git a/examples/torch/classification/main.py b/examples/torch/classification/main.py index 2854c894f0a..c2c89e88505 100644 --- a/examples/torch/classification/main.py +++ b/examples/torch/classification/main.py @@ -380,14 +380,15 @@ def train( is_best = is_best_by_accuracy or compression_stage > best_compression_stage if is_best: best_acc1 = acc1 - config.mlflow.safe_call("log_metric", "best_acc1", best_acc1) best_compression_stage = max(compression_stage, best_compression_stage) - acc = best_acc1 / 100 - if config.metrics_dump is not None: - write_metrics(acc, config.metrics_dump) if is_main_process(): logger.info(statistics.to_str()) + if config.metrics_dump is not None: + acc = best_acc1 / 100 + write_metrics(acc, config.metrics_dump) + config.mlflow.safe_call("log_metric", "best_acc1", best_acc1) + checkpoint_path = osp.join(config.checkpoint_save_dir, get_run_name(config) + "_last.pth") checkpoint = { "epoch": epoch + 1, @@ -727,19 +728,19 @@ def validate(val_loader, model, criterion, config, epoch=0, log_validation_info= ) ) - if is_main_process() and log_validation_info: - config.tb.add_scalar("val/loss", losses.avg, len(val_loader) * epoch) - config.tb.add_scalar("val/top1", top1.avg, len(val_loader) * epoch) - config.tb.add_scalar("val/top5", top5.avg, len(val_loader) * epoch) - config.mlflow.safe_call("log_metric", "val/loss", float(losses.avg), epoch) - config.mlflow.safe_call("log_metric", "val/top1", float(top1.avg), epoch) - config.mlflow.safe_call("log_metric", "val/top5", float(top5.avg), epoch) + if is_main_process(): + if log_validation_info: + config.tb.add_scalar("val/loss", losses.avg, len(val_loader) * epoch) + config.tb.add_scalar("val/top1", top1.avg, len(val_loader) * epoch) + config.tb.add_scalar("val/top5", top5.avg, len(val_loader) * epoch) + config.mlflow.safe_call("log_metric", "val/loss", float(losses.avg), epoch) + config.mlflow.safe_call("log_metric", "val/top1", float(top1.avg), epoch) + config.mlflow.safe_call("log_metric", "val/top5", float(top5.avg), epoch) - if log_validation_info: - logger.info(" * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}\n".format(top1=top1, top5=top5)) + logger.info(" * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}\n".format(top1=top1, top5=top5)) - acc = top1.avg / 100 if config.metrics_dump is not None: + acc = top1.avg / 100 write_metrics(acc, config.metrics_dump) return top1.avg, top5.avg, losses.avg diff --git a/tests/torch/test_compression_training.py b/tests/torch/test_compression_training.py index a58b2563caa..4fc77842dc7 100644 --- a/tests/torch/test_compression_training.py +++ b/tests/torch/test_compression_training.py @@ -442,7 +442,7 @@ def test_compression_train(self, desc: CompressionTrainingTestDescriptor, tmp_pa self._validate_train_metric(desc) @pytest.mark.dependency(depends=["train"]) - def test_compression_eval(self, desc: LEGRTrainingTestDescriptor, tmp_path, mocker): + def test_compression_eval(self, desc: CompressionTrainingTestDescriptor, tmp_path, mocker): validator = desc.get_validator() args = validator.get_default_args(tmp_path) metric_file_path = self._add_args_for_eval(args, desc, tmp_path) @@ -497,7 +497,7 @@ def test_compression_nas_eval(self, nas_desc: NASTrainingTestDescriptor, tmp_pat self._validate_eval_metric(nas_desc, metric_file_path) @staticmethod - def _validate_eval_metric(desc, metric_file_path): + def _validate_eval_metric(desc: CompressionTrainingTestDescriptor, metric_file_path): with open(str(metric_file_path), encoding="utf8") as metric_file: metrics = json.load(metric_file) ref_metric = metrics["Accuracy"]