diff --git a/conda_build/pytorch-metric-learning/meta.yaml b/conda_build/pytorch-metric-learning/meta.yaml index 8c8eb8e8..f633649c 100644 --- a/conda_build/pytorch-metric-learning/meta.yaml +++ b/conda_build/pytorch-metric-learning/meta.yaml @@ -1,5 +1,5 @@ {% set name = "pytorch-metric-learning" %} -{% set version = "1.1.1" %} +{% set version = "1.1.2" %} package: name: "{{ name|lower }}" @@ -7,7 +7,7 @@ package: source: url: "https://pypi.io/packages/source/{{ name[0] }}/{{ name }}/{{ name }}-{{ version }}.tar.gz" - sha256: 6e572dc54179c762abc333fc4c6f68fcd909e800f9519ca1463235d14b9f5c44 + sha256: aa2a28b7eb6a3a72f2ab14f59073de286832acc5433863de3c8cfc8e8fed38f4 build: number: 0 diff --git a/src/pytorch_metric_learning/__init__.py b/src/pytorch_metric_learning/__init__.py index a82b376d..72f26f59 100644 --- a/src/pytorch_metric_learning/__init__.py +++ b/src/pytorch_metric_learning/__init__.py @@ -1 +1 @@ -__version__ = "1.1.1" +__version__ = "1.1.2" diff --git a/src/pytorch_metric_learning/utils/logging_presets.py b/src/pytorch_metric_learning/utils/logging_presets.py index 35b96e94..fa4d137f 100644 --- a/src/pytorch_metric_learning/utils/logging_presets.py +++ b/src/pytorch_metric_learning/utils/logging_presets.py @@ -54,7 +54,7 @@ def end_of_iteration_hook(self, trainer): trainer.loss_tracker.loss_weights, {"parent_name": "loss_weights"}, ], - [trainer.loss_funcs, {"recursive_types": [torch.nn.Module]}], + [trainer.loss_funcs, {"recursive_types": [torch.nn.Module, dict]}], [trainer.mining_funcs, {}], [trainer.models, {}], [trainer.optimizers, {"custom_attr_func": self.optimizer_custom_attr_func}], diff --git a/src/pytorch_metric_learning/utils/module_with_records.py b/src/pytorch_metric_learning/utils/module_with_records.py index 3f540f2b..9fa039f6 100644 --- a/src/pytorch_metric_learning/utils/module_with_records.py +++ b/src/pytorch_metric_learning/utils/module_with_records.py @@ -4,9 +4,11 @@ class ModuleWithRecords(torch.nn.Module): - def __init__(self, collect_stats=c_f.COLLECT_STATS): + def __init__(self, collect_stats=None): super().__init__() - self.collect_stats = collect_stats + self.collect_stats = ( + c_f.COLLECT_STATS if collect_stats is None else collect_stats + ) def add_to_recordable_attributes( self, name=None, list_of_names=None, is_stat=False diff --git a/tests/trainers/test_metric_loss_only.py b/tests/trainers/test_metric_loss_only.py index b01b6c18..55baa109 100644 --- a/tests/trainers/test_metric_loss_only.py +++ b/tests/trainers/test_metric_loss_only.py @@ -8,6 +8,7 @@ from torchvision import datasets, transforms from pytorch_metric_learning.losses import NTXentLoss +from pytorch_metric_learning.reducers import AvgNonZeroReducer from pytorch_metric_learning.samplers import MPerClassSampler from pytorch_metric_learning.testers import GlobalEmbeddingSpaceTester from pytorch_metric_learning.trainers import MetricLossOnly @@ -36,7 +37,7 @@ def test_metric_loss_only(self): ) ) - loss_fn = NTXentLoss() + loss_fn = NTXentLoss(reducer=AvgNonZeroReducer()) normalize_transform = transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] @@ -166,6 +167,18 @@ def test_metric_loss_only(self): num_epochs = 2 trainer.train(num_epochs=num_epochs) + + for record_name in [ + "metric_loss_NTXentLoss", + "metric_loss_NTXentLoss__modules_distance_CosineSimilarity", + "metric_loss_NTXentLoss__modules_reducer_AvgNonZeroReducer", + ]: + self.assertTrue(record_keeper.table_exists(record_name)) + self.assertTrue( + len(record_keeper.query(f"SELECT * FROM {record_name}")) + == num_epochs * iterations_per_epoch / log_freq + ) + best_epoch, best_accuracy = hooks.get_best_epoch_and_accuracy( tester, "val" ) diff --git a/tests/utils/test_common_functions.py b/tests/utils/test_common_functions.py index 19c0051c..9fa8024e 100644 --- a/tests/utils/test_common_functions.py +++ b/tests/utils/test_common_functions.py @@ -57,10 +57,15 @@ def test_torch_standard_scaler(self): def test_collect_stats_flag(self): self.assertTrue(c_f.COLLECT_STATS == WITH_COLLECT_STATS) - loss_fn = TripletMarginLoss() - self.assertTrue(loss_fn.collect_stats == WITH_COLLECT_STATS) - self.assertTrue(loss_fn.distance.collect_stats == WITH_COLLECT_STATS) - self.assertTrue(loss_fn.reducer.collect_stats == WITH_COLLECT_STATS) + for x in [True, False, True, False, WITH_COLLECT_STATS]: + c_f.COLLECT_STATS = x + loss_fn = TripletMarginLoss() + self.assertTrue(loss_fn.collect_stats == x) + self.assertTrue(loss_fn.distance.collect_stats == x) + self.assertTrue(loss_fn.reducer.collect_stats == x) + for x in [True, False]: + loss_fn = TripletMarginLoss(collect_stats=x) + self.assertTrue(loss_fn.collect_stats == x) def test_check_shapes(self): embeddings = torch.randn(32, 512, 3)