Skip to content

Commit

Permalink
Fix classification sample metric dumps not in main process
Browse files Browse the repository at this point in the history
  • Loading branch information
daniil-lyakhov committed Sep 19, 2023
1 parent 8ceff78 commit 9dd78cf
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 16 deletions.
29 changes: 15 additions & 14 deletions examples/torch/classification/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,14 +380,15 @@ def train(
is_best = is_best_by_accuracy or compression_stage > best_compression_stage
if is_best:
best_acc1 = acc1
config.mlflow.safe_call("log_metric", "best_acc1", best_acc1)
best_compression_stage = max(compression_stage, best_compression_stage)
acc = best_acc1 / 100
if config.metrics_dump is not None:
write_metrics(acc, config.metrics_dump)
if is_main_process():
logger.info(statistics.to_str())

if config.metrics_dump is not None:
acc = best_acc1 / 100
write_metrics(acc, config.metrics_dump)
config.mlflow.safe_call("log_metric", "best_acc1", best_acc1)

checkpoint_path = osp.join(config.checkpoint_save_dir, get_run_name(config) + "_last.pth")
checkpoint = {
"epoch": epoch + 1,
Expand Down Expand Up @@ -727,19 +728,19 @@ def validate(val_loader, model, criterion, config, epoch=0, log_validation_info=
)
)

if is_main_process() and log_validation_info:
config.tb.add_scalar("val/loss", losses.avg, len(val_loader) * epoch)
config.tb.add_scalar("val/top1", top1.avg, len(val_loader) * epoch)
config.tb.add_scalar("val/top5", top5.avg, len(val_loader) * epoch)
config.mlflow.safe_call("log_metric", "val/loss", float(losses.avg), epoch)
config.mlflow.safe_call("log_metric", "val/top1", float(top1.avg), epoch)
config.mlflow.safe_call("log_metric", "val/top5", float(top5.avg), epoch)
if is_main_process():
if log_validation_info:
config.tb.add_scalar("val/loss", losses.avg, len(val_loader) * epoch)
config.tb.add_scalar("val/top1", top1.avg, len(val_loader) * epoch)
config.tb.add_scalar("val/top5", top5.avg, len(val_loader) * epoch)
config.mlflow.safe_call("log_metric", "val/loss", float(losses.avg), epoch)
config.mlflow.safe_call("log_metric", "val/top1", float(top1.avg), epoch)
config.mlflow.safe_call("log_metric", "val/top5", float(top5.avg), epoch)

if log_validation_info:
logger.info(" * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}\n".format(top1=top1, top5=top5))
logger.info(" * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}\n".format(top1=top1, top5=top5))

acc = top1.avg / 100
if config.metrics_dump is not None:
acc = top1.avg / 100
write_metrics(acc, config.metrics_dump)

return top1.avg, top5.avg, losses.avg
Expand Down
4 changes: 2 additions & 2 deletions tests/torch/test_compression_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,7 +442,7 @@ def test_compression_train(self, desc: CompressionTrainingTestDescriptor, tmp_pa
self._validate_train_metric(desc)

@pytest.mark.dependency(depends=["train"])
def test_compression_eval(self, desc: LEGRTrainingTestDescriptor, tmp_path, mocker):
def test_compression_eval(self, desc: CompressionTrainingTestDescriptor, tmp_path, mocker):
validator = desc.get_validator()
args = validator.get_default_args(tmp_path)
metric_file_path = self._add_args_for_eval(args, desc, tmp_path)
Expand Down Expand Up @@ -497,7 +497,7 @@ def test_compression_nas_eval(self, nas_desc: NASTrainingTestDescriptor, tmp_pat
self._validate_eval_metric(nas_desc, metric_file_path)

@staticmethod
def _validate_eval_metric(desc, metric_file_path):
def _validate_eval_metric(desc: CompressionTrainingTestDescriptor, metric_file_path):
with open(str(metric_file_path), encoding="utf8") as metric_file:
metrics = json.load(metric_file)
ref_metric = metrics["Accuracy"]
Expand Down

0 comments on commit 9dd78cf

Please sign in to comment.