Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed issues #491, #430 and #602 #643

Open
wants to merge 5 commits into
base: v3.0
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ dist/
*.egg-info/
site/
venv/
.vscode
.devcontainer
.ipynb_checkpoints
examples/notebooks/dataset
examples/notebooks/CIFAR10_Dataset
Expand Down
5 changes: 4 additions & 1 deletion src/pytorch_metric_learning/losses/margin_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,11 @@ def compute_loss(self, embeddings, labels, indices_tuple, ref_emb, ref_labels):
if len(anchor_idx) == 0:
return self.zero_losses()

# Gives error on my computer if self.beta is on cpu and labels are on cuda
self.beta.data = c_f.to_device(
self.beta.data, device=embeddings.device, dtype=embeddings.dtype
)
beta = self.beta if len(self.beta) == 1 else self.beta[labels[anchor_idx]]
beta = c_f.to_device(beta, device=embeddings.device, dtype=embeddings.dtype)

mat = self.distance(embeddings, ref_emb)

Expand Down
5 changes: 4 additions & 1 deletion src/pytorch_metric_learning/utils/accuracy_calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,7 +470,10 @@ def get_accuracy(
)

knn_distances, knn_indices = self.knn_func(
query, num_k, reference, ref_includes_query
query,
reference,
num_k,
ref_includes_query, # modified to follow the same signature of faiss
)

knn_labels = reference_labels[knn_indices]
Expand Down
45 changes: 38 additions & 7 deletions src/pytorch_metric_learning/utils/common_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
import logging
import os
import re
from typing import List, Tuple, Union

import numpy as np
import scipy.stats
import torch
from torch import nn

LOGGER_NAME = "PML"
LOGGER = logging.getLogger(LOGGER_NAME)
Expand Down Expand Up @@ -394,13 +396,13 @@ def check_shapes(embeddings, labels):


def assert_distance_type(obj, distance_type=None, **kwargs):
obj_name = obj.__class__.__name__
if distance_type is not None:
if is_list_or_tuple(distance_type):
distance_type_str = ", ".join(x.__name__ for x in distance_type)
distance_type_str = "one of " + distance_type_str
else:
distance_type_str = distance_type.__name__
obj_name = obj.__class__.__name__
assert isinstance(
obj.distance, distance_type
), "{} requires the distance metric to be {}".format(
Expand Down Expand Up @@ -459,13 +461,42 @@ def to_dtype(x, tensor=None, dtype=None):
return x


def to_device(x, tensor=None, device=None, dtype=None):
def to_device(
x: Union[torch.Tensor, nn.Parameter, List, Tuple],
tensor=None,
device=None,
dtype: Union[torch.dtype, List, Tuple] = None,
):
dv = device if device is not None else tensor.device
if x.device != dv:
x = x.to(dv)
if dtype is not None:
x = to_dtype(x, dtype=dtype)
return x
dt = (
dtype if dtype is not None else x.dtype
) # Specify if by default cast to x.dtype or tensor.dtype
if not is_list_or_tuple(x):
x = [x]

if is_list_or_tuple(dt):
if len(dt) == len(x):
xd = [
to_dtype(x[i].to(dv), tensor=tensor, dtype=dt[i]) for i in range(len(x))
]
else:
raise RuntimeError(
f"The size of dtype was {len(dt)}. It is only available 1 or the same of x"
)
else:
xd = [to_dtype(xt.to(dv), tensor=tensor, dtype=dt) for xt in x]

if len(xd) == 1:
xd = xd[0]
return xd


def check_multiple_gpus(gpus):
if gpus is not None:
if not isinstance(gpus, (list, tuple)):
raise TypeError("gpus must be a list")
if len(gpus) < 1:
raise ValueError("gpus must have length greater than 0")


def set_ref_emb(embeddings, labels, ref_emb, ref_labels):
Expand Down
Loading