Skip to content

Commit

Permalink
Include dict in recursive_types in logging_presets
Browse files Browse the repository at this point in the history
  • Loading branch information
KevinMusgrave committed Feb 16, 2022
1 parent c3e59bc commit 48dc52c
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 3 deletions.
2 changes: 1 addition & 1 deletion src/pytorch_metric_learning/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "1.1.2.dev0"
__version__ = "1.1.2.dev1"
2 changes: 1 addition & 1 deletion src/pytorch_metric_learning/utils/logging_presets.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def end_of_iteration_hook(self, trainer):
trainer.loss_tracker.loss_weights,
{"parent_name": "loss_weights"},
],
[trainer.loss_funcs, {"recursive_types": [torch.nn.Module]}],
[trainer.loss_funcs, {"recursive_types": [torch.nn.Module, dict]}],
[trainer.mining_funcs, {}],
[trainer.models, {}],
[trainer.optimizers, {"custom_attr_func": self.optimizer_custom_attr_func}],
Expand Down
15 changes: 14 additions & 1 deletion tests/trainers/test_metric_loss_only.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from torchvision import datasets, transforms

from pytorch_metric_learning.losses import NTXentLoss
from pytorch_metric_learning.reducers import AvgNonZeroReducer
from pytorch_metric_learning.samplers import MPerClassSampler
from pytorch_metric_learning.testers import GlobalEmbeddingSpaceTester
from pytorch_metric_learning.trainers import MetricLossOnly
Expand Down Expand Up @@ -36,7 +37,7 @@ def test_metric_loss_only(self):
)
)

loss_fn = NTXentLoss()
loss_fn = NTXentLoss(reducer=AvgNonZeroReducer())

normalize_transform = transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
Expand Down Expand Up @@ -166,6 +167,18 @@ def test_metric_loss_only(self):

num_epochs = 2
trainer.train(num_epochs=num_epochs)

for record_name in [
"metric_loss_NTXentLoss",
"metric_loss_NTXentLoss__modules_distance_CosineSimilarity",
"metric_loss_NTXentLoss__modules_reducer_AvgNonZeroReducer",
]:
self.assertTrue(record_keeper.table_exists(record_name))
self.assertTrue(
len(record_keeper.query(f"SELECT * FROM {record_name}"))
== num_epochs * iterations_per_epoch / log_freq
)

best_epoch, best_accuracy = hooks.get_best_epoch_and_accuracy(
tester, "val"
)
Expand Down

0 comments on commit 48dc52c

Please sign in to comment.