Skip to content

Commit

Permalink
Merge pull request #426 from KevinMusgrave/dev
Browse files Browse the repository at this point in the history
v.1.1.2
  • Loading branch information
KevinMusgrave authored Feb 16, 2022
2 parents df53e74 + 81d80de commit 0271f52
Show file tree
Hide file tree
Showing 6 changed files with 31 additions and 11 deletions.
4 changes: 2 additions & 2 deletions conda_build/pytorch-metric-learning/meta.yaml
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
{% set name = "pytorch-metric-learning" %}
{% set version = "1.1.1" %}
{% set version = "1.1.2" %}

package:
name: "{{ name|lower }}"
version: "{{ version }}"

source:
url: "https://pypi.io/packages/source/{{ name[0] }}/{{ name }}/{{ name }}-{{ version }}.tar.gz"
sha256: 6e572dc54179c762abc333fc4c6f68fcd909e800f9519ca1463235d14b9f5c44
sha256: aa2a28b7eb6a3a72f2ab14f59073de286832acc5433863de3c8cfc8e8fed38f4

build:
number: 0
Expand Down
2 changes: 1 addition & 1 deletion src/pytorch_metric_learning/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "1.1.1"
__version__ = "1.1.2"
2 changes: 1 addition & 1 deletion src/pytorch_metric_learning/utils/logging_presets.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def end_of_iteration_hook(self, trainer):
trainer.loss_tracker.loss_weights,
{"parent_name": "loss_weights"},
],
[trainer.loss_funcs, {"recursive_types": [torch.nn.Module]}],
[trainer.loss_funcs, {"recursive_types": [torch.nn.Module, dict]}],
[trainer.mining_funcs, {}],
[trainer.models, {}],
[trainer.optimizers, {"custom_attr_func": self.optimizer_custom_attr_func}],
Expand Down
6 changes: 4 additions & 2 deletions src/pytorch_metric_learning/utils/module_with_records.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@


class ModuleWithRecords(torch.nn.Module):
def __init__(self, collect_stats=c_f.COLLECT_STATS):
def __init__(self, collect_stats=None):
super().__init__()
self.collect_stats = collect_stats
self.collect_stats = (
c_f.COLLECT_STATS if collect_stats is None else collect_stats
)

def add_to_recordable_attributes(
self, name=None, list_of_names=None, is_stat=False
Expand Down
15 changes: 14 additions & 1 deletion tests/trainers/test_metric_loss_only.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from torchvision import datasets, transforms

from pytorch_metric_learning.losses import NTXentLoss
from pytorch_metric_learning.reducers import AvgNonZeroReducer
from pytorch_metric_learning.samplers import MPerClassSampler
from pytorch_metric_learning.testers import GlobalEmbeddingSpaceTester
from pytorch_metric_learning.trainers import MetricLossOnly
Expand Down Expand Up @@ -36,7 +37,7 @@ def test_metric_loss_only(self):
)
)

loss_fn = NTXentLoss()
loss_fn = NTXentLoss(reducer=AvgNonZeroReducer())

normalize_transform = transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
Expand Down Expand Up @@ -166,6 +167,18 @@ def test_metric_loss_only(self):

num_epochs = 2
trainer.train(num_epochs=num_epochs)

for record_name in [
"metric_loss_NTXentLoss",
"metric_loss_NTXentLoss__modules_distance_CosineSimilarity",
"metric_loss_NTXentLoss__modules_reducer_AvgNonZeroReducer",
]:
self.assertTrue(record_keeper.table_exists(record_name))
self.assertTrue(
len(record_keeper.query(f"SELECT * FROM {record_name}"))
== num_epochs * iterations_per_epoch / log_freq
)

best_epoch, best_accuracy = hooks.get_best_epoch_and_accuracy(
tester, "val"
)
Expand Down
13 changes: 9 additions & 4 deletions tests/utils/test_common_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,15 @@ def test_torch_standard_scaler(self):

def test_collect_stats_flag(self):
self.assertTrue(c_f.COLLECT_STATS == WITH_COLLECT_STATS)
loss_fn = TripletMarginLoss()
self.assertTrue(loss_fn.collect_stats == WITH_COLLECT_STATS)
self.assertTrue(loss_fn.distance.collect_stats == WITH_COLLECT_STATS)
self.assertTrue(loss_fn.reducer.collect_stats == WITH_COLLECT_STATS)
for x in [True, False, True, False, WITH_COLLECT_STATS]:
c_f.COLLECT_STATS = x
loss_fn = TripletMarginLoss()
self.assertTrue(loss_fn.collect_stats == x)
self.assertTrue(loss_fn.distance.collect_stats == x)
self.assertTrue(loss_fn.reducer.collect_stats == x)
for x in [True, False]:
loss_fn = TripletMarginLoss(collect_stats=x)
self.assertTrue(loss_fn.collect_stats == x)

def test_check_shapes(self):
embeddings = torch.randn(32, 512, 3)
Expand Down

0 comments on commit 0271f52

Please sign in to comment.