diff --git a/recbole/data/dataloader/abstract_dataloader.py b/recbole/data/dataloader/abstract_dataloader.py index 5bc630b94..7b60f0810 100644 --- a/recbole/data/dataloader/abstract_dataloader.py +++ b/recbole/data/dataloader/abstract_dataloader.py @@ -25,6 +25,20 @@ start_iter = False +class NoDuplicateDistributedSampler(torch.utils.data.DistributedSampler): + """ + A distributed sampler that doesn't add duplicates. + Arguments are the same as DistributedSampler + Refer to https://github.com/pytorch/pytorch/issues/25162#issuecomment-1227647626 + """ + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + if not self.drop_last and len(self.dataset) % self.num_replicas != 0: + # Some ranks may have fewer samples, that's fine + if self.rank >= len(self.dataset) % self.num_replicas: + self.num_samples -= 1 + self.total_size = len(self.dataset) + class AbstractDataLoader(torch.utils.data.DataLoader): """:class:`AbstractDataLoader` is an abstract object which would return a batch of data which is loaded by :class:`~recbole.data.interaction.Interaction` when it is iterated. @@ -57,7 +71,7 @@ def __init__(self, config, dataset, sampler, shuffle=False): self.transform = construct_transform(config) self.is_sequential = config["MODEL_TYPE"] == ModelType.SEQUENTIAL if not config["single_spec"]: - index_sampler = torch.utils.data.distributed.DistributedSampler( + index_sampler = NoDuplicateDistributedSampler( list(range(self.sample_size)), shuffle=shuffle, drop_last=False ) self.step = max(1, self.step // config["world_size"]) diff --git a/recbole/evaluator/collector.py b/recbole/evaluator/collector.py index 07a913721..248bbc233 100644 --- a/recbole/evaluator/collector.py +++ b/recbole/evaluator/collector.py @@ -14,12 +14,13 @@ from recbole.evaluator.register import Register import torch -import copy class DataStruct(object): - def __init__(self): + def __init__(self, init=None): self._data_dict = {} + if init is not None: + self._data_dict.update(init) def __getitem__(self, name: str): return self._data_dict[name] @@ -41,6 +42,9 @@ def get(self, name: str): def set(self, name: str, value): self._data_dict[name] = value + def __iter__(self): + return iter(self._data_dict.items()) + def update_tensor(self, name: str, value: torch.Tensor): if name not in self._data_dict: self._data_dict[name] = value.cpu().clone().detach() @@ -190,7 +194,7 @@ def eval_batch_collect( if self.register.need("data.label"): self.label_field = self.config["LABEL_FIELD"] self.data_struct.update_tensor( - "data.label", interaction[self.label_field].to(self.device) + "data.label", interaction[self.label_field] ) def model_collect(self, model: torch.nn.Module): @@ -213,13 +217,13 @@ def eval_collect(self, eval_pred: torch.Tensor, data_label: torch.Tensor): if self.register.need("data.label"): self.label_field = self.config["LABEL_FIELD"] - self.data_struct.update_tensor("data.label", data_label.to(self.device)) + self.data_struct.update_tensor("data.label", data_label) def get_data_struct(self): """Get all the evaluation resource that been collected. And reset some of outdated resource. """ - returned_struct = copy.deepcopy(self.data_struct) + returned_struct = DataStruct(self.data_struct) for key in ["rec.topk", "rec.meanrank", "rec.score", "rec.items", "data.label"]: if key in self.data_struct: del self.data_struct[key] diff --git a/recbole/trainer/trainer.py b/recbole/trainer/trainer.py index 99ccec71c..8c4ba2a09 100644 --- a/recbole/trainer/trainer.py +++ b/recbole/trainer/trainer.py @@ -31,7 +31,7 @@ from recbole.data.interaction import Interaction from recbole.data.dataloader import FullSortEvalDataLoader -from recbole.evaluator import Evaluator, Collector +from recbole.evaluator import Evaluator, Collector, DataStruct from recbole.utils import ( ensure_dir, get_local_time, @@ -46,6 +46,7 @@ WandbLogger, ) from torch.nn.parallel import DistributedDataParallel +import torch.distributed as dist class AbstractTrainer(object): @@ -577,8 +578,12 @@ def evaluate( return if load_best_model: + # Refer to: https://pytorch.org/tutorials/intermediate/ddp_tutorial.html#save-and-load-checkpoints + if not self.config["single_spec"]: + dist.barrier() checkpoint_file = model_file or self.saved_model_file - checkpoint = torch.load(checkpoint_file, map_location=self.device) + map_location = {"cuda:%d" % 0: "cuda:%d" % self.config["local_rank"]} + checkpoint = torch.load(checkpoint_file, map_location=map_location) self.model.load_state_dict(checkpoint["state_dict"]) self.model.load_other_parameter(checkpoint.get("other_parameter")) message_output = "Loading model structure and parameters from {}".format( @@ -608,9 +613,7 @@ def evaluate( else eval_data ) - num_sample = 0 for batch_idx, batched_data in enumerate(iter_data): - num_sample += len(batched_data) interaction, scores, positive_u, positive_i = eval_func(batched_data) if self.gpu_available and show_progress: iter_data.set_postfix_str( @@ -621,35 +624,31 @@ def evaluate( ) self.eval_collector.model_collect(self.model) struct = self.eval_collector.get_data_struct() - result = self.evaluator.evaluate(struct) if not self.config["single_spec"]: - result = self._map_reduce(result, num_sample) + struct = self._gather_evaluation_resources(struct) + result = self.evaluator.evaluate(struct) self.wandblogger.log_eval_metrics(result, head="eval") return result - def _map_reduce(self, result, num_sample): - gather_result = {} - total_sample = [ - torch.zeros(1).to(self.device) for _ in range(self.config["world_size"]) - ] - torch.distributed.all_gather( - total_sample, torch.Tensor([num_sample]).to(self.device) - ) - total_sample = torch.cat(total_sample, 0) - total_sample = torch.sum(total_sample).item() - for key, value in result.items(): - result[key] = torch.Tensor([value * num_sample]).to(self.device) - gather_result[key] = [ - torch.zeros_like(result[key]).to(self.device) - for _ in range(self.config["world_size"]) - ] - torch.distributed.all_gather(gather_result[key], result[key]) - gather_result[key] = torch.cat(gather_result[key], dim=0) - gather_result[key] = round( - torch.sum(gather_result[key]).item() / total_sample, - self.config["metric_decimal_place"], - ) - return gather_result + def _gather_evaluation_resources(self, struct: DataStruct) -> DataStruct: + """ + Gather the evaluation resources from all ranks, e.g., 'rec.items', 'rec.score', 'data.label' + Only 'rec.*' and 'data.label' are gathered, because they are distributed into different ranks. + Args: + struct: data struct collected from all ranks + + Returns: gathered data struct + + """ + gather_struct = DataStruct(struct) + for key, value in struct: + # Adjust the condition according to + # [the key definition in evaluator](/docs/source/developer_guide/customize_metrics.rst#set-metric_need) + if key.startswith("rec.") or key == "data.label": + gather_struct[key] = [None for _ in range(self.config["world_size"])] + dist.all_gather_object(gather_struct[key], value) + gather_struct[key] = torch.cat(gather_struct[key], dim=0) + return gather_struct def _spilt_predict(self, interaction, batch_size): spilt_interaction = dict() @@ -786,7 +785,7 @@ def pretrain(self, train_data, verbose=True, show_progress=False): self.logger.info(train_loss_output) self._add_train_loss_to_tensorboard(epoch_idx, train_loss) - if (epoch_idx + 1) % self.save_step == 0: + if (epoch_idx + 1) % self.save_step == 0 and self.config["local_rank"] == 0: saved_model_file = os.path.join( self.checkpoint_dir, "{}-{}-{}.pth".format( @@ -986,6 +985,8 @@ def _save_checkpoint(self, epoch): epoch (int): the current epoch id """ + if not self.config["single_spec"] and self.config["local_rank"] != 0: + return state = { "config": self.config, "epoch": epoch, diff --git a/recbole/utils/utils.py b/recbole/utils/utils.py index ed202448c..8e7aead16 100644 --- a/recbole/utils/utils.py +++ b/recbole/utils/utils.py @@ -48,8 +48,7 @@ def ensure_dir(dir_path): dir_path (str): directory path """ - if not os.path.exists(dir_path): - os.makedirs(dir_path) + os.makedirs(dir_path, exist_ok=True) def get_model(model_name):