Skip to content

Commit

Permalink
Refactor torch training test to stop using metrics from checkpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
daniil-lyakhov committed Sep 15, 2023
1 parent 1e7a0c2 commit 8848d75
Showing 1 changed file with 2 additions and 6 deletions.
8 changes: 2 additions & 6 deletions tests/torch/test_compression_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,9 +174,7 @@ def finalize(self, dataset_dir, tmp_path_factory, weekly_models_path) -> "Compre
return self

def get_metric(self):
return self.sample_handler.get_metric_value_from_checkpoint(
self.checkpoint_save_dir, self.checkpoint_name, self.config_path
)
return self.expected_accuracy_

def _get_weight_path(self, weekly_models_path):
if self.weights_filename_ is None:
Expand Down Expand Up @@ -249,9 +247,7 @@ def subnet_expected_accuracy(self, subnet_expected_accuracy: float):
return self

def get_subnet_metric(self):
return self.sample_handler.get_metric_value_from_checkpoint(
self.checkpoint_save_dir, self.subnet_checkpoint_name
)
return self.subnet_expected_accuracy_

def _get_weight_path(self, weekly_models_path):
return os.path.join(
Expand Down

0 comments on commit 8848d75

Please sign in to comment.