Skip to content

Commit

Permalink
CMCCallback fix (#941)
Browse files Browse the repository at this point in the history
* fix args order

* fix conformity matrix shape

* fix

* CHANGELOG.md

* manually checked convergence of samplers and cmc score.

* fix cmc05 threshold
  • Loading branch information
elephantmipt authored Sep 25, 2020
1 parent c816f3d commit 19ae864
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 7 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Logging double logging :) ([#936](https://github.com/catalyst-team/catalyst/pull/936))

- CMCCallback ([#941](https://github.com/catalyst-team/catalyst/pull/941))

## [20.09] - 2020-09-07

Expand Down
2 changes: 1 addition & 1 deletion bin/tests/check_dl_core_callbacks.sh
Original file line number Diff line number Diff line change
Expand Up @@ -798,7 +798,7 @@ metrics = utils.load_config('$LOGFILE')
EPS = 0.00001
assert metrics['last']['cmc01'] > 0.1 # slightly better then random
assert metrics['last']['cmc05'] > 0.5
assert metrics['last']['cmc05'] > 0.4
"""

################################ pipeline 22 ################################
Expand Down
10 changes: 5 additions & 5 deletions catalyst/dl/callbacks/metrics/cmc_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,15 +138,15 @@ def on_loader_end(self, runner: "IRunner"):
self._query_idx == self._query_size
), "An error occurred during the accumulation process."

conformity_matrix = self._query_labels == self._gallery_labels.reshape(
conformity_matrix = self._gallery_labels == self._query_labels.reshape(
-1, 1
)
for key in self.list_args:
metric = self._metric_fn(
self._gallery_embeddings,
self._query_embeddings,
conformity_matrix,
key,
query_embeddings=self._query_embeddings,
gallery_embeddings=self._gallery_embeddings,
conformity_matrix=conformity_matrix,
topk=key,
)
runner.loader_metrics[f"{self._prefix}{key:02}"] = metric
self._gallery_embeddings = None
Expand Down
2 changes: 1 addition & 1 deletion tests/_tests_scripts/dl_z_mvp_mnist_metric_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def main() -> None:
This function checks metric learning pipeline with
different triplets samplers.
"""
cmc_score_th = 0.9
cmc_score_th = 0.85

# Note! cmc_score should be > 0.97
# after 600 epoch. Please check it mannually
Expand Down

0 comments on commit 19ae864

Please sign in to comment.