Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[🐛 BUG] Incorrect evaluation results due to multi-GPU distributed sampler #1872

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion recbole/data/dataloader/abstract_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,20 @@
start_iter = False


class NoDuplicateDistributedSampler(torch.utils.data.DistributedSampler):
"""
A distributed sampler that doesn't add duplicates.
Arguments are the same as DistributedSampler
Refer to https://github.com/pytorch/pytorch/issues/25162#issuecomment-1227647626
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if not self.drop_last and len(self.dataset) % self.num_replicas != 0:
# Some ranks may have fewer samples, that's fine
if self.rank >= len(self.dataset) % self.num_replicas:
self.num_samples -= 1
self.total_size = len(self.dataset)

class AbstractDataLoader(torch.utils.data.DataLoader):
""":class:`AbstractDataLoader` is an abstract object which would return a batch of data which is loaded by
:class:`~recbole.data.interaction.Interaction` when it is iterated.
Expand Down Expand Up @@ -57,7 +71,7 @@ def __init__(self, config, dataset, sampler, shuffle=False):
self.transform = construct_transform(config)
self.is_sequential = config["MODEL_TYPE"] == ModelType.SEQUENTIAL
if not config["single_spec"]:
index_sampler = torch.utils.data.distributed.DistributedSampler(
index_sampler = NoDuplicateDistributedSampler(
list(range(self.sample_size)), shuffle=shuffle, drop_last=False
)
self.step = max(1, self.step // config["world_size"])
Expand Down
14 changes: 9 additions & 5 deletions recbole/evaluator/collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,13 @@

from recbole.evaluator.register import Register
import torch
import copy


class DataStruct(object):
def __init__(self):
def __init__(self, init=None):
self._data_dict = {}
if init is not None:
self._data_dict.update(init)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Used for "deep copy"


def __getitem__(self, name: str):
return self._data_dict[name]
Expand All @@ -41,6 +42,9 @@ def get(self, name: str):
def set(self, name: str, value):
self._data_dict[name] = value

def __iter__(self):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Used for for loop, e.g.:

for key, value in struct:
      ...

return iter(self._data_dict.items())

def update_tensor(self, name: str, value: torch.Tensor):
if name not in self._data_dict:
self._data_dict[name] = value.cpu().clone().detach()
Expand Down Expand Up @@ -190,7 +194,7 @@ def eval_batch_collect(
if self.register.need("data.label"):
self.label_field = self.config["LABEL_FIELD"]
self.data_struct.update_tensor(
"data.label", interaction[self.label_field].to(self.device)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Redundant conversion because all input will be transferred to cpu in the update_tensor method.

"data.label", interaction[self.label_field]
)

def model_collect(self, model: torch.nn.Module):
Expand All @@ -213,13 +217,13 @@ def eval_collect(self, eval_pred: torch.Tensor, data_label: torch.Tensor):

if self.register.need("data.label"):
self.label_field = self.config["LABEL_FIELD"]
self.data_struct.update_tensor("data.label", data_label.to(self.device))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above.

self.data_struct.update_tensor("data.label", data_label)

def get_data_struct(self):
"""Get all the evaluation resource that been collected.
And reset some of outdated resource.
"""
returned_struct = copy.deepcopy(self.data_struct)
returned_struct = DataStruct(self.data_struct)
for key in ["rec.topk", "rec.meanrank", "rec.score", "rec.items", "data.label"]:
if key in self.data_struct:
del self.data_struct[key]
Expand Down
61 changes: 31 additions & 30 deletions recbole/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@

from recbole.data.interaction import Interaction
from recbole.data.dataloader import FullSortEvalDataLoader
from recbole.evaluator import Evaluator, Collector
from recbole.evaluator import Evaluator, Collector, DataStruct
from recbole.utils import (
ensure_dir,
get_local_time,
Expand All @@ -46,6 +46,7 @@
WandbLogger,
)
from torch.nn.parallel import DistributedDataParallel
import torch.distributed as dist


class AbstractTrainer(object):
Expand Down Expand Up @@ -577,8 +578,12 @@ def evaluate(
return

if load_best_model:
# Refer to: https://pytorch.org/tutorials/intermediate/ddp_tutorial.html#save-and-load-checkpoints
if not self.config["single_spec"]:
dist.barrier()
checkpoint_file = model_file or self.saved_model_file
checkpoint = torch.load(checkpoint_file, map_location=self.device)
map_location = {"cuda:%d" % 0: "cuda:%d" % self.config["local_rank"]}
checkpoint = torch.load(checkpoint_file, map_location=map_location)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fix the EOFError: Ran out of input error when using DDP.

Refer to the example in the ddp tutorial.

self.model.load_state_dict(checkpoint["state_dict"])
self.model.load_other_parameter(checkpoint.get("other_parameter"))
message_output = "Loading model structure and parameters from {}".format(
Expand Down Expand Up @@ -608,9 +613,7 @@ def evaluate(
else eval_data
)

num_sample = 0
for batch_idx, batched_data in enumerate(iter_data):
num_sample += len(batched_data)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unused any longer.

interaction, scores, positive_u, positive_i = eval_func(batched_data)
if self.gpu_available and show_progress:
iter_data.set_postfix_str(
Expand All @@ -621,35 +624,31 @@ def evaluate(
)
self.eval_collector.model_collect(self.model)
struct = self.eval_collector.get_data_struct()
result = self.evaluator.evaluate(struct)
if not self.config["single_spec"]:
result = self._map_reduce(result, num_sample)
struct = self._gather_evaluation_resources(struct)
result = self.evaluator.evaluate(struct)
Copy link
Contributor Author

@ChenglongMa ChenglongMa Sep 21, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gather struct from all GPUs. And concatenate them as one. We can then evaluate and compute the result in the same way as on a single GPU.

self.wandblogger.log_eval_metrics(result, head="eval")
return result

def _map_reduce(self, result, num_sample):
gather_result = {}
total_sample = [
torch.zeros(1).to(self.device) for _ in range(self.config["world_size"])
]
torch.distributed.all_gather(
total_sample, torch.Tensor([num_sample]).to(self.device)
)
total_sample = torch.cat(total_sample, 0)
total_sample = torch.sum(total_sample).item()
for key, value in result.items():
result[key] = torch.Tensor([value * num_sample]).to(self.device)
gather_result[key] = [
torch.zeros_like(result[key]).to(self.device)
for _ in range(self.config["world_size"])
]
torch.distributed.all_gather(gather_result[key], result[key])
gather_result[key] = torch.cat(gather_result[key], dim=0)
gather_result[key] = round(
torch.sum(gather_result[key]).item() / total_sample,
self.config["metric_decimal_place"],
)
return gather_result
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unused anymore.

def _gather_evaluation_resources(self, struct: DataStruct) -> DataStruct:
"""
Gather the evaluation resources from all ranks, e.g., 'rec.items', 'rec.score', 'data.label'
Only 'rec.*' and 'data.label' are gathered, because they are distributed into different ranks.
Args:
struct: data struct collected from all ranks

Returns: gathered data struct

"""
gather_struct = DataStruct(struct)
for key, value in struct:
# Adjust the condition according to
# [the key definition in evaluator](/docs/source/developer_guide/customize_metrics.rst#set-metric_need)
if key.startswith("rec.") or key == "data.label":
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only rec.* and data.label are gathered, because they are distributed into different GPUs. While other keys like data.num_items, data.num_users are the same across different GPUs.

The keys refer to docs/source/developer_guide/customize_metrics.rst#set-metric_need

gather_struct[key] = [None for _ in range(self.config["world_size"])]
dist.all_gather_object(gather_struct[key], value)
gather_struct[key] = torch.cat(gather_struct[key], dim=0)
return gather_struct

def _spilt_predict(self, interaction, batch_size):
spilt_interaction = dict()
Expand Down Expand Up @@ -786,7 +785,7 @@ def pretrain(self, train_data, verbose=True, show_progress=False):
self.logger.info(train_loss_output)
self._add_train_loss_to_tensorboard(epoch_idx, train_loss)

if (epoch_idx + 1) % self.save_step == 0:
if (epoch_idx + 1) % self.save_step == 0 and self.config["local_rank"] == 0:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Refer to: https://pytorch.org/tutorials/intermediate/ddp_tutorial.html#save-and-load-checkpoints:

All processes should see same parameters as they all start from same random parameters and gradients are synchronized in backward passes. Therefore, saving it in one process is sufficient.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please double-check if the following code needs to be fixed:

  1. In XGBoostTrainer

    if load_best_model:
    if model_file:
    checkpoint_file = model_file
    else:
    checkpoint_file = self.temp_best_file
    self.model.load_model(checkpoint_file)

  2. In LightGBMTrainer

    if load_best_model:
    if model_file:
    checkpoint_file = model_file
    else:
    checkpoint_file = self.temp_best_file
    self.model = self.lgb.Booster(model_file=checkpoint_file)

Thanks!

saved_model_file = os.path.join(
self.checkpoint_dir,
"{}-{}-{}.pth".format(
Expand Down Expand Up @@ -986,6 +985,8 @@ def _save_checkpoint(self, epoch):
epoch (int): the current epoch id

"""
if not self.config["single_spec"] and self.config["local_rank"] != 0:
return
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above.

state = {
"config": self.config,
"epoch": epoch,
Expand Down
3 changes: 1 addition & 2 deletions recbole/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,7 @@ def ensure_dir(dir_path):
dir_path (str): directory path

"""
if not os.path.exists(dir_path):
os.makedirs(dir_path)
os.makedirs(dir_path, exist_ok=True)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In DDP, when multiple processes jointly query whether a file/dir exists, the previous method will report an error because there is no thread lock.



def get_model(model_name):
Expand Down
Loading