Skip to content

Commit

Permalink
first fix of issue KevinMusgrave#491
Browse files Browse the repository at this point in the history
  • Loading branch information
domen committed Jun 21, 2023
1 parent ab47660 commit 9b3e533
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 32 deletions.
55 changes: 31 additions & 24 deletions src/pytorch_metric_learning/utils/common_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ def map(self, labels, hierarchy_level):


def add_to_recordable_attributes(
input_obj, name=None, list_of_names=None, is_stat=False
input_obj, name=None, list_of_names=None, is_stat=False
):
if is_stat:
attr_name_list_name = "_record_these_stats"
Expand Down Expand Up @@ -291,8 +291,8 @@ def modelpath_creator(folder, basename, identifier, extension=".pth"):

def save_model(model, filepath):
if any(
isinstance(model, x)
for x in [torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel]
isinstance(model, x)
for x in [torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel]
):
torch.save(model.module.state_dict(), filepath)
else:
Expand All @@ -317,13 +317,13 @@ def load_model(model_def, model_filename, device):


def operate_on_dict_of_models(
input_dict,
suffix,
folder,
operation,
logging_string="",
log_if_successful=False,
assert_success=False,
input_dict,
suffix,
folder,
operation,
logging_string="",
log_if_successful=False,
assert_success=False,
):
for k, v in input_dict.items():
model_path = modelpath_creator(folder, k, suffix)
Expand Down Expand Up @@ -396,13 +396,13 @@ def check_shapes(embeddings, labels):


def assert_distance_type(obj, distance_type=None, **kwargs):
obj_name = obj.__class__.__name__
if distance_type is not None:
if is_list_or_tuple(distance_type):
distance_type_str = ", ".join(x.__name__ for x in distance_type)
distance_type_str = "one of " + distance_type_str
else:
distance_type_str = distance_type.__name__
obj_name = obj.__class__.__name__
assert isinstance(
obj.distance, distance_type
), "{} requires the distance metric to be {}".format(
Expand Down Expand Up @@ -462,35 +462,42 @@ def to_dtype(x, tensor=None, dtype=None):


def to_device(
x: Union[torch.Tensor, nn.Parameter, List, Tuple],
tensor=None,
device=None,
dtype: Union[torch.dtype, List, Tuple] = None,
x: Union[torch.Tensor, nn.Parameter, List, Tuple],
tensor=None,
device=None,
dtype: Union[torch.dtype, List, Tuple] = None,
):
dv = device if device is not None else tensor.device
is_iterable = is_list_or_tuple(x)
if not is_iterable:
dt = dtype if dtype is not None else x.dtype # Specify if by default cast to x.dtype or tensor.dtype
if not is_list_or_tuple(x):
x = [x]

xd = x
if is_list_or_tuple(dtype):
if len(dtype) == len(x):
if is_list_or_tuple(dt):
if len(dt) == len(x):
xd = [
to_dtype(x[i].to(dv), tensor=tensor, dtype=dtype[i])
to_dtype(x[i].to(dv), tensor=tensor, dtype=dt[i])
for i in range(len(x))
]
else:
raise RuntimeError(
f"The size of dtype was {len(dtype)}. It is only available 1 or the same of x"
f"The size of dtype was {len(dt)}. It is only available 1 or the same of x"
)
elif dtype is not None:
xd = [to_dtype(xt.to(dv), tensor=tensor, dtype=dtype) for xt in x]
else:
xd = [to_dtype(xt.to(dv), tensor=tensor, dtype=dt) for xt in x]

if len(xd) == 1:
xd = xd[0]
return xd


def check_multiple_gpus(gpus):
if gpus is not None:
if not isinstance(gpus, (list, tuple)):
raise TypeError("gpus must be a list")
if len(gpus) < 1:
raise ValueError("gpus must have length greater than 0")


def set_ref_emb(embeddings, labels, ref_emb, ref_labels):
if ref_emb is not None:
if ref_labels is not None:
Expand Down
13 changes: 5 additions & 8 deletions src/pytorch_metric_learning/utils/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,11 +159,7 @@ def __init__(
self.index_init_fn = (
faiss.IndexFlatL2 if index_init_fn is None else index_init_fn
)
if gpus is not None:
if not isinstance(gpus, (list, tuple)):
raise TypeError("gpus must be a list")
if len(gpus) < 1:
raise ValueError("gpus must have length greater than 0")
c_f.check_multiple_gpus(gpus)
self.gpus = gpus

def __call__(
Expand Down Expand Up @@ -232,13 +228,14 @@ def __call__(self, x, nmb_clusters):
kmeans = faiss.Kmeans(d, nmb_clusters, **self.kwargs)
kmeans.train(x)
_, idxs = kmeans.index.search(x, 1)
return torch.tensor([int(n[0]) for n in idxs], dtype=int, device=device)
return torch.tensor([int(n[0]) for n in idxs], dtype=torch.int, device=device)


def add_to_index_and_search(index, query, reference, k):
indexOnOnlyOneGPU = (not isinstance(index, faiss.IndexShards)) and (isinstance(index, faiss.GpuIndex) or isinstance(index, faiss.GpuIndexShards)) # Issue #491
if reference is not None:
index.add(reference.float().cpu())
return index.search(query.float().cpu(), k)
return index.search(query.float() if indexOnOnlyOneGPU else query.float().cpu(), k)


def convert_to_gpu_index(index, gpus):
Expand All @@ -260,8 +257,8 @@ def try_gpu(index, query, reference, k, is_cuda, gpus):
gpu_index = None
gpus_are_available = faiss.get_num_gpus() > 0
gpu_condition = (is_cuda or (gpus is not None)) and gpus_are_available
max_k_for_gpu = 1024 if float(torch.version.cuda) < 9.5 else 2048
if gpu_condition:
max_k_for_gpu = 1024 if float(torch.version.cuda) < 9.5 else 2048
if k <= max_k_for_gpu:
gpu_index = convert_to_gpu_index(index, gpus)
try:
Expand Down

0 comments on commit 9b3e533

Please sign in to comment.