Skip to content

Commit

Permalink
Merge pull request #1875 from ChenglongMa/fix-collect-result-from-mul…
Browse files Browse the repository at this point in the history
…ti-gpu-bug

[🐛 BUG] Fix bugs when collecting results from `mp.spawn` in multi-GPU training
  • Loading branch information
Ethan-TZ authored Oct 14, 2023
2 parents a9259e3 + d7fd793 commit 22d8ad8
Show file tree
Hide file tree
Showing 6 changed files with 204 additions and 146 deletions.
68 changes: 46 additions & 22 deletions docs/source/get_started/distributed_training.rst
Original file line number Diff line number Diff line change
Expand Up @@ -121,21 +121,33 @@ In above example, you can create a new python file (e.g., `run_a.py`) on node A,
nproc = 4,
group_offset = 0
)
# Optional, only needed if you want to get the result of each process.
queue = mp.get_context('spawn').SimpleQueue()
config_dict = config_dict or {}
config_dict.update({
"world_size": args.world_size,
"ip": args.ip,
"port": args.port,
"nproc": args.nproc,
"offset": args.group_offset,
})
kwargs = {
"config_dict": config_dict,
"queue": queue, # Optional
}
mp.spawn(
run_recboles,
args=(
args.model,
args.dataset,
config_file_list,
args.ip,
args.port,
args.world_size,
args.nproc,
args.group_offset,
),
nprocs=args.nproc,
args=(args.model, args.dataset, args.config_file_list, kwargs),
nprocs=nproc,
join=True,
)
# Normally, there should be only one item in the queue
res = None if queue.empty() else queue.get()
Then run the following command:

Expand All @@ -159,21 +171,33 @@ Similarly, you can create a new python file (e.g., `run_b.py`) on node B, and wr
nproc = 4,
group_offset = 4
)
# Optional, only needed if you want to get the result of each process.
queue = mp.get_context('spawn').SimpleQueue()
config_dict = config_dict or {}
config_dict.update({
"world_size": args.world_size,
"ip": args.ip,
"port": args.port,
"nproc": args.nproc,
"offset": args.group_offset,
})
kwargs = {
"config_dict": config_dict,
"queue": queue, # Optional
}
mp.spawn(
run_recboles,
args=(
args.model,
args.dataset,
config_file_list,
args.ip,
args.port,
args.world_size,
args.nproc,
args.group_offset,
),
nprocs=args.nproc,
args=(args.model, args.dataset, args.config_file_list, kwargs),
nprocs=nproc,
join=True,
)
# Normally, there should be only one item in the queue
res = None if queue.empty() else queue.get()
Then run the following command:

Expand Down
1 change: 1 addition & 0 deletions recbole/quick_start/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from recbole.quick_start.quick_start import (
run,
run_recbole,
objective_function,
load_data_and_model,
Expand Down
103 changes: 84 additions & 19 deletions recbole/quick_start/quick_start.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,17 @@
########################
"""
import logging
from logging import getLogger

import sys
import torch.distributed as dist
from collections.abc import MutableMapping
from logging import getLogger


import pickle
from ray import tune

from recbole.config import Config
from recbole.data import (
create_dataset,
data_preparation,
save_split_dataloaders,
load_split_dataloaders,
)
from recbole.data.transform import construct_transform
from recbole.utils import (
Expand All @@ -39,8 +36,69 @@
)


def run(
model,
dataset,
config_file_list=None,
config_dict=None,
saved=True,
nproc=1,
world_size=-1,
ip="localhost",
port="5678",
group_offset=0,
):
if nproc == 1 and world_size <= 0:
res = run_recbole(
model=model,
dataset=dataset,
config_file_list=config_file_list,
config_dict=config_dict,
saved=saved,
)
else:
if world_size == -1:
world_size = nproc
import torch.multiprocessing as mp

# Refer to https://discuss.pytorch.org/t/problems-with-torch-multiprocess-spawn-and-simplequeue/69674/2
# https://discuss.pytorch.org/t/return-from-mp-spawn/94302/2
queue = mp.get_context('spawn').SimpleQueue()

config_dict = config_dict or {}
config_dict.update(
{
"world_size": world_size,
"ip": ip,
"port": port,
"nproc": nproc,
"offset": group_offset,
}
)
kwargs = {
"config_dict": config_dict,
"queue": queue,
}

mp.spawn(
run_recboles,
args=(model, dataset, config_file_list, kwargs),
nprocs=nproc,
join=True,
)

# Normally, there should be only one item in the queue
res = None if queue.empty() else queue.get()
return res


def run_recbole(
model=None, dataset=None, config_file_list=None, config_dict=None, saved=True
model=None,
dataset=None,
config_file_list=None,
config_dict=None,
saved=True,
queue=None,
):
r"""A fast running api, which includes the complete process of
training and testing a model on a specified dataset
Expand All @@ -51,6 +109,7 @@ def run_recbole(
config_file_list (list, optional): Config files used to modify experiment parameters. Defaults to ``None``.
config_dict (dict, optional): Parameters dictionary used to modify experiment parameters. Defaults to ``None``.
saved (bool, optional): Whether to save the model. Defaults to ``True``.
queue (torch.multiprocessing.Queue, optional): The queue used to pass the result to the main process. Defaults to ``None``.
"""
# configurations initialization
config = Config(
Expand Down Expand Up @@ -104,27 +163,33 @@ def run_recbole(
logger.info(set_color("best valid ", "yellow") + f": {best_valid_result}")
logger.info(set_color("test result", "yellow") + f": {test_result}")

return {
result = {
"best_valid_score": best_valid_score,
"valid_score_bigger": config["valid_metric_bigger"],
"best_valid_result": best_valid_result,
"test_result": test_result,
}

if not config["single_spec"]:
dist.destroy_process_group()

if config["local_rank"] == 0 and queue is not None:
queue.put(result) # for multiprocessing, e.g., mp.spawn

return result # for the single process


def run_recboles(rank, *args):
ip, port, world_size, nproc, offset = args[3:]
args = args[:3]
kwargs = args[-1]
if not isinstance(kwargs, MutableMapping):
raise ValueError(
f"The last argument of run_recboles should be a dict, but got {type(kwargs)}"
)
kwargs["config_dict"] = kwargs.get("config_dict", {})
kwargs["config_dict"]["local_rank"] = rank
run_recbole(
*args,
config_dict={
"local_rank": rank,
"world_size": world_size,
"ip": ip,
"port": port,
"nproc": nproc,
"offset": offset,
},
*args[:3],
**kwargs,
)


Expand Down
36 changes: 11 additions & 25 deletions run_recbole.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,8 @@
# @Email : [email protected], [email protected], [email protected]

import argparse
from ast import arg

from recbole.quick_start import run_recbole, run_recboles
from recbole.quick_start import run

if __name__ == "__main__":
parser = argparse.ArgumentParser()
Expand Down Expand Up @@ -44,26 +43,13 @@
args.config_files.strip().split(" ") if args.config_files else None
)

if args.nproc == 1 and args.world_size <= 0:
run_recbole(
model=args.model, dataset=args.dataset, config_file_list=config_file_list
)
else:
if args.world_size == -1:
args.world_size = args.nproc
import torch.multiprocessing as mp

mp.spawn(
run_recboles,
args=(
args.model,
args.dataset,
config_file_list,
args.ip,
args.port,
args.world_size,
args.nproc,
args.group_offset,
),
nprocs=args.nproc,
)
run(
args.model,
args.dataset,
config_file_list=config_file_list,
nproc=args.nproc,
world_size=args.world_size,
ip=args.ip,
port=args.port,
group_offset=args.group_offset,
)
44 changes: 11 additions & 33 deletions run_recbole_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,41 +4,10 @@


import argparse
from ast import arg

from recbole.quick_start import run_recbole, run_recboles
from recbole.quick_start import run
from recbole.utils import list_to_latex


def run(args, model, config_file_list):
if args.nproc == 1 and args.world_size <= 0:
res = run_recbole(
model=model,
dataset=args.dataset,
config_file_list=config_file_list,
)
else:
if args.world_size == -1:
args.world_size = args.nproc
import torch.multiprocessing as mp

res = mp.spawn(
run_recboles,
args=(
args.model,
args.dataset,
config_file_list,
args.ip,
args.port,
args.world_size,
args.nproc,
args.group_offset,
),
nprocs=args.nproc,
)
return res


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
Expand Down Expand Up @@ -92,7 +61,16 @@ def run(args, model, config_file_list):

valid_res_dict = {"Model": model}
test_res_dict = {"Model": model}
result = run(args, model, config_file_list)
result = run(
model,
args.dataset,
config_file_list=config_file_list,
nproc=args.nproc,
world_size=args.world_size,
ip=args.ip,
port=args.port,
group_offset=args.group_offset,
)
valid_res_dict.update(result["best_valid_result"])
test_res_dict.update(result["test_result"])
bigger_flag = result["valid_score_bigger"]
Expand Down
Loading

0 comments on commit 22d8ad8

Please sign in to comment.